Skip to content

Issues loading model trained with flash attention

TransformerV2 can use flash attention for speeeeeeed. However, when such a model is saved to a checkpoint, we require flash attention to be installed to load this checkpoint, as the flash attention is included in the pickle checkpoint file. See this mattermost discussion/error message. This isn't ideal, as we want the model to be valid on any system. We also ran into some (solvable) issues with precision, which inspired !327 (merged). I think the best/easiest solution is to add to our Checkpoint callback, and explicitly copy the model and change the backend if needed before saving it.