Refactor `ModelBase`
Refactor ModelBase
so that EmbeddingBase
and TripletGNNBase
share common methods.
In order to define a model inheriting from ModelBase
, a few methods must be overridden.
# Inference method
def inference(self, batch: Data, **options) -> typing.Dict[str, torch.Tensor]
# Loss computation
def compute_loss(self, batch: Data, outputs: typing.Any) -> torch.Tensor
# Add log entry after a training step
def log_training_step(self, outputs: typing.Dict[str, typing.Any]) -> None
# Add log entry after a validation step
def log_validation_step(self, outputs: typing.Dict[str, typing.Any]) -> None
# OPTIONAL: Compute performance metrics (e.g., efficiency, purity, auc)
def evaluate_metrics(self, outputs: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]
Then, the following methods are defined for you
# By default, training / validation inference step is `inference`
# For the embedding, this is not the case so this method is overridden
def training_validation_inference(
self, batch: Data, validation: bool = False
) -> typing.Dict[str, torch.Tensor]:
outputs = self.inference(batch)
outputs["loss"] = self.compute_loss(batch=batch, outputs=outputs)
return outputs
def training_step(self, batch: Data, batch_idx: int) -> torch.Tensor | None:
outputs = self.training_validation_inference(batch, validation=False)
self.log_training_step(outputs=outputs)
return outputs["loss"]
def validation_step(self, batch: Data, batch_idx: int) -> torch.Tensor:
outputs = self.training_validation_inference(batch, validation=True)
self.log_validation_step(outputs=outputs)
return outputs["loss"]
Edited by Anthony Correia