diff --git a/pjlsa/__init__.py b/pjlsa/__init__.py index 9df7b3209940379a44a306aba24577c188b51901..734ef237293c63ead09da9f98a0602f8965d07ea 100644 --- a/pjlsa/__init__.py +++ b/pjlsa/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.18" +__version__ = "0.2.19" __cmmnbuild_deps__ = [ {"product": "log4j", "groupId": "log4j"}, diff --git a/pjlsa/_pjlsa.py b/pjlsa/_pjlsa.py index bf69a43f7715595d2aeb54b0e3177518ea72bbae..b07ad351ddb666e76dca02097801c62e12e4a673 100644 --- a/pjlsa/_pjlsa.py +++ b/pjlsa/_pjlsa.py @@ -47,13 +47,18 @@ import jpype import cmmnbuild_dep_manager +import typing + # Python data descriptors TrimHeader = namedtuple( - "TrimHeader", ["id", "beamProcesses", "createdDate", "description", "clientInfo"], + "TrimHeader", + ["id", "beamProcesses", "createdDate", "description", "clientInfo"], ) OpticTableItem = namedtuple("OpticTableItem", ["time", "id", "name"]) TrimTuple = namedtuple("TrimTuple", ["time", "data"]) -Calibration = namedtuple("Calibration", ["field", "current", "fieldtype", "name"]) +Calibration = namedtuple( + "Calibration", ["field", "current", "fieldtype", "name"] +) PCInfo = namedtuple( "PCInfo", [ @@ -70,17 +75,17 @@ PCInfo = namedtuple( ) -pcinfogetter={ - "accelerationLimit":'getAccelerationLimit', - "decelerationLimit":'getDecelerationLimit', - "didtMin":'getDidtMin', - "didtMax":'getDidtMax', - "iMinOp":'getIMinOp', - "iNom":'getINom', - "iPNo":'getIPNo', - "iUlt":'getIUlt', - "polaritySwitch":'isPolaritySwitch' - } +pcinfogetter = { + "accelerationLimit": "getAccelerationLimit", + "decelerationLimit": "getDecelerationLimit", + "didtMin": "getDidtMin", + "didtMax": "getDidtMax", + "iMinOp": "getIMinOp", + "iNom": "getINom", + "iPNo": "getIPNo", + "iUlt": "getIUlt", + "polaritySwitch": "isPolaritySwitch", +} Context = namedtuple("Context", ["timestamp", "name", "user"]) @@ -99,13 +104,14 @@ def _build_TrimHeader(th): def _toJavaDate(t): - """Date from string, datetime, unixtimestamp to java date - """ + """Date from string, datetime, unixtimestamp to java date""" Date = jpype.java.util.Date if isinstance(t, str): return jpype.java.sql.Timestamp.valueOf(t) elif isinstance(t, datetime.datetime): - return jpype.java.sql.Timestamp.valueOf(t.strftime("%Y-%m-%d %H:%M:%S.%f")) + return jpype.java.sql.Timestamp.valueOf( + t.strftime("%Y-%m-%d %H:%M:%S.%f") + ) elif t is None: return None elif isinstance(t, Date): @@ -121,8 +127,12 @@ def _toJavaList(lst): return res +class PyRbacTokenLike(typing.Protocol): + def encode(self) -> bytes: ... + + class BaseLSAClient(object): - def __init__(self, server="gpn", system_properties : dict = {}): + def __init__(self, server="gpn", system_properties: dict = {}): self._mgr = cmmnbuild_dep_manager.Manager() self._mgr.jvm_required() @@ -144,8 +154,11 @@ class BaseLSAClient(object): if configuredServer is None: self._System.setProperty("lsa.server", server) elif configuredServer != server: - raise RuntimeError("LSA is already configured to connect to server '%s'. " - "Please restart python to change server to '%s'." % (configuredServer, server)) + raise RuntimeError( + "LSA is already configured to connect to server '%s'. " + "Please restart python to change server to '%s'." + % (configuredServer, server) + ) @contextmanager def java_api(self): @@ -158,6 +171,7 @@ class BaseLSAClient(object): with self._mgr.imports(): # work-around to fire up JPypes forward converters - TODO: remove me in JPype 0.8! from java.util import HashSet, HashMap, ArrayList, Date + ArrayList() HashMap() HashSet() @@ -174,6 +188,7 @@ class LSAClient(BaseLSAClient): self._lsa = self._cern.lsa self._client = self._lsa.client self._domain = self._lsa.domain + self._rbac = self._cern.rbac # Java classes self._ContextService = self._client.ContextService @@ -196,44 +211,91 @@ class LSAClient(BaseLSAClient): self._Parameter = self._domain.settings.Parameter self._ParameterSettings = self._domain.settings.ParameterSettings self._Setting = self._domain.settings.Setting - self._StandAloneBeamProcess = self._domain.settings.StandAloneBeamProcess + self._StandAloneBeamProcess = ( + self._domain.settings.StandAloneBeamProcess + ) self._Knob = self._domain.settings.Knob self._FunctionSetting = self._domain.settings.spi.FunctionSetting self._ScalarSetting = self._domain.settings.spi.ScalarSetting + self._RbaToken = self._rbac.common.RbaToken + self._ClientTierTokenHolder = ( + self._rbac.util.holder.ClientTierTokenHolder + ) + self._RbaTokenLookup = self._rbac.util.lookup.RbaTokenLookup + self._ParametersRequestBuilder = ( self._domain.settings.factory.ParametersRequestBuilder ) self._Device = self._domain.devices.Device - self._DeviceRequestBuilder = self._domain.devices.factory.DevicesRequestBuilder + self._DeviceRequestBuilder = ( + self._domain.devices.factory.DevicesRequestBuilder + ) self._ParameterTreesRequestBuilder = ( self._domain.settings.factory.ParameterTreesRequestBuilder ) - self._ParameterTreesRequest = self._domain.settings.ParameterTreesRequest + self._ParameterTreesRequest = ( + self._domain.settings.ParameterTreesRequest + ) self._ParameterTreesRequestTreeDirection = ( self._ParameterTreesRequest.TreeDirection ) - self._CalibrationFunctionTypes = self._domain.optics.CalibrationFunctionTypes + self._CalibrationFunctionTypes = ( + self._domain.optics.CalibrationFunctionTypes + ) # non lsa classes - self._CernAccelerator = self._cern.accsoft.commons.domain.CernAccelerator + self._CernAccelerator = ( + self._cern.accsoft.commons.domain.CernAccelerator + ) - self._contextService = self._ServiceLocator.getService(self._ContextService) + self._contextService = self._ServiceLocator.getService( + self._ContextService + ) self._trimService = self._ServiceLocator.getService(self._TrimService) - self._settingService = self._ServiceLocator.getService(self._SettingService) - self._parameterService = self._ServiceLocator.getService(self._ParameterService) - self._contextService = self._ServiceLocator.getService(self._ContextService) + self._settingService = self._ServiceLocator.getService( + self._SettingService + ) + self._parameterService = self._ServiceLocator.getService( + self._ParameterService + ) + self._contextService = self._ServiceLocator.getService( + self._ContextService + ) self._lhcService = self._ServiceLocator.getService(self._LhcService) self._hyperCycleService = self._ServiceLocator.getService( self._HyperCycleService ) self._knobService = self._ServiceLocator.getService(self._KnobService) - self._opticService = self._ServiceLocator.getService(self._OpticService) - self._deviceService = self._ServiceLocator.getService(self._DeviceService) - self._fidelService = self._ServiceLocator.getService(self._FidelService) + self._opticService = self._ServiceLocator.getService( + self._OpticService + ) + self._deviceService = self._ServiceLocator.getService( + self._DeviceService + ) + self._fidelService = self._ServiceLocator.getService( + self._FidelService + ) + + def set_rbac_token(self, token: typing.Union[bytes, PyRbacTokenLike]): + if token: + if isinstance(token, bytes): + token_bytes = token + else: + token_bytes = token.encode() + self._doSetTokenToJava(token_bytes) + return + + def _doSetTokenToJava(self, token_bytes: bytes) -> None: + token = self._RbaToken(token_bytes) + self._ClientTierTokenHolder.setRbaToken(token) + if self._RbaTokenLookup.findRbaToken() is None: + raise RuntimeError( + "Could not reuse RBAC token, maybe it has expired?" + ) def _getContextFamily(self, name): if isinstance(name, str): @@ -259,20 +321,22 @@ class LSAClient(BaseLSAClient): else: return self._hyperCycleService.findHyperCycle(hypercycle) - def findOperationalContexts(self, accelerator: str = 'sps'): + def findOperationalContexts(self, accelerator: str = "sps"): accelerator = self._getAccelerator(accelerator) cycles = self._contextService.findStandAloneCycles(accelerator) - cycles = filter(lambda cyc: str(cyc.getContextCategory()) == 'OPERATIONAL', cycles) + cycles = filter( + lambda cyc: str(cyc.getContextCategory()) == "OPERATIONAL", cycles + ) return sorted(map(str, cycles)) - def findResidentContexts(self, accelerator: str = 'sps'): + def findResidentContexts(self, accelerator: str = "sps"): accelerator = self._getAccelerator(accelerator) cycles = self._contextService.findResidentContexts(accelerator) return sorted(map(str, cycles)) - def findActiveContexts(self, accelerator: str = 'sps'): + def findActiveContexts(self, accelerator: str = "sps"): accelerator = self._getAccelerator(accelerator) cycles = self._contextService.findActiveContexts(accelerator) @@ -316,7 +380,10 @@ class LSAClient(BaseLSAClient): return str(self._getHyperCycle().getResidentBeamProcess(category)) def getResidentBeamProcesses(self): - return [str(p) for p in list(self._getHyperCycle().getResidentBeamProcesses())] + return [ + str(p) + for p in list(self._getHyperCycle().getResidentBeamProcesses()) + ] def findParameterNames(self, deviceName=None, groupName=None, regexp=""): req = self._ParametersRequestBuilder() @@ -341,7 +408,7 @@ class LSAClient(BaseLSAClient): return list(map(str, deviceList)) def findUserContextMappingHistory( - self, t1, t2, accelerator="lhc", contextFamily="beamprocess" + self, t1, t2, accelerator="lhc", contextFamily="beamprocess" ): acc = self._getAccelerator(accelerator) contextFamily = self._getContextFamily(contextFamily) @@ -351,13 +418,19 @@ class LSAClient(BaseLSAClient): acc, contextFamily, t1, t2 ) out = [ - (ct.getMappingTimestamp() / 1000.0, ct.getContextName(), ct.getUser()) + ( + ct.getMappingTimestamp() / 1000.0, + ct.getContextName(), + ct.getUser(), + ) for ct in res ] return Context(*map(np.array, zip(*out))) def findBeamProcessHistory(self, t1, t2, accelerator="lhc"): - cts = self.findUserContextMappingHistory(t1, t2, accelerator=accelerator) + cts = self.findUserContextMappingHistory( + t1, t2, accelerator=accelerator + ) import pytimber db = pytimber.LoggingDB() @@ -381,7 +454,9 @@ class LSAClient(BaseLSAClient): lst = self._parameterService.findParameters(req.build()) return lst - def _getRawTrimHeadersByBeamprocess(self, param, beamprocess, start=None, end=None): + def _getRawTrimHeadersByBeamprocess( + self, param, beamprocess, start=None, end=None + ): bp = self._getBeamProcess(beamprocess) thrb = self._cern.lsa.domain.settings.TrimHeadersRequestBuilder() thrb.beamProcesses(self._java.util.Collections.singleton(bp)) @@ -439,7 +514,9 @@ class LSAClient(BaseLSAClient): param.add(self._getParameter(pp)) return param - def _getTrimHeadersByBeamprocess(self, parameter, beamprocess, start=None, end=None): + def _getTrimHeadersByBeamprocess( + self, parameter, beamprocess, start=None, end=None + ): return [ _build_TrimHeader(th) for th in self._getRawTrimHeadersByBeamprocess( @@ -456,33 +533,37 @@ class LSAClient(BaseLSAClient): ] def getTrimHeaders( - self, parameter, beamprocess=None, cycle=None, start=None, end=None + self, parameter, beamprocess=None, cycle=None, start=None, end=None ): if beamprocess is not None: - return self._getTrimHeadersByBeamprocess(parameter, - beamprocess=beamprocess, - start=start, end=end) + return self._getTrimHeadersByBeamprocess( + parameter, beamprocess=beamprocess, start=start, end=end + ) else: - return self._getTrimHeadersByCycle(parameter, - cycle=cycle, - start=start, end=end) + return self._getTrimHeadersByCycle( + parameter, cycle=cycle, start=start, end=end + ) def _getTrimsByBeamprocess( - self, parameter, beamprocess, start=None, end=None, part="value" + self, parameter, beamprocess, start=None, end=None, part="value" ): parameterList = self._buildParameterList(parameter) bp = self._getBeamProcess(beamprocess) timestamps = {} values = {} - for th in self._getRawTrimHeadersByBeamprocess(parameterList, bp, start, end): + for th in self._getRawTrimHeadersByBeamprocess( + parameterList, bp, start, end + ): csrb = ( self._cern.lsa.domain.settings.ContextSettingsRequestBuilder() ) csrb.standAloneContext(bp) csrb.parameters(parameterList) csrb.at(th.getCreatedDate().toInstant()) - contextSettings = self._settingService.findContextSettings(csrb.build()) + contextSettings = self._settingService.findContextSettings( + csrb.build() + ) for pp in parameterList: parameterSetting = contextSettings.getParameterSettings(pp) if parameterSetting is None: @@ -497,7 +578,9 @@ class LSAClient(BaseLSAClient): elif part == "target": value = setting.getTargetScalarValue().getDouble() elif part == "correction": - value = setting.getCorrectionScalarValue().getDouble() + value = ( + setting.getCorrectionScalarValue().getDouble() + ) else: raise ValueError("Invalid Setting Part: " + part) elif type(setting) is self._FunctionSetting: @@ -524,7 +607,7 @@ class LSAClient(BaseLSAClient): return out def _getTrimsByCycle( - self, parameter, cycle, start=None, end=None, part="value" + self, parameter, cycle, start=None, end=None, part="value" ): parameterList = self._buildParameterList(parameter) cy = self._getCycle(cycle) @@ -532,13 +615,15 @@ class LSAClient(BaseLSAClient): timestamps = {} values = {} for th in self._getRawTrimHeadersByCycle( - parameterList, cy, start, end + parameterList, cy, start, end ): csrb = self._domain.settings.ContextSettingsRequestBuilder() csrb.standAloneContext(cy) csrb.parameters(parameterList) csrb.at(th.getCreatedDate().toInstant()) - contextSettings = self._settingService.findContextSettings(csrb.build()) + contextSettings = self._settingService.findContextSettings( + csrb.build() + ) for pp in parameterList: parameterSetting = contextSettings.getParameterSettings(pp) if parameterSetting is None: @@ -557,7 +642,9 @@ class LSAClient(BaseLSAClient): elif part == "target": value = setting.getTargetScalarValue().getDouble() elif part == "correction": - value = setting.getCorrectionScalarValue().getDouble() + value = ( + setting.getCorrectionScalarValue().getDouble() + ) else: raise ValueError("Invalid Setting Part: " + part) elif type(setting) is self._FunctionSetting: @@ -584,18 +671,26 @@ class LSAClient(BaseLSAClient): return out def getTrims( - self, parameter, beamprocess=None, cycle=None, start=None, end=None, part="value" + self, + parameter, + beamprocess=None, + cycle=None, + start=None, + end=None, + part="value", ): if beamprocess is not None: - return self._getTrimsByBeamprocess(parameter, - beamprocess=beamprocess, - start=start, end=end, - part=part) + return self._getTrimsByBeamprocess( + parameter, + beamprocess=beamprocess, + start=start, + end=end, + part=part, + ) else: - return self._getTrimsByCycle(parameter, - cycle=cycle, - start=start, end=end, - part=part) + return self._getTrimsByCycle( + parameter, cycle=cycle, start=start, end=end, part=part + ) def _getLastTrimByBeamprocess(self, parameter, beamprocess, part="value"): th = self._getTrimHeadersByBeamprocess(parameter, beamprocess)[-1] @@ -612,18 +707,18 @@ class LSAClient(BaseLSAClient): return TrimTuple(res.time[-1], res.data[-1]) def getLastTrim( - self, parameter, beamprocess=None, cycle=None, part="value" + self, parameter, beamprocess=None, cycle=None, part="value" ): if beamprocess is not None: - return self._getLastTrimByBeamprocess(parameter, - beamprocess=beamprocess, - part=part) + return self._getLastTrimByBeamprocess( + parameter, beamprocess=beamprocess, part=part + ) else: - return self._getLastTrimByCycle(parameter, - cycle=cycle, - part=part) + return self._getLastTrimByCycle(parameter, cycle=cycle, part=part) - def _getLastTrimValueByBeamprocess(self, parameter, beamprocess, part="value"): + def _getLastTrimValueByBeamprocess( + self, parameter, beamprocess, part="value" + ): th = self._getTrimHeadersByBeamprocess(parameter, beamprocess)[-1] res = self._getTrimsByBeamprocess( parameter, beamprocess, part=part, start=th.createdDate @@ -638,16 +733,16 @@ class LSAClient(BaseLSAClient): return res.data[-1] def getLastTrimValue( - self, parameter, beamprocess=None, cycle=None, part="value" + self, parameter, beamprocess=None, cycle=None, part="value" ): if beamprocess is not None: - return self._getLastTrimValueByBeamprocess(parameter, - beamprocess=beamprocess, - part=part) + return self._getLastTrimValueByBeamprocess( + parameter, beamprocess=beamprocess, part=part + ) else: - return self._getLastTrimValueByCycle(parameter, - cycle=cycle, - part=part) + return self._getLastTrimValueByCycle( + parameter, cycle=cycle, part=part + ) def getOpticTable(self, beamprocess): bp = self._getBeamProcess(beamprocess) @@ -657,7 +752,9 @@ class LSAClient(BaseLSAClient): 0 ].getOpticsTableItems() return [ - OpticTableItem(time=o.getTime(), id=o.getOpticId(), name=o.getOpticName()) + OpticTableItem( + time=o.getTime(), id=o.getOpticId(), name=o.getOpticName() + ) for o in opticTable ] @@ -675,9 +772,13 @@ class LSAClient(BaseLSAClient): self._ParameterTreesRequestTreeDirection.DEPENDENT_TREE ) elif direction == "source": - req.setTreeDirection(self._ParameterTreesRequestTreeDirection.SOURCE_TREE) + req.setTreeDirection( + self._ParameterTreesRequestTreeDirection.SOURCE_TREE + ) else: - raise ValueError('invalid direction, expecting "dependent" or "source"') + raise ValueError( + 'invalid direction, expecting "dependent" or "source"' + ) req.setParameter(self._getParameter(parameter)) tree = self._parameterService.findParameterTrees(req.build()) params = {} @@ -689,7 +790,9 @@ class LSAClient(BaseLSAClient): def getOpticStrength(self, optic): if not hasattr(optic, "name"): optic = self._opticService.findOpticByName(optic) - out = [(st.logicalHWName, st.strength) for st in optic.getOpticStrengths()] + out = [ + (st.logicalHWName, st.strength) for st in optic.getOpticStrengths() + ] return dict(out) def _getOptics(self, name): @@ -717,14 +820,18 @@ class LSAClient(BaseLSAClient): pcname = pcs[madname] if full is True: nl = self._java.util.Collections.singleton(pcname) - pcs = self._deviceService.findActualDevicesByLogicalHardwareName(nl) + pcs = self._deviceService.findActualDevicesByLogicalHardwareName( + nl + ) pcname = list(pcs[pcname])[0].toString() return pcname def findMadStrengthNameByPCName(self, pcname, full=False): if full is True: nl = self._java.util.Collections.singleton(pcname) - pcs = self._deviceService.findLogicalHardwaresByActualDeviceNames(nl) + pcs = self._deviceService.findLogicalHardwaresByActualDeviceNames( + nl + ) pcname = list(pcs[pcname])[0].toString() nl = self._java.util.Collections.singleton(pcname) madnames = self._deviceService.findMadStrengthNamesByLogicalNames(nl) @@ -747,23 +854,30 @@ class LSAClient(BaseLSAClient): return dict((cc.getName(), cc) for cc in cals) def dump_calibrations(self, outdir="calib"): - """ Dump all calibration in directory <outdir> - """ + """Dump all calibration in directory <outdir>""" os.mkdir(outdir) for name, cc in self._get_calibrations(): - ff = cc.getCalibrationFunctionByType(self._CalibrationFunctionTypes.B_FIELD) + ff = cc.getCalibrationFunctionByType( + self._CalibrationFunctionTypes.B_FIELD + ) if ff is not None: field = ff.toXArray() current = ff.toYArray() fn = os.path.join(outdir, "%s.txt" % name) print(fn) fh = open(fn, "w") - fh.write("\n".join(["%s %s" % (i, f) for i, f in zip(current, field)])) + fh.write( + "\n".join( + ["%s %s" % (i, f) for i, f in zip(current, field)] + ) + ) fh.close() def getPCInfo(self, pcname): pc = self._deviceService.findPowerConverterInfo(pcname) - info = PCInfo(*(getattr(pc, pcinfogetter[nn])() for nn in PCInfo._fields)) + info = PCInfo( + *(getattr(pc, pcinfogetter[nn])() for nn in PCInfo._fields) + ) return info @@ -772,6 +886,7 @@ class LSAClientGSI(BaseLSAClient): Supports production, integration and development servers. """ + def __init__(self, server="gsi-dev"): """Connect to either production, integration or development LSA server at GSI. The server variable should correspondingly @@ -785,6 +900,6 @@ class LSAClientGSI(BaseLSAClient): cfg_url = "https://websvcdev.acc.gsi.de/groups/cscoap/config/" else: raise ValueError( - 'server should be one of "gsi-pro", "gsi-int" or "gsi-dev"') + 'server should be one of "gsi-pro", "gsi-int" or "gsi-dev"' + ) super().__init__(server, {"csco.default.property.config.url": cfg_url}) -