Skip to content

Tweaks to Transformerv2: layerscale + varlen_attn + small optim

Matthew Leigh requested to merge matt_dev into main

Description

This is MR includes some changes and refactoring to the TransformersV2 architecture which includes the following additions:

  • Changed the merge_masks function
    • The new behavior is to let padded nodes receive what they want and to only limit sending
    • The computational effect is the same (still need an NxN matrix)
    • But this should reduce the Nans produced by only padded nodes
  • Added FlashAttention with Varlen Support
    • Brings in the optional dependency https://github.com/Dao-AILab/flash-attention
    • Although FlashAttention2 is now supported natively by pytorch, the incredibly useful flash_attn_varlen_qkvpacked_func is not
    • This function allows attention mechanism work without any padding at all saving significant memory and time
    • I have added the flash attention as an optional backend when using the Attention class
    • Necessary pack and unpack functions are added to the TransformerV2 forward call to facilitate this
    • The time saved increases as a function of how padded the data is (shown below)
    • During inference / ONNX export we can always switch back to one of the other backends as they result in identical outputs

image.png

  • Reduced the number of linear layers in Attention and GLU
    • Slight changes to the linear layer setup in these two blocks to slightly improve parallelism
    • No change in mathematics
  • Changed the way the backends are called
    • There are now four supported backends: torch-flash, torch-math, torch-meff, flash-varlen
    • The first three now change the behavior of the torch-attn function
    • The latter calls its own method inside the Attention block due to unique arguments
  • Added LayerScale and PreNormScaledResidual classes (https://arxiv.org/pdf/2103.17239.pdf)
    • LayerScale is now used by almost all deep vision transformers
    • Initialising LayerScale with None disables it. So it is still an optional feature.
    • It scales the output of a function by learnable parameters initialized near zero before a residual connection
    • The PreNormScaledResidual class is a convenient way to wrap a common pattern we see in transformers
      • output = layerscale(fn(norm(x))) + x
      • In the transformer blocks fn is most commonly a MHA block or a FF network
  • Added a self attention block to the decoder:
    • I am not too sure if this was a bug, as decoder blocks are usually: self attention -> cross attention -> ff
    • The original was only: cross attention -> ff
    • If this was the intended behavior then we should change the name to not clash with literature
  • Added/Modified the suite of tests for the new setup, including the timing speedups
    • Right now the times are: torch MHA > Salt MHA w/o varlen > Salt MHA w/ varlen

Things that were removed: (I very flexible in bringing them back)

  • Removed the SelfAttention and CrossAttention blocks
    • It is much easier and faster to just have one block Attention where it is optional to provide an extra kv tensor
  • Removed window_size from the Attention class
    • This was not used during forward but present in the init methods
    • Window size however means that we have to be careful about how we order our inputs as we are no longer permutation invariant
  • Removed n_kv_heads from the Attention class
    • This can be brought back but would result in a bit of extra code fitting it together with everything else
    • The motivation for grouped query attention is to reduce the memory overhead storing the kv-cache during autoregressive generation which we don't do anyway
  • Removed the add_zero_attn option.
    • Its not advised that one uses it anyway. Plus the adding a nan_to_num check at the end is much faster than increasing the size of the message passing matrix.

Review checklist:

  • CI Passing
  • Comments addressed
  • Source branch is up to date with target
Edited by Matthew Leigh

Merge request reports