diff --git a/etx4velo/pipeline/GNN/models/triplet_interaction_gnn.py b/etx4velo/pipeline/GNN/models/triplet_interaction_gnn.py index fbd5980a9b0c4fbb5781bdedb95c4a5cdb795e5f..43c3b305b5bc6d635d1779dab72ee92a2ee05261 100644 --- a/etx4velo/pipeline/GNN/models/triplet_interaction_gnn.py +++ b/etx4velo/pipeline/GNN/models/triplet_interaction_gnn.py @@ -123,36 +123,29 @@ class TripletInteractionGNN(TripletGNNBase): ) # Final edge output classification network - if self.only_e: - self.output_edge_classifier = make_mlp( - n_edge_hiddens, - [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1], - layer_norm=hparams["layernorm"], - output_activation=None, - hidden_activation=hparams["hidden_activation"], - ) - self.output_triplet_classifier = make_mlp( - 2 * n_edge_hiddens, - [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1], - layer_norm=hparams["layernorm"], - output_activation=None, - hidden_activation=hparams["hidden_activation"], - ) - else: - self.output_edge_classifier = make_mlp( - 2 * n_node_hiddens + n_edge_hiddens, - [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1], - layer_norm=hparams["layernorm"], - output_activation=None, - hidden_activation=hparams["hidden_activation"], - ) - self.output_triplet_classifier = make_mlp( - 3 * n_node_hiddens + 2 * n_edge_hiddens, - [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1], - layer_norm=hparams["layernorm"], - output_activation=None, - hidden_activation=hparams["hidden_activation"], - ) + edge_classifier_input_size = ( + n_edge_hiddens if self.only_e else 2 * n_node_hiddens + n_edge_hiddens + ) + triplet_classifier_input_size = ( + 2 * n_edge_hiddens + if self.only_e + else 3 * n_node_hiddens + 2 * n_edge_hiddens + ) + + self.output_edge_classifier = make_mlp( + edge_classifier_input_size, + [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1], + layer_norm=hparams["layernorm"], + output_activation=None, + hidden_activation=hparams["hidden_activation"], + ) + self.output_triplet_classifier = make_mlp( + triplet_classifier_input_size, + [n_hiddens] * hparams["nb_edge_classifier_layers"] + [1], + layer_norm=hparams["layernorm"], + output_activation=None, + hidden_activation=hparams["hidden_activation"], + ) def scatter_add( self, source: torch.Tensor, index: torch.Tensor, h: torch.Tensor