Skip to content

Upgrade to Pytorch 2.0

Xiangyang Ju requested to merge xju-upgrade-pytorch into dev

The MR is to allow the code to run in Pytorch >= 2.0, PyG >= 2.4, and lighting >= 2.1.2. The code is tested with the Example_1 in the following environment.

torch:  2.1.0
pytorch_lightning:  2.1.2
pyg:  2.4.0
cuda: 12.2
cugraph:  23.10.00
torch_scatter:  2.1.2

It turns out not many changes were required. The keys in torch_geometric.data.Data is not a property but a method in PyG 2.4.0. Add a utility function to adapt that. In addition, I made some fixes.

To use the torch.compile in PyTorch 2.0, we need to factor the nn.Module out from the LightningModule because simply adding torch.compile(LightingModule) will break self.log functions. The change will help solve issue #59.

In the training configuration, the Lightning trainer does not like cpu accelerators with zero devices. So I changed configuration files from:

gpus: 0

to

accelerator: cpu
devices: 2

If people use GPUs, use the following:

accelerator: gpu
devices: 4
Edited by Xiangyang Ju

Merge request reports

Loading