From 7231a55815572f22f036e29580385dda70ac6c74 Mon Sep 17 00:00:00 2001
From: Chris Burr <christopher.burr@cern.ch>
Date: Tue, 17 Sep 2024 16:03:55 +0200
Subject: [PATCH] feat: Support getting bulk input queries for transformations

---
 .../DB/TransformationDB.py                    | 25 ++++++++++++++-----
 .../Service/TransformationManagerHandler.py   |  6 +++++
 2 files changed, 25 insertions(+), 6 deletions(-)

diff --git a/src/LHCbDIRAC/TransformationSystem/DB/TransformationDB.py b/src/LHCbDIRAC/TransformationSystem/DB/TransformationDB.py
index 2aa09fac59..a2d7bd1a86 100755
--- a/src/LHCbDIRAC/TransformationSystem/DB/TransformationDB.py
+++ b/src/LHCbDIRAC/TransformationSystem/DB/TransformationDB.py
@@ -18,6 +18,7 @@ more specific data processing databases
 import threading
 import copy
 import re
+from collections import defaultdict
 
 from DIRAC import gLogger, S_OK, S_ERROR
 from DIRAC.TransformationSystem.DB.TransformationDB import TransformationDB as DIRACTransformationDB
@@ -277,16 +278,25 @@ class TransformationDB(DIRACTransformationDB):
         return res
 
     def getBookkeepingQuery(self, transID, connection=False):
+        """Get the bookkeeping query parameters."""
+        result = self.getBookkeepingQueries([transID], connection)
+        if result["OK"]:
+            result["Value"] = result["Value"][transID]
+        return result
+
+    def getBookkeepingQueries(self, transIDs, connection=False):
         """Get the bookkeeping query parameters."""
         connection = self.__getConnection(connection)
-        req = "SELECT * FROM BkQueriesNew WHERE TransformationID=%d" % (int(transID))
+        req = (
+            "SELECT TransformationID, ParameterName, ParameterValue FROM BkQueriesNew "
+            f"WHERE TransformationID IN ({', '.join(str(t) for t in transIDs)})"
+        )
         res = self._query(req, conn=connection)
         if not res["OK"]:
             return res
-        if not res["Value"]:
-            return S_ERROR("BkQuery %d not found" % int(transID))
-        bkDict = {}
+        bkDict = defaultdict(dict)
         for row in res["Value"]:
+            transID = row[0]
             parameter = row[1]
             value = row[2]
             if value and value != "All":
@@ -299,8 +309,11 @@ class TransformationDB(DIRACTransformationDB):
                         value = [int(x) for x in value]
                     if not value:
                         continue
-                bkDict[parameter] = value
-        return S_OK(bkDict)
+                bkDict[transID][parameter] = value
+        for transID in transIDs:
+            if transID not in bkDict:
+                return S_ERROR("BkQuery %d not found" % int(transID))
+        return S_OK(dict(bkDict))
 
     def __insertExistingTransformationFiles(self, transID, fileTuplesList, connection=False):
         """extends DIRAC.__insertExistingTransformationFiles Does not add userSE
diff --git a/src/LHCbDIRAC/TransformationSystem/Service/TransformationManagerHandler.py b/src/LHCbDIRAC/TransformationSystem/Service/TransformationManagerHandler.py
index 10eb895c5a..1e0749e6fe 100644
--- a/src/LHCbDIRAC/TransformationSystem/Service/TransformationManagerHandler.py
+++ b/src/LHCbDIRAC/TransformationSystem/Service/TransformationManagerHandler.py
@@ -50,6 +50,12 @@ class TransformationManagerHandlerMixin:
     def export_getBookkeepingQuery(self, transID):
         return self.transformationDB.getBookkeepingQuery(transID)
 
+    types_getBookkeepingQueries = [list]
+
+    @classmethod
+    def export_getBookkeepingQueries(self, transIDs):
+        return self.transformationDB.getBookkeepingQueries(transIDs)
+
     types_getTransformationsWithBkQueries = [list]
 
     @classmethod
-- 
GitLab