diff --git a/nnfwtbn/__init__.py b/nnfwtbn/__init__.py index 88ab420c8d563375ae9c1f1d647cc629ab062bf3..fd97db6db609c951bf75f678597c93d613a2eb32 100644 --- a/nnfwtbn/__init__.py +++ b/nnfwtbn/__init__.py @@ -3,4 +3,4 @@ __version__ = "0.0.0" from nnfwtbn.variable import Variable, RangeBlindingStrategy from nnfwtbn.process import Process from nnfwtbn.cut import Cut -from nnfwtbn.plot import HistogramFactory, hist, confusion_matrix +from nnfwtbn.plot import HistogramFactory, hist, confusion_matrix, roc diff --git a/nnfwtbn/plot.py b/nnfwtbn/plot.py index bbd1d7e5e278672d1a756a5d2bde507e53327227..04913140cce0884082084a853750410c0803f534 100644 --- a/nnfwtbn/plot.py +++ b/nnfwtbn/plot.py @@ -236,7 +236,7 @@ def confusion_matrix(df, x_processes, y_processes, x_label, y_label, # Handle axes, figure if figure is None: - figure, axes = plt.subplots() + figure, axes = plt.subplots() #figsize=(5,5)) elif axes is None: axes = figure.subplots() @@ -259,5 +259,97 @@ def confusion_matrix(df, x_processes, y_processes, x_label, y_label, ), **kwds) axes.set_xlabel(x_label) axes.set_ylabel(y_label) + figure.subplots_adjust(top=0.85) + + if ATLAS is not None: + figure.text(0.18, 0.89, "ATLAS", + fontdict={"size": 18, "style": "italic", "weight": "bold"}) + if isinstance(ATLAS, str): + figure.text(0.35, 0.89, ATLAS, + fontdict={"size": 12, }) return data + + +def roc(df, signal_process, background_process, discriminant, steps=100, + selection=None, + min=None, max=None, axes=None, weight=None): + """ + Creates a ROC. + """ + # Wrap column string by variable + if discriminant is None: + discriminant = Variable("unity", lambda d: np.ones(len(d))) + elif isinstance(discriminant, str): + discriminant = Variable(discriminant, discriminant) + + if weight is None: + weight = Variable("unity", lambda d: np.ones(len(d))) + elif isinstance(weight, str): + weight = Variable(weight, weight) + + # Handle selection + if selection is None: + selection = Cut(lambda d: variable(d) * 0 == 0) + elif not isinstance(selection, Cut): + selection = Cut(selection) + + # Handle axes + if axes is None: + fig, axes = figure.subplots() + if min is None: + min = discriminant(df).min() + if max is None: + max = discriminant(df).max() + + df = df[selection(df)] + + signal_effs = [] + background_ieffs = [] + n_total_sig = weight(df[signal_process.selection(df)]).sum() + n_total_bkg = weight(df[background_process.selection(df)]).sum() + for cut_value in np.linspace(min, max, steps): + residual_df = df[discriminant(df) >= cut_value] + + n_total = weight(residual_df).sum() + if n_total == 0: + continue + + signal_df = residual_df[signal_process.selection(residual_df)] + background_df = residual_df[background_process.selection(residual_df)] + + n_signal = weight(signal_df).sum() + n_background = weight(background_df).sum() + + signal_effs.append(n_signal / n_total_sig) + background_ieffs.append(1 - n_background / n_total_bkg) + + data = pd.DataFrame({"Signal efficiency": signal_effs, + "1 - Background efficiency": background_ieffs}) + sns.lineplot(x="Signal efficiency", y="1 - Background efficiency", data=data, + ax=axes, ci=None, label=discriminant.name) + axes.plot([0, 1], [1, 0], color='gray', linestyle=':') + + axes.set_xlim((0, 1)) + axes.set_ylim((0, 1.3)) + axes.legend(loc=1, frameon=False) + + axes.tick_params("both", which="both", direction="in") + axes.tick_params("both", which="major", length=6) + axes.tick_params("both", which="minor", length=3) + axes.tick_params("x", which="both", top=True) + axes.tick_params("y", which="both", right=True) + axes.xaxis.set_minor_locator(AutoMinorLocator()) + axes.yaxis.set_minor_locator(AutoMinorLocator()) + + if ATLAS is not None: + axes.text(0.04, 0.89, "ATLAS", transform=axes.transAxes, + fontdict={"size": 18, "style": "italic", "weight": "bold"}) + if isinstance(ATLAS, str): + axes.text(0.25, 0.89, ATLAS, transform=axes.transAxes, + fontdict={"size": 12, }) + + return data + + +