diff --git a/dhi/util.py b/dhi/util.py index 6a3a2b961ee86cc8c13b56aa5f92c3f65d07a402..5a92a7652c1b4822137560a08cd0a2c46f5e078b 100644 --- a/dhi/util.py +++ b/dhi/util.py @@ -12,8 +12,11 @@ import shutil import itertools import array import contextlib +import tempfile +import operator import logging +import six # modules and objects from lazy imports _plt = None @@ -113,6 +116,46 @@ def to_root_latex(s): return s +shell_colors = { + "default": 39, + "black": 30, + "red": 31, + "green": 32, + "yellow": 33, + "blue": 34, + "magenta": 35, + "cyan": 36, + "light_gray": 37, + "dark_gray": 90, + "light_red": 91, + "light_green": 92, + "light_yellow": 93, + "light_blue": 94, + "light_magenta": 95, + "light_cyan": 96, + "white": 97, +} + + +def colored(msg, color=None, force=False): + """ + Return the colored version of a string *msg*. Unless *force* is *True*, the *msg* string is + returned unchanged in case the output is not a tty. Simplified from law.util.colored. + """ + if not force: + try: + tty = os.isatty(sys.stdout.fileno()) + except: + tty = False + + if not tty: + return msg + + color = shell_colors.get(color, shell_colors["default"]) + + return "\033[{}m{}\033[0m".format(color, msg) + + def linspace(start, stop, steps, precision=7): """ Same as np.linspace with *start*, *stop* and *steps* being directly forwarded but the generated @@ -178,11 +221,12 @@ def minimize_1d(objective, bounds, start=None, niter=10, **kwargs): def create_tgraph(n, *args, **kwargs): - """ create_tgraph(n, *args, pad=None) + """create_tgraph(n, *args, pad=None, insert=None) Creates a ROOT graph with *n* points, where the type is *TGraph* for two, *TGraphErrors* for 4 and *TGraphAsymmErrors* for six *args*. Each argument is converted to a python array with typecode "f". When *pad* is *True*, the graph is padded by one additional point on each side - with the same edge value. + with the same edge value. When *insert* is given, it should be a list of tuples with values + ``(index, values...)`` denoting the index, coordinates and errors of points to be inserted. """ ROOT = import_ROOT() @@ -202,10 +246,21 @@ def create_tgraph(n, *args, **kwargs): a = n * list(a) _args.append(list(a)) - # apply edge padding when requested - if kwargs.get("pad"): - n += 2 - _args = [(a[:1] + a + a[-1:]) for a in _args] + # apply edge padding when requested with a configurable width + pad = kwargs.get("pad") + if pad: + w = 1 if not isinstance(pad, int) else int(pad) + n += 2 * w + _args = [(w * a[:1] + a + w * a[-1:]) for a in _args] + + # insert custom points + insert = kwargs.get("insert") + if insert: + for values in insert: + idx, values = values[0], values[1:] + for i, v in enumerate(values): + _args[i].insert(idx, v) + n += 1 if n == 0: return cls(n) @@ -336,15 +391,184 @@ def poisson_asym_errors(v): return err_up, err_down +def unique_recarray(a, cols=None, sort=True, test_metric=None): + import numpy as np + + metric, test_fn = test_metric or (None, None) + + # use all columns by default, except for the optional test metric + if not cols: + cols = list(a.dtype.names) + if metric and metric in cols: + cols.remove(metric) + else: + cols = list(cols) + + # get the indices of unique entries and sort them + indices = np.unique(a[cols], return_index=True)[1] + + # by default, indices are ordered such that the columns used to identify duplicates are sorted + # so when sort is True, keep it that way, and otherwise sort indices to preserve the old order + if not sort: + indices = sorted(indices) + + # create the unique array + b = np.array(a[indices]) + + # perform a check to see if removed values differ in a certain metric + if metric: + removed_indices = set(range(a.shape[0])) - set(indices) + for i in removed_indices: + # get the removed metric value + removed_metric = float(a[i][metric]) + # get the kept metric value + j = six.moves.reduce(operator.and_, [b[c] == v for c, v in zip(cols, a[i][cols])]) + j = np.argwhere(j).flatten()[0] + kept_metric = float(b[j][metric]) + # call test_fn except when both values are nan + both_nan = np.isnan(removed_metric) and np.isnan(kept_metric) + if not both_nan and not test_fn(kept_metric, removed_metric): + raise Exception("duplicate entries identified by columns {} with '{}' values of {} " + "(kept) and {} (removed at row {}) differ".format( + cols, metric, kept_metric, removed_metric, i)) + + return b + + +class TFileCache(object): + def __init__(self, logger=None): + super(TFileCache, self).__init__() + + self.logger = logger or logging.getLogger( + "{}_{}".format(self.__class__.__name__, hex(id(self))) + ) + + # cache of files opened for reading + # abs_path -> {tfile: TFile} + self._r_cache = {} + + # cache of files opened for writing + # abs_path -> {tmp_path: str, tfile: TFile, objects: [(tobj, towner, name), ...]} + self._w_cache = {} + + def __enter__(self): + return self + + def __exit__(self, err_type, err_value, traceback): + self.finalize(skip_write=err_type is not None) + + def _clear(self): + self._r_cache.clear() + self._w_cache.clear() + + def open_tfile(self, path, mode): + ROOT = import_ROOT() + + abs_path = real_path(path) + + if mode == "READ": + if abs_path not in self._r_cache: + # just open the file and cache it + tfile = ROOT.TFile(abs_path, mode) + self._r_cache[abs_path] = {"tfile": tfile} + + self.logger.debug("opened tfile {} with mode {}".format(abs_path, mode)) + + return self._r_cache[abs_path]["tfile"] + + else: + if abs_path not in self._w_cache: + # determine a temporary location + suffix = "_" + os.path.basename(abs_path) + tmp_path = tempfile.mkstemp(suffix=suffix)[1] + if os.path.exists(tmp_path): + os.remove(tmp_path) + + # copy the file when existing + if os.path.exists(abs_path): + shutil.copy2(abs_path, tmp_path) + + # open the file and cache it + tfile = ROOT.TFile(tmp_path, mode) + self._w_cache[abs_path] = {"tmp_path": tmp_path, "tfile": tfile, "objects": []} + + self.logger.debug( + "opened tfile {} with mode {} in temporary location {}".format( + abs_path, mode, tmp_path + ) + ) + + return self._w_cache[abs_path]["tfile"] + + def write_tobj(self, path, tobj, towner=None, name=None): + ROOT = import_ROOT() + + if isinstance(path, ROOT.TFile): + # lookup the cache entry by the tfile reference + for data in self._w_cache.values(): + if data["tfile"] == path: + data["objects"].append((tobj, towner, name)) + break + else: + raise Exception("cannot write object {} unknown TFile {}".format(tobj, path)) + + else: + abs_path = real_path(path) + if abs_path not in self._w_cache: + raise Exception("cannot write object {} into unopened file {}".format(tobj, path)) + + self._w_cache[abs_path]["objects"].append((tobj, towner, name)) + + def finalize(self, skip_write=False): + if self._r_cache: + # close files opened for reading + for abs_path, data in self._r_cache.items(): + if data["tfile"] and data["tfile"].IsOpen(): + data["tfile"].Close() + self.logger.debug( + "closed {} cached file(s) opened for reading".format(len(self._r_cache)) + ) + + if self._w_cache: + # close files opened for reading, write objects and move to actual location + ROOT = import_ROOT() + ignore_level_orig = ROOT.gROOT.ProcessLine("gErrorIgnoreLevel;") + ROOT.gROOT.ProcessLine("gErrorIgnoreLevel = kFatal;") + + for abs_path, data in self._w_cache.items(): + if data["tfile"] and data["tfile"].IsOpen(): + if not skip_write: + data["tfile"].cd() + for tobj, towner, name in data["objects"]: + if towner: + towner.cd() + args = (name,) if name else () + tobj.Write(*args) + + data["tfile"].Close() + + if not skip_write: + shutil.move(data["tmp_path"], abs_path) + self.logger.debug( + "moving back temporary file {} to {}".format(data["tmp_path"], abs_path) + ) + + self.logger.debug( + "closed {} cached file(s) opened for writing".format(len(self._w_cache)) + ) + ROOT.gROOT.ProcessLine("gErrorIgnoreLevel = {};".format(ignore_level_orig)) + + # clear + self._clear() + + class ROOTColorGetter(object): def __init__(self, **cache): super(ROOTColorGetter, self).__init__() self.cache = cache or {} - def __getattr__(self, attr): - ROOT = import_ROOT() - + def _get_color(self, attr): if attr not in self.cache: self.cache[attr] = self.create_color(attr) elif not isinstance(self.cache[attr], int): @@ -352,8 +576,14 @@ class ROOTColorGetter(object): return self.cache[attr] + def __call__(self, *args, **kwargs): + return self._get_color(*args, **kwargs) + + def __getattr__(self, attr): + return self._get_color(attr) + def __getitem__(self, key): - return getattr(self, key) + return self._get_color(key) @classmethod def create_color(cls, obj): @@ -362,11 +592,13 @@ class ROOTColorGetter(object): if isinstance(obj, int): return obj elif isinstance(obj, str): - return getattr(ROOT, "k" + obj.capitalize()) + c = getattr(ROOT, "k" + obj.capitalize(), None) + if c is not None: + return c elif isinstance(obj, tuple) and len(obj) in [2, 3, 4]: c = ROOT.TColor.GetColor(*obj[:3]) if len(obj) >= 3 else obj[0] if len(obj) in [2, 4]: c = ROOT.TColor.GetColorTransparent(c, obj[-1]) return c - else: - raise AttributeError("cannot interpret '{}' as color".format(obj)) + + raise AttributeError("cannot interpret '{}' as color".format(obj))