Tweaks to Transformerv2: layerscale + varlen_attn + small optim
Description
This is MR includes some changes and refactoring to the TransformersV2 architecture which includes the following additions:
- Changed the
merge_masks
function- The new behavior is to let padded nodes receive what they want and to only limit sending
- The computational effect is the same (still need an NxN matrix)
- But this should reduce the Nans produced by only padded nodes
- Added FlashAttention with Varlen Support
- Brings in the optional dependency https://github.com/Dao-AILab/flash-attention
- Although FlashAttention2 is now supported natively by pytorch, the incredibly useful
flash_attn_varlen_qkvpacked_func
is not - This function allows attention mechanism work without any padding at all saving significant memory and time
- I have added the flash attention as an optional backend when using the Attention class
- Necessary pack and unpack functions are added to the TransformerV2 forward call to facilitate this
- The time saved increases as a function of how padded the data is (shown below)
- During inference / ONNX export we can always switch back to one of the other backends as they result in identical outputs
- Reduced the number of linear layers in
Attention
andGLU
- Slight changes to the linear layer setup in these two blocks to slightly improve parallelism
- No change in mathematics
- Changed the way the backends are called
- There are now four supported backends:
torch-flash, torch-math, torch-meff, flash-varlen
- The first three now change the behavior of the
torch-attn
function - The latter calls its own method inside the
Attention
block due to unique arguments
- There are now four supported backends:
- Added
LayerScale
andPreNormScaledResidual
classes (https://arxiv.org/pdf/2103.17239.pdf)- LayerScale is now used by almost all deep vision transformers
- Initialising LayerScale with
None
disables it. So it is still an optional feature. - It scales the output of a function by learnable parameters initialized near zero before a residual connection
- The
PreNormScaledResidual
class is a convenient way to wrap a common pattern we see in transformersoutput = layerscale(fn(norm(x))) + x
- In the transformer blocks
fn
is most commonly a MHA block or a FF network
- Added a self attention block to the decoder:
- I am not too sure if this was a bug, as decoder blocks are usually: self attention -> cross attention -> ff
- The original was only: cross attention -> ff
- If this was the intended behavior then we should change the name to not clash with literature
- Added/Modified the suite of tests for the new setup, including the timing speedups
- Right now the times are: torch MHA > Salt MHA w/o varlen > Salt MHA w/ varlen
Things that were removed: (I very flexible in bringing them back)
- Removed the
SelfAttention
andCrossAttention
blocks- It is much easier and faster to just have one block
Attention
where it is optional to provide an extrakv
tensor
- It is much easier and faster to just have one block
- Removed
window_size
from theAttention
class- This was not used during
forward
but present in the init methods - Window size however means that we have to be careful about how we order our inputs as we are no longer permutation invariant
- This was not used during
- Removed
n_kv_heads
from theAttention
class- This can be brought back but would result in a bit of extra code fitting it together with everything else
- The motivation for grouped query attention is to reduce the memory overhead storing the kv-cache during autoregressive generation which we don't do anyway
- Removed the
add_zero_attn
option.- Its not advised that one uses it anyway. Plus the adding a nan_to_num check at the end is much faster than increasing the size of the message passing matrix.
Review checklist:
-
CI Passing -
Comments addressed -
Source branch is up to date with target