DRAFT: Enable reliable auto-resume
This MR aims to do one thing: fix the broken auto-resume behaviour that especially appears when doing distributed training. This occurs because training is resumed mid-way through an epoch. We can avoid this by specifying Trainer.fit(ckpt_path="last")
, which uses the custom saved checkpoint at the end of the last epoch. But, it turns out to achieve this, we need to restructure some of the directory management. It's boring, but it works like this:
ATTENTION: This workflow has been superseded by the following
Any of the 3 workflows must guarantee that upon resumption, pytorch lightning resumes from the latest checkpoint in the default_root_dir
. This is done by assigning checkpoint_resume_dir
to default_root_dir
if it exists, removing all present HPC checkpoints, and assigning checkpoint_path to the latest checkpoint in checkpoint_resume_dir
.
Case 1: If no checkpoint nor checkpoint_resume_dir
is given, i.e. the training starts from scratch:
First run: checkpoint_resume_dir=None
, checkpoint_path=None
, default_root_dir
does not exist
default_root_dir = get_default_root_dir(config)
stage_module = stage_module_class(config)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)
Resumed run: checkpoint_resume_dir=None
, checkpoint_path=None
, default_root_dir
exists
default_root_dir = get_default_root_dir(config)
if os.path.isdir(default_root_dir): checkpoint_resume_dir = default_root_dir
if checkpoint_resume_dir:
- remove all hpc checkpoint
- default_root_dir = checkpoint_resume_dir
- checkpoint_path = find_latest_checkpoint(default_root_dir)
if checkpoint_path:
- stage_module, config = load_module(checkpoint_path, stage_module_class)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)
Case 2: If checkpoint is given, i.e. the training starts from a checkpoint at the end of some epoch:
First run: checkpoint_resume_dir=None
, checkpoint_path=/path/to/checkpoint
, default_root_dir
does not exist
default_root_dir = get_default_root_dir(config)
if checkpoint_path and (not os.path.isdir(default_root_dir)):
- os.mkdir(default_root_dir)
- shutil.copyfile(checkpoint_path, os.path.join(default_root_dir, “resumed_checkpoint.ckpt”))
if os.path.isdir(default_root_dir): checkpoint_resume_dir = default_root_dir
if checkpoint_resume_dir:
- remove all hpc checkpoint
- default_root_dir = checkpoint_resume_dir
- checkpoint_path = find_latest_checkpoint(default_root_dir) # now “resumed_checkpoint.ckpt”
if checkpoint_path:
- stage_module, config = load_module(checkpoint_path, stage_module_class)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)
Resumed run: checkpoint_resume_dir=None
, checkpoint_path=/path/to/checkpoint
, default_root_dir
exists
default_root_dir = get_default_root_dir(config)
if os.path.isdir(default_root_dir): checkpoint_resume_dir = default_root_dir
if checkpoint_resume_dir:
- remove all hpc checkpoint
- default_root_dir = checkpoint_resume_dir
- checkpoint_path = find_latest_checkpoint(default_root_dir) # now “last.ckpt”
if checkpoint_path:
- stage_module, config = load_module(checkpoint_path, stage_module_class)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)
Case 3: If checkpoint_resume_dir
given, the training starts from a the last checkpoint from a previous default_root_dir
(e.g. continue an interrupted run):
First run: checkpoint_resume_dir=/path/to/dir
, checkpoint_path=None
, default_root_dir
does not exist
default_root_dir = get_default_root_dir(config)
if checkpoint_resume_dir:
- remove all hpc checkpoint
- default_root_dir = checkpoint_resume_dir # resetting the default_root_dir to the checkpoint_resume_dir provided, all subsequent checkpoints saved here
- checkpoint_path = find_latest_checkpoint(default_root_dir) # now “resumed_checkpoint.ckpt”
if checkpoint_path:
- stage_module, config = load_module(checkpoint_path, stage_module_class)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)
Resumed run: checkpoint_resume_dir=/path/to/dir
, checkpoint_path=None
, default_root_dir
exists
default_root_dir = get_default_root_dir(config)
if os.path.isdir(default_root_dir): checkpoint_resume_dir = default_root_dir
if checkpoint_resume_dir:
- remove all hpc checkpoint
- default_root_dir = checkpoint_resume_dir
- checkpoint_path = find_latest_checkpoint(default_root_dir) # now “last.ckpt”
if checkpoint_path:
- stage_module, config = load_module(checkpoint_path, stage_module_class)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)
-
Explicit checkpoint path works as expected -
Explicit checkpoint dir works as expected -
SLURM batch resume works as expected (1 GPU) -
SLURM batch resume works as expected (4 GPUs) -
Local W&B works as expected -
No logger, no checkpoint, working as expected -
Ensure that infer stage works with explicit checkpoint -
Ensure that infer stage works automatically without checkpoint