Skip to content

Introducing muP

Maxence Draguet requested to merge mdraguet/salt:max-new-temp-muP into main

This MR introduces muP in order to achieve muTransfer in salt for GN2 models. The codebase is adapted to leverage the muP pip library. This makes it possible to run muTransfer, a technique for HyperParameter Optimisation (HPO) leveraging smaller models to zero-shot to large models [1].

A gentle introduction to muP is given in this talk.

In effect, the changes are:

  • Introduction of a muP_utils in utils, with the main script being main_muP.py, callable from anywhere having salt installed as setup_muP. This script will use the generateModel.py to instantiate the base and delta models (the delta having some widths varying w.r.t. the base to indicate what parameter is being muTransfer along - for example, the embedding size). Some more functions are added to check that muP is correctly setup.
  • Modifications to the configs: in order to let the model load the base and delta and properly make use of the muP library.
  • Modifications to the ModelWrapper and specific modules (transformer, attention, dense, ...) in order to use the proper initialisation and optimiser (the AdamW from muP).

[1] Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer

Edited by Maxence Draguet

Merge request reports