Skip to content

Save InteractionGNN2 to Torch Script

Xiangyang Ju requested to merge mr-torchscript into dev

This MR is to update those Neural Networks used in Example_1 and Example_2 so that they can be saved to TorchScript format through the script scripts/save_full_model.py. One can run the following to perform the conversion for various models.

python scripts/save_full_model.py gnn_train.yaml

where the gnn_train.yaml is the training configuration.

Summary of the changes:

  • Add two notebooks for
    1. testing the feasibility of converting torch_geometric aggregation functions to TorchScript or ONNX,
    2. testing the feasibility of converting nn.Module to TorchScript.
  • Add the black version to the requirements.txt, because CI uses a specific version of black==22.8.0 and the results of the black --check . changes when the latest black version (24.3.0) is used.
  • Move the "input preparations" part (like stacking node features) from the training_step and other places in LightningModule to def get(self, idx) in the Dataset.

Follow-up developments/issues: #83 , #84.

Edited by Xiangyang Ju

Merge request reports