Skip to content

Refactor `ModelBase`

Anthony Correia requested to merge anthonyc/refactor_model into main

Refactor ModelBase so that EmbeddingBase and TripletGNNBase share common methods.

In order to define a model inheriting from ModelBase, a few methods must be overridden.

# Inference method
def inference(self, batch: Data, **options) -> typing.Dict[str, torch.Tensor]

# Loss computation
def compute_loss(self, batch: Data, outputs: typing.Any) -> torch.Tensor

# Add log entry after a training step
def log_training_step(self, outputs: typing.Dict[str, typing.Any]) -> None

# Add log entry after a validation step
def log_validation_step(self, outputs: typing.Dict[str, typing.Any]) -> None

# OPTIONAL: Compute performance metrics (e.g., efficiency, purity, auc)
def evaluate_metrics(self, outputs: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]

Then, the following methods are defined for you

# By default, training / validation inference step is `inference`
# For the embedding, this is not the case so this method is overridden
def training_validation_inference(
    self, batch: Data, validation: bool = False
) -> typing.Dict[str, torch.Tensor]:
    outputs = self.inference(batch)
    outputs["loss"] = self.compute_loss(batch=batch, outputs=outputs)
    return outputs

def training_step(self, batch: Data, batch_idx: int) -> torch.Tensor | None:
    outputs = self.training_validation_inference(batch, validation=False)
    return outputs["loss"]

def validation_step(self, batch: Data, batch_idx: int) -> torch.Tensor:
    outputs = self.training_validation_inference(batch, validation=True)
    return outputs["loss"]
Edited by Anthony Correia

Merge request reports