Commit 985ba25e authored by Maxence Draguet's avatar Maxence Draguet
Browse files

Including the extras sampling step for ttbar fraction enforcment

parent a8e042af
......@@ -288,8 +288,8 @@ def RunPreparation(args, config):
pbar.close()
if n_jets_to_get > 0:
print(
"WARNING: Not enough selected jets from files,"
" only ", jets_loaded
"WARNING: Not enough selected jets from files," " only ",
jets_loaded,
)
......@@ -377,25 +377,9 @@ def RunMerging(args, config):
def RunUndersampling(args, config):
"""Applies required cuts to the samples and applies the downsampling."""
required_total_njets = int(config.njets)
np.random.seed(42)
Z_bool, tt_bool = False, False
if config.ttbar_frac > 0.0 and config.ttbar_frac < 1:
Z_bool, tt_bool = True, True
nbZ = int(required_total_njets * (1 - config.ttbar_frac))
nbtt = int(required_total_njets * config.ttbar_frac)
elif not(config.ttbar_frac > 0.0):
Z_bool, tt_bool = True, False
nbZ = int(config.njets)
nbtt = 0
elif not(config.ttbar_frac < 1):
Z_bool, tt_bool = False, True
nbtt = int(config.njets)
nbZ = 0
else:
print("Invalid value of ttbar_frac", config.ttbar_frac)
return
N_list = upt.GetNJetsPerIteration(config)
# TODO: switch to dask
if config.sampling_method == "count":
sampling_method = "count"
elif config.sampling_method == "weight":
......@@ -404,65 +388,11 @@ def RunUndersampling(args, config):
print("Unspecified sampling method, default is count")
sampling_method = "count"
# First step: load the files and apply the cuts
if Z_bool:
vec_Z, tnp_Zprime = upt.cutsample(f_Z, category=0, take_tracks=args.tracks)
f_Z.close()
Z_bjets = vec_Z[vec_Z["HadronConeExclTruthLabelID"] == 5]
Z_cjets = vec_Z[vec_Z["HadronConeExclTruthLabelID"] == 4]
Z_ujets = vec_Z[vec_Z["HadronConeExclTruthLabelID"] == 0]
Z_tjets = None # For tau jets ... to be integrated later
if args.tracks:
Z_btracks = tnp_Zprime[vec_Z["HadronConeExclTruthLabelID"] == 5]
Z_ctracks = tnp_Zprime[vec_Z["HadronConeExclTruthLabelID"] == 4]
Z_utracks = tnp_Zprime[vec_Z["HadronConeExclTruthLabelID"] == 0]
Z_ttracks = None # For tau tracks ... to be integrated later
# Free memory
vec_Z, tnp_Zprime = None, None
del vec_Z, del tnp_Zprime
nbZ_available = Z_bjets.shape[0]
if nbZ_available >
n_selected = np.random.choice(
np.where(sample["category"] == 0)[0],
nZ_required,
replace=False,
)
if tt_bool:
tt_bjets, tt_btracks = upt.cutsample(f_tt_bjets, take_tracks=args.tracks)
f_tt_bjets.close()
tt_cjets, tt_ctracks = upt.cutsample(f_tt_cjets, take_tracks=args.tracks)
f_tt_cjets.close()
tt_ujets, tt_utracks = upt.cutsample(f_tt_ujets, take_tracks=args.tracks)
f_tt_ujets.close()
nbtt_available = tt_bjets.shape[0]
# Perform a check to enforce the right fraction of each
if Z_bool and tt_bool:
if nbtt_available < nbtt:
if nbZ_available < nbZ:
# too few of both
ttbar_frac_available = nbtt_available / (nbZ_available + nbtt_available)
else:
elif Z_bool:
elif tt_bool:
if Z_bool:
# Second step: sample
Z_bjets, Z_cjets, Z_ujets, Z_tjets, Z_btracks, Z_ctracks, Z_utracks, Z_ttracks, downs = upt.RunSampling(
Z_bjets, Z_cjets, Z_ujets, Z_tjets, Z_btracks, Z_ctracks, Z_utracks, Z_ttracks,
sampling_method, tracks=args.tracks
)
# Downsample ttbar:
if tt_bool:
tt_bjets, tt_cjets, tt_ujets, tt_tjets, tt_btracks, tt_ctracks, tt_utracks, tt_ttracks, _ = upt.RunSampling(
tt_bjets, tt_cjets, tt_ujets, tt_tjets, tt_btracks, tt_ctracks, tt_utracks, tt_ttracks,
sampling_method, tracks=args.tracks
)
# initialise input files (they are not yet loaded to memory)
f_Z = h5py.File(config.f_z, "r")
f_tt_bjets = h5py.File(config.f_tt_bjets, "r")
f_tt_cjets = h5py.File(config.f_tt_cjets, "r")
f_tt_ujets = h5py.File(config.f_tt_ujets, "r")
for x in range(config.iterations):
print("Iteration", x + 1, "of", config.iterations)
......@@ -541,6 +471,7 @@ def RunUndersampling(args, config):
tnp_tt_u = np.delete(tnp_tt_u, indices_toremove_ujets, 0)
print("starting undersampling")
bjets = np.concatenate(
[vec_Z[vec_Z["HadronConeExclTruthLabelID"] == 5], vec_tt_bjets]
)
......@@ -550,12 +481,7 @@ def RunUndersampling(args, config):
ujets = np.concatenate(
[vec_Z[vec_Z["HadronConeExclTruthLabelID"] == 0], vec_tt_ujets]
)
downs = upt.UnderSampling(bjets, cjets, ujets)
b_indices, c_indices, u_indices = downs.GetIndices()
bjets = bjets[b_indices]
cjets = cjets[c_indices]
ujets = ujets[u_indices]
tjets = None
if args.tracks:
btrk = np.concatenate(
......@@ -563,26 +489,92 @@ def RunUndersampling(args, config):
tnp_Zprime[vec_Z["HadronConeExclTruthLabelID"] == 5],
tnp_tt_b,
]
)[b_indices]
)
ctrk = np.concatenate(
[
tnp_Zprime[vec_Z["HadronConeExclTruthLabelID"] == 4],
tnp_tt_c,
]
)[c_indices]
)
utrk = np.concatenate(
[
tnp_Zprime[vec_Z["HadronConeExclTruthLabelID"] == 0],
tnp_tt_u,
]
)[u_indices]
ttfrac = float(
len(bjets[bjets["category"] == 1])
+ len(cjets[cjets["category"] == 1])
+ len(ujets[ujets["category"] == 1])
) / float(len(bjets) + len(cjets) + len(ujets))
print("ttbar fraction:", round(ttfrac, 2))
)
ttrk = None
else:
btrk, ctrk, utrk, ttrk = None, None, None, None
# Do the sampling:
(
bjets,
cjets,
ujets,
tjets,
btrk,
ctrk,
utrk,
ttrk,
downs,
) = upt.RunSampling(
bjets,
cjets,
ujets,
tjets,
btrk,
ctrk,
utrk,
ttrk,
sampling_method,
tracks=args.tracks,
)
# Print some statistics on the sample formed
statistics_dict = upt.RunStatSamples(bjets, cjets, ujets, tjets)
if config.enforce_ttbar_frac:
# If one wants to enforce the ttbar fraction demanded
# Normally not required, except if target number of jets is above total available
bjets, bindices = upt.EnforceFraction(
bjets, config.ttbar_frac, statistics_dict, label="b"
)
cjets, cindices = upt.EnforceFraction(
cjets, config.ttbar_frac, statistics_dict, label="c"
)
ujets, uindices = upt.EnforceFraction(
ujets, config.ttbar_frac, statistics_dict, label="u"
)
if args.tracks:
if bindices is not None:
btrk = btrk[bindices]
if cindices is not None:
ctrk = ctrk[cindices]
if uindices is not None:
utrk = utrk[uindices]
# Need to re-sample to make sure flavours are still distributed as demanded.
(
bjets,
cjets,
ujets,
tjets,
btrk,
ctrk,
utrk,
ttrk,
_,
) = upt.RunSampling(
bjets,
cjets,
ujets,
tjets,
btrk,
ctrk,
utrk,
ttrk,
sampling_method,
tracks=args.tracks,
)
statistics_dict = upt.RunStatSamples(bjets, cjets, ujets, tjets)
out_file = config.GetFileName(x + 1, option="downsampled")
print("saving file:", out_file)
......
import numpy as np
from scipy.stats import binned_statistic_2d
from umami.preprocessing_tools.Cuts import GetCuts
class UnderSampling(object):
"""
The DownSampling is used to prepare the training dataset. It makes sure
that in each pT/eta bin the same amount of jets are filled.
The DownSampling is used to prepare the training dataset. It makes sure
that in each pT/eta bin the same amount of jets are filled.
"""
def __init__(self, bjets, cjets, ujets, tjets=None):
......@@ -14,7 +14,7 @@ class UnderSampling(object):
self.cjets = cjets
self.ujets = ujets
self.tjets = tjets
self.bool_tjets = (tjets is not None)
self.bool_tjets = tjets is not None
self.pt_bins = np.concatenate(
(np.linspace(0, 600000, 351), np.linspace(650000, 6000000, 84))
)
......@@ -26,16 +26,18 @@ class UnderSampling(object):
def GetIndices(self):
"""
Applies the UnderSampling to the given arrays.
Returns the indices for the jets to be used separately for b, c and
light jets (as well as taus, optionally).
Applies the UnderSampling to the given arrays.
Returns the indices for the jets to be used separately for b, c and
light jets (as well as taus, optionally).
"""
binnumbers_b, ind_b, stat_b = self.GetBins(self.bjets)
binnumbers_c, _, stat_c = self.GetBins(self.cjets)
binnumbers_u, _, stat_u = self.GetBins(self.ujets)
if self.bool_tjets:
binnumbers_t, _, stat_t = self.GetBins(self.tjets)
min_count_per_bin = np.amin([stat_b, stat_c, stat_u, stat_t], axis=0)
min_count_per_bin = np.amin(
[stat_b, stat_c, stat_u, stat_t], axis=0
)
else:
min_count_per_bin = np.amin([stat_b, stat_c, stat_u], axis=0)
......@@ -87,21 +89,23 @@ class UnderSampling(object):
np.sort(np.concatenate(bjet_indices)),
np.sort(np.concatenate(cjet_indices)),
np.sort(np.concatenate(ujet_indices)),
sorted_tjets
sorted_tjets,
)
def GetIndicesPerBin(self):
"""
Applies the UnderSampling to the given arrays.
Returns the indices for the jets to be used separately for b, c and
light jets (as well as taus, optionally).
Applies the UnderSampling to the given arrays.
Returns the indices for the jets to be used separately for b, c and
light jets (as well as taus, optionally).
"""
binnumbers_b, ind_b, stat_b = self.GetBins(self.bjets)
binnumbers_c, _, stat_c = self.GetBins(self.cjets)
binnumbers_u, _, stat_u = self.GetBins(self.ujets)
if self.bool_tjets:
binnumbers_t, _, stat_t = self.GetBins(self.tjets)
min_count_per_bin = np.amin([stat_b, stat_c, stat_u, stat_t], axis=0)
min_count_per_bin = np.amin(
[stat_b, stat_c, stat_u, stat_t], axis=0
)
else:
min_count_per_bin = np.amin([stat_b, stat_c, stat_u], axis=0)
......@@ -145,12 +149,7 @@ class UnderSampling(object):
)
)
return (
bjet_indices,
cjet_indices,
ujet_indices,
tjet_indices
)
return (bjet_indices, cjet_indices, ujet_indices, tjet_indices)
def GetBins(self, df):
statistic, xedges, yedges, binnumber = binned_statistic_2d(
......@@ -171,13 +170,13 @@ class UnderSampling(object):
class UnderSamplingProp(object):
"""
Alternative to the UnderSampling approach, this implements a
proportional sampler to prepare the training dataset. It makes sure
that in each pT/eta bin each category has the same ratio of jets.
This is especially suited if not enough statistics is available for
some of the labels.
For example, in bin X, if 1% of b, 2% of c, 3 % of l jets are found,
sampler will take 1% of all b, 1% of all c and 1% of all l in the bin.
Alternative to the UnderSampling approach, this implements a
proportional sampler to prepare the training dataset. It makes sure
that in each pT/eta bin each category has the same ratio of jets.
This is especially suited if not enough statistics is available for
some of the labels.
For example, in bin X, if 1% of b, 2% of c, 3 % of l jets are found,
sampler will take 1% of all b, 1% of all c and 1% of all l in the bin.
"""
def __init__(self, bjets, cjets, ujets, tjets=None):
......@@ -186,7 +185,7 @@ class UnderSamplingProp(object):
self.cjets = cjets
self.ujets = ujets
self.tjets = tjets
self.bool_tjets = (tjets is not None)
self.bool_tjets = tjets is not None
self.pt_bins = np.concatenate(
(np.linspace(0, 600000, 351), np.linspace(650000, 6000000, 84))
)
......@@ -198,16 +197,18 @@ class UnderSamplingProp(object):
def GetIndices(self):
"""
Applies the weighted UnderSampling to the given arrays.
Returns the indices for the jets to be used separately for b, c and
light jets (as well as taus, optionally).
Applies the weighted UnderSampling to the given arrays.
Returns the indices for the jets to be used separately for b, c and
light jets (as well as taus, optionally).
"""
binnumbers_b, ind_b, stat_b, total_b = self.GetBins(self.bjets)
binnumbers_c, _, stat_c, total_c = self.GetBins(self.cjets)
binnumbers_u, _, stat_u, total_u = self.GetBins(self.ujets)
if self.bool_tjets:
binnumbers_t, _, stat_t, total_t = self.GetBins(self.tjets)
min_weight_per_bin = np.amin([stat_b, stat_c, stat_u, stat_t], axis=0)
min_weight_per_bin = np.amin(
[stat_b, stat_c, stat_u, stat_t], axis=0
)
else:
min_weight_per_bin = np.amin([stat_b, stat_c, stat_u], axis=0)
......@@ -259,10 +260,9 @@ class UnderSamplingProp(object):
np.sort(np.concatenate(bjet_indices)),
np.sort(np.concatenate(cjet_indices)),
np.sort(np.concatenate(ujet_indices)),
sorted_tjets
sorted_tjets,
)
def GetBins(self, df):
statistic, xedges, yedges, binnumber = binned_statistic_2d(
x=df[self.pT_var_name],
......@@ -280,7 +280,12 @@ class UnderSamplingProp(object):
total_count = df.shape[0]
weighted_flatten_statistic = statistic.flatten() / total_count
return binnumber, bins_indices_flat, weighted_flatten_statistic, total_count
return (
binnumber,
bins_indices_flat,
weighted_flatten_statistic,
total_count,
)
class Weighting2D(object):
......@@ -353,30 +358,8 @@ class Weighting2D(object):
return binnumber, bins_indices_flat, statistic.flatten()
def GetNJetsPerIteration(config, sampling_method, N_b, N_c, N_u):
"""
Returns the number of jet per 'flavour' to sample per iteration.
"""
required_total_njets = int(config.njets)
if config.ttbar_frac > 0.0:
required_tt_jets = int(required_total_njets * config.ttbar_frac)
required_Zext_jets = int(required_total_njets - required_tt_jets)
if sampling_method == "count":
N_available = [N_b, N_c, N_u]
N_sorted_idx = sorted(range(len(N_available)), key=N_available.__getitem__)
smallest_available_N = N_available[N_sorted_idx[0]]
if required_njets > smallest_available_N:
print("Requiring more jets per flavour than are available")
# In this case, take as much as the limiting one you can
# and as much of the other as possible
if N_available[N_sorted_idx[1]] > 2 * smallest_available_N:
required_njets
else:
elif sampling_method == "weight":
if
def GetNJetsPerIteration(config, total_number_of_taus=0):
if config.iterations == 0:
raise ValueError("The iterations have to be >=1 and not 0.")
if config.ttbar_frac > 0.0:
......@@ -404,7 +387,7 @@ def GetNJetsPerIteration(config, sampling_method, N_b, N_c, N_u):
def GetScales(vec, w, varname, custom_defaults_vars):
"""
Calculates the weighted average and std for vector vec and weight w.
Calculates the weighted average and std for vector vec and weight w.
"""
if np.sum(w) == 0:
raise ValueError("Sum of weights has to be >0.")
......@@ -427,7 +410,7 @@ def GetScales(vec, w, varname, custom_defaults_vars):
def dict_in(varname, average, std, default):
"""
Creates dictionary entry containing scale and shift parameters.
Creates dictionary entry containing scale and shift parameters.
"""
return {
"name": varname,
......@@ -439,7 +422,7 @@ def dict_in(varname, average, std, default):
def Gen_default_dict(scale_dict):
"""
Generates default value dictionary from scale/shift dictionary.
Generates default value dictionary from scale/shift dictionary.
"""
default_dict = {}
for elem in scale_dict:
......@@ -448,49 +431,71 @@ def Gen_default_dict(scale_dict):
default_dict[elem["name"]] = elem["default"]
return default_dict
def EnforceFraction(sample, ttbar_frac, statistics_dict, label, tolerance=0.01):
def EnforceFraction(
sample, ttbar_frac, statistics_dict, label, tolerance=0.01
):
"""
If the ttbar fraction obtained is off from the one expected (= ttbar_frac) by more
than tolerance, further downsamples to reach the expected fraction.
than tolerance, further downsamples to reach the expected fraction.
Requires a statistics_dict like the one produced RunStatSamples.
The key in the dict corresponding to the sample must be stored in label.
Requires a statistics_dict like the one produced RunStatSamples.
The key in the dict corresponding to the sample must be stored in label.
"""
down_sample = False
n_selected = None
np.random.seed(42)
np.random.seed(42)
(ttbar_frac_achieved, ntt, nZ) = statistics_dict[label]
if ttbar_frac_achieved < (ttbar_frac - tolerance): # Too much Z'
if ttbar_frac_achieved < (ttbar_frac - tolerance): # Too much Z'
nZ_required = int(ntt / ttbar_frac - ntt)
if nZ_required > nZ:
print("Error, requiring {} Z while only {} available".format(nZ_required, nZ))
if nZ_required > nZ:
print(
"Error, requiring {} Z while only {} available".format(
nZ_required, nZ
)
)
n_selected = np.random.choice(
np.where(sample["category"] == 0)[0],
nZ_required,
replace=False,
)
n_selected = np.concatenate([n_selected, np.where(sample["category"] == 1)[0]])
n_selected = np.concatenate(
[n_selected, np.where(sample["category"] == 1)[0]]
)
down_sample = True
elif ttbar_frac_achieved > (ttbar_frac + tolerance): # Too much ttbar
elif ttbar_frac_achieved > (ttbar_frac + tolerance): # Too much ttbar
ntt_required = int(ttbar_frac / (1 - ttbar_frac) * nZ)
if ntt_required > ntt:
print("Error, requiring {} tt while only {} available".format(ntt_required, ntt))
if ntt_required > ntt:
print(
"Error, requiring {} tt while only {} available".format(
ntt_required, ntt
)
)
n_selected = np.random.choice(
np.where(sample["category"] == 1)[0],
ntt_required,
replace=False,
)
n_selected = np.concatenate([n_selected, np.where(sample["category"] == 0)[0]])
n_selected = np.concatenate(
[n_selected, np.where(sample["category"] == 0)[0]]
)
down_sample = True
if down_sample:
sample = sample[n_selected]
nX_tt = len(sample[sample["category"] == 1])
nXjets = len(sample)
ttfrac_X = float(nX_tt)/nXjets
print("Further downsampled! {} {} jets: {} ttbar (frac: {}) | {} Z'-ext (frac: {})".format(
nXjets, label, nX_tt, round(ttfrac_X, 2), nXjets - nX_tt, round(1 - ttfrac_X, 2)))
ttfrac_X = float(nX_tt) / nXjets
print(
"Further downsampled! {} {} jets: {} ttbar (frac: {}) | {} Z'-ext (frac: {})".format(
nXjets,
label,
nX_tt,
round(ttfrac_X, 2),
nXjets - nX_tt,
round(1 - ttfrac_X, 2),
)
)
return sample, n_selected
def RunStatSamples(bjets, cjets, ujets, tjets=None):
......@@ -508,33 +513,54 @@ def RunStatSamples(bjets, cjets, ujets, tjets=None):
if tjets is not None:
nt_tt = len(tjets[tjets["category"] == 1])
ntjets = len(tjets)
ttfrac = float(
nb_tt
+ nc_tt
+ nu_tt
+ nt_tt
) / float(nbjets + ncjets + nujets + ntjets)
ttfrac = float(nb_tt + nc_tt + nu_tt + nt_tt) / float(
nbjets + ncjets + nujets + ntjets
)
else:
ttfrac = float(
nb_tt
+ nc_tt
+ nu_tt
) / float(nbjets + ncjets + nujets)
ttfrac = float(nb_tt + nc_tt + nu_tt) / float(nbjets + ncjets + nujets)
print("ttbar fraction:", round(ttfrac, 2))
ttfrac_b = float(nb_tt)/nbjets
ttfrac_c = float(nc_tt)/ncjets
ttfrac_u = float(nu_tt)/nujets
if tjets is not None:
ttfrac_t = float(nt_tt)/ntjets
print("{} b jets: {} ttbar (frac: {}) | {} Z'-ext (frac: {})".format(nbjets, nb_tt,
round(ttfrac_b, 2), nbjets - nb_tt, round(1 - ttfrac_b, 2)))
print("{} c jets: {} ttbar (frac: {}) | {} Z'-ext (frac: {})".format(ncjets, nc_tt,
round(ttfrac_c, 2), ncjets - nc_tt, round(1 - ttfrac_c, 2)))
print("{} u jets: {} ttbar (frac: {}) | {} Z'-ext (frac: {})".format(nujets, nu_tt,
round(ttfrac_u, 2), nujets - nu_tt, round(1 - ttfrac_u, 2)))
ttfrac_b = float(nb_tt) / nbjets
ttfrac_c = float(nc_tt) / ncjets
ttfrac_u = float(nu_tt) / nujets
if tjets is not None:
ttfrac_t = float(nt_tt) / ntjets
print(
"{} b jets: {} ttbar (frac: {}) | {} Z'-ext (frac: {})".format(
nbjets,
nb_tt,
round(ttfrac_b, 2),
nbjets - nb_tt,
round(1 - ttfrac_b, 2),
)
)
print(
"{} c jets: {} ttbar (frac: {}) | {} Z'-ext (frac: {})".format(
ncjets,
nc_tt,
round(ttfrac_c, 2),
ncjets - nc_tt,
round(1 - ttfrac_c, 2),
)
)
print(
"{} u jets: {} ttbar (frac: {}) | {} Z'-ext (frac: {})".format(
nujets,
nu_tt,
round(ttfrac_u, 2),
nujets - nu_tt,
round(1 - ttfrac_u, 2),