Skip to content

DRAFT: Enable reliable auto-resume

Daniel Thomas Murnane requested to merge dmurnane_fix_distributed_autoresume into dev

This MR aims to do one thing: fix the broken auto-resume behaviour that especially appears when doing distributed training. This occurs because training is resumed mid-way through an epoch. We can avoid this by specifying Trainer.fit(ckpt_path="last"), which uses the custom saved checkpoint at the end of the last epoch. But, it turns out to achieve this, we need to restructure some of the directory management. It's boring, but it works like this: image

ATTENTION: This workflow has been superseded by the following

Screenshot_2024-11-27_at_11.52.26

Any of the 3 workflows must guarantee that upon resumption, pytorch lightning resumes from the latest checkpoint in the default_root_dir. This is done by assigning checkpoint_resume_dir to default_root_dir if it exists, removing all present HPC checkpoints, and assigning checkpoint_path to the latest checkpoint in checkpoint_resume_dir.

Case 1: If no checkpoint nor checkpoint_resume_dir is given, i.e. the training starts from scratch:

First run: checkpoint_resume_dir=None, checkpoint_path=None, default_root_dir does not exist

default_root_dir = get_default_root_dir(config)
stage_module = stage_module_class(config)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)

Resumed run: checkpoint_resume_dir=None, checkpoint_path=None, default_root_dir exists

default_root_dir = get_default_root_dir(config)
if os.path.isdir(default_root_dir): checkpoint_resume_dir = default_root_dir 
if checkpoint_resume_dir: 
- remove all hpc checkpoint 
- default_root_dir = checkpoint_resume_dir
- checkpoint_path = find_latest_checkpoint(default_root_dir)
if checkpoint_path:
- stage_module, config = load_module(checkpoint_path, stage_module_class)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)

Case 2: If checkpoint is given, i.e. the training starts from a checkpoint at the end of some epoch:

First run: checkpoint_resume_dir=None, checkpoint_path=/path/to/checkpoint, default_root_dir does not exist

default_root_dir = get_default_root_dir(config)

if checkpoint_path and (not os.path.isdir(default_root_dir)):
- os.mkdir(default_root_dir)
- shutil.copyfile(checkpoint_path, os.path.join(default_root_dir, “resumed_checkpoint.ckpt”))

if os.path.isdir(default_root_dir): checkpoint_resume_dir = default_root_dir 

if checkpoint_resume_dir: 
- remove all hpc checkpoint 
- default_root_dir = checkpoint_resume_dir
- checkpoint_path = find_latest_checkpoint(default_root_dir) # now “resumed_checkpoint.ckpt”

if checkpoint_path:
- stage_module, config = load_module(checkpoint_path, stage_module_class)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)

Resumed run: checkpoint_resume_dir=None, checkpoint_path=/path/to/checkpoint, default_root_dir exists

default_root_dir = get_default_root_dir(config)

if os.path.isdir(default_root_dir): checkpoint_resume_dir = default_root_dir 

if checkpoint_resume_dir: 
- remove all hpc checkpoint 
- default_root_dir = checkpoint_resume_dir
- checkpoint_path = find_latest_checkpoint(default_root_dir) # now “last.ckpt”

if checkpoint_path:
- stage_module, config = load_module(checkpoint_path, stage_module_class)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)

Case 3: If checkpoint_resume_dir given, the training starts from a the last checkpoint from a previous default_root_dir (e.g. continue an interrupted run):

First run: checkpoint_resume_dir=/path/to/dir, checkpoint_path=None, default_root_dir does not exist

default_root_dir = get_default_root_dir(config)

if checkpoint_resume_dir: 
- remove all hpc checkpoint 
- default_root_dir = checkpoint_resume_dir # resetting the default_root_dir to the checkpoint_resume_dir provided, all subsequent checkpoints saved here
- checkpoint_path = find_latest_checkpoint(default_root_dir) # now “resumed_checkpoint.ckpt”

if checkpoint_path:
- stage_module, config = load_module(checkpoint_path, stage_module_class)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)

Resumed run: checkpoint_resume_dir=/path/to/dir, checkpoint_path=None, default_root_dir exists

default_root_dir = get_default_root_dir(config)

if os.path.isdir(default_root_dir): checkpoint_resume_dir = default_root_dir 

if checkpoint_resume_dir: 
- remove all hpc checkpoint 
- default_root_dir = checkpoint_resume_dir
- checkpoint_path = find_latest_checkpoint(default_root_dir) # now “last.ckpt”

if checkpoint_path:
- stage_module, config = load_module(checkpoint_path, stage_module_class)
trainer = Trainer(default_root_dir = default_root_dir)
trainer.fit(model, ckpt_path = checkpoint_path)
  • Explicit checkpoint path works as expected
  • Explicit checkpoint dir works as expected
  • SLURM batch resume works as expected (1 GPU)
  • SLURM batch resume works as expected (4 GPUs)
  • Local W&B works as expected
  • No logger, no checkpoint, working as expected
  • Ensure that infer stage works with explicit checkpoint
  • Ensure that infer stage works automatically without checkpoint
Edited by Minh Tuan Pham

Merge request reports

Loading