Skip to content
Snippets Groups Projects
Commit 5dd9228f authored by Georges Trad's avatar Georges Trad
Browse files

adding rbac feature (compatible with pyrbac)

parent d2519656
No related branches found
No related tags found
No related merge requests found
Pipeline #11789340 failed
__version__ = "0.2.18"
__version__ = "0.2.19"
__cmmnbuild_deps__ = [
{"product": "log4j", "groupId": "log4j"},
......
......@@ -47,13 +47,18 @@ import jpype
import cmmnbuild_dep_manager
import typing
# Python data descriptors
TrimHeader = namedtuple(
"TrimHeader", ["id", "beamProcesses", "createdDate", "description", "clientInfo"],
"TrimHeader",
["id", "beamProcesses", "createdDate", "description", "clientInfo"],
)
OpticTableItem = namedtuple("OpticTableItem", ["time", "id", "name"])
TrimTuple = namedtuple("TrimTuple", ["time", "data"])
Calibration = namedtuple("Calibration", ["field", "current", "fieldtype", "name"])
Calibration = namedtuple(
"Calibration", ["field", "current", "fieldtype", "name"]
)
PCInfo = namedtuple(
"PCInfo",
[
......@@ -70,17 +75,17 @@ PCInfo = namedtuple(
)
pcinfogetter={
"accelerationLimit":'getAccelerationLimit',
"decelerationLimit":'getDecelerationLimit',
"didtMin":'getDidtMin',
"didtMax":'getDidtMax',
"iMinOp":'getIMinOp',
"iNom":'getINom',
"iPNo":'getIPNo',
"iUlt":'getIUlt',
"polaritySwitch":'isPolaritySwitch'
}
pcinfogetter = {
"accelerationLimit": "getAccelerationLimit",
"decelerationLimit": "getDecelerationLimit",
"didtMin": "getDidtMin",
"didtMax": "getDidtMax",
"iMinOp": "getIMinOp",
"iNom": "getINom",
"iPNo": "getIPNo",
"iUlt": "getIUlt",
"polaritySwitch": "isPolaritySwitch",
}
Context = namedtuple("Context", ["timestamp", "name", "user"])
......@@ -99,13 +104,14 @@ def _build_TrimHeader(th):
def _toJavaDate(t):
"""Date from string, datetime, unixtimestamp to java date
"""
"""Date from string, datetime, unixtimestamp to java date"""
Date = jpype.java.util.Date
if isinstance(t, str):
return jpype.java.sql.Timestamp.valueOf(t)
elif isinstance(t, datetime.datetime):
return jpype.java.sql.Timestamp.valueOf(t.strftime("%Y-%m-%d %H:%M:%S.%f"))
return jpype.java.sql.Timestamp.valueOf(
t.strftime("%Y-%m-%d %H:%M:%S.%f")
)
elif t is None:
return None
elif isinstance(t, Date):
......@@ -121,8 +127,12 @@ def _toJavaList(lst):
return res
class PyRbacTokenLike(typing.Protocol):
def encode(self) -> bytes: ...
class BaseLSAClient(object):
def __init__(self, server="gpn", system_properties : dict = {}):
def __init__(self, server="gpn", system_properties: dict = {}):
self._mgr = cmmnbuild_dep_manager.Manager()
self._mgr.jvm_required()
......@@ -144,8 +154,11 @@ class BaseLSAClient(object):
if configuredServer is None:
self._System.setProperty("lsa.server", server)
elif configuredServer != server:
raise RuntimeError("LSA is already configured to connect to server '%s'. "
"Please restart python to change server to '%s'." % (configuredServer, server))
raise RuntimeError(
"LSA is already configured to connect to server '%s'. "
"Please restart python to change server to '%s'."
% (configuredServer, server)
)
@contextmanager
def java_api(self):
......@@ -158,6 +171,7 @@ class BaseLSAClient(object):
with self._mgr.imports():
# work-around to fire up JPypes forward converters - TODO: remove me in JPype 0.8!
from java.util import HashSet, HashMap, ArrayList, Date
ArrayList()
HashMap()
HashSet()
......@@ -174,6 +188,7 @@ class LSAClient(BaseLSAClient):
self._lsa = self._cern.lsa
self._client = self._lsa.client
self._domain = self._lsa.domain
self._rbac = self._cern.rbac
# Java classes
self._ContextService = self._client.ContextService
......@@ -196,44 +211,91 @@ class LSAClient(BaseLSAClient):
self._Parameter = self._domain.settings.Parameter
self._ParameterSettings = self._domain.settings.ParameterSettings
self._Setting = self._domain.settings.Setting
self._StandAloneBeamProcess = self._domain.settings.StandAloneBeamProcess
self._StandAloneBeamProcess = (
self._domain.settings.StandAloneBeamProcess
)
self._Knob = self._domain.settings.Knob
self._FunctionSetting = self._domain.settings.spi.FunctionSetting
self._ScalarSetting = self._domain.settings.spi.ScalarSetting
self._RbaToken = self._rbac.common.RbaToken
self._ClientTierTokenHolder = (
self._rbac.util.holder.ClientTierTokenHolder
)
self._RbaTokenLookup = self._rbac.util.lookup.RbaTokenLookup
self._ParametersRequestBuilder = (
self._domain.settings.factory.ParametersRequestBuilder
)
self._Device = self._domain.devices.Device
self._DeviceRequestBuilder = self._domain.devices.factory.DevicesRequestBuilder
self._DeviceRequestBuilder = (
self._domain.devices.factory.DevicesRequestBuilder
)
self._ParameterTreesRequestBuilder = (
self._domain.settings.factory.ParameterTreesRequestBuilder
)
self._ParameterTreesRequest = self._domain.settings.ParameterTreesRequest
self._ParameterTreesRequest = (
self._domain.settings.ParameterTreesRequest
)
self._ParameterTreesRequestTreeDirection = (
self._ParameterTreesRequest.TreeDirection
)
self._CalibrationFunctionTypes = self._domain.optics.CalibrationFunctionTypes
self._CalibrationFunctionTypes = (
self._domain.optics.CalibrationFunctionTypes
)
# non lsa classes
self._CernAccelerator = self._cern.accsoft.commons.domain.CernAccelerator
self._CernAccelerator = (
self._cern.accsoft.commons.domain.CernAccelerator
)
self._contextService = self._ServiceLocator.getService(self._ContextService)
self._contextService = self._ServiceLocator.getService(
self._ContextService
)
self._trimService = self._ServiceLocator.getService(self._TrimService)
self._settingService = self._ServiceLocator.getService(self._SettingService)
self._parameterService = self._ServiceLocator.getService(self._ParameterService)
self._contextService = self._ServiceLocator.getService(self._ContextService)
self._settingService = self._ServiceLocator.getService(
self._SettingService
)
self._parameterService = self._ServiceLocator.getService(
self._ParameterService
)
self._contextService = self._ServiceLocator.getService(
self._ContextService
)
self._lhcService = self._ServiceLocator.getService(self._LhcService)
self._hyperCycleService = self._ServiceLocator.getService(
self._HyperCycleService
)
self._knobService = self._ServiceLocator.getService(self._KnobService)
self._opticService = self._ServiceLocator.getService(self._OpticService)
self._deviceService = self._ServiceLocator.getService(self._DeviceService)
self._fidelService = self._ServiceLocator.getService(self._FidelService)
self._opticService = self._ServiceLocator.getService(
self._OpticService
)
self._deviceService = self._ServiceLocator.getService(
self._DeviceService
)
self._fidelService = self._ServiceLocator.getService(
self._FidelService
)
def set_rbac_token(self, token: typing.Union[bytes, PyRbacTokenLike]):
if token:
if isinstance(token, bytes):
token_bytes = token
else:
token_bytes = token.encode()
self._doSetTokenToJava(token_bytes)
return
def _doSetTokenToJava(self, token_bytes: bytes) -> None:
token = self._RbaToken(token_bytes)
self._ClientTierTokenHolder.setRbaToken(token)
if self._RbaTokenLookup.findRbaToken() is None:
raise RuntimeError(
"Could not reuse RBAC token, maybe it has expired?"
)
def _getContextFamily(self, name):
if isinstance(name, str):
......@@ -259,20 +321,22 @@ class LSAClient(BaseLSAClient):
else:
return self._hyperCycleService.findHyperCycle(hypercycle)
def findOperationalContexts(self, accelerator: str = 'sps'):
def findOperationalContexts(self, accelerator: str = "sps"):
accelerator = self._getAccelerator(accelerator)
cycles = self._contextService.findStandAloneCycles(accelerator)
cycles = filter(lambda cyc: str(cyc.getContextCategory()) == 'OPERATIONAL', cycles)
cycles = filter(
lambda cyc: str(cyc.getContextCategory()) == "OPERATIONAL", cycles
)
return sorted(map(str, cycles))
def findResidentContexts(self, accelerator: str = 'sps'):
def findResidentContexts(self, accelerator: str = "sps"):
accelerator = self._getAccelerator(accelerator)
cycles = self._contextService.findResidentContexts(accelerator)
return sorted(map(str, cycles))
def findActiveContexts(self, accelerator: str = 'sps'):
def findActiveContexts(self, accelerator: str = "sps"):
accelerator = self._getAccelerator(accelerator)
cycles = self._contextService.findActiveContexts(accelerator)
......@@ -316,7 +380,10 @@ class LSAClient(BaseLSAClient):
return str(self._getHyperCycle().getResidentBeamProcess(category))
def getResidentBeamProcesses(self):
return [str(p) for p in list(self._getHyperCycle().getResidentBeamProcesses())]
return [
str(p)
for p in list(self._getHyperCycle().getResidentBeamProcesses())
]
def findParameterNames(self, deviceName=None, groupName=None, regexp=""):
req = self._ParametersRequestBuilder()
......@@ -341,7 +408,7 @@ class LSAClient(BaseLSAClient):
return list(map(str, deviceList))
def findUserContextMappingHistory(
self, t1, t2, accelerator="lhc", contextFamily="beamprocess"
self, t1, t2, accelerator="lhc", contextFamily="beamprocess"
):
acc = self._getAccelerator(accelerator)
contextFamily = self._getContextFamily(contextFamily)
......@@ -351,13 +418,19 @@ class LSAClient(BaseLSAClient):
acc, contextFamily, t1, t2
)
out = [
(ct.getMappingTimestamp() / 1000.0, ct.getContextName(), ct.getUser())
(
ct.getMappingTimestamp() / 1000.0,
ct.getContextName(),
ct.getUser(),
)
for ct in res
]
return Context(*map(np.array, zip(*out)))
def findBeamProcessHistory(self, t1, t2, accelerator="lhc"):
cts = self.findUserContextMappingHistory(t1, t2, accelerator=accelerator)
cts = self.findUserContextMappingHistory(
t1, t2, accelerator=accelerator
)
import pytimber
db = pytimber.LoggingDB()
......@@ -381,7 +454,9 @@ class LSAClient(BaseLSAClient):
lst = self._parameterService.findParameters(req.build())
return lst
def _getRawTrimHeadersByBeamprocess(self, param, beamprocess, start=None, end=None):
def _getRawTrimHeadersByBeamprocess(
self, param, beamprocess, start=None, end=None
):
bp = self._getBeamProcess(beamprocess)
thrb = self._cern.lsa.domain.settings.TrimHeadersRequestBuilder()
thrb.beamProcesses(self._java.util.Collections.singleton(bp))
......@@ -439,7 +514,9 @@ class LSAClient(BaseLSAClient):
param.add(self._getParameter(pp))
return param
def _getTrimHeadersByBeamprocess(self, parameter, beamprocess, start=None, end=None):
def _getTrimHeadersByBeamprocess(
self, parameter, beamprocess, start=None, end=None
):
return [
_build_TrimHeader(th)
for th in self._getRawTrimHeadersByBeamprocess(
......@@ -456,33 +533,37 @@ class LSAClient(BaseLSAClient):
]
def getTrimHeaders(
self, parameter, beamprocess=None, cycle=None, start=None, end=None
self, parameter, beamprocess=None, cycle=None, start=None, end=None
):
if beamprocess is not None:
return self._getTrimHeadersByBeamprocess(parameter,
beamprocess=beamprocess,
start=start, end=end)
return self._getTrimHeadersByBeamprocess(
parameter, beamprocess=beamprocess, start=start, end=end
)
else:
return self._getTrimHeadersByCycle(parameter,
cycle=cycle,
start=start, end=end)
return self._getTrimHeadersByCycle(
parameter, cycle=cycle, start=start, end=end
)
def _getTrimsByBeamprocess(
self, parameter, beamprocess, start=None, end=None, part="value"
self, parameter, beamprocess, start=None, end=None, part="value"
):
parameterList = self._buildParameterList(parameter)
bp = self._getBeamProcess(beamprocess)
timestamps = {}
values = {}
for th in self._getRawTrimHeadersByBeamprocess(parameterList, bp, start, end):
for th in self._getRawTrimHeadersByBeamprocess(
parameterList, bp, start, end
):
csrb = (
self._cern.lsa.domain.settings.ContextSettingsRequestBuilder()
)
csrb.standAloneContext(bp)
csrb.parameters(parameterList)
csrb.at(th.getCreatedDate().toInstant())
contextSettings = self._settingService.findContextSettings(csrb.build())
contextSettings = self._settingService.findContextSettings(
csrb.build()
)
for pp in parameterList:
parameterSetting = contextSettings.getParameterSettings(pp)
if parameterSetting is None:
......@@ -497,7 +578,9 @@ class LSAClient(BaseLSAClient):
elif part == "target":
value = setting.getTargetScalarValue().getDouble()
elif part == "correction":
value = setting.getCorrectionScalarValue().getDouble()
value = (
setting.getCorrectionScalarValue().getDouble()
)
else:
raise ValueError("Invalid Setting Part: " + part)
elif type(setting) is self._FunctionSetting:
......@@ -524,7 +607,7 @@ class LSAClient(BaseLSAClient):
return out
def _getTrimsByCycle(
self, parameter, cycle, start=None, end=None, part="value"
self, parameter, cycle, start=None, end=None, part="value"
):
parameterList = self._buildParameterList(parameter)
cy = self._getCycle(cycle)
......@@ -532,13 +615,15 @@ class LSAClient(BaseLSAClient):
timestamps = {}
values = {}
for th in self._getRawTrimHeadersByCycle(
parameterList, cy, start, end
parameterList, cy, start, end
):
csrb = self._domain.settings.ContextSettingsRequestBuilder()
csrb.standAloneContext(cy)
csrb.parameters(parameterList)
csrb.at(th.getCreatedDate().toInstant())
contextSettings = self._settingService.findContextSettings(csrb.build())
contextSettings = self._settingService.findContextSettings(
csrb.build()
)
for pp in parameterList:
parameterSetting = contextSettings.getParameterSettings(pp)
if parameterSetting is None:
......@@ -557,7 +642,9 @@ class LSAClient(BaseLSAClient):
elif part == "target":
value = setting.getTargetScalarValue().getDouble()
elif part == "correction":
value = setting.getCorrectionScalarValue().getDouble()
value = (
setting.getCorrectionScalarValue().getDouble()
)
else:
raise ValueError("Invalid Setting Part: " + part)
elif type(setting) is self._FunctionSetting:
......@@ -584,18 +671,26 @@ class LSAClient(BaseLSAClient):
return out
def getTrims(
self, parameter, beamprocess=None, cycle=None, start=None, end=None, part="value"
self,
parameter,
beamprocess=None,
cycle=None,
start=None,
end=None,
part="value",
):
if beamprocess is not None:
return self._getTrimsByBeamprocess(parameter,
beamprocess=beamprocess,
start=start, end=end,
part=part)
return self._getTrimsByBeamprocess(
parameter,
beamprocess=beamprocess,
start=start,
end=end,
part=part,
)
else:
return self._getTrimsByCycle(parameter,
cycle=cycle,
start=start, end=end,
part=part)
return self._getTrimsByCycle(
parameter, cycle=cycle, start=start, end=end, part=part
)
def _getLastTrimByBeamprocess(self, parameter, beamprocess, part="value"):
th = self._getTrimHeadersByBeamprocess(parameter, beamprocess)[-1]
......@@ -612,18 +707,18 @@ class LSAClient(BaseLSAClient):
return TrimTuple(res.time[-1], res.data[-1])
def getLastTrim(
self, parameter, beamprocess=None, cycle=None, part="value"
self, parameter, beamprocess=None, cycle=None, part="value"
):
if beamprocess is not None:
return self._getLastTrimByBeamprocess(parameter,
beamprocess=beamprocess,
part=part)
return self._getLastTrimByBeamprocess(
parameter, beamprocess=beamprocess, part=part
)
else:
return self._getLastTrimByCycle(parameter,
cycle=cycle,
part=part)
return self._getLastTrimByCycle(parameter, cycle=cycle, part=part)
def _getLastTrimValueByBeamprocess(self, parameter, beamprocess, part="value"):
def _getLastTrimValueByBeamprocess(
self, parameter, beamprocess, part="value"
):
th = self._getTrimHeadersByBeamprocess(parameter, beamprocess)[-1]
res = self._getTrimsByBeamprocess(
parameter, beamprocess, part=part, start=th.createdDate
......@@ -638,16 +733,16 @@ class LSAClient(BaseLSAClient):
return res.data[-1]
def getLastTrimValue(
self, parameter, beamprocess=None, cycle=None, part="value"
self, parameter, beamprocess=None, cycle=None, part="value"
):
if beamprocess is not None:
return self._getLastTrimValueByBeamprocess(parameter,
beamprocess=beamprocess,
part=part)
return self._getLastTrimValueByBeamprocess(
parameter, beamprocess=beamprocess, part=part
)
else:
return self._getLastTrimValueByCycle(parameter,
cycle=cycle,
part=part)
return self._getLastTrimValueByCycle(
parameter, cycle=cycle, part=part
)
def getOpticTable(self, beamprocess):
bp = self._getBeamProcess(beamprocess)
......@@ -657,7 +752,9 @@ class LSAClient(BaseLSAClient):
0
].getOpticsTableItems()
return [
OpticTableItem(time=o.getTime(), id=o.getOpticId(), name=o.getOpticName())
OpticTableItem(
time=o.getTime(), id=o.getOpticId(), name=o.getOpticName()
)
for o in opticTable
]
......@@ -675,9 +772,13 @@ class LSAClient(BaseLSAClient):
self._ParameterTreesRequestTreeDirection.DEPENDENT_TREE
)
elif direction == "source":
req.setTreeDirection(self._ParameterTreesRequestTreeDirection.SOURCE_TREE)
req.setTreeDirection(
self._ParameterTreesRequestTreeDirection.SOURCE_TREE
)
else:
raise ValueError('invalid direction, expecting "dependent" or "source"')
raise ValueError(
'invalid direction, expecting "dependent" or "source"'
)
req.setParameter(self._getParameter(parameter))
tree = self._parameterService.findParameterTrees(req.build())
params = {}
......@@ -689,7 +790,9 @@ class LSAClient(BaseLSAClient):
def getOpticStrength(self, optic):
if not hasattr(optic, "name"):
optic = self._opticService.findOpticByName(optic)
out = [(st.logicalHWName, st.strength) for st in optic.getOpticStrengths()]
out = [
(st.logicalHWName, st.strength) for st in optic.getOpticStrengths()
]
return dict(out)
def _getOptics(self, name):
......@@ -717,14 +820,18 @@ class LSAClient(BaseLSAClient):
pcname = pcs[madname]
if full is True:
nl = self._java.util.Collections.singleton(pcname)
pcs = self._deviceService.findActualDevicesByLogicalHardwareName(nl)
pcs = self._deviceService.findActualDevicesByLogicalHardwareName(
nl
)
pcname = list(pcs[pcname])[0].toString()
return pcname
def findMadStrengthNameByPCName(self, pcname, full=False):
if full is True:
nl = self._java.util.Collections.singleton(pcname)
pcs = self._deviceService.findLogicalHardwaresByActualDeviceNames(nl)
pcs = self._deviceService.findLogicalHardwaresByActualDeviceNames(
nl
)
pcname = list(pcs[pcname])[0].toString()
nl = self._java.util.Collections.singleton(pcname)
madnames = self._deviceService.findMadStrengthNamesByLogicalNames(nl)
......@@ -747,23 +854,30 @@ class LSAClient(BaseLSAClient):
return dict((cc.getName(), cc) for cc in cals)
def dump_calibrations(self, outdir="calib"):
""" Dump all calibration in directory <outdir>
"""
"""Dump all calibration in directory <outdir>"""
os.mkdir(outdir)
for name, cc in self._get_calibrations():
ff = cc.getCalibrationFunctionByType(self._CalibrationFunctionTypes.B_FIELD)
ff = cc.getCalibrationFunctionByType(
self._CalibrationFunctionTypes.B_FIELD
)
if ff is not None:
field = ff.toXArray()
current = ff.toYArray()
fn = os.path.join(outdir, "%s.txt" % name)
print(fn)
fh = open(fn, "w")
fh.write("\n".join(["%s %s" % (i, f) for i, f in zip(current, field)]))
fh.write(
"\n".join(
["%s %s" % (i, f) for i, f in zip(current, field)]
)
)
fh.close()
def getPCInfo(self, pcname):
pc = self._deviceService.findPowerConverterInfo(pcname)
info = PCInfo(*(getattr(pc, pcinfogetter[nn])() for nn in PCInfo._fields))
info = PCInfo(
*(getattr(pc, pcinfogetter[nn])() for nn in PCInfo._fields)
)
return info
......@@ -772,6 +886,7 @@ class LSAClientGSI(BaseLSAClient):
Supports production, integration and development servers.
"""
def __init__(self, server="gsi-dev"):
"""Connect to either production, integration or development
LSA server at GSI. The server variable should correspondingly
......@@ -785,6 +900,6 @@ class LSAClientGSI(BaseLSAClient):
cfg_url = "https://websvcdev.acc.gsi.de/groups/cscoap/config/"
else:
raise ValueError(
'server should be one of "gsi-pro", "gsi-int" or "gsi-dev"')
'server should be one of "gsi-pro", "gsi-int" or "gsi-dev"'
)
super().__init__(server, {"csco.default.property.config.url": cfg_url})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment