update `forward` function
The current Neural Network forward
function takes batch
as the only input. batch
is a dictionary of many things including input data and model configurations. Despite its convenience, it creates hurdles in understanding what the inputs to the model and what are the configurations, and it prevents converting the model to TorchScript or ONNX.
The proposal is to define a forward
function with a list of clearly annotated variables, such as
def forward(
self,
node_features: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None,
):
pass