From 99ae0e322fe6aa364450ea44eeaef1b289ceb461 Mon Sep 17 00:00:00 2001
From: Mintu Kumar <mintu.kumar@cern.ch>
Date: Thu, 11 Jul 2024 13:22:56 +0200
Subject: [PATCH] bTagSF implementation

---
 higgs_dna/systematics/__init__.py             |   3 +
 .../systematics/event_weight_systematics.py   | 268 +++++++++++++++++-
 higgs_dna/workflows/top.py                    |  28 +-
 scripts/pull_files.py                         |  55 +++-
 4 files changed, 343 insertions(+), 11 deletions(-)

diff --git a/higgs_dna/systematics/__init__.py b/higgs_dna/systematics/__init__.py
index ded59b41..2b9610cc 100644
--- a/higgs_dna/systematics/__init__.py
+++ b/higgs_dna/systematics/__init__.py
@@ -18,6 +18,7 @@ from .event_weight_systematics import (
     AlphaS,
     PartonShower,
     cTagSF,
+    bTagShapeSF,
     Zpt,
 )
 from .jet_systematics import (
@@ -164,6 +165,7 @@ weight_systematics = {
     "PreselSF": partial(PreselSF, is_correction=False),
     "TriggerSF": partial(TriggerSF, is_correction=False),
     "cTagSF": partial(cTagSF, is_correction=False),
+    "bTagShapeSF": partial(bTagShapeSF, is_correction=False),
     "AlphaS": partial(AlphaS),
     "PartonShower": partial(PartonShower),
     "LHEScale": None,
@@ -181,6 +183,7 @@ weight_corrections = {
     "PreselSF": partial(PreselSF, is_correction=True),
     "TriggerSF": partial(TriggerSF, is_correction=True),
     "cTagSF": partial(cTagSF, is_correction=True),
+    "bTagShapeSF": partial(bTagShapeSF, is_correction=True),
     "NNLOPS": partial(NNLOPS, is_correction=True),
     "Zpt": partial(Zpt, is_correction=True),
 }
diff --git a/higgs_dna/systematics/event_weight_systematics.py b/higgs_dna/systematics/event_weight_systematics.py
index fa7d21fe..33075ad4 100644
--- a/higgs_dna/systematics/event_weight_systematics.py
+++ b/higgs_dna/systematics/event_weight_systematics.py
@@ -707,6 +707,267 @@ def PartonShower(photons, events, weights, dataset_name, **kwargs):
     return weights
 
 
+def bTagShapeSF(events, weights, is_correction=True, year="2017", **kwargs):
+    avail_years = ["2016preVFP", "2016postVFP", "2017", "2018", "2022preEE", "2022postEE", "2023preBPix", "2023postBPix"]
+    if year not in avail_years:
+        print(f"\n WARNING: only scale corrections for the year strings {avail_years} are already implemented! \n Exiting. \n")
+        exit()
+    btag_systematics = [
+        "lf",
+        "hf",
+        "cferr1",
+        "cferr2",
+        "lfstats1",
+        "lfstats2",
+        "hfstats1",
+        "hfstats2",
+        "jes",
+    ]
+    inputFilePath = "JSONs/bTagSF/"
+    btag_correction_configs = {
+        "2016preVFP": {
+            "file": os.path.join(
+                inputFilePath , "2016preVFP_UL/btagging.json.gz"
+            ),
+            "method": "deepJet_shape",
+            "systs": btag_systematics,
+        },
+        "2016postVFP": {
+            "file": os.path.join(
+                inputFilePath , "2016postVFP_UL/btagging.json.gz"
+            ),
+            "method": "deepJet_shape",
+            "systs": btag_systematics,
+        },
+        "2017": {
+            "file": os.path.join(
+                inputFilePath , "2017_UL/btagging.json.gz"
+            ),
+            "method": "deepJet_shape",
+            "systs": btag_systematics,
+        },
+        "2018": {
+            "file": os.path.join(
+                inputFilePath , "2018_UL/btagging.json.gz"
+            ),
+            "method": "deepJet_shape",
+            "systs": btag_systematics,
+        },
+        "2022preEE":{
+            "file": os.path.join(
+                inputFilePath , "2022_Summer22/btagging.json.gz"
+            ),
+            "method": "deepJet_shape",
+            "systs": btag_systematics,
+        },
+        "2022postEE":{
+            "file": os.path.join(
+                inputFilePath , "2022_Summer22EE/btagging.json.gz"
+            ),
+            "method": "deepJet_shape",
+            "systs": btag_systematics,
+        },
+        "2023preBPix":{
+            "file": os.path.join(
+                inputFilePath , "2023_Summer23/btagging.json.gz"
+            ),
+            "method": "deepJet_shape",
+            "systs": btag_systematics,
+        },
+        "2023postBPix":{
+            "file": os.path.join(
+                inputFilePath , "2023_Summer23BPix/btagging.json.gz"
+            ),
+            "method": "deepJet_shape",
+            "systs": btag_systematics,
+        },
+    }
+    jsonpog_file = os.path.join(
+        os.path.dirname(__file__), btag_correction_configs[year]["file"]
+    )
+    evaluator = correctionlib.CorrectionSet.from_file(jsonpog_file)[
+        btag_correction_configs[year]["method"]
+    ]
+
+    dummy_sf = ak.ones_like(events["event"])
+
+    relevant_jets = events["sel_jets"][
+        np.abs(events["sel_jets"].eta) < 2.5
+    ]
+    # only calculate correction to nominal weight
+    # we will evaluate the scale factors relative to all jets to be multiplied
+    jet_pt = relevant_jets.pt
+    jet_eta = np.abs(relevant_jets.eta)
+    jet_hFlav = relevant_jets.hFlav
+    jet_btagDeepFlavB = relevant_jets.btagDeepFlavB
+
+    # Convert the jets in one dimension array and store the orignal structure of the ak array in counts
+    flat_jet_pt = ak.flatten(jet_pt)
+    flat_jet_eta = ak.flatten(jet_eta)
+    flat_jet_btagDeepFlavB = ak.flatten(jet_btagDeepFlavB)
+    flat_jet_hFlav = ak.flatten(jet_hFlav)
+
+    counts = ak.num(jet_hFlav)
+
+    logger.info("Warning: you have to normalise b-tag weights afterwards so that they do not change the yield!")
+
+    if is_correction:
+
+        _sf = []
+        # Evluate the scale factore per jet and unflatten the scale fatores in original structure
+        _sf = ak.unflatten(
+            evaluator.evaluate(
+                "central",
+                flat_jet_hFlav,
+                flat_jet_eta,
+                flat_jet_pt,
+                flat_jet_btagDeepFlavB,
+            ),
+            counts
+        )
+        # Multiply the scale factore of all jets in a even
+        sf = ak.prod(_sf,axis=1)
+
+        sfs_up = [None for _ in btag_systematics]
+        sfs_down = [None for _ in btag_systematics]
+
+    else:
+        # only calculate correction to nominal weight
+        # replace by accessing partial weight!
+        _sf = []
+        # Evluate the scale factore per jet and unflatten the scale fatores in original structure
+        _sf_central = evaluator.evaluate(
+            "central",
+            flat_jet_hFlav,
+            flat_jet_eta,
+            flat_jet_pt,
+            flat_jet_btagDeepFlavB,
+        )
+        # Multiply the scale factore of all jets in a even
+
+        sf = ak.values_astype(dummy_sf, np.float)
+        sf_central = ak.prod(
+            ak.unflatten(_sf_central, counts),
+            axis=1
+        )
+
+        variations = {}
+
+        # Define a condiation based the jet flavour because the json file are defined for the 4(c),5(b),0(lf) flavour jets
+        flavour_condition = np.logical_or(jet_hFlav < 4,jet_hFlav > 5)
+        # Replace the flavour to 0 (lf) if the jet flavour is neither 4 nor 5
+        jet_hFlav_JSONrestricted = ak.where(flavour_condition, 0 ,jet_hFlav)
+        flat_jet_hFlav_JSONrestricted = ak.flatten(jet_hFlav_JSONrestricted)
+        # We need a dmmy sf array set to one to multiply for flavour dependent systentic variation
+        flat_dummy_sf = ak.ones_like(flat_jet_hFlav_JSONrestricted)
+
+        for syst_name in btag_correction_configs[year]["systs"]:
+
+            # we will append the scale factors relative to all jets to be multiplied
+            _sfup = []
+            _sfdown = []
+            variations[syst_name] = {}
+
+            if "cferr" in syst_name:
+                # we to remember which jet is correspond to c(hadron flv 4) jets
+                cjet_masks = flat_jet_hFlav_JSONrestricted == 4
+
+                flat_jet_hFlavC_JSONrestricted = ak.where(flat_jet_hFlav_JSONrestricted != 4, 4 ,flat_jet_hFlav_JSONrestricted)
+                _Csfup = evaluator.evaluate(
+                    "up_" + syst_name,
+                    flat_jet_hFlavC_JSONrestricted,
+                    flat_jet_eta,
+                    flat_jet_pt,
+                    flat_jet_btagDeepFlavB,
+                )
+
+                _Csfdown = evaluator.evaluate(
+                    "down_" + syst_name,
+                    flat_jet_hFlavC_JSONrestricted,
+                    flat_jet_eta,
+                    flat_jet_pt,
+                    flat_jet_btagDeepFlavB,
+                )
+                _Csfup = ak.where(
+                    cjet_masks,
+                    _Csfup,
+                    flat_dummy_sf,
+                )
+                _Csfdown = ak.where(
+                    cjet_masks,
+                    _Csfdown,
+                    flat_dummy_sf,
+                )
+                # Replace all the calculated sf with 1 when there is light jet or with flavour b otherwise keep the cerntral weight
+                _sfcentral_Masked_notC = ak.where(
+                    ~cjet_masks,
+                    _sf_central,
+                    flat_dummy_sf,
+                )
+                _sfup = ak.unflatten(np.multiply(_sfcentral_Masked_notC,_Csfup),counts)
+                _sfdown = ak.unflatten(np.multiply(_sfcentral_Masked_notC,_Csfdown),counts)
+            else:
+                # We to remember which jet is correspond to c(hadron flv 4) jets
+                cjet_masks = flat_jet_hFlav_JSONrestricted == 4
+
+                flat_jet_hFlavNonC_JSONrestricted = ak.where(cjet_masks, 0 ,flat_jet_hFlav_JSONrestricted)
+
+                _NonCsfup = evaluator.evaluate(
+                    "up_" + syst_name,
+                    flat_jet_hFlavNonC_JSONrestricted,
+                    flat_jet_eta,
+                    flat_jet_pt,
+                    flat_jet_btagDeepFlavB,
+                )
+
+                _NonCsfdown = evaluator.evaluate(
+                    "down_" + syst_name,
+                    flat_jet_hFlavNonC_JSONrestricted,
+                    flat_jet_eta,
+                    flat_jet_pt,
+                    flat_jet_btagDeepFlavB,
+                )
+
+                _NonCsfup = ak.where(
+                    ~cjet_masks,
+                    _NonCsfup,
+                    flat_dummy_sf,
+                )
+                _NonCsfdown = ak.where(
+                    ~cjet_masks,
+                    _NonCsfdown,
+                    flat_dummy_sf,
+                )
+                # Replace all the calculated sf with 1 when there is c jet otherwise keep the cerntral weight
+                _sfcentral_Masked_C = ak.where(
+                    cjet_masks,
+                    _sf_central,
+                    flat_dummy_sf,
+                )
+                _sfup = ak.unflatten(np.multiply(_sfcentral_Masked_C,_NonCsfup),counts)
+                _sfdown = ak.unflatten(np.multiply(_sfcentral_Masked_C,_NonCsfdown),counts)
+
+            sf_up = ak.prod(_sfup,axis=1)
+            sf_down = ak.prod(_sfdown,axis=1)
+            variations[syst_name]["up"] = sf_up
+            variations[syst_name]["down"] = sf_down
+        # coffea weights.add_multivariation() wants a list of arrays for the multiple up and down variations
+        # we devide sf_central because cofea processor save the up and down vartion by multiplying the central weights
+        sfs_up = [variations[syst_name]["up"] / sf_central for syst_name in btag_systematics]
+        sfs_down = [variations[syst_name]["down"] / sf_central for syst_name in btag_systematics]
+
+    weights.add_multivariation(
+        name="bTagSF",
+        weight=sf,
+        modifierNames=btag_systematics,
+        weightsUp=sfs_up,
+        weightsDown=sfs_down,
+        shift=False,
+    )
+
+    return weights
+
+
 def cTagSF(events, weights, is_correction=True, year="2017", **kwargs):
     """
     Add c-tagging reshaping SFs as from /https://github.com/higgs-charm/flashgg/blob/dev/cH_UL_Run2_withBDT/Systematics/scripts/applyCTagCorrections.py
@@ -821,11 +1082,8 @@ def cTagSF(events, weights, is_correction=True, year="2017", **kwargs):
         for nth in _sf:
             sf = sf * nth
 
-        sfs_up = []
-        sfs_down = []
-        for syst in ctag_systematics:
-            sfs_up.append(ak.values_astype(dummy_sf, np.float))
-            sfs_down.append(ak.values_astype(dummy_sf, np.float))
+        sfs_up = [ak.values_astype(dummy_sf, np.float) for _ in ctag_systematics]
+        sfs_down = [ak.values_astype(dummy_sf, np.float) for _ in ctag_systematics]
 
         weights.add_multivariation(
             name="cTagSF",
diff --git a/higgs_dna/workflows/top.py b/higgs_dna/workflows/top.py
index d801cf26..80891f55 100644
--- a/higgs_dna/workflows/top.py
+++ b/higgs_dna/workflows/top.py
@@ -72,6 +72,8 @@ class TopProcessor(HggBaseProcessor):  # type: ignore
             output_format=output_format
         )
 
+        self.el_iso_wp = "WP90"
+
     def process_extra(self, events: ak.Array) -> ak.Array:
         return events, {}
 
@@ -339,7 +341,12 @@ class TopProcessor(HggBaseProcessor):  # type: ignore
                             "hFlav": Jets.hadronFlavour
                             if self.data_kind == "mc"
                             else ak.zeros_like(Jets.pt),
-                            "btagDeepFlav_B": Jets.btagDeepFlavB,
+                            "btagPNetB": Jets.btagPNetB,
+                            "btagDeepFlavB": Jets.btagDeepFlavB,
+                            "btagRobustParTAK4B": Jets.btagRobustParTAK4B,
+                            "btagRobustParTAK4CvB": Jets.btagRobustParTAK4CvB,
+                            "btagRobustParTAK4CvL": Jets.btagRobustParTAK4CvL,
+                            "btagRobustParTAK4QG": Jets.btagRobustParTAK4QG,
                             "btagDeepFlav_CvB": Jets.btagDeepFlavCvB,
                             "btagDeepFlav_CvL": Jets.btagDeepFlavCvL,
                             "btagDeepFlav_QG": Jets.btagDeepFlavQG,
@@ -391,9 +398,10 @@ class TopProcessor(HggBaseProcessor):  # type: ignore
                     # adding selected jets to events to be used in ctagging SF calculation
                     events["sel_jets"] = jets
                     n_jets = ak.num(jets)
+                    diphotons["JetHT"] = ak.sum(jets.pt,axis=1)
 
-                    num_jets = 6
-                    jet_properties = ["pt", "eta", "phi", "mass", "charge", "btagDeepFlav_B"]
+                    num_jets = 8
+                    jet_properties = ["pt", "eta", "phi", "mass", "charge", "btagPNetB", "btagDeepFlavB", "btagRobustParTAK4B", "btagRobustParTAK4CvB", "btagRobustParTAK4CvL", "btagRobustParTAK4QG"]
                     for i in range(num_jets):
                         for prop in jet_properties:
                             key = f"jet{i+1}_{prop}"
@@ -458,7 +466,7 @@ class TopProcessor(HggBaseProcessor):  # type: ignore
                         return histos_etc
                     if self.data_kind == "mc":
                         # initiate Weight container here, after selection, since event selection cannot easily be applied to weight container afterwards
-                        event_weights = Weights(size=len(events[selection_mask]))
+                        event_weights = Weights(size=len(events[selection_mask]),storeIndividual=True)
 
                         # corrections to event weights:
                         for correction_name in correction_names:
@@ -478,7 +486,10 @@ class TopProcessor(HggBaseProcessor):  # type: ignore
                                     dataset_name=dataset_name,
                                     year=self.year[dataset_name][0],
                                 )
-
+                        metadata["sum_weight_central_wo_bTagSF"] = str(
+                            ak.sum(event_weights.partial_weight(exclude=["bTagSF"]))
+                        )
+                        diphotons["bTagWeight"] = event_weights.partial_weight(include=["bTagSF"])
                         # systematic variations of event weights go to nominal output dataframe:
                         if do_variation == "nominal":
                             for systematic_name in systematic_names:
@@ -533,6 +544,9 @@ class TopProcessor(HggBaseProcessor):  # type: ignore
                                         )
 
                         diphotons["weight_central"] = event_weights.weight()
+                        metadata["sum_weight_central"] = str(
+                            ak.sum(event_weights.weight())
+                        )
                         # Store variations with respect to central weight
                         if do_variation == "nominal":
                             if len(event_weights.variations):
@@ -543,6 +557,10 @@ class TopProcessor(HggBaseProcessor):  # type: ignore
                                 diphotons["weight_" + modifier] = event_weights.weight(
                                     modifier=modifier
                                 )
+                                if ("bTagSF" in modifier):
+                                    metadata["sum_weight_" + modifier] = str(
+                                        ak.sum(event_weights.weight(modifier=modifier))
+                                    )
 
                         # Multiply weight by genWeight for normalisation in post-processing chain
                         event_weights._weight = (
diff --git a/scripts/pull_files.py b/scripts/pull_files.py
index 2c0362ba..52a3de41 100644
--- a/scripts/pull_files.py
+++ b/scripts/pull_files.py
@@ -20,7 +20,7 @@ parser.add_argument(
     dest="target",
     help="Choose the target to download (default: %(default)s)",
     default="GoldenJson",
-    choices=["GoldenJSON", "cTag", "PhotonID", "PU", "SS", "JetMET", "CDFs", "JEC", "JER", "Material", "TriggerSF", "PreselSF", "eVetoSF", "Flows", "FNUF", "ShowerShape", "LooseMva","LowMass-DiPhotonMVA"],
+    choices=["GoldenJSON", "cTag", "bTag", "PhotonID", "PU", "SS", "JetMET", "CDFs", "JEC", "JER", "Material", "TriggerSF", "PreselSF", "eVetoSF", "Flows", "FNUF", "ShowerShape", "LooseMva","LowMass-DiPhotonMVA"],
 )
 
 parser.add_argument(
@@ -421,6 +421,56 @@ def get_eveto_json(logger, target_dir):
     fetch_file("eVetoSF", logger, from_to_dict, type="copy")
 
 
+def get_btag_json(logger, target_dir):
+    if target_dir is not None:
+        to_prefix = target_dir
+    else:
+        to_prefix = os.path.join(
+            os.path.dirname(__file__), "../higgs_dna/systematics/JSONs/bTagSF/"
+        )
+
+    from_to_dict = {
+        "2016preVFP": {
+            "from": "/cvmfs/cms.cern.ch/rsync/cms-nanoAOD/jsonpog-integration/POG/BTV/2016preVFP_UL/btagging.json.gz",
+            "to": f"{to_prefix}/2016preVFP_UL/btagging.json.gz",
+        },
+        "2016postVFP": {
+            "from": "/cvmfs/cms.cern.ch/rsync/cms-nanoAOD/jsonpog-integration/POG/BTV/2016postVFP_UL/btagging.json.gz",
+            "to": f"{to_prefix}/2016postVFP_UL/btagging.json.gz",
+        },
+        "2017": {
+            "from": "/cvmfs/cms.cern.ch/rsync/cms-nanoAOD/jsonpog-integration/POG/BTV/2017_UL/btagging.json.gz",
+            "to": f"{to_prefix}/2017_UL/btagging.json.gz",
+        },
+        "2018": {
+            "from": "/cvmfs/cms.cern.ch/rsync/cms-nanoAOD/jsonpog-integration/POG/BTV/2018_UL/btagging.json.gz",
+            "to": f"{to_prefix}/2018_UL/btagging.json.gz",
+        },
+        "2018": {
+            "from": "/cvmfs/cms.cern.ch/rsync/cms-nanoAOD/jsonpog-integration/POG/BTV/2018_UL/btagging.json.gz",
+            "to": f"{to_prefix}/2018_UL/btagging.json.gz",
+        },
+        "2022preEE": {
+            "from": "/cvmfs/cms.cern.ch/rsync/cms-nanoAOD/jsonpog-integration/POG/BTV/2022_Summer22/btagging.json.gz",
+            "to": f"{to_prefix}/2022_Summer22/btagging.json.gz",
+        },
+        "2022postEE": {
+            "from": "/cvmfs/cms.cern.ch/rsync/cms-nanoAOD/jsonpog-integration/POG/BTV/2022_Summer22EE/btagging.json.gz",
+            "to": f"{to_prefix}/2022_Summer22EE/btagging.json.gz",
+        },
+        "2023preBPix": {
+            "from": "/cvmfs/cms.cern.ch/rsync/cms-nanoAOD/jsonpog-integration/POG/BTV/2023_Summer23/btagging.json.gz",
+            "to": f"{to_prefix}/2023_Summer23/btagging.json.gz",
+        },
+        "2023postBPix": {
+            "from": "/cvmfs/cms.cern.ch/rsync/cms-nanoAOD/jsonpog-integration/POG/BTV/2023_Summer23BPix/btagging.json.gz",
+            "to": f"{to_prefix}/2023_Summer23BPix/btagging.json.gz",
+        },
+    }
+    fetch_file("bTag", logger, from_to_dict, type="copy")
+
+
+
 def get_ctag_json(logger, target_dir):
     if target_dir is not None:
         to_prefix = target_dir
@@ -774,6 +824,7 @@ if __name__ == "__main__":
         get_mass_decorrelation_CDF(logger, args.target_dir)
         get_Flow_files(logger, args.target_dir)
         get_ctag_json(logger, args.target_dir)
+        get_btag_json(logger, args.target_dir)
         get_photonid_json(logger, args.target_dir)
         get_jetmet_json(logger, args.target_dir)
         get_jec_files(logger, args.target_dir)
@@ -798,6 +849,8 @@ if __name__ == "__main__":
         get_Flow_files(logger, args.target_dir)
     elif args.target == "cTag":
         get_ctag_json(logger, args.target_dir)
+    elif args.target == "bTag":
+        get_btag_json(logger, args.target_dir)
     elif args.target == "PhotonID":
         get_photonid_json(logger, args.target_dir)
     elif args.target == "JetMET":
-- 
GitLab