Commit 4afd8357 authored by Chris Burr's avatar Chris Burr
Browse files

Merge branch 'fix_tests' into 'master'

Fix test_b2oc

Closes #1

See merge request !3
parents d146f3e5 a418a955
Pipeline #3030361 canceled with stages
in 1 minute and 13 seconds
......@@ -41,7 +41,7 @@ testing =
console_scripts =
apd-cache = apd.command:cmd_cache_ap_info
apd-list-pfns = apd.command:cmd_list_pfns
apd-list-samples = apd.command:cmd_list_samples
###############################################################################
# Linting
......
......@@ -9,57 +9,166 @@
# or submit itself to any jurisdiction. #
###############################################################################
# ("Charm", "D2HH", use_both_polarities=True)
# Calling "datasets" returns a list of PFNs corrosponding to the requested dataset
# Keyword arguments are interpreted as tags
# Combining all of the tags must give a unique dataset or an error is raised
# To get PFNs from multiple datasets pass lists as the arguments this is
# equivalent. i.e.
# datasets(eventtype="27163904", datatype=[2017, 2018], polarity=["magup", "magdown"])
# is the same as:
# datasets(eventtype="27163904", datatype=2017, polarity="magup") +
# datasets(eventtype="27163904", datatype=2017, polarity="magdown") +
# datasets(eventtype="27163904", datatype=2018, polarity="magup") +
# datasets(eventtype="27163904", datatype=2018, polarity="magdown")
import itertools
import logging
import os
from apd.ap_info import SampleCollection, fetch_ap_info, load_ap_info
from apd.ap_info import (
SampleCollection,
fetch_ap_info,
iterable,
load_ap_info,
safe_casefold,
)
logger = logging.getLogger("apd")
def _validate_tags(tags, default_tags=None):
"""Method that checks the dictionary of tag names, values that should be used
to filter the data accordingly.
- Special cases ate handled: tags "name" and "version" as well as "data" and "mc"
(which are converted to a "config" value).
- tag values cannot be None
- tag values cannot be of type bytes
- int tag values are converted to string
"""
# Merging the default tags with the ones passed
effective_tags = tags
if default_tags:
for t, v in default_tags.items():
if t not in effective_tags:
effective_tags[t] = v
# Final dict that will be returned
cleaned = {}
# name and version are special tags in our case, we check their validity
if "name" in effective_tags:
raise Exception("name is not supported on AnalysisData objects")
version = effective_tags.get("version", None)
if version and iterable(version):
raise Exception("version argument doesn't support iterables")
# Special handling for the data and mc tags to avoid having to
# use the config tag
# The config tag is set according to the following table:
#
# | mc\data | True | False | None |
# |:-------:|:----------:|:----------:|:--------------:|
# | True | ValueError | mc | mc |
# | False | lhcb | ValueError | lhcb |
# | None | lhcb | mc | config not set |
dataval = effective_tags.get("data", None)
mcval = effective_tags.get("mc", None)
config = None
# We only set the config if one of the options data or mc was specified
if dataval is None:
# In this case we check whether mc has been specified and use that
if mcval is not None:
if mcval:
config = "mc"
else:
config = "lhcb"
# dataval has been explicitly set to true
elif dataval:
if mcval:
raise ValueError("values of data= and mc= are inconsistent")
config = "lhcb"
# dataval has been explicitly set to false
else:
if mcval is not None and not mcval:
# mcval explicitly set to False in contradiction with dataval
raise ValueError("values of data= and mc= are inconsistent")
config = "mc"
# Check if config was set as well !
if config:
explicit_config = effective_tags.get("config", None)
if explicit_config is not None:
if explicit_config != config:
raise ValueError("cannot specify data or mc as well as config")
cleaned["config"] = config
# Applying other checks
for t, v in effective_tags.items():
# Ignore those as we translated it to config already
if t in ["data", "mc"]:
continue
if v is None:
raise TypeError(f"{t} value is None")
if isinstance(v, bytes):
raise TypeError(f"{t} value is of type {type(v)}")
if isinstance(v, int) and not isinstance(v, bool):
cleaned[t] = str(v)
else:
cleaned[t] = v
return cleaned
def sample_check(samples, tags):
"""Filter the SampleCollection and check that we have the
samples that we expect"""
# Fixing the dict to make sure each item is a list of tuple tag/value
# Fixing the dict to make sure each item is a list
ltags = {}
dimensions = tags.keys()
for tag, value in tags.items():
if not isinstance(value, (list, tuple)):
ltags[tag] = [(tag, value)]
if not iterable(value):
ltags[safe_casefold(tag)] = [safe_casefold(value)]
else:
ltags[tag] = [(tag, v) for v in value]
ltags[safe_casefold(tag)] = [safe_casefold(v) for v in value]
logger.debug("Checking samples for tags: %s", str(ltags))
# Take the cartesian product using itertools
# Cardinal product of all the lists
products = list(itertools.product(*ltags.values()))
# convert the list of list of tuples to a dictionary we can pass to the filter function
# argvals = [{a: b for (a, b) in bin} for bin in products]
argvals = [dict(bin) for bin in products]
hist = {p: 0 for p in products}
errors = []
for a in argvals:
logger.debug("Processing %s", {str(a)})
tmp = samples.filter(**a)
if len(tmp) != 1:
errors.append((a, len(tmp)))
# Iterating on the samples an increasing the count
for stags in samples.itertags():
coordinates = tuple(safe_casefold(stags[d]) for d in dimensions)
try:
hist[coordinates] = hist[coordinates] + 1
except KeyError as ke:
raise KeyError(
f"Encountered sample with tags {str(coordinates)} which does not match filtering criteria {str(dict(ltags))}"
) from ke
# Now checking whether we have one entry per bin
errors = []
for coordinate, sample_count in hist.items():
if sample_count != 1:
logger.debug("Error %d samples for %s", sample_count, {str(coordinate)})
errors.append((dict(zip(dimensions, coordinate)), sample_count))
return errors
class AnalysisData:
""" Class allowing to access the metadata for a specific analysis """
"""Class allowing to access the metadata for a specific analysis.
Default values for the tags to filter the data can be passed as argument to the contructor
(as well as the required working group and analysis names)
e.g. datasets = AnalysisData("b2oc", "b02dkpi", polarity="magdown")
Invoking () returns a list of PFNs corresponding to the requested dataset
Keyword arguments are interpreted as tags
Combining all of the tags must give a unique dataset or an error is raised
To get PFNs from multiple datasets pass lists as the arguments this is
equivalent. i.e.
datasets(eventtype="27163904", datatype=[2017, 2018], polarity=["magup", "magdown"])
is the same as:
datasets(eventtype="27163904", datatype=2017, polarity="magup") +
datasets(eventtype="27163904", datatype=2017, polarity="magdown") +
datasets(eventtype="27163904", datatype=2018, polarity="magup") +
datasets(eventtype="27163904", datatype=2018, polarity="magdown")
"""
def __init__(
self,
......@@ -67,20 +176,17 @@ class AnalysisData:
analysis,
metadata_cache=None,
api_url="https://lbap.app.cern.ch",
check_data=True,
**kwargs,
):
"""Constructor that can either fetch the data from the AP service
or load from the cache.
tags can be specified as keyword arguments to specify the data to be analyzed.
The class will check that the appropriate data is available.
"""
Constructor that configures the can either fetch the data from the AP service or load from a local cache.
Analysis Production tags can be specified as keyword arguments to specify the data to be analyzed.
"""
self.working_group = working_group
self.analysis = analysis
# Tags is a list of tags that can be used to restrict the samples that will be used
self.tags = kwargs
self.default_tags = _validate_tags(kwargs)
# self.samples is a SampleCollection filled in with the values
metadata_cache = metadata_cache or os.environ.get("APD_METADATA_CACHE_DIR", "")
......@@ -95,38 +201,83 @@ class AnalysisData:
logger.debug("Fetching Analysis Production data from %s", api_url)
self.samples = fetch_ap_info(working_group, analysis, None, api_url)
# Filter samples and check that we have what we expect
if check_data and self.tags:
errors = sample_check(self.samples, self.tags)
if len(errors) > 0:
txt = ",".join([f"{c} samples for {str(t)}" for (t, c) in errors])
logger.error("Error loading data: %s", txt)
raise ValueError("Error loading data: " + txt)
def __call__(self, *, version=None, name=None, return_pfns=True, **tags):
def __call__(
self, *, version=None, name=None, return_pfns=True, check_data=True, **tags
):
"""Main method that returns the dataset info.
The normal behaviour is to return the PFNs for the samples but setting
return_pfns to false returns the SampleCollection"""
# Cannot mix data from 2 versions in the same dataset
if not version:
version = self.default_tags.get("version", None)
if iterable(version):
raise Exception("version argument doesn't support iterables")
# Establishing the list of damples to run on
samples = self.samples
if version and name:
# No need to apply other tags, this specifies explicitly a specific dataset
# We return it straight away
logger.debug("Filtering for version/name %s/%s", name, version)
samples = samples.filter("version", version).filter("name", name)
if name and not version:
# This specifies a name across several versions of the AP: i.e. the same files processed
# with different versions of the scripts, I am not sure of the use case
samples = samples.filter("name", name)
if len(samples) != 1:
raise ValueError(f"{len(samples)} matching {name}, should be exactly 1")
for tagname, tagvalue in tags.items():
logger.debug("Filtering for %s = %s", tagname, tagvalue)
samples = samples.filter(**tags)
if name:
if version:
# No need to apply other tags, this specifies explicitly a specific dataset
# We return it straight away
logger.debug("Filtering for version/name %s/%s", name, version)
samples = samples.filter("version", version)
if len(samples) == 0:
raise KeyError(f"No version {version}")
samples = samples.filter("name", name)
if len(samples) == 0:
raise KeyError(f"No name {name}")
else:
# We check whether a version was specified in the default tags
samples = samples.filter("name", name)
if len(samples) != 1:
raise ValueError(
f"{len(samples)} matching {name}, should be exactly 1"
)
else:
# Merge the current tags with the default passed to the constructor
# and check that they are consistent
effective_tags = _validate_tags(tags, self.default_tags)
if version:
effective_tags["version"] = version
for tagname, tagvalue in effective_tags.items():
logger.debug("Filtering for %s = %s", tagname, tagvalue)
# Appying the filters in one go
samples = samples.filter(**effective_tags)
logger.debug("Matched %d samples", len(samples))
# Filter samples and check that we have what we expect
if check_data:
errors = sample_check(samples, effective_tags)
if len(errors) > 0:
error_txt = f"{len(errors)} problem(s) found"
for etags, ecount in errors:
error_txt += f"\n{str(etags)}: {ecount} samples"
if ecount > 0:
error_txt += " e.g. (3 samples printed)"
match_list = [
str(m)
for m in itertools.islice(
samples.filter(**etags).itertags(), 0, 3
)
]
error_txt += "".join(
["\n" + " " * 5 + str(m) for m in match_list]
)
logger.debug("Error loading data: %s", error_txt)
raise ValueError("Error loading data: " + error_txt)
if return_pfns:
return samples.PFNs()
return samples
def __str__(self):
txt = f"AnalysysProductions: {self.working_group} / {self.analysis}\n"
txt += str(self.samples)
return txt
......@@ -11,7 +11,9 @@
#
# Tool to load and interpret information from the AnalysisProductions data endpoint
#
import collections.abc
import json
import logging
import os
from pathlib import Path
......@@ -19,6 +21,22 @@ import requests
import apd.cern_sso
logger = logging.getLogger("apd")
def iterable(arg):
""" Version of Iterable that excludes str """
return isinstance(arg, collections.abc.Iterable) and not isinstance(
arg, (str, bytes)
)
def safe_casefold(a):
""" casefold that can be called on any type, does nothing on non str """
if isinstance(a, str):
return a.casefold()
return a
class APDataDownloader:
def __init__(self, api_url="https://lbap.app.cern.ch"):
......@@ -29,7 +47,7 @@ class APDataDownloader:
return {"Authorization": f"Bearer {self._get_token()}"}
def _get_token(self):
""" Get the API token, authentification with the CERN SSO """
"""Get the API token, authentification with the CERN SSO"""
# Getting the token using apd.cern_sso.
# We have a copy of this module as it is not released on pypi.
# N.B. This requires a kerberos token for an account that belongs to lhcb-general
......@@ -68,7 +86,7 @@ class APDataDownloader:
def fetch_ap_info(
working_group, analysis, loader=None, api_url="https://lbap.app.cern.ch"
):
""" Fetch the API info from the service """
"""Fetch the API info from the service"""
if not loader:
loader = APDataDownloader(api_url)
......@@ -82,7 +100,7 @@ def fetch_ap_info(
def cache_ap_info(
cache_dir, working_group, analysis, loader=None, api_url="https://lbap.app.cern.ch"
):
""" Fetch the AP info and cache it locally """
"""Fetch the AP info and cache it locally"""
cache_dir = Path(cache_dir)
samples = fetch_ap_info(working_group, analysis, loader, api_url)
wgdir = cache_dir / working_group
......@@ -98,12 +116,19 @@ def cache_ap_info(
return samples
def _find_case_insensitive(mydir, filename):
for f in os.listdir(mydir):
if f.casefold() == filename.casefold():
return f
raise FileNotFoundError(f"{filename} in {mydir}")
def load_ap_info(cache_dir, working_group, analysis):
""" Load the API info from a cache file """
"""Load the API info from a cache file"""
cache_dir = Path(cache_dir)
wgdir = cache_dir / working_group
anadir = wgdir / analysis
datafile = wgdir / f"{analysis}.json"
wgdir = cache_dir / _find_case_insensitive(cache_dir, working_group)
anadir = wgdir / _find_case_insensitive(wgdir, analysis)
datafile = wgdir / _find_case_insensitive(wgdir, f"{analysis}.json")
tagsfile = anadir / "tags.json"
with open(datafile) as f:
data = json.load(f)
......@@ -113,7 +138,7 @@ def load_ap_info(cache_dir, working_group, analysis):
def load_ap_info_from_single_file(filename):
""" Load the API info from a cache file """
"""Load the API info from a cache file"""
if not os.path.exists(filename):
raise Exception("Please specify a valid file as metadata cache")
......@@ -148,6 +173,14 @@ class SampleCollection:
tags["name"] = sample["name"]
return tags
def __repr__(self):
return "\n".join(
[
f"{s['name']} {s['version']} | " + str(self._sampleTags(s))
for s in self.info
]
)
def __iter__(self):
for s in self.info:
yield s
......@@ -164,30 +197,33 @@ class SampleCollection:
if (len(args) != 0) and len(args) != 2:
raise ValueError(
"fileter method takes two positional arguments or keyword arguments"
"filter method takes two positional arguments or keyword arguments"
)
def _compare_tag(sample, ftag, fvalue):
"""Utility method than handles specific tags, but not iterables"""
return safe_casefold(self._sampleTags(sample).get(ftag)) == safe_casefold(
fvalue
)
def _filter1(samples, ftag, fvalue):
logger.debug("filtering samples for %s:%s", ftag, fvalue)
if callable(fvalue):
matching = [
sample
for sample in samples
if fvalue(self._sampleTags(sample).get(ftag, None))
if fvalue(safe_casefold(self._sampleTags(sample).get(ftag, None)))
]
elif isinstance(fvalue, (list, tuple)):
elif iterable(fvalue):
# We join the requests matching in an empty SampleCollection
matching = []
for v in fvalue:
matching += [
sample
for sample in samples
if self._sampleTags(sample).get(ftag, None) == v
sample for sample in samples if _compare_tag(sample, ftag, v)
]
else:
matching = [
sample
for sample in samples
if self._sampleTags(sample).get(ftag, None) == fvalue
sample for sample in samples if _compare_tag(sample, ftag, fvalue)
]
return matching
......@@ -199,7 +235,7 @@ class SampleCollection:
return SampleCollection(samples, self.tags)
def PFNs(self):
""" Collects the PFNs """
"""Collects the PFNs"""
pfns = []
for sample in self.info:
for pfnlist in sample["lfns"].values():
......
......@@ -13,6 +13,7 @@
#
import logging
import os
import sys
import click
import click_log
......@@ -24,6 +25,15 @@ logger = logging.getLogger("apd")
click_log.basic_config(logger)
def exception_handler(exception_type, exception, _):
# All your trace are belong to us!
# your format
print("%s: %s" % (exception_type.__name__, exception))
sys.excepthook = exception_handler
@click.command()
@click.argument("cache_directory")
@click.argument("working_group")
......@@ -41,7 +51,7 @@ def cmd_cache_ap_info(cache_directory, working_group, analysis):
@click.argument("analysis")
@click.option(
"--cache_directory",
default=os.environ.get("APD_CACHE", None),
default=os.environ.get("APD_METADATA_CACHE_DIR", None),
help="Specify location of the cached analysis data files",
)
@click.option("--tag", default=None, help="Tag to filter datasets", multiple=True)
......@@ -57,11 +67,26 @@ def cmd_cache_ap_info(cache_directory, working_group, analysis):
@click.option(
"--datatype", default=None, help="datatype to filter the datasets", multiple=True
)
@click.option("--polarity", default=None, help="polarity to filter the datasets")
@click.option(
"--polarity", default=None, help="polarity to filter the datasets", multiple=True
)
@click.option("--name", default=None, help="dataset name")
@click.option("--version", default=None, help="dataset version")
@click_log.simple_verbosity_option(logger)
def cmd_list_pfns(
working_group, analysis, cache_directory, tag, value, eventtype, datatype, polarity
working_group,
analysis,
cache_directory,
tag,
value,
eventtype,
datatype,
polarity,
name,
version,
):
"""List the PFNs for the analysis, matching the tags specified.
This command checks that the arguments are not ambiguous."""
# Dealing with the cache
if not cache_directory:
......@@ -78,7 +103,90 @@ def cmd_list_pfns(
# Loading the data and filtering/displaying
datasets = AnalysisData(working_group, analysis, metadata_cache=cache_directory)
for f in datasets(
tag=tag, value=value, eventtype=eventtype, datatype=datatype, polarity=polarity
):
filter_tags = {}
if name is not None:
filter_tags["name"] = name
if version is not None:
filter_tags["version"] = version
if eventtype != ():
filter_tags["eventtype"] = eventtype
if datatype != ():
filter_tags["datatype"] = datatype
if polarity != ():
filter_tags["polarity"] = polarity
filter_tags |= dict(zip(tag, value))
for f in datasets(**filter_tags):
click.echo(f)
@click.command()
@click.argument("working_group")
@click.argument("analysis")
@click.option(
"--cache_directory",
default=os.environ.get("APD_METADATA_CACHE_DIR", None),
help="Specify location of the cached analysis data files",
)
@click.option("--tag", default=None, help="Tag to filter datasets", multiple=True)
@click.option(
"--value",
default=None,
help="Tag value used if the name is specified",
multiple=True,
)
@click.option(
"--eventtype", default=None, help="eventtype to filter the datasets", multiple=True
)
@click.option(
"--datatype", default=None, help="datatype to filter the datasets", multiple=True
)
@click.option(
"--polarity", default=None, help="polarity to filter the datasets", multiple=True
)
@click.option("--name", default=None, help="dataset name")
@click.option("--version", default=None, help="dataset version")
@click_log.simple_verbosity_option(logger)
def cmd_list_samples(
working_group,
analysis,
cache_directory,
tag,
value,
eventtype,
datatype,
polarity,
name,
version,