diff --git a/config/CI_test_BDT_binary_classification.cfg b/config/CI_test_BDT_binary_classification.cfg index d42c81a479dd6e3301e01903c22c3909cda2d317..2252156d01ae9a9f2b6dc9dfad36d9509ebf4df5 100644 --- a/config/CI_test_BDT_binary_classification.cfg +++ b/config/CI_test_BDT_binary_classification.cfg @@ -16,7 +16,6 @@ GENERAL DoYields = True ConfusionPlotScale = Linear StackPlotScale = Linear - SeparationPlotScale = Linear PermImportancePlotScale = Log ROCSampling = 25 Blinding = 9999.0 diff --git a/config/CI_test_BDT_multiclass_classification.cfg b/config/CI_test_BDT_multiclass_classification.cfg index 55e45a64e4d545dd119e406d51c0fd4cfc3556d8..e41ccbbd213e29ef9a12953107f18dff850576a8 100644 --- a/config/CI_test_BDT_multiclass_classification.cfg +++ b/config/CI_test_BDT_multiclass_classification.cfg @@ -18,7 +18,6 @@ GENERAL ConfusionPlotScale = Log StackPlotScale = Log StackPlot2DScale = Log - SeparationPlotScale = Log ROCSampling = 25 Blinding = 9999.0 SBBlinding = 0.5 diff --git a/config/CI_test_Regression.cfg b/config/CI_test_Regression.cfg index 50e47183194990d44277b2d843ec202e8afd8a02..adea56724cf2ae5767d37c551601fb2a5704509d 100644 --- a/config/CI_test_Regression.cfg +++ b/config/CI_test_Regression.cfg @@ -16,7 +16,6 @@ GENERAL DoYields = True ConfusionPlotScale = Log StackPlotScale = Log - SeparationPlotScale = Log Blinding = 9999.0 ###Variable Section### diff --git a/config/CI_test_binary_GNN_Hetero_classification.cfg b/config/CI_test_binary_GNN_Hetero_classification.cfg index b9dd074b7c5efb57e14e6e72e876bc44df94d36a..b58387b92a76794ebc85b12887f7b936a809c230 100644 --- a/config/CI_test_binary_GNN_Hetero_classification.cfg +++ b/config/CI_test_binary_GNN_Hetero_classification.cfg @@ -16,7 +16,6 @@ GENERAL DoYields = True ConfusionPlotScale = Log StackPlotScale = Linear - SeparationPlotScale = Log Blinding = 9999.0 ###Sample Section### diff --git a/config/CI_test_binary_GNN_classification.cfg b/config/CI_test_binary_GNN_classification.cfg index e3fee2df05277bbc32e42cee0f0fe6b2c8955f06..0fe6fc5fa73f3729265a9b12188d22ab957356d3 100644 --- a/config/CI_test_binary_GNN_classification.cfg +++ b/config/CI_test_binary_GNN_classification.cfg @@ -16,7 +16,6 @@ GENERAL DoYields = True ConfusionPlotScale = Log StackPlotScale = Linear - SeparationPlotScale = Log Blinding = 9999.0 ###Sample Section### diff --git a/config/CI_test_binary_classification.cfg b/config/CI_test_binary_classification.cfg index 87301c97c3c48564a2633381d7439974f07b495b..860d7f7aaaa6922ad1122ce405cd6a90554ab3dc 100644 --- a/config/CI_test_binary_classification.cfg +++ b/config/CI_test_binary_classification.cfg @@ -18,7 +18,6 @@ GENERAL ConfusionPlotScale = Log StackPlotScale = Log StackPlot2DScale = Log - SeparationPlotScale = Log ROCSampling = 25 Blinding = 9999.0 SBBlinding = 0.5 diff --git a/config/CI_test_multiclass_classification.cfg b/config/CI_test_multiclass_classification.cfg index 178a4959ace9bb53abce8d6a1372340ce7b77b51..d79ef149c0f430b06de95715adda53c2db8994b8 100644 --- a/config/CI_test_multiclass_classification.cfg +++ b/config/CI_test_multiclass_classification.cfg @@ -18,7 +18,6 @@ GENERAL ConfusionPlotScale = Log StackPlotScale = Log StackPlot2DScale = Log - SeparationPlotScale = Log ROCSampling = 25 Blinding = 9999.0 SBBlinding = 0.5 diff --git a/docs/Settings/index.md b/docs/Settings/index.md index 0124b7277410179a5bd01c24c0308e9327745fa7..a5508f3a5676a5c643f0ec7e10d4a30a1ab9eb83 100644 --- a/docs/Settings/index.md +++ b/docs/Settings/index.md @@ -27,7 +27,6 @@ The following tables contain config file options and the effect these options ha | CustomLabel | Custom text that is printed below the ATLAS label.| | PlotFormat | List of file endings to specify the file format of the produced plots. Allowed options are `pdf`, `png`, `eps`, `jpeg`, `gif` and `svg`.| | StackPlotScale | Can be `Linear` or `Log`. Determines the Y-axis scale for stack plots. | -| SeparationPlotScale | Can be `Linear` or `Log`. Determines the Y-axis scale for (1D) separation plots. | | ConfusionPlotScale | Can be `Linear` or `Log`. Determines the Y-axis scale for binary confusion plots. | | ConfusionMatrixNorm | Can be `row`, `column` or `None`. When `row` (`column`) is chosen, the rows (columns) of the confusion matrix are normalised to unity. If `None` is chosen, no normalisation is applied. | | PermImportancePlotScale | Can be `Linear` or `Log`. Determines the X-axis scale for permutation importance plots. | diff --git a/python/HelperModules/configparser.py b/python/HelperModules/configparser.py index 6a5c5b041656065527f8ec04833d74eaa0128856..196a89d97227c61a496965eb80aa45dc2a739ceb 100644 --- a/python/HelperModules/configparser.py +++ b/python/HelperModules/configparser.py @@ -80,7 +80,6 @@ class configparser(): opthandler.register("LossPlotYScale", "Linear") opthandler.register("StackPlotScale", "Linear") opthandler.register("StackPlot2DScale", "Linear") - opthandler.register("SeparationPlotScale", "Linear") opthandler.register("ConfusionPlotScale", "Linear") opthandler.register("ConfusionMatrixNorm", "row") opthandler.register("PermImportancePlotScale", "Linear") @@ -109,7 +108,6 @@ class configparser(): opthandler.register_check("LossPlotYScale", check="is_inlist", lst=["Log", "Linear"]) opthandler.register_check("StackPlotScale", check="is_inlist", lst=["Log", "Linear"]) opthandler.register_check("StackPlot2DScale", check="is_inlist", lst=["Log", "Linear"]) - opthandler.register_check("SeparationPlotScale", check="is_inlist", lst=["Log", "Linear"]) opthandler.register_check("ConfusionPlotScale", check="is_inlist", lst=["Log", "Linear"]) opthandler.register_check("ConfusionMatrixNorm", check="is_inlist", lst=["row", "column", "None"]) opthandler.register_check("PermImportancePlotScale", check="is_inlist", lst=["Log", "Linear"]) diff --git a/python/HelperModules/graphbuilder.py b/python/HelperModules/graphbuilder.py index 10d541cf11f1bfa0a2e2a269f666e01d37957783..a495688629b9c1590df119389670211f0e8d0724 100644 --- a/python/HelperModules/graphbuilder.py +++ b/python/HelperModules/graphbuilder.py @@ -7,7 +7,6 @@ import torch from HelperModules.messagehandler import ErrorMessage, InfoMessage from torch_geometric.data import Data, HeteroData - class graphbuilder(): """ Class to build graphs. @@ -54,7 +53,7 @@ class graphbuilder(): for node in self.m_gnodes: if node.get("PruneIfValue") is not None: if any(dfrow[f] == v for f, v in zip(node.get("Features"), node.get("PruneIfValue"))): - node.deactivate() + self.m_inactivelist.append(node.get("Name")) def _build_global(self, dfrow, data): """Build global features. @@ -110,12 +109,16 @@ class graphbuilder(): if not self._nodeisglobal(node.get("Type")) and self._nodeisactive(node.get("Name")): # source node not global and active source = node.get("Name") for target_ID, target in enumerate(node.get("Targets")): - if not self._nodeisglobal(self.m_map[target]["Node"].get("Type")) and self._nodeisactive(self.m_map[target]["Node"].get("Type")): # same for target + if self._nodeisactive(target) and not self._nodeisglobal(self.m_map[target]["Node"].get("Type")): # same for target edge_index[0] = np.append(edge_index[0], self.m_map[source]["Idx"]) edge_index[1] = np.append(edge_index[1], self.m_map[target]["Idx"]) - edge_attr.append(np.array([dfrow[feature] for feature in node.get("EdgeFeatures")[target_ID]])) + if node.get("EdgeFeatures"): + edge_attr.append(np.array([dfrow[feature] for feature in node.get("EdgeFeatures")[target_ID]])) try: - return torch.from_numpy(np.array(edge_index)).contiguous().long(), torch.from_numpy(np.array(edge_attr)) + if edge_attr: + return torch.from_numpy(np.array(edge_index)).contiguous().long(), torch.from_numpy(np.array(edge_attr)) + else: + return torch.from_numpy(np.array(edge_index)).contiguous().long(), None except TypeError as e: ErrorMessage(f"Conversion to torch tensor was not successful! {e}") return None @@ -184,14 +187,19 @@ class graphbuilder(): self.m_edgetypes.append(key) if key not in data.metadata()[1]: data[key].edge_index = [[self.m_map[source]["Idx"]], [self.m_map[target]["Idx"]]] - data[key].edge_attr = [[dfrow[feature] for feature in node.get("EdgeFeatures")[target_ID]]] + if node.get("EdgeFeatures"): + data[key].edge_attr = [[dfrow[feature] for feature in node.get("EdgeFeatures")[target_ID]]] else: data[key].edge_index[0].append(self.m_map[source]["Idx"]) data[key].edge_index[1].append(self.m_map[target]["Idx"]) - data[key].edge_attr.append([dfrow[feature] for feature in node.get("EdgeFeatures")[target_ID]]) + if node.get("EdgeFeatures"): + data[key].edge_attr.append([dfrow[feature] for feature in node.get("EdgeFeatures")[target_ID]]) for edgeType in data.metadata()[1]: data[edgeType].edge_index = torch.tensor(data[edgeType].edge_index, dtype=torch.long) - data[edgeType].edge_attr = torch.tensor(data[edgeType].edge_attr, dtype=torch.float) + if data[edgeType].edge_attr: + data[edgeType].edge_attr = torch.tensor(data[edgeType].edge_attr, dtype=torch.float) + else: + data[edgeType].edge_attr = None def _build_Homglobal(self, dfrow): """Build global features for homogeneous GNN diff --git a/python/PlotterModules/plothandler.py b/python/PlotterModules/plothandler.py index 7ed45489efbfaf44006d61f276122f961a1fcceb..afa713e3c090f12be6dc2b995395a52314b87415 100644 --- a/python/PlotterModules/plothandler.py +++ b/python/PlotterModules/plothandler.py @@ -475,6 +475,7 @@ class plothandler(): sep.draw() sep.setxlabel("Sig", output.get("Label")) sep.setylabel("Sig", "Fraction of Events") + sep.setscale("y", output.get("Scale")) sep.setLabels(self.m_cfgset.get("GENERAL").get("ATLASLabel"), self.m_cfgset.get("GENERAL").get("CMLabel"), self.m_cfgset.get("GENERAL").get("CustomLabel")) @@ -512,6 +513,7 @@ class plothandler(): sep.draw() sep.setxlabel("h1", variable.get("Label")) sep.setylabel("h1", "Sample") + sep.setscale("z", variable.get("Scale")) sep.setLabels(self.m_cfgset.get("GENERAL").get("ATLASLabel"), self.m_cfgset.get("GENERAL").get("CMLabel"), self.m_cfgset.get("GENERAL").get("CustomLabel")) diff --git a/python/PlotterModules/plots.py b/python/PlotterModules/plots.py index 3bc080186b0b9e4d67bebb974398c5dfe0fcb73c..c3e0304c9819d1351726cc32dddda23325e6e007 100644 --- a/python/PlotterModules/plots.py +++ b/python/PlotterModules/plots.py @@ -1117,6 +1117,11 @@ class confusionmatrix(basicplot): self.addhist("confmatrix", confmatrix_hist) def draw(self): + """TODO describe function + + :returns: + + """ self.m_hists["confmatrix"].GetXaxis().SetTitle("Prediction") self.m_hists["confmatrix"].GetYaxis().SetTitle("True Label") self.m_hists["confmatrix"].GetYaxis().SetTitleOffset(2.8) diff --git a/setup.sh b/setup.sh index e0b39526fdd1c402555c8f6776ab9902f51e819d..f3612ddf50dcf6d9c5afdda62a730f595be6c95f 100644 --- a/setup.sh +++ b/setup.sh @@ -37,8 +37,6 @@ echo "" alias mvatrainer="python3 $MVA_TRAINER_BASE_DIR/python/mva-trainer.py" -chmod +x $MVA_TRAINER_BASE_DIR/python/mva-trainer.py - # Determine whether Python3 is installed echo -e "SETUP\t Checking Python3 version ..." # Let's make sure we print the version...