diff --git a/etx4velo/pipeline/GNN/models/triplet_interaction_gnn.py b/etx4velo/pipeline/GNN/models/triplet_interaction_gnn.py
index fbd5980a9b0c4fbb5781bdedb95c4a5cdb795e5f..43c3b305b5bc6d635d1779dab72ee92a2ee05261 100644
--- a/etx4velo/pipeline/GNN/models/triplet_interaction_gnn.py
+++ b/etx4velo/pipeline/GNN/models/triplet_interaction_gnn.py
@@ -123,36 +123,29 @@ class TripletInteractionGNN(TripletGNNBase):
             )
 
         # Final edge output classification network
-        if self.only_e:
-            self.output_edge_classifier = make_mlp(
-                n_edge_hiddens,
-                [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1],
-                layer_norm=hparams["layernorm"],
-                output_activation=None,
-                hidden_activation=hparams["hidden_activation"],
-            )
-            self.output_triplet_classifier = make_mlp(
-                2 * n_edge_hiddens,
-                [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1],
-                layer_norm=hparams["layernorm"],
-                output_activation=None,
-                hidden_activation=hparams["hidden_activation"],
-            )
-        else:
-            self.output_edge_classifier = make_mlp(
-                2 * n_node_hiddens + n_edge_hiddens,
-                [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1],
-                layer_norm=hparams["layernorm"],
-                output_activation=None,
-                hidden_activation=hparams["hidden_activation"],
-            )
-            self.output_triplet_classifier = make_mlp(
-                3 * n_node_hiddens + 2 * n_edge_hiddens,
-                [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1],
-                layer_norm=hparams["layernorm"],
-                output_activation=None,
-                hidden_activation=hparams["hidden_activation"],
-            )
+        edge_classifier_input_size = (
+            n_edge_hiddens if self.only_e else 2 * n_node_hiddens + n_edge_hiddens
+        )
+        triplet_classifier_input_size = (
+            2 * n_edge_hiddens
+            if self.only_e
+            else 3 * n_node_hiddens + 2 * n_edge_hiddens
+        )
+
+        self.output_edge_classifier = make_mlp(
+            edge_classifier_input_size,
+            [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1],
+            layer_norm=hparams["layernorm"],
+            output_activation=None,
+            hidden_activation=hparams["hidden_activation"],
+        )
+        self.output_triplet_classifier = make_mlp(
+            triplet_classifier_input_size,
+            [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1],
+            layer_norm=hparams["layernorm"],
+            output_activation=None,
+            hidden_activation=hparams["hidden_activation"],
+        )
 
     def scatter_add(
         self, source: torch.Tensor, index: torch.Tensor, h: torch.Tensor