From 8b8edd5a2e1ecf251b0d3d1c2c9c0add7e2bbfe9 Mon Sep 17 00:00:00 2001
From: Marcel R <github.riga@icloud.com>
Date: Wed, 31 Mar 2021 20:13:55 +0200
Subject: [PATCH] Revert changes in dhi/util.py.

---
 dhi/util.py | 258 +++++++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 245 insertions(+), 13 deletions(-)

diff --git a/dhi/util.py b/dhi/util.py
index 6a3a2b96..5a92a765 100644
--- a/dhi/util.py
+++ b/dhi/util.py
@@ -12,8 +12,11 @@ import shutil
 import itertools
 import array
 import contextlib
+import tempfile
+import operator
 import logging
 
+import six
 
 # modules and objects from lazy imports
 _plt = None
@@ -113,6 +116,46 @@ def to_root_latex(s):
     return s
 
 
+shell_colors = {
+    "default": 39,
+    "black": 30,
+    "red": 31,
+    "green": 32,
+    "yellow": 33,
+    "blue": 34,
+    "magenta": 35,
+    "cyan": 36,
+    "light_gray": 37,
+    "dark_gray": 90,
+    "light_red": 91,
+    "light_green": 92,
+    "light_yellow": 93,
+    "light_blue": 94,
+    "light_magenta": 95,
+    "light_cyan": 96,
+    "white": 97,
+}
+
+
+def colored(msg, color=None, force=False):
+    """
+    Return the colored version of a string *msg*. Unless *force* is *True*, the *msg* string is
+    returned unchanged in case the output is not a tty. Simplified from law.util.colored.
+    """
+    if not force:
+        try:
+            tty = os.isatty(sys.stdout.fileno())
+        except:
+            tty = False
+
+        if not tty:
+            return msg
+
+    color = shell_colors.get(color, shell_colors["default"])
+
+    return "\033[{}m{}\033[0m".format(color, msg)
+
+
 def linspace(start, stop, steps, precision=7):
     """
     Same as np.linspace with *start*, *stop* and *steps* being directly forwarded but the generated
@@ -178,11 +221,12 @@ def minimize_1d(objective, bounds, start=None, niter=10, **kwargs):
 
 
 def create_tgraph(n, *args, **kwargs):
-    """ create_tgraph(n, *args, pad=None)
+    """create_tgraph(n, *args, pad=None, insert=None)
     Creates a ROOT graph with *n* points, where the type is *TGraph* for two, *TGraphErrors* for
     4 and *TGraphAsymmErrors* for six *args*. Each argument is converted to a python array with
     typecode "f". When *pad* is *True*, the graph is padded by one additional point on each side
-    with the same edge value.
+    with the same edge value. When *insert* is given, it should be a list of tuples with values
+    ``(index, values...)`` denoting the index, coordinates and errors of points to be inserted.
     """
     ROOT = import_ROOT()
 
@@ -202,10 +246,21 @@ def create_tgraph(n, *args, **kwargs):
             a = n * list(a)
         _args.append(list(a))
 
-    # apply edge padding when requested
-    if kwargs.get("pad"):
-        n += 2
-        _args = [(a[:1] + a + a[-1:]) for a in _args]
+    # apply edge padding when requested with a configurable width
+    pad = kwargs.get("pad")
+    if pad:
+        w = 1 if not isinstance(pad, int) else int(pad)
+        n += 2 * w
+        _args = [(w * a[:1] + a + w * a[-1:]) for a in _args]
+
+    # insert custom points
+    insert = kwargs.get("insert")
+    if insert:
+        for values in insert:
+            idx, values = values[0], values[1:]
+            for i, v in enumerate(values):
+                _args[i].insert(idx, v)
+            n += 1
 
     if n == 0:
         return cls(n)
@@ -336,15 +391,184 @@ def poisson_asym_errors(v):
     return err_up, err_down
 
 
+def unique_recarray(a, cols=None, sort=True, test_metric=None):
+    import numpy as np
+
+    metric, test_fn = test_metric or (None, None)
+
+    # use all columns by default, except for the optional test metric
+    if not cols:
+        cols = list(a.dtype.names)
+        if metric and metric in cols:
+            cols.remove(metric)
+    else:
+        cols = list(cols)
+
+    # get the indices of unique entries and sort them
+    indices = np.unique(a[cols], return_index=True)[1]
+
+    # by default, indices are ordered such that the columns used to identify duplicates are sorted
+    # so when sort is True, keep it that way, and otherwise sort indices to preserve the old order
+    if not sort:
+        indices = sorted(indices)
+
+    # create the unique array
+    b = np.array(a[indices])
+
+    # perform a check to see if removed values differ in a certain metric
+    if metric:
+        removed_indices = set(range(a.shape[0])) - set(indices)
+        for i in removed_indices:
+            # get the removed metric value
+            removed_metric = float(a[i][metric])
+            # get the kept metric value
+            j = six.moves.reduce(operator.and_, [b[c] == v for c, v in zip(cols, a[i][cols])])
+            j = np.argwhere(j).flatten()[0]
+            kept_metric = float(b[j][metric])
+            # call test_fn except when both values are nan
+            both_nan = np.isnan(removed_metric) and np.isnan(kept_metric)
+            if not both_nan and not test_fn(kept_metric, removed_metric):
+                raise Exception("duplicate entries identified by columns {} with '{}' values of {} "
+                    "(kept) and {} (removed at row {}) differ".format(
+                        cols, metric, kept_metric, removed_metric, i))
+
+    return b
+
+
+class TFileCache(object):
+    def __init__(self, logger=None):
+        super(TFileCache, self).__init__()
+
+        self.logger = logger or logging.getLogger(
+            "{}_{}".format(self.__class__.__name__, hex(id(self)))
+        )
+
+        # cache of files opened for reading
+        # abs_path -> {tfile: TFile}
+        self._r_cache = {}
+
+        # cache of files opened for writing
+        # abs_path -> {tmp_path: str, tfile: TFile, objects: [(tobj, towner, name), ...]}
+        self._w_cache = {}
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, err_type, err_value, traceback):
+        self.finalize(skip_write=err_type is not None)
+
+    def _clear(self):
+        self._r_cache.clear()
+        self._w_cache.clear()
+
+    def open_tfile(self, path, mode):
+        ROOT = import_ROOT()
+
+        abs_path = real_path(path)
+
+        if mode == "READ":
+            if abs_path not in self._r_cache:
+                # just open the file and cache it
+                tfile = ROOT.TFile(abs_path, mode)
+                self._r_cache[abs_path] = {"tfile": tfile}
+
+                self.logger.debug("opened tfile {} with mode {}".format(abs_path, mode))
+
+            return self._r_cache[abs_path]["tfile"]
+
+        else:
+            if abs_path not in self._w_cache:
+                # determine a temporary location
+                suffix = "_" + os.path.basename(abs_path)
+                tmp_path = tempfile.mkstemp(suffix=suffix)[1]
+                if os.path.exists(tmp_path):
+                    os.remove(tmp_path)
+
+                # copy the file when existing
+                if os.path.exists(abs_path):
+                    shutil.copy2(abs_path, tmp_path)
+
+                # open the file and cache it
+                tfile = ROOT.TFile(tmp_path, mode)
+                self._w_cache[abs_path] = {"tmp_path": tmp_path, "tfile": tfile, "objects": []}
+
+                self.logger.debug(
+                    "opened tfile {} with mode {} in temporary location {}".format(
+                        abs_path, mode, tmp_path
+                    )
+                )
+
+            return self._w_cache[abs_path]["tfile"]
+
+    def write_tobj(self, path, tobj, towner=None, name=None):
+        ROOT = import_ROOT()
+
+        if isinstance(path, ROOT.TFile):
+            # lookup the cache entry by the tfile reference
+            for data in self._w_cache.values():
+                if data["tfile"] == path:
+                    data["objects"].append((tobj, towner, name))
+                    break
+            else:
+                raise Exception("cannot write object {} unknown TFile {}".format(tobj, path))
+
+        else:
+            abs_path = real_path(path)
+            if abs_path not in self._w_cache:
+                raise Exception("cannot write object {} into unopened file {}".format(tobj, path))
+
+            self._w_cache[abs_path]["objects"].append((tobj, towner, name))
+
+    def finalize(self, skip_write=False):
+        if self._r_cache:
+            # close files opened for reading
+            for abs_path, data in self._r_cache.items():
+                if data["tfile"] and data["tfile"].IsOpen():
+                    data["tfile"].Close()
+            self.logger.debug(
+                "closed {} cached file(s) opened for reading".format(len(self._r_cache))
+            )
+
+        if self._w_cache:
+            # close files opened for reading, write objects and move to actual location
+            ROOT = import_ROOT()
+            ignore_level_orig = ROOT.gROOT.ProcessLine("gErrorIgnoreLevel;")
+            ROOT.gROOT.ProcessLine("gErrorIgnoreLevel = kFatal;")
+
+            for abs_path, data in self._w_cache.items():
+                if data["tfile"] and data["tfile"].IsOpen():
+                    if not skip_write:
+                        data["tfile"].cd()
+                        for tobj, towner, name in data["objects"]:
+                            if towner:
+                                towner.cd()
+                            args = (name,) if name else ()
+                            tobj.Write(*args)
+
+                    data["tfile"].Close()
+
+                    if not skip_write:
+                        shutil.move(data["tmp_path"], abs_path)
+                        self.logger.debug(
+                            "moving back temporary file {} to {}".format(data["tmp_path"], abs_path)
+                        )
+
+            self.logger.debug(
+                "closed {} cached file(s) opened for writing".format(len(self._w_cache))
+            )
+            ROOT.gROOT.ProcessLine("gErrorIgnoreLevel = {};".format(ignore_level_orig))
+
+        # clear
+        self._clear()
+
+
 class ROOTColorGetter(object):
     def __init__(self, **cache):
         super(ROOTColorGetter, self).__init__()
 
         self.cache = cache or {}
 
-    def __getattr__(self, attr):
-        ROOT = import_ROOT()
-
+    def _get_color(self, attr):
         if attr not in self.cache:
             self.cache[attr] = self.create_color(attr)
         elif not isinstance(self.cache[attr], int):
@@ -352,8 +576,14 @@ class ROOTColorGetter(object):
 
         return self.cache[attr]
 
+    def __call__(self, *args, **kwargs):
+        return self._get_color(*args, **kwargs)
+
+    def __getattr__(self, attr):
+        return self._get_color(attr)
+
     def __getitem__(self, key):
-        return getattr(self, key)
+        return self._get_color(key)
 
     @classmethod
     def create_color(cls, obj):
@@ -362,11 +592,13 @@ class ROOTColorGetter(object):
         if isinstance(obj, int):
             return obj
         elif isinstance(obj, str):
-            return getattr(ROOT, "k" + obj.capitalize())
+            c = getattr(ROOT, "k" + obj.capitalize(), None)
+            if c is not None:
+                return c
         elif isinstance(obj, tuple) and len(obj) in [2, 3, 4]:
             c = ROOT.TColor.GetColor(*obj[:3]) if len(obj) >= 3 else obj[0]
             if len(obj) in [2, 4]:
                 c = ROOT.TColor.GetColorTransparent(c, obj[-1])
             return c
-        else:
-            raise AttributeError("cannot interpret '{}' as color".format(obj))
+
+        raise AttributeError("cannot interpret '{}' as color".format(obj))
-- 
GitLab