Skip to content
Snippets Groups Projects
Commit 86481c0e authored by Michael Schenk's avatar Michael Schenk
Browse files

Merge branch 'develop' into 'master'

Fix EnvSpec check in envs.Metadata

See merge request !5
parents 647e8f89 7e410c48
No related branches found
No related tags found
1 merge request!5Fix EnvSpec check in envs.Metadata
Pipeline #10057576 passed
......@@ -17,11 +17,6 @@ repos:
args: ["--unsafe"]
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://gitlab.cern.ch/pre-commit-hook-mirrors/astral-sh/ruff-pre-commit
rev: v0.6.3
hooks:
- id: ruff
args: ["--fix", "--exit-non-zero-on-fix"]
- repo: https://gitlab.cern.ch/pre-commit-hook-mirrors/psf/black-pre-commit-mirror
rev: 24.8.0
hooks:
......
......@@ -15,7 +15,7 @@ from cernml.coi.registration import EnvSpec
try:
import importlib.metadata as importlib_metadata
except ImportError:
import importlib_metadata # type: ignore
import importlib_metadata # type: ignore[no-redef]
if t.TYPE_CHECKING:
# pylint: disable = ungrouped-imports, unused-import, import-error
......@@ -26,12 +26,12 @@ if t.TYPE_CHECKING:
LOG = getLogger(__name__)
BUILTIN_ENVS = [
BUILTIN_ENVS: t.List[str] = [
# "cern_awake_env.machine",
# "cern_awake_env.simulation",
# "cern_isolde_offline_env",
# "cern_leir_transfer_line_env",
#"cern_sps_splitter_opt_env",
# "cern_sps_splitter_opt_env",
# "cern_sps_tune_env",
# "cern_sps_zs_alignment_env",
# "linac3_lebt_tuning",
......@@ -50,10 +50,17 @@ class Metadata:
self, metadata_holder: t.Union[coi.Problem, t.Type[coi.Problem], EnvSpec]
) -> None:
self._metadata = dict(coi.Problem.metadata)
if isinstance(metadata_holder, EnvSpec):
self._metadata.update(metadata_holder.entry_point.metadata)
else:
self._metadata.update(metadata_holder.metadata)
metadata = getattr(metadata_holder, "metadata", None)
if metadata is None:
# No metadata available, assume it's an EnvSpec and
# `entry_point` is a `Problem` subclass.
metadata_holder = t.cast(
t.Type[coi.Problem], getattr(metadata_holder, "entry_point", None)
)
metadata = getattr(metadata_holder, "metadata", None)
if metadata is None:
raise TypeError("cannot find metadata: " + repr(metadata_holder))
self._metadata.update(metadata)
@property
def cancellable(self) -> bool:
......@@ -82,7 +89,11 @@ def iter_env_names(
for spec in coi.registry.all():
if machine and Metadata(spec).machine != machine:
continue
if superclass and not issubclass(spec.entry_point, superclass):
if (
superclass
and isinstance(spec.entry_point, type)
and not issubclass(spec.entry_point, superclass)
):
continue
yield spec.id
......@@ -114,7 +125,7 @@ def make_env_by_name(
kwargs["japc"] = make_japc()
if metadata.cancellable:
kwargs["cancellation_token"] = token
return spec.make(**kwargs)
return t.cast(coi.Problem, spec.make(**kwargs))
def get_custom_optimizers(spec: EnvSpec) -> t.Mapping[str, "Optimizer"]:
......@@ -123,8 +134,8 @@ def get_custom_optimizers(spec: EnvSpec) -> t.Mapping[str, "Optimizer"]:
This takes all endpoints into account: both the interface on the
environment itself and entry-points.
"""
optimizers = {}
if issubclass(spec.entry_point, coi.CustomOptimizerProvider):
optimizers: t.Dict[str, "Optimizer"] = {}
if coi.is_custom_optimizer_provider_class(spec.entry_point):
optimizers.update(spec.entry_point.get_optimizers())
entry_points = _get_entry_points(group="cernml.custom_optimizers", name=spec.id)
duplicate_names = set()
......@@ -166,10 +177,15 @@ def get_custom_policies(
env itself. This is because the env is not instantiated yet at the
time when this function runs.
"""
policies = {}
policies: t.Dict[str, t.Optional[coi.CustomPolicyProvider]] = {}
env_class = spec.entry_point
if issubclass(env_class, coi.CustomPolicyProvider):
policies.update(dict.fromkeys(env_class.get_policy_names(), None))
if isinstance(env_class, type) and issubclass(env_class, coi.CustomPolicyProvider):
policies.update(
dict.fromkeys(
t.cast(t.Type[coi.CustomPolicyProvider], env_class).get_policy_names(),
None,
)
)
entry_points = _get_entry_points(group="cernml.custom_policies", name=spec.id)
duplicate_names = set()
for ep in entry_points:
......@@ -206,4 +222,4 @@ def _get_entry_points(
if hasattr(all_entry_points, "select"):
return tuple(all_entry_points.select(group=group, name=name))
# Deprecated API:
return tuple(ep for ep in all_entry_points.get(group, ()) if ep.name == name)
return tuple(ep for ep in all_entry_points.get(group, ()) if ep.name == name) # type: ignore[attr-defined]
......@@ -77,7 +77,7 @@ def exception_dialog(
text,
parent=parent,
buttons=QtWidgets.QMessageBox.Close,
keywords=tuple(exc.exc_type.__name__ for exc in _iter_exc_chain(exception)),
keywords=_gather_keywords(exception),
)
dialog.setInformativeText("".join(exception.format_exception_only()))
dialog.setDetailedText("".join(exception.format()))
......@@ -190,6 +190,18 @@ class _TracebackHighlighter(QtGui.QSyntaxHighlighter):
self.setFormat(match.capturedStart(0), match.capturedLength(0), Qt.red)
def _gather_keywords(exception: t.Optional[TracebackException]) -> tuple[str, ...]:
res = []
for exc in _iter_exc_chain(exception):
name = getattr(exc, "exc_type_str", None)
if name is None:
name = getattr(exc.exc_type, "__name__", None)
if name is None:
continue
res.append(name)
return tuple(res)
def _iter_exc_chain(
exc: t.Optional[TracebackException],
) -> t.Iterator[TracebackException]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment