diff --git a/freeforestml/model.py b/freeforestml/model.py
index 3f42c22e69850ea117657dbc185bc23a27b61a6d..ec395e4ad05a9f7915cd91691b0ea00ca928cb25 100644
--- a/freeforestml/model.py
+++ b/freeforestml/model.py
@@ -13,10 +13,12 @@ import tensorflow
 from freeforestml.variable import Variable
 from freeforestml.helpers import python_to_str, str_to_python
 
+
 class CrossValidator(ABC):
     """
     Abstract class of a cross validation method.
     """
+
     def __init__(self, k, mod_var=None, frac_var=None):
         """
         Creates a new cross validator. The argument k determines the number of
@@ -70,7 +72,7 @@ class CrossValidator(ABC):
         """
 
     @abstractmethod
-    def select_training(self, df, fold_i, for_predicting = False):
+    def select_training(self, df, fold_i, for_predicting=False):
         """
         Returns the index array to select all training events from the dataset for the
         given fold.
@@ -90,7 +92,7 @@ class CrossValidator(ABC):
         given fold.
         """
 
-    def select_cv_set(self, df, cv, fold_i, for_predicting = False):
+    def select_cv_set(self, df, cv, fold_i, for_predicting=False):
         """
         Returns the index array to select all events from the cross validator
         set specified with cv ('train', 'val', 'test') for the given fold.
@@ -99,7 +101,8 @@ class CrossValidator(ABC):
             raise ValueError("Argument 'cv' must be one of 'train', 'val', "
                              "'test', 'all'; but was %s." % repr(cv))
         if cv == "train":
-            selected = self.select_training(df, fold_i, for_predicting = for_predicting)
+            selected = self.select_training(
+                df, fold_i, for_predicting=for_predicting)
         elif cv == "val":
             selected = self.select_validation(df, fold_i)
         else:
@@ -110,9 +113,9 @@ class CrossValidator(ABC):
         """
         Returns and array of integers to specify which event was used
         for train/val/test in which fold. Mostly useful for the inference/predict
-        step. For cross validators with a high number of folds, so that an event 
-        is used in multiple folds for the training set, a single fold number is 
-        retrieved so that the folds are equally represented in the predicted 
+        step. For cross validators with a high number of folds, so that an event
+        is used in multiple folds for the training set, a single fold number is
+        retrieved so that the folds are equally represented in the predicted
         training data.
         """
         fold_info = np.zeros(len(df), dtype='bool') - 1
@@ -149,12 +152,14 @@ class CrossValidator(ABC):
             class_object = getattr(sys.modules[__name__], class_name)
             k = input_file[key].attrs["k"]
             mod_mode = input_file[key].attrs["mod_mode"]
-            variable = Variable.load_from_h5(path, os.path.join(key, "variable"))
+            variable = Variable.load_from_h5(
+                path, os.path.join(key, "variable"))
             if mod_mode:
                 return class_object(k=k, mod_var=variable)
             else:
                 return class_object(k=k, frac_var=variable)
 
+
 class ClassicalCV(CrossValidator):
     """
     Performs the k-fold cross validation on half of the data set. The other
@@ -182,9 +187,9 @@ class ClassicalCV(CrossValidator):
         else:
             variable = self.variable(df) % 1
             return (slice_id / (self.k * 2.0) <= variable) \
-                   & (variable < (slice_id + 1.0) / (self.k * 2))
+                & (variable < (slice_id + 1.0) / (self.k * 2))
 
-    def select_training(self, df, fold_i, for_predicting = False):
+    def select_training(self, df, fold_i, for_predicting=False):
         """
         Returns the index array to select all training events from the dataset for the
         given fold.
@@ -212,7 +217,7 @@ class ClassicalCV(CrossValidator):
         selected = np.zeros(len(df), dtype='bool')
         for slice_i in range(self.k, self.k * 2):
             selected = selected | self.select_slice(df, slice_i)
-            
+
         return selected
 
 
@@ -228,6 +233,7 @@ class NoTestCV(CrossValidator):
     used for the training or if real-time (non-hep) data is used as a "test"
     set.
     """
+
     def __init__(self, mod_var=None, frac_var=None, k=10):
         """
         The parameter k defines the inverse fraction of the validation set.
@@ -248,9 +254,9 @@ class NoTestCV(CrossValidator):
         else:
             variable = self.variable(df) % 1
             return (slice_id / self.k <= variable) \
-                   & (variable < (slice_id + 1.0) / self.k)
+                & (variable < (slice_id + 1.0) / self.k)
 
-    def select_training(self, df, fold_i, for_predicting = False):
+    def select_training(self, df, fold_i, for_predicting=False):
         """
         Returns the index array to select all training events from the
         dataset. The fold_i parameter has no effect.
@@ -276,6 +282,7 @@ class NoTestCV(CrossValidator):
         selected = np.zeros(len(df), dtype='bool')
         return selected
 
+
 class BinaryCV(CrossValidator):
     """
     Defines a training set and a test set using a binary split. There is no
@@ -289,6 +296,7 @@ class BinaryCV(CrossValidator):
     retrain the model on the full half. The valiation performance contain in
     HepNet.history is the test performance.
     """
+
     def __init__(self, mod_var=None, frac_var=None, k=None):
         """
         k is set to 2. The argument k has no effect.
@@ -308,9 +316,9 @@ class BinaryCV(CrossValidator):
         else:
             variable = self.variable(df) % 1
             return (slice_id / self.k <= variable) \
-                   & (variable < (slice_id + 1.0) / self.k)
+                & (variable < (slice_id + 1.0) / self.k)
 
-    def select_training(self, df, fold_i, for_predicting = False):
+    def select_training(self, df, fold_i, for_predicting=False):
         """
         Returns the index array to select all training events from the dataset for the
         given fold.
@@ -345,6 +353,7 @@ class MixedCV(CrossValidator):
 
         Va=Validation, Tr=Training, Te=Test
     """
+
     def select_slice(self, df, slice_id):
         """
         Returns the index array to select all events from the dataset of a
@@ -358,13 +367,13 @@ class MixedCV(CrossValidator):
         else:
             variable = self.variable(df) % 1
             return (slice_id / self.k <= variable) \
-                   & (variable < (slice_id + 1.0) / self.k)
+                & (variable < (slice_id + 1.0) / self.k)
 
-    def select_training_slices(self, fold_i, for_predicting = False):
+    def select_training_slices(self, fold_i, for_predicting=False):
         """
-        Returns array with integers corresponding 
+        Returns array with integers corresponding
         to the data slices used in training fold_i.
-        If 'for_predicting' is set to True only one slice 
+        If 'for_predicting' is set to True only one slice
         is returned for each fold so that the folds are equally represented
         in the predicted training data.
         """
@@ -379,27 +388,29 @@ class MixedCV(CrossValidator):
                 all_slices_for_folds[-1].append(slice_i)
 
         # if we select the slices for training we are done
-        if not for_predicting: return all_slices_for_folds[fold_i]
-        
+        if not for_predicting:
+            return all_slices_for_folds[fold_i]
+
         # all_slices_for_folds looks e.g. like:
         # [[0, 1, 2], [0, 1, 4], [0, 3, 4], [2, 3, 4], [1, 2, 3]]
         # need to select array with uniq entries:
         # [0, 1, 2, 4, 3]
-        uniq_el = lambda ar: set(x for l in ar for x in l)
+        def uniq_el(ar): return set(x for l in ar for x in l)
         exclusive_slices = []
         for i, slices in enumerate(all_slices_for_folds):
             for sl in slices:
                 if sl not in exclusive_slices and sl in uniq_el(all_slices_for_folds[i:]):
                     exclusive_slices.append(sl)
         return [exclusive_slices[fold_i]]
-        
-    def select_training(self, df, fold_i, for_predicting = False):
+
+    def select_training(self, df, fold_i, for_predicting=False):
         """
         Returns the index array to select all training events from the dataset for the
         given fold.
         """
         selected = np.zeros(len(df), dtype='bool')
-        slices = self.select_training_slices(fold_i, for_predicting = for_predicting)
+        slices = self.select_training_slices(
+            fold_i, for_predicting=for_predicting)
         for slice_i in slices:
             selected = selected | self.select_slice(df, slice_i)
 
@@ -580,6 +591,7 @@ class EstimatorNormalizer(Normalizer):
     def offsets(self):
         return -self.center / self. width
 
+
 def normalize_category_weights(df, categories, weight='weight'):
     """
     The categorical weight normalizer acts on the weight variable only. The
@@ -610,8 +622,9 @@ class HepNet:
     variables, the input weights, and the actual Keras model. A HEP net has no
     free parameters.
     """
+
     def __init__(self, keras_model, cross_validator, normalizer, input_list,
-                 output_list):
+                 output_list, wandb_log_func=None):
         """
         Creates a new HEP model. The keras model parameter must be a class that
         returns a new instance of the compiled model (The HEP net needs to
@@ -631,6 +644,7 @@ class HepNet:
         self.norm_cls = normalizer
         self.input_list = input_list
         self.output_list = output_list
+        self.wandb_log_func = wandb_log_func
         self.norms = []
         self.models = []
         self.history = pd.DataFrame()
@@ -680,7 +694,7 @@ class HepNet:
         elif isinstance(event_weight, str):
             event_weight = Variable(event_weight, event_weight)
 
-        ### Loop over folds:
+        # Loop over folds:
         self.norms = []
         self.models = []
         self.history = pd.DataFrame()
@@ -708,31 +722,49 @@ class HepNet:
             lazily_initialized_callbacks = []
             lazily_initialized_callbacks_names = []
             for cc in all_callbacks:
+                if isinstance(cc, dict):
+                    if "Z0Callback" in cc.keys():
+                        c_tmp = cc["Z0Callback"](validation_df[self.input_list], validation_df[self.output_list], np.array(
+                            event_weight(validation_df)), "val")
+                        lazily_initialized_callbacks.append(c_tmp)
+
+                        c_tmp2 = cc["Z0Callback"](training_df[self.input_list], training_df[self.output_list], np.array(event_weight(training_df)), "train")
+                        lazily_initialized_callbacks.append(c_tmp2)
+
                 if cc == "Z0Callback":  # callback that retrieves significance
                     lazily_initialized_callbacks.append(Z0Callback(validation_df[self.input_list],
-                                                       validation_df[self.output_list],
-                                                       # only use event weights, no sample weights
-                                                       np.array(event_weight(validation_df)) ))
+                                                                   validation_df[self.output_list],
+                                                                   # only use event weights, no sample weights
+                                                                   np.array(event_weight(validation_df)), self.wandb_log_func))
                     lazily_initialized_callbacks_names.append(cc)
-            callbacks = [c for c in all_callbacks if not c in lazily_initialized_callbacks_names] + lazily_initialized_callbacks
-            
+                if cc == "MultiClassZ0Callback":  # callback that retrieves significance
+                    lazily_initialized_callbacks.append(MultiClassZ0Callback(validation_df[self.input_list],
+                                                                             validation_df[self.output_list],
+                                                                             # only use event weights, no sample weights
+                                                                             np.array(event_weight(validation_df)), self.wandb_log_func))
+                    lazily_initialized_callbacks_names.append(cc)
+            callbacks = [
+                c for c in all_callbacks if not c in lazily_initialized_callbacks_names and not isinstance(c, dict)] + lazily_initialized_callbacks
+
             history = model.fit(training_df[self.input_list],
                                 training_df[self.output_list],
                                 validation_data=(
                                     validation_df[self.input_list],
                                     validation_df[self.output_list],
                                     np.array(train_weight(validation_df)),
-                                ),
-                                sample_weight=np.array(train_weight(training_df)),
-                                callbacks = callbacks, 
-                                **kwds)
+            ),
+                sample_weight=np.array(
+                                    train_weight(training_df)),
+                callbacks=callbacks,
+                **kwds)
 
             history = history.history
-            history['fold'] = np.ones(len(history['loss']), dtype='int') * fold_i
+            history['fold'] = np.ones(
+                len(history['loss']), dtype='int') * fold_i
             history['epoch'] = np.arange(len(history['loss']))
             self.history = pd.concat([self.history, pd.DataFrame(history)])
 
-    def predict(self, df, cv='val', retrieve_fold_info = False, **kwds):
+    def predict(self, df, cv='val', retrieve_fold_info=False, **kwds):
         """
         Calls predict() on the Keras model. The argument cv specifies the
         cross validation set to select: 'train', 'val', 'test'.
@@ -752,7 +784,8 @@ class HepNet:
             norm = self.norms[fold_i]
 
             # identify fold
-            selected = self.cv.select_cv_set(df, cv, fold_i, for_predicting = True)
+            selected = self.cv.select_cv_set(
+                df, cv, fold_i, for_predicting=True)
 
             test_set |= selected
             out[selected] = model.predict(norm(df[selected][self.input_list]),
@@ -764,7 +797,7 @@ class HepNet:
         test_df = test_df.assign(**out)
 
         if retrieve_fold_info:
-            fold = {cv + "_fold" :  self.cv.retrieve_fold_info(df, cv)}
+            fold = {cv + "_fold":  self.cv.retrieve_fold_info(df, cv)}
             test_df = test_df.assign(**fold)
 
         return test_df
@@ -793,8 +826,8 @@ class HepNet:
             # the following error is thrown:
             # NotImplementedError: numpy() is only available when eager execution is enabled.
             group = output_file.create_group("models/default")
-            group.attrs["model_cls"] = np.string_(python_to_str(self.model_cls))
-
+            group.attrs["model_cls"] = np.string_(
+                python_to_str(self.model_cls))
 
             # save class name of default normalizer as string
             group = output_file.create_group("normalizers/default")
@@ -806,7 +839,8 @@ class HepNet:
         # save normalizer (only if already trained)
         if len(self.norms) == self.cv.k:
             for fold_i in range(self.cv.k):
-                self.norms[fold_i].save_to_h5(path, "normalizers/fold_{}".format(fold_i))
+                self.norms[fold_i].save_to_h5(
+                    path, "normalizers/fold_{}".format(fold_i))
 
         # save input/output lists
         pd.DataFrame(self.input_list).to_hdf(path, "input_list")
@@ -822,7 +856,8 @@ class HepNet:
         """
         # load default model and normalizer
         with h5py.File(path, "r") as input_file:
-            model = str_to_python(input_file["models/default"].attrs["model_cls"].decode())
+            model = str_to_python(
+                input_file["models/default"].attrs["model_cls"].decode())
             normalizer_class_name = input_file["normalizers/default"].attrs["norm_cls"].decode()
             normalizer = getattr(sys.modules[__name__], normalizer_class_name)
 
@@ -849,12 +884,14 @@ class HepNet:
                 else:
                     path_token.insert(-1, f"fold_{fold_i}")
 
-                model = tensorflow.keras.models.load_model(".".join(path_token))
+                model = tensorflow.keras.models.load_model(
+                    ".".join(path_token))
                 instance.models.append(model)
 
         # load normalizer
         for fold_i in range(cv.k):
-            norm = Normalizer.load_from_h5(path, "normalizers/fold_{}".format(fold_i))
+            norm = Normalizer.load_from_h5(
+                path, "normalizers/fold_{}".format(fold_i))
             if norm is not None:
                 instance.norms.append(norm)
 
@@ -873,7 +910,7 @@ class HepNet:
         The path_base argument should be a path or a name of the network. The
         names of the generated files are created by appending to path_base.
 
-		The optional expression can be used to inject the CAF expression when
+                The optional expression can be used to inject the CAF expression when
         the NN is used. The final json file will contain an entry KEY=VALUE if
         a variable matches the dict key.
         """
@@ -885,7 +922,8 @@ class HepNet:
                 arch_file.write(arch)
 
             # now save the weights as an HDF5 file
-            self.models[fold_i].save_weights('%s_wght_%d.h5' % (path_base, fold_i))
+            self.models[fold_i].save_weights(
+                '%s_wght_%d.h5' % (path_base, fold_i))
 
             with open("%s_vars_%d.json" % (path_base, fold_i), "w") \
                     as variable_file:
@@ -894,7 +932,7 @@ class HepNet:
                 offsets = [o / s for o, s in zip(offsets, scales)]
 
                 variables = [("%s=%s" % (v, expression[v]))
-                                if v in expression else v
+                             if v in expression else v
                              for v in self.input_list]
 
                 inputs = [dict(name=v, offset=o, scale=s)
@@ -910,36 +948,210 @@ class HepNet:
                       f"{path_base}_wght_{fold_i}.h5 "
                       f"> {path_base}_{fold_i}.json", file=script_file)
 
+
 class Z0Callback(tensorflow.keras.callbacks.Callback):
 
-    def __init__(self, X_valid=0, Y_valid=0, W_valid = 0):
+    def __init__(self, X_valid=0, Y_valid=0, W_valid=0, wandb_log=None):
         self.X_valid = np.array(X_valid)
         self.Y_valid = np.array(Y_valid)
         self.W_valid = np.array(W_valid)
         self.W_valid = self.W_valid.reshape((self.W_valid.shape[0], 1))
+        self.wandb_log = wandb_log
 
     def add_to_history(self, Z0):
         if "Z0" in self.model.history.history.keys():
             self.model.history.history["Z0"].append(Z0)
-        else: # first epoch
+        else:  # first epoch
             self.model.history.history["Z0"] = [Z0]
+        if not self.wandb_log is None:
+            self.wandb_log({"Z0": Z0})
 
     def on_epoch_end(self, epoch, logs=None):
-        
+
         y_pred = np.array(self.model.predict(self.X_valid, batch_size=4096))
         w_bkg = self.W_valid[self.Y_valid == 0]
         w_sig = self.W_valid[self.Y_valid == 1]
         y_bkg = y_pred[self.Y_valid == 0]
         y_sig = y_pred[self.Y_valid == 1]
 
-        c_sig , edges = np.histogram(y_sig, 20, weights=w_sig, range = (0, 1))
-        c_bkg , edges = np.histogram(y_bkg, 20, weights=w_bkg, range = (0, 1))
+        c_sig, edges = np.histogram(y_sig, 20, weights=w_sig, range=(0, 1))
+        c_bkg, edges = np.histogram(y_bkg, 20, weights=w_bkg, range=(0, 1))
 
-        Z0_func = lambda s, b: np.sqrt( 2*((s+b)*  np.log1p (s/b) - s))
-        z_list = [Z0_func(si, bi) for si, bi in zip(c_sig, c_bkg) if bi > 0 and si > 0]
-        Z0 = np.sqrt( np.sum( np.square(z_list) )  )
+        def Z0_func(s, b): return np.sqrt(2*((s+b) * np.log1p(s/b) - s))
+        z_list = [Z0_func(si, bi)
+                  for si, bi in zip(c_sig, c_bkg) if bi > 0 and si > 0]
+        Z0 = np.sqrt(np.sum(np.square(z_list)))
 
         self.add_to_history(Z0)
 
         print("\nINFO: Significance in epoch {} is Z0 = {}".format(epoch, Z0))
-        
+
+
+class MultiClassZ0Callback(tensorflow.keras.callbacks.Callback):
+
+    def __init__(self, X=0, Y=0, W=0, targets="", wandb_log=None, plot_hists=True):
+        self.X = np.array(X)
+        self.VBF_target = np.array(Y["VBF_target"])
+        self.ggF_target = np.array(Y["ggF_target"])
+        self.bkg_target = np.array(Y["bkg_target"])
+        self.W_valid = np.array(W)
+        self.W_valid = self.W_valid.reshape((self.W_valid.shape[0], 1))
+
+        self.wandb_log = wandb_log
+        self.plot_hists = plot_hists
+
+    def add_to_history(self, key, val):
+        if key in self.model.history.history.keys():
+            self.model.history.history[key].append(val)
+        else:  # first epoch
+            self.model.history.history[key] = [val]
+
+    def on_epoch_end(self, epoch, logs=None):
+        # predict. Will output a nx3 array
+        y_pred = np.array(self.model.predict(self.X, batch_size=4096))
+
+        # have shape (n,1)
+        w_VBF = self.W_valid[self.VBF_target == 1]
+        w_VBF_bkg = self.W_valid[self.VBF_target == 0]
+
+        w_ggF = self.W_valid[self.ggF_target == 1]
+        w_ggF_bkg = self.W_valid[self.ggF_target == 0]
+
+        # we want to normalize the weights to remove the dependence
+        # of this metric on the size of the val set that is used
+
+        # 260 is ~expected number of VBF events in common VBF/ggF SR
+        w_VBF = w_VBF / sum(w_VBF) * 250
+        # 2300 is ~expected number of total bkg events in common VBF/ggF SR
+        w_VBF_bkg = w_VBF_bkg / sum(w_VBF_bkg) * 60000
+
+        # 684 is ~expected number of ggF events in common VBF/ggF SR
+        w_ggF = w_ggF / sum(w_ggF) * 500
+        # 2300 is ~expected number of total bkg events in common VBF/ggF SR
+        w_ggF_bkg = w_ggF_bkg / sum(w_ggF_bkg) * 60000
+
+        # get predictions for individual process in arrays
+        # The order is as provided in the config files
+        # which for now is VBF, ggF, bkg
+        # shape (n,)
+        # VBF predictions
+        y_VBF = y_pred[self.VBF_target == 1, 0]
+        y_VBF_bkg = y_pred[self.VBF_target == 0, 0]
+        # ggF predictions
+        y_ggF = y_pred[self.ggF_target == 1, 1]
+        y_ggF_bkg = y_pred[self.ggF_target == 0, 1]
+
+        # reshape to (n, 1)
+        y_VBF = y_VBF.reshape((y_VBF.shape[0], 1))
+        y_VBF_bkg = y_VBF_bkg.reshape((y_VBF_bkg.shape[0], 1))
+        y_ggF = y_ggF.reshape((y_ggF.shape[0], 1))
+        y_ggF_bkg = y_ggF_bkg.reshape((y_ggF_bkg.shape[0], 1))
+
+        # make histograms contents
+        bins = 20
+        c_VBF, edges = np.histogram(y_VBF, bins, weights=w_VBF, range=(0, 1))
+        c_VBF_bkg, _ = np.histogram(
+            y_VBF_bkg, bins, weights=w_VBF_bkg, range=(0, 1))
+        c_ggF, _ = np.histogram(y_ggF, bins, weights=w_ggF, range=(0, 1))
+        c_ggF_bkg, _ = np.histogram(
+            y_ggF_bkg, bins, weights=w_ggF_bkg, range=(0, 1))
+
+        # get significance from histograms
+        Z0_VBF = self.get_Z0(c_VBF, c_VBF_bkg)
+        Z0_ggF = self.get_Z0(c_ggF, c_ggF_bkg)
+
+        print("\nINFO: Significance in epoch {} is Z0_VBF = {}, Z0_ggF = {}".format(
+            epoch, Z0_VBF, Z0_ggF))
+
+        # Add to hep net history
+        self.add_to_history(key="Z0_VBF", val=Z0_VBF)
+        self.add_to_history(key="Z0_ggF", val=Z0_ggF)
+
+        if not self.wandb_log is None:
+            self.wandb_log({"Z0_VBF": Z0_VBF, "Z0_ggF": Z0_ggF})
+
+            import plotly.express as px
+            import plotly.graph_objects as go
+            # plotly histogram does not support weights
+            histbins = 0.5 * (edges[:-1] + edges[1:])
+
+            fig = go.Figure(layout=go.Layout(bargap=0.0, barmode="overlay", barnorm="fraction", yaxis=go.layout.YAxis(
+                type="log", title="Events"), xaxis=go.layout.XAxis(title="VBF DNN output")))
+            fig.add_bar(x=histbins, y=c_VBF_bkg, opacity=0.6, name="Bkg")
+            fig.add_bar(x=histbins, y=c_VBF, opacity=0.6, name="Sig")
+            fig.add_annotation(text="Z0 = {:.2f}".format(Z0_VBF),
+                               showarrow=False, x=0.2, y=0.1)
+
+            fig2 = go.Figure(layout=go.Layout(bargap=0.0, barmode="overlay", barnorm="fraction", yaxis=go.layout.YAxis(
+                type="log", title="Events"), xaxis=go.layout.XAxis(title="ggF DNN output")))
+            fig2.add_bar(x=histbins, y=c_ggF_bkg, opacity=0.6, name="Bkg")
+            fig2.add_bar(x=histbins, y=c_ggF, opacity=0.6, name="Sig")
+            fig2.add_annotation(text="Z0 = {:.2f}".format(
+                Z0_ggF), showarrow=False, x=0.2, y=0.1)
+
+            # also make normed plots
+            c_VBF, _ = np.histogram(
+                y_VBF, bins, weights=w_VBF, range=(0, 1), density=1)
+            c_VBF_bkg, _ = np.histogram(
+                y_VBF_bkg, bins, weights=w_VBF_bkg, range=(0, 1), density=1)
+            c_ggF, _ = np.histogram(
+                y_ggF, bins, weights=w_ggF, range=(0, 1), density=1)
+            c_ggF_bkg, _ = np.histogram(
+                y_ggF_bkg, bins, weights=w_ggF_bkg, range=(0, 1), density=1)
+
+            fig3 = go.Figure(layout=go.Layout(bargap=0.0, barmode="overlay", barnorm="fraction", yaxis=go.layout.YAxis(
+                type="log", title="Normalized Events"), xaxis=go.layout.XAxis(title="VBF DNN output")))
+            fig3.add_bar(x=histbins, y=c_VBF_bkg, opacity=0.6, name="Bkg")
+            fig3.add_bar(x=histbins, y=c_VBF, opacity=0.6, name="Sig")
+
+            fig4 = go.Figure(layout=go.Layout(bargap=0.0, barmode="overlay", barnorm="fraction", yaxis=go.layout.YAxis(
+                type="log", title="Normalized Events"), xaxis=go.layout.XAxis(title="ggF DNN output")))
+            fig4.add_bar(x=histbins, y=c_ggF_bkg, opacity=0.6, name="Bkg")
+            fig4.add_bar(x=histbins, y=c_ggF, opacity=0.6, name="Sig")
+
+            self.wandb_log({"VBF DNN output": fig, "ggF DNN output": fig2,
+                            "VBF DNN output (norm)": fig3, "ggF DNN output (norm)": fig4})
+
+        if self.plot_hists:
+            self.plot(y_VBF, w_VBF, y_VBF_bkg, w_VBF_bkg,
+                      "test_plots/vbf_dnn_epoch{}.png".format(epoch), "VBF DNN output")
+            self.plot(y_ggF, w_ggF, y_ggF_bkg, w_ggF_bkg,
+                      "test_plots/ggF_dnn_epoch{}.png".format(epoch), "ggF DNN output")
+
+    def get_Z0(self, h1, h2):
+        z_list = [self.Z0_poisson(si, bi)
+                  for si, bi in zip(h1, h2) if bi > 0 and si > 0]
+        Z0 = np.sqrt(np.sum(np.square(z_list)))
+        return Z0
+
+    def Z0_poisson(self, s, b):
+        return np.sqrt(2*((s+b) * np.log1p(s/b) - s))
+
+    def plot(self, hs, ws, hb, wb, fname, xlabel, nbins=20):
+        import matplotlib.pyplot as plt
+        plt.rc('axes', labelsize=12)
+        # puts the legend in the best possible spot in the upper right corner (0.5,0.5,0.5,0.5)
+        plt.yscale("log")
+
+        plt.hist(hs, nbins,
+                 facecolor='red', alpha=1,
+                 color='red',
+                 range=(0, 1), density=1,
+                 weights=ws,
+                 histtype='step',  # bar or step
+                 )
+        plt.hist(hb, nbins,
+                 facecolor='blue', alpha=1,
+                 color='blue',
+                 range=(0, 1), density=1,
+                 weights=wb,
+                 histtype='step',  # bar or step
+                 )
+
+        ax = plt.gca()
+        ax.set_xlabel(xlabel, loc='right')
+        # saves the figure at the outfilepath with the outfileName. dpi means dots per inch, essentially the resolution of the image
+
+        plt.savefig(fname, dpi=360)
+        # plt.savefig(fname, dpi=360)
+        plt.clf()