#!/bin/env python
###############################################################################
# (c) Copyright 2021-2024 CERN for the benefit of the LHCb Collaboration      #
#                                                                             #
# This software is distributed under the terms of the GNU General Public      #
# Licence version 3 (GPL Version 3), copied verbatim in the file "COPYING".   #
#                                                                             #
# In applying this licence, CERN does not waive the privileges and immunities #
# granted to it by virtue of its status as an Intergovernmental Organization  #
# or submit itself to any jurisdiction.                                       #
###############################################################################
"""
Measure rates from the logfile of a Moore execution
"""
import argparse
import pandas
import logging
import numpy
import os
import re
import uproot


logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO)

LOGGER = logging.getLogger(__name__)


def _rates_from_log(file, input_rate):
    """
    Build a DataFrame with the rates read from a log file
    """
    with open(file) as f:
        LOGGER.info(f"Reading file '{file}'")
        values = re.findall(r"\s*LAZY_AND: Hlt2.*DecisionWithOutput\s*#=[0-9]*\s*Sum=[0-9]*.*", f.read())

    data = {}
    for v in values:

        try:
            name = re.search("Hlt2.*DecisionWithOutput", v).group().replace("DecisionWithOutput", "")
        except:
            raise RuntimeError("Unable to access the line name information from string '{v}'")

        try:
            this_total_events = int(re.search("#=[0-9]*", v).group().replace("#=", ""))
            this_passed = int(re.search("Sum=[0-9]*", v).group().replace("Sum=", ""))
        except:
            raise RuntimeError(f"Unable to access information for line '{name}', extracted from string '{v}'")

        data[name] = dict(total=this_total_events, passed=this_passed)

    df = pandas.DataFrame.from_dict(data, orient="index")

    df.loc[:,"rate"] = df["passed"] * input_rate / df["total"]

    df.sort_values("rate", ascending=False, inplace=True)

    return df


def rates_from_log(file, input_rate, output, nrows):
    """
    Determine the lines rates from a log file
    """
    df = _rates_from_log(file, input_rate)

    pandas.set_option("display.max_colwidth", 100)
    pandas.set_option("display.max_rows", nrows)

    LOGGER.info(f"Results:{os.linesep}{df}")

    if output:
        df.to_markdown(output)
        LOGGER.info(f"Rates have been written in '{output}'")


def compare_rates(before, after, input_rate, output, nrows):
    """
    Compare the rates obtained from two log files, using the first
    entry as a reference
    """
    fdf = _rates_from_log(before, input_rate)
    sdf = _rates_from_log(after, input_rate)

    index = list(sorted(set(list(fdf.index) + list(sdf.index))))

    diff = pandas.DataFrame(index=index, columns=["before", "after", "difference"], dtype=float)
    diff["before"] = fdf["rate"]
    diff["after"] = sdf["rate"]
    diff["difference"] = diff["after"] - diff["before"]
    diff["difference (%)"] = 100.0 * diff["difference"] / diff["before"]

    diff.sort_values("after", ascending=False, inplace=True)
    diff.rename(columns={"before": "rate (before)", "after": "rate (after)"}, inplace=True)

    pandas.set_option("display.max_colwidth", 100)
    pandas.set_option("display.max_rows", nrows)

    LOGGER.info(f"Results:{os.linesep}{diff}")

    if output:
        diff.to_markdown(output)
        LOGGER.info(f"Rates have been written in '{output}'")


def check_histograms(file, regex, nrows):
    """
    Explore the monitoring histograms of HLT2 and do some general checks
    """
    comp = re.compile(regex)

    mean_candidates = "mean candidates"
    max_candidates = "max. candidates"
    possible_duplicated_candidates = "possible duplicated cand."
    overflow_candidates = "overflow multiple candidates"

    df = pandas.DataFrame(columns=[mean_candidates, possible_duplicated_candidates, max_candidates, overflow_candidates])
    with uproot.open(file) as f:
        for k in filter(comp.match, f.keys(filter_classname="TDirectory")):
            directory = f[k]
            n_cand = directory["n_candidates"]
            values_with_overflow, edges_with_overflow = n_cand.to_numpy(flow=True)

            edges = edges_with_overflow[1:-1]
            values = values_with_overflow[1:-1]

            if numpy.any(values > 0):
                centers = 0.5 * (edges[1:] + edges[:-1])
                df.loc[k, mean_candidates] = numpy.sum(values * centers) / numpy.sum(values)
                df.loc[k, max_candidates] = edges[::-1][numpy.argmax(values[::-1] > 0) + 1]
                df.loc[k, possible_duplicated_candidates] = (numpy.sum(values) > 0) and numpy.any(values[0::2] > 0) and numpy.allclose(values[1::2], 0.)
                df.loc[k, overflow_candidates] = values_with_overflow[0] > 0 or values_with_overflow[-1] > 0
            else:
                df.loc[k, :] = numpy.nan

    df.sort_values(max_candidates, ascending=False, inplace=True)

    pandas.set_option("display.max_colwidth", 100)
    pandas.set_option("display.max_rows", nrows)

    LOGGER.info(f"Results:{os.linesep}{df}")


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description=__doc__)

    subparsers = parser.add_subparsers(help="Subcommand to run")

    rates_from_log_p = subparsers.add_parser("rates-from-log", description=rates_from_log.__doc__)
    rates_from_log_p.set_defaults(_function=rates_from_log)
    rates_from_log_p.add_argument("file", type=str, help="ROOT file with the histograms")

    compare_rates_p = subparsers.add_parser("compare-rates", description=compare_rates.__doc__)
    compare_rates_p.set_defaults(_function=compare_rates)
    compare_rates_p.add_argument("before", type=str, help="First log file")
    compare_rates_p.add_argument("after", type=str, help="Second log file")

    for p in rates_from_log_p, compare_rates_p:
        p.add_argument("--input-rate", type=float, required=True, help="Input rate")
        p.add_argument("--output", type=str, default=None, help="Output file containing the rate information")

    check_histograms_p = subparsers.add_parser("check-histograms", description=check_histograms.__doc__)
    check_histograms_p.set_defaults(_function=check_histograms)
    check_histograms_p.add_argument("file", type=str, help="ROOT file with the histograms")
    check_histograms_p.add_argument("--regex", type=str, default=None, help="Regular expression to get trigger line names")

    for p in rates_from_log_p, compare_rates_p, check_histograms_p:
        p.add_argument("--nrows", type=int, default=500, help="Number of rows to display")
    
    config = vars(parser.parse_args())

    function = config.pop("_function")

    function(**config)