Resuming training with full training state restored
Currently in the train_stage.py
, given checkpoint path, we load the module with the checkpoint path. However, this will only load the model parameters but not the full training state. We need to give the checkpoint path to the trainer.fit in order to resume from the last training status.
For example:
trainer.fit(stage_module, ckpt_path=checkpoint)