diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 45037a52725db152566cfc89e86554db739943cc..baae9cce59054fa6b14cec380c3393f70b20d0ce 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -36,9 +36,9 @@ release_sdist: - mkdir not-the-source-dir && cd not-the-source-dir - pip install -e .. - python -c "import pyjapc; print(pyjapc.__version__)" - - pip install stubgenj -e ..[test] + - pip install stubgenj -e ..[test] types-pytz - python -m cmmnbuild_dep_manager resolve - - python -m stubgenj cern java.lang java.util org.mockito + - python -m stubgenj cern java.lang java.util org.mockito org.apache --classpath $(python -m cmmnbuild_dep_manager class_path) --output-dir $(python -c 'import site; print(site.getsitepackages()[0])') # NOTE: For successful mypy test, we must use Python >=3.8 diff --git a/docs/conf.py b/docs/conf.py index b0b37726da7da680dc7ed722cadf1ea4f0156dac..bbd4f05f4aea2d8d5be526df415ba4bff9b4378e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -8,6 +8,7 @@ extensions = [ 'sphinx.ext.viewcode', 'sphinx.ext.napoleon', 'sphinx.ext.autosummary', + 'sphinx.ext.autodoc.typehints', ] @@ -44,4 +45,6 @@ exclude_patterns = ['build'] html_theme = 'acc_py' -autoclass_content = 'both' \ No newline at end of file +autoclass_content = 'both' + +autodoc_typehints = "description" diff --git a/pyjapc/_japc.py b/pyjapc/_japc.py index 29e1d48d6c1ba53ad2808cbb21eff5c3765690dc..09499d4d946e7782182f8909aa8a8a3f877bca99 100644 --- a/pyjapc/_japc.py +++ b/pyjapc/_japc.py @@ -1,6 +1,7 @@ """PyJapc is a Python to FESA/LSA/INCA interface via JAPC. """ +from __future__ import annotations import datetime import base64 @@ -18,6 +19,11 @@ import cmmnbuild_dep_manager import jpype as jp import numpy as np +if typing.TYPE_CHECKING: + # Import the JPype CERN package for type-checking only (not importable at runtime). + import cern + + from . import _types as pyjapc_types _JPrimitiveTypes = ( jp.JChar, jp.JString, jp.JBoolean, # type: ignore @@ -30,18 +36,21 @@ import pytz from . import _jpype_utils - -_INSTANCE_DEFAULT = object() +# Define a sentinel which can be used as an alternative to None for +# keyword argument defaults (thereby allowing None to have a meaning for those +# arguments). This is typically used when you want to be able to distinguish +# "the default argument" from "the argument value is None". +_INSTANCE_DEFAULT: None = object() # type: ignore class PyJapc: def __init__( self, - selector="LHC.USER.ALL", - incaAcceleratorName="auto", - noSet=False, - timeZone="utc", - logLevel=None, + selector: pyjapc_types.Selector = "LHC.USER.ALL", + incaAcceleratorName: typing.Optional[str] = "auto", + noSet: bool = False, + timeZone: typing.Optional[pyjapc_types.Timezone] = "utc", + logLevel: typing.Optional[pyjapc_types.LogLevel] = None, ) -> None: """Start a JVM and load up the JAPC classes. @@ -61,7 +70,8 @@ class PyJapc: Injector Control Architecture (InCA) framework. It can be any of ``AD``, ``CTF3``, ``ISOLDE``, ``LEIR``, ``LHC``, ``LINAC4``, ``NORTH``, ``PS``, - ``PSB``, ``SCT``, ``SPS`` or can be even an empty string (``""``). + ``PSB``, ``SCT``, ``SPS`` or can be even an empty string (``""``). The valid values come + from the ``cern.accsoft.commons.domain.CernAccelerator`` class. You might need to call :meth:`rbacLogin()` to make full use of InCA. @@ -90,10 +100,13 @@ class PyJapc: # Dictionaries for caching often used objects # -------------------------------------------------------------------- # For caching all the "ParameterObjects" ever created - self._paramDict = dict() + self._paramDict: typing.Dict[str, pyjapc_types.ParameterTypes] = {} # For caching all the "SubscriptionHandles" ever created - self._subscriptionHandleDict = dict() + self._subscriptionHandleDict: typing.Dict[ + str, + typing.List[pyjapc_types.SubscriptionTypes], + ] = {} self._java_gc = _jpype_utils.JavaGCCollector() @@ -166,6 +179,7 @@ class PyJapc: # -------------------------------------------------------------------- # Store some user settings # -------------------------------------------------------------------- + self._selectedTimezone: typing.Optional[datetime.tzinfo] if timeZone == "utc": self._selectedTimezone = pytz.utc elif timeZone == "local": @@ -179,7 +193,6 @@ class PyJapc: if self._noSet: self.log.info("No SETs will be made as noSet=True") - # INCA accelerator name from timing domain lookup _incaAccFromTiming = { "ADE": "AD", @@ -206,7 +219,7 @@ class PyJapc: # This behaviour must remain consistent for PyJapc v2. self.rbacLogout() - def _setup_jvm(self, log_level): + def _setup_jvm(self, log_level: typing.Optional[pyjapc_types.LogLevel]) -> None: """Startup the JVM and the connection to Python (JPype).""" cmmnbuild_dep_manager.Manager(lvl=log_level).jvm_required() @@ -224,6 +237,7 @@ class PyJapc: # Map Python log levels (10, 20, ...) to Java levels (10000, 20000, ...) # but fall back to just passing the given level if we don't know what to do with it. if log_level in known_levels: + assert isinstance(log_level, int) java_level = known_levels[log_level] else: java_level = log4j.Level.toLevel(log_level) @@ -232,22 +246,25 @@ class PyJapc: log4j.Logger.getRootLogger().setLevel(log4j.Level.WARN) # Enable `*` Wildcard selectors (see https://wikis.cern.ch/display/JAPC/Wildcard+Selectors) - jp.java.lang.System.setProperty("default.wildcard.subscription.on", "true") + java = jp.JPackage('java') + java.lang.System.setProperty("default.wildcard.subscription.on", "true") @staticmethod - def enableInThisThread(): + def enableInThisThread() -> None: """Allows PyJapc object to be used in threads other than the main one. For more details please see here: http://jpype.readthedocs.io/en/latest/userguide.html#threading """ - jp.attachThreadToJVM() + jp.attachThreadToJVM() # type: ignore + # Type ignore caused by stubgenj overlap of types? + # See also https://gitlab.cern.ch/scripting-tools/stubgenj/-/issues/3 def _giveMeSelector( - self, - timingSelectorOverride=_INSTANCE_DEFAULT, - dataFilterOverride=_INSTANCE_DEFAULT, - ): + self, + timingSelectorOverride: typing.Optional[pyjapc_types.Selector] = _INSTANCE_DEFAULT, + dataFilterOverride: typing.Optional[pyjapc_types.DataFilter] = _INSTANCE_DEFAULT, + ) -> cern.japc.core.Selector: """Produce and return a JAPC selector object with certain overrides. Args: @@ -267,6 +284,12 @@ class PyJapc: else: timingSelector = self._selector.getId() + dataFilter: typing.Optional[ + typing.Union[ + pyjapc_types.DataFilter, + cern.japc.value.ParameterValue, + ] + ] if datafilter_set: dataFilter = dataFilterOverride else: @@ -288,7 +311,11 @@ class PyJapc: else: return self._selector - def setSelector(self, timingSelector, dataFilter=None): + def setSelector( + self, + timingSelector: pyjapc_types.Selector, + dataFilter: typing.Optional[pyjapc_types.DataFilter] = None, + ) -> None: """Sets the default selector and filter used for GET/SET/SUBS. This selector and filter is used if you don't specify an override in @@ -324,20 +351,20 @@ class PyJapc: """Return the current timing selector of the PyJapc instance""" return self._selector.getId() - def setDataFilter(self, dataFilter: typing.Dict[str, typing.Any]) -> None: + def setDataFilter(self, dataFilter: pyjapc_types.DataFilter) -> None: """Set the data filter for this PyJapc instance""" self._selector = self._giveMeSelector(dataFilterOverride=dataFilter) - def getDataFilter(self) -> typing.Dict[str, typing.Any]: + def getDataFilter(self) -> typing.Optional[pyjapc_types.DataFilter]: """Return the current data filter for this PyJapc instance""" - data_filter = self._selector.getDataFilter() - if data_filter: - data_filter = self._convertValToPy(data_filter) + j_data_filter = self._selector.getDataFilter() + if j_data_filter: + data_filter = self._convertValToPy(j_data_filter) else: data_filter = None return data_filter - def getUsers(self, machine) -> typing.List[str]: + def getUsers(self, machine: str) -> typing.List[str]: """Get a list of users for a particular machine. Note that you may need to set the environment variable TGM_NETWORK @@ -348,12 +375,20 @@ class PyJapc: """ cern = jp.JPackage("cern") TgmUtil = cern.japc.ext.tgm.TgmUtil - users = [user for user in TgmUtil.getLinesforMachineGroup(machine, TgmUtil.USER)] + users: typing.List[pyjapc_types.JString] = [ + user for user in TgmUtil.getLinesforMachineGroup(machine, TgmUtil.USER) + ] if not users: users = ["ALL"] return list(map(lambda x: '{0}.{1}.{2}'.format(machine, TgmUtil.USER, x), users)) - def rbacLogin(self, username=None, password=None, loginDialog=False, readEnv=True): + def rbacLogin( + self, + username: typing.Optional[str] = None, + password: typing.Optional[str] = None, + loginDialog: bool = False, + readEnv: bool = True + ) -> None: """Perform RBAC authentication. This is required to work with access-protected parameters. @@ -399,7 +434,9 @@ class PyJapc: if readEnv and env: try: self.log.info("Reusing RBAC token from environment") - token = cern.rbac.common.RbaToken(base64.b64decode(env)) + # type ignore because of bug in stubgenj (bytes -> List[int]) + token_bytes: typing.List[int] = base64.b64decode(env) # type: ignore + token = cern.rbac.common.RbaToken(token_bytes) cern.rbac.util.holder.ClientTierTokenHolder.setRbaToken(token) if cern.rbac.util.lookup.RbaTokenLookup.findRbaToken() is None: @@ -431,7 +468,12 @@ class PyJapc: self._rbaLoginService = None raise e - def _doLogin(self, byLoc, username, password): + def _doLogin( + self, + byLoc: bool, + username: typing.Optional[str], + password: typing.Optional[str], + ) -> None: cern = jp.JPackage("cern") self.rbacLogout() loginBuilder = cern.rbac.util.authentication.LoginServiceBuilder.newInstance() @@ -442,6 +484,7 @@ class PyJapc: loginBuilder.loginPolicy(cern.rbac.common.authentication.LoginPolicy.LOCATION) else: self.log.info("Performing explicit RBAC login as {0}".format(username)) + assert username is not None and password is not None loginBuilder.loginPolicy(cern.rbac.common.authentication.LoginPolicy.EXPLICIT) loginBuilder.userName(username) loginBuilder.userPassword(password) @@ -449,7 +492,7 @@ class PyJapc: self._rbaLoginService.loginNewUser() self.log.info("RBAC login successful") - def rbacLogout(self): + def rbacLogout(self) -> None: """Ends your RBAC session (if one is open) and returns your token.""" cern = jp.JPackage("cern") if self._rbaLoginService is not None: @@ -460,16 +503,23 @@ class PyJapc: self._rbaLoginService = None cern.rbac.util.holder.ClientTierTokenHolder.clear() - def rbacGetToken(self): + def rbacGetToken(self) -> cern.rbac.common.RbaToken: """Returns the RBAC token as a Java object (``cern.rbac.common.RbaToken``).""" cern = jp.JPackage("cern") return cern.rbac.util.lookup.RbaTokenLookup.findRbaToken() def rbacGetSerializedToken(self) -> str: """Returns a Base64 encoded serialization of the RBAC token.""" - return base64.b64encode(np.array(self.rbacGetToken().getEncoded(), dtype=np.uint8)).decode() + token = np.array(self.rbacGetToken().getEncoded(), dtype=np.uint8) + # Numpy array types don't yet fit well into base64.b64encode according to mypy. + # Trick it into believing it is just bytes until it knows better. + token_bytes: bytes = typing.cast(bytes, token) # type: ignore + return base64.b64encode(token_bytes).decode() - def _getDictKeyFromParameterName(self, parameterName): + def _getDictKeyFromParameterName( + self, + parameterName: typing.Union[str, pyjapc_types.SequenceOfStrings], + ) -> str: """parameterName can be a string or a list of strings returns a unique identifier, which can be used as dict key """ @@ -486,6 +536,18 @@ class PyJapc: )) return parameterKey + @typing.overload + def _getJapcPar( + self, + parameterName: str, + ) -> cern.japc.core.transaction.TransactionalParameter: ... + + @typing.overload + def _getJapcPar( + self, + parameterName: pyjapc_types.SequenceOfStrings, + ) -> cern.japc.core.group.ParameterGroup: ... + def _getJapcPar(self, parameterName): """Create the JAPC parameter object and return it. @@ -517,15 +579,15 @@ class PyJapc: def getParam( self, - parameterName, - getHeader=False, - noPyConversion=False, - unixtime=False, - onValueReceived=None, - onException=None, - timingSelectorOverride=_INSTANCE_DEFAULT, - dataFilterOverride=_INSTANCE_DEFAULT, - ): + parameterName: typing.Union[str, pyjapc_types.SequenceOfStrings], + getHeader: bool = False, + noPyConversion: bool = False, + unixtime: bool = False, + onValueReceived: typing.Optional[typing.Callable[[typing.Any], None]] = None, + onException: typing.Optional[typing.Callable[[str, str, typing.Any], None]] = None, + timingSelectorOverride: typing.Optional[pyjapc_types.Selector] = _INSTANCE_DEFAULT, + dataFilterOverride: typing.Optional[pyjapc_types.DataFilter] = _INSTANCE_DEFAULT, + ) -> typing.Any: """Fetch the value of a single FESA parameter or of a FESA ParameterGroup Args: @@ -649,14 +711,16 @@ class PyJapc: onException=onException) p.getValue(s, listener) - def setParam(self, - parameterName, - parameterValue, - checkDims=True, - dtype=None, - timingSelectorOverride=_INSTANCE_DEFAULT, - dataFilterOverride=_INSTANCE_DEFAULT, - ): + def setParam( + self, + parameterName: str, + # TODO: There are more types needed here. + parameterValue: typing.Union[bool, int, float, str, np.ndarray], + checkDims: bool = True, + dtype: typing.Optional[typing.Union[str, typing.Type]] = None, + timingSelectorOverride: typing.Optional[pyjapc_types.Selector] = _INSTANCE_DEFAULT, + dataFilterOverride: typing.Optional[pyjapc_types.DataFilter] = _INSTANCE_DEFAULT, + ) -> None: """Set the value of a device parameter. Args: @@ -778,7 +842,7 @@ class PyJapc: # Carry out the actual set (if not in safemode) # -------------------------------------------------------------------- if self._noSet: - self.log.warning("{0} would be set to:\n{1}".format(parameterName, parValNew.toString())) + self.log.warning("{0} would be set to:\n{1}".format(parameterName, str(parValNew))) else: s = self._giveMeSelector( timingSelectorOverride=timingSelectorOverride, @@ -817,7 +881,13 @@ class PyJapc: return (val, head) return val - def _convertPyToVal(self, pyVal, vdesc=None, checkDims=True, dtype=None): + def _convertPyToVal( + self, + pyVal: typing.Any, + vdesc: typing.Optional[cern.japc.value.ValueDescriptor] = None, + checkDims: bool = True, + dtype: typing.Any = None, + ) -> cern.japc.value.ParameterValue: """Converts anything Python (also dict()) to anything JAPC. It tries to do an array dimension check if vdesc is provided @@ -849,11 +919,13 @@ class PyJapc: # Also check array dimensions in Python # Can be MAP or SIMPLE if vdesc.getType().toString() == "Simple": + assert isinstance(vdesc, cern.japc.value.SimpleDescriptor) # Check input array shape against FESA if checkDims: self._checkDimVsJAPC(pyVal, vdesc) parValNew = self._convertPyToSimpleVal(pyVal, vdesc, dtype=dtype) elif vdesc.getType().toString() == "Map": + assert isinstance(vdesc, cern.japc.value.MapDescriptor) # Create a new MAP parValNew = self._mapParameterValueFactory.newValue() @@ -879,7 +951,11 @@ class PyJapc: return parValNew - def getParamInfo(self, parameterName, noPyConversion=False): + def getParamInfo( + self, + parameterName: str, + noPyConversion: bool = False, + ) -> typing.Union[str, cern.japc.core.ParameterDescriptor]: """Return a string description of the parameter. Args: @@ -896,19 +972,24 @@ class PyJapc: if noPyConversion: return p.getParameterDescriptor() else: - return p.getParameterDescriptor().toString() + return str(p.getParameterDescriptor()) def subscribeParam( self, - parameterName, - onValueReceived=None, - onException=None, - getHeader=False, - noPyConversion=False, - unixtime=False, - timingSelectorOverride=_INSTANCE_DEFAULT, - dataFilterOverride=_INSTANCE_DEFAULT, - ): + parameterName: typing.Union[str, pyjapc_types.SequenceOfStrings], + onValueReceived: typing.Optional[ + typing.Callable[ + [str, typing.Any, typing.Optional[typing.Dict[str, typing.Any]]], + None, + ], + ] = None, + onException: typing.Optional[typing.Callable[[str, str, typing.Any], None]] = None, + getHeader: bool = False, + noPyConversion: bool = False, + unixtime: bool = False, + timingSelectorOverride: typing.Optional[pyjapc_types.Selector] = _INSTANCE_DEFAULT, + dataFilterOverride: typing.Optional[pyjapc_types.DataFilter] = _INSTANCE_DEFAULT, + ) -> pyjapc_types.SubscriptionTypes: """Subscribe to a Parameter with a Python callback function. Args: @@ -1030,13 +1111,14 @@ class PyJapc: def getNextParamValue( self, - parameterName, *, - getHeader=False, - timingSelectorOverride=_INSTANCE_DEFAULT, - dataFilterOverride=_INSTANCE_DEFAULT, - n_values=None, + parameterName: str, + *, + getHeader: bool = False, + timingSelectorOverride: typing.Optional[pyjapc_types.Selector] = _INSTANCE_DEFAULT, + dataFilterOverride: typing.Optional[pyjapc_types.DataFilter] = _INSTANCE_DEFAULT, + n_values: typing.Optional[int] = None, timeout: float = 0, - ): + ) -> typing.Any: """Return the first non-first-update value that is acquired when subscribing to the given parameter. In order to get the *current* value use the :meth:`getParam` method. @@ -1129,7 +1211,11 @@ class PyJapc: raise TimeoutError('No new value available in the given time') time.sleep(0.0001) - def stopSubscriptions(self, parameterName=None, selector=None): + def stopSubscriptions( + self, + parameterName: typing.Optional[str] = None, + selector: typing.Optional[pyjapc_types.Selector] = None, + ) -> None: """Stop Monitoring on all previously subscribed parameters. Args: @@ -1141,7 +1227,11 @@ class PyJapc: for handler in self._filterSubscriptions(parameterName, selector): handler.stopMonitoring() - def clearSubscriptions(self, parameterName=None, selector=None): + def clearSubscriptions( + self, + parameterName: typing.Optional[str] = None, + selector: typing.Optional[pyjapc_types.Selector] = None, + ) -> None: """Clear the internal list of subscription handles. Call this to avoid that :meth:`startSubscriptions()` starts old and unwanted @@ -1167,19 +1257,29 @@ class PyJapc: # Python GC is done with them. self._java_gc.trigger() - def startSubscriptions(self, parameterName=None, selector=None): + def startSubscriptions( + self, + parameterName: typing.Optional[str] = None, + selector: typing.Optional[pyjapc_types.Selector] = None, + ) -> None: """Start Monitoring on all previously Subscribed Parameters. Args: parameterName (Optional[str]): If not ``None``, only the subscription - of this particular parameter will restarted. + of this particular parameter will started. selector (Optional[str]): If not ``None``, it augments the parameterName to stop subscription for the particular selector only. """ for handler in self._filterSubscriptions(parameterName, selector): handler.startMonitoring() - def _filterSubscriptions(self, parameterName, selector): + def _filterSubscriptions( + self, + parameterName: typing.Optional[str] = None, + selector: typing.Optional[pyjapc_types.Selector] = None, + ) -> typing.Generator[pyjapc_types.SubscriptionTypes, None, None]: + # TODO: We should support parameterName being a list of strings, + # and transform the argument into the proper key with _getDictKeyFromParameterName. if parameterName is not None: if selector is not None: key = self._transformSubscribeCacheKey( @@ -1390,17 +1490,31 @@ class PyJapc: return jValue - def _getSimpleValFromDesc(self, valueDescriptor): + def _getSimpleValFromDesc( + self, + valueDescriptor: typing.Optional[cern.japc.value.SimpleDescriptor], + ) -> cern.japc.value.SimpleParameterValue: """Return an empty `SimpleParameterValue` of the same type as `valueDescriptor` This can be filled with a value and then handed to a `ParameterValue` to do a `SET` with .setValue() """ cern = jp.JPackage("cern") - vdWrapper = jp.JObject(valueDescriptor, cern.japc.value.SimpleDescriptor) + if valueDescriptor is None: + # Type ignore reason because of jp.JObject. Problem with stubgenj + # https://gitlab.cern.ch/scripting-tools/stubgenj/-/issues/3. + # TODO: Decide if we really want to be building a null object here. + vdWrapper = jp.JObject(valueDescriptor, cern.japc.value.SimpleDescriptor) # type: ignore + else: + vdWrapper = valueDescriptor parValNew = self._simpleParameterValueFactory.newValue(vdWrapper) return parValNew - def _convertPyToSimpleVal(self, pyVal, valueDescriptor=None, dtype=None): + def _convertPyToSimpleVal( + self, + pyVal: typing.Any, + valueDescriptor: typing.Optional[cern.japc.value.SimpleDescriptor] = None, + dtype: typing.Any = None, + ) -> cern.japc.value.SimpleParameterValue: """Convert a numpy array/primitive to a JAPC SimpleParameterValue of different types. @@ -1423,6 +1537,8 @@ class PyJapc: cern = jp.JPackage("cern") + parValNew: cern.japc.value.SimpleParameterValue # type: ignore + # -------------------------------------------------------------------- # Special case: Numpy array to JAPC DiscreteFunction(List) # -------------------------------------------------------------------- @@ -1433,12 +1549,17 @@ class PyJapc: parValNew = cern.japc.value.spi.value.simple.DiscreteFunctionValue(df) elif ts == "DiscreteFunctionList": - # Allcoate JArray for DFs - dfa = jp.JArray(cern.japc.value.DiscreteFunction)(len(pyVal)) + # Allocate JArray for DFs + # Type ignore because of https://gitlab.cern.ch/scripting-tools/stubgenj/-/issues/3. + dfa = jp.JArray(cern.japc.value.DiscreteFunction)(len(pyVal)) # type: ignore # Iterate over first dimension of user data for i, funcDat in enumerate(pyVal): funcDat2 = np.array(funcDat, dtype="double") - dfa[i] = self._functionFactory.newDiscreteFunction(funcDat2[0, :], funcDat2[1, :]) + # Type ignore because stubgenj assumes a array[double] is List[double] + dfa[i] = self._functionFactory.newDiscreteFunction( + funcDat2[0, :], # type: ignore + funcDat2[1, :], # type: ignore + ) dfl = self._functionFactory.newDiscreteFunctionList(dfa) parValNew = cern.japc.value.spi.value.simple.DiscreteFunctionListValue(dfl) @@ -1463,7 +1584,10 @@ class PyJapc: return parValNew - def _convertPyToSimpleValFallback(self, pyVal): + def _convertPyToSimpleValFallback( + self, + pyVal: typing.Any, + ) -> cern.japc.value.SimpleParameterValue: """Conv. numpy array/primitive to a JAPC SimpleParameterValue We will guess what kind of `SimpleParameterValue` object to produce by looking at the Python type the user has provided @@ -1491,18 +1615,23 @@ class PyJapc: # Use complex numpy type as input e.g. float32 javaVarType = self._getJavaValue(pyVal.dtype.type, None) + # Type ignore because of https://gitlab.cern.ch/scripting-tools/stubgenj/-/issues/3. + arr_type = jp.JArray(javaVarType, 1) # type: ignore + # Convert the numpy array to a list and then to a 1D JArray (flattened) # Note that at some point JPype will be able to digest # numpy arrays directly and .tolist() will not be needed any more (faster) # 11.1.16: Checked and array conversion still does not work with JPype 0.6.1 - jArrayValues = jp.JArray(javaVarType, 1)(pyVal.flatten().tolist()) + jArrayValues = arr_type(pyVal.flatten().tolist()) if pyVal.ndim == 1: # Create the shiny new 1D JAPC ParameterValue object parValNew = self._simpleParameterValueFactory.newValue(jArrayValues) elif pyVal.ndim == 2: + # Type ignore because of https://gitlab.cern.ch/scripting-tools/stubgenj/-/issues/3. + arr_type = jp.JArray(jp.JInt) # type: ignore # Store array shape in JAPC friendly format - jArrayShape = jp.JArray(jp.JInt)(pyVal.shape) + jArrayShape = arr_type(pyVal.shape) # Create the shiny new 2D JAPC ParameterValue object parValNew = self._simpleParameterValueFactory.newValue(jArrayValues, jArrayShape) elif pyVal.ndim == 0: @@ -1615,19 +1744,30 @@ class PyJapc: return val - def _convertValToPy(self, val): + @typing.overload + def _convertValToPy(self, val: cern.japc.value.MapParameterValue) -> typing.Dict[str, typing.Any]: ... + + @typing.overload + def _convertValToPy(self, val: cern.japc.value.ParameterValue) -> typing.Any: + ... + + def _convertValToPy(self, val: cern.japc.value.ParameterValue) -> typing.Any: """Convert the Java JAPC ParameterValue Map or SimpleParameter object to a Python equivalent. """ if val is None: return None + cern = jp.JPackage("cern") + # TODO: Use isinstance here, rather than type-checking by string. t = val.getType().toString() # Can be "Map" or "Simple" if t == "Simple": + assert isinstance(val, cern.japc.value.SimpleParameterValue) return self._convertSimpleValToPy(val) elif t == "Map": + assert isinstance(val, cern.japc.value.MapParameterValue) # Do a quick and dirty conversion to Python dict() d = dict() for n in val.getNames(): diff --git a/pyjapc/_japc.pyi b/pyjapc/_japc.pyi deleted file mode 100644 index 48c4ec585985f075c48e72fc9c847f5c4af3682c..0000000000000000000000000000000000000000 --- a/pyjapc/_japc.pyi +++ /dev/null @@ -1,46 +0,0 @@ -# Stubs for pyjapc (Python 3) -# -# NOTE: This dynamically typed stub was automatically generated by stubgen. -# -# NOTE: This was manually edited after stubgen generation. - -from typing import Any, Optional, Dict, List, Union, Type, Callable -import logging -import datetime -import numpy as np - -DataFilter = Dict[str, Any] -Selector = str - -class PyJapc: - log: logging.Logger = ... - def __init__(self, selector: Selector = ..., incaAcceleratorName: Optional[str] = ..., noSet: bool = ..., timeZone: Optional[Union[str, datetime.tzinfo]] = ..., logLevel: Optional[int] = ...) -> None: ... - def __del__(self) -> None: ... - @staticmethod - def enableInThisThread() -> None: ... - def setSelector(self, timingSelector: Selector, dataFilter: Optional[DataFilter] = ...) -> None: ... - def getSelector(self) -> Optional[Selector]: ... - def setDataFilter(self, dataFilter: DataFilter) -> None: ... - def getDataFilter(self) -> DataFilter: ... - def getUsers(self, machine: str) -> List[str]: ... - def rbacLogin(self, username: Optional[str] = ..., password: Optional[str] = ..., loginDialog: bool = ..., readEnv: bool = ...) -> None: ... - def rbacLogout(self) -> None: ... - def rbacGetToken(self) -> str: ... - def rbacGetSerializedToken(self) -> str: ... - def getParam(self, parameterName: Union[str, List[str]], getHeader: bool = ..., noPyConversion: bool = ..., unixtime: bool = ..., onValueReceived: Optional[Callable[[Any], None]] = ..., onException: Optional[Callable[[str, str, Any], None]] = ..., timingSelectorOverride: Optional[Selector] = ..., dataFilterOverride: Optional[DataFilter] = ...) -> Any: ... - def setParam(self, parameterName: str, parameterValue: Union[bool, int, float, str, np.ndarray], checkDims: bool = ..., dtype: Optional[Union[str, Type]] = ..., timingSelectorOverride: Optional[Selector] = ..., dataFilterOverride: Optional[DataFilter] = ...) -> None: ... - def getParamInfo(self, parameterName: str, noPyConversion: bool = ...): ... - def subscribeParam(self, parameterName: str, onValueReceived: Optional[Callable[[str, Any, Optional[Dict[str, Any]]], None]] = ..., onException: Optional[Callable[[str, str, Any], None]] = ..., getHeader: bool = ..., noPyConversion: bool = ..., unixtime: bool = ..., timingSelectorOverride: Optional[Selector] = ..., dataFilterOverride: Optional[DataFilter] = ...): ... - def stopSubscriptions(self, parameterName: Optional[str] = ..., selector: Optional[Selector] = ...) -> None: ... - def clearSubscriptions(self, parameterName: Optional[str] = ..., selector: Optional[Selector] = ...) -> None: ... - def startSubscriptions(self, parameterName: Optional[str] = ..., selector: Optional[Selector] = ...) -> None: ... - - def getNextParamValue( - self, - parameterName: str, *, - getHeader: bool = ..., - timingSelectorOverride: Optional[Selector] = ..., - dataFilterOverride: Optional[DataFilter] = ..., - n_values: Optional[int] = ..., - timeout: float, - ) -> Any: ... diff --git a/pyjapc/_types.pyi b/pyjapc/_types.pyi new file mode 100644 index 0000000000000000000000000000000000000000..5618ff290cae6d8fa3af53a005f3a63baeab3f87 --- /dev/null +++ b/pyjapc/_types.pyi @@ -0,0 +1,44 @@ +import typing + +import cern +import java + +import datetime + + +# Input types +JString = typing.Union[java.lang.String, str] + +DataFilter = typing.Dict[str, typing.Any] +Selector = str +Timezone = typing.Union[str, datetime.tzinfo] + +LogLevelNames = typing.Union[ + typing.Literal['CRITICAL'], + typing.Literal['FATAL'], + typing.Literal['ERROR'], + typing.Literal['WARN'], + typing.Literal['WARNING'], + typing.Literal['INFO'], + typing.Literal['DEBUG'], + typing.Literal['NOTSET'], +] + +LogLevel = typing.Union[int, LogLevelNames, str] + +# Note that we don't use Sequence[str] here as this includes str itself. +SequenceOfStrings = typing.Union[ + typing.List[str], + typing.Tuple[str, ...], +] + +# Common JAPC (Java) types +SubscriptionTypes = typing.Union[ + cern.japc.core.SubscriptionHandle, + cern.japc.core.group.GroupSubscriptionHandle, +] +ParameterTypes = typing.Union[ + cern.japc.core.Parameter, + cern.japc.core.group.ParameterGroup, +] + diff --git a/pyjapc/tests/test_pyjapc.py b/pyjapc/tests/test_pyjapc.py index c3e269f41e7f8dc7b0ee45a398de92f19dd8a016..39a9274f799a390324811e51f3dbea8a66c4db2e 100644 --- a/pyjapc/tests/test_pyjapc.py +++ b/pyjapc/tests/test_pyjapc.py @@ -409,12 +409,16 @@ def test_long_array_to_py_conversion(japc): assert isinstance(r, np.ndarray) assert r.dtype.type == np.int64 - # JPype currently converts longs to np.longlong. These cannot be converted by - # _getJavaValue, and therefore results in error when doing a set of a result - # from get. Confirm that this behaviour is addressed, and that JPype still - # exhibits this behaviour (fixed in https://github.com/jpype-project/jpype/pull/1039). + if jp.__version__ <= "1.3.0": + # JPype currently converts longs to np.longlong. These cannot be converted by + # _getJavaValue, and therefore results in error when doing a set of a result + # from get. Confirm that this behaviour is addressed, and that JPype still + # exhibits this behaviour (fixed in https://github.com/jpype-project/jpype/pull/1039). + expected_type = np.longlong + else: + expected_type = np.int64 arr = np.array(jp.JArray(jp.JLong)(arr)) - assert arr.dtype.type == np.longlong + assert arr.dtype.type == expected_type def test_float_array_to_py_conversion(japc):