diff --git a/.github/workflows/joss-paper.yml b/.github/workflows/joss-paper.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f85b711e32cf99e435aaa477c4cae2ff2a51ad57
--- /dev/null
+++ b/.github/workflows/joss-paper.yml
@@ -0,0 +1,23 @@
+on: [push]
+
+jobs:
+  paper:
+    runs-on: ubuntu-latest
+    name: Paper Draft
+    steps:
+      - name: Checkout
+        uses: actions/checkout@v3
+      - name: Build draft PDF
+        uses: openjournals/openjournals-draft-action@master
+        with:
+          journal: joss
+          # This should be the path to the paper within your repo.
+          paper-path: paper/paper.md
+      - name: Upload
+        uses: actions/upload-artifact@v1
+        with:
+          name: paper
+          # This is the output path where Pandoc will write the compiled
+          # PDF. Note, this should be the same directory as the input
+          # paper.md
+          path: paper/paper.pdf
diff --git a/.gitignore b/.gitignore
index 828e588cacacb6212d59096cf44a58e6c8a3edd7..acbbac8d487328815bd7f5f177f843bffc33ffc3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,6 +8,7 @@ python_install/
 conda/
 condor/
 logs/
+slurm/
 tmp/
 env/
 plots/
@@ -26,6 +27,7 @@ user/
 *.h5
 *.egg-info/
 *_cache/
+*.cache/
 *.ipynb
 .coverage*
 .cometml-runs
diff --git a/.gitlab/.ci-docs.yaml b/.gitlab/.ci-docs.yaml
index 49bf5334c01ec32f6a4ad73fa9cf8d67faf683c7..7b2ff477e4a6c0558760d4624df265fe339d37bd 100644
--- a/.gitlab/.ci-docs.yaml
+++ b/.gitlab/.ci-docs.yaml
@@ -13,6 +13,6 @@ pages:
   variables:
     GIT_DEPTH: 0
     GIT_STRATEGY: clone
-  #rules:
-  #  - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH && $CI_PROJECT_PATH=="atlas-flavor-tagging-tools/algorithms/salt"
-  #    when: always
+  rules:
+    - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH && $CI_PROJECT_PATH=="atlas-flavor-tagging-tools/algorithms/salt"
+      when: always
diff --git a/.gitlab/.ci-test.yaml b/.gitlab/.ci-test.yaml
index d1beb97619bfd63976236b9d3d31c27de492cd7b..4c665af84a7733c69d26ad4046caf77a1ab563ca 100644
--- a/.gitlab/.ci-test.yaml
+++ b/.gitlab/.ci-test.yaml
@@ -8,7 +8,7 @@ variables:
   artifacts:
     paths: [.coverage*]
   rules:
-    - changes: ["*", "salt/**/*.py"]
+    - changes: ["pyproject.toml", "requirements.txt", "salt/**/*.py", "**/*.yaml"]
 
 # --------------------------- UNIT TESTS ---------------------------
 unit-tests:
@@ -128,9 +128,9 @@ test-maskformer:
 test-parameterisation_concatenation:
   <<: *test-template
   script:
-    - $TEST_CMD salt/tests/test_pipeline.py::test_parameterisation_concatenation
+    - $TEST_CMD salt/tests/test_pipeline.py::test_param_concat
 
 test-parameterisation_featurewise:
   <<: *test-template
   script:
-    - $TEST_CMD salt/tests/test_pipeline.py::test_parameterisation_featurewise
+    - $TEST_CMD salt/tests/test_pipeline.py::test_param_featurewise
diff --git a/.vscode/settings.json b/.vscode/settings.json
index ece0a73a73ceaa28ad52aaefc1ce5a863db7d3a3..d4f963dcb1964550e68b674728208ef481a527ee 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -3,6 +3,9 @@
     "autoDocstring.docstringFormat": "numpy",
     "[python]": {
         "editor.defaultFormatter": "charliermarsh.ruff",
+        "editor.rulers": [
+            100
+        ],
         "editor.codeActionsOnSave": {
             "source.organizeImports": "explicit",
             "source.fixAll": "explicit"
@@ -19,4 +22,4 @@
     ],
     "python.testing.unittestEnabled": false,
     "python.testing.pytestEnabled": true
-}
\ No newline at end of file
+}
diff --git a/docs/api/transformer.md b/docs/api/transformer.md
index 74cff8591f101058c24d2cdbf54fccd9ca5907c7..e555c85c0f02c0f7204ac09fd7e5ab3aaccdf075 100644
--- a/docs/api/transformer.md
+++ b/docs/api/transformer.md
@@ -2,7 +2,6 @@
 ## ::: salt.models.transformer_v2.Attention
     options:
       members: [forward]
-## ::: salt.models.transformer_v2.SelfAttention
 ## ::: salt.models.transformer_v2.GLU
 ## ::: salt.models.transformer_v2.EncoderLayer
 ## ::: salt.models.transformer_v2.TransformerV2
diff --git a/docs/configuration.md b/docs/configuration.md
index 57d856a6e3e265a1c6ce431ef46d772b09664aa7..0954c8a67c9490848b7767f87f61f628a0793e6e 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -283,6 +283,14 @@ by [Dumoulin et al](https://distill.pub/2018/feature-wise-transformations/). Her
           dense_config_bias:
             hidden_layers: [4]
             output_size: 17
+        - layer: encoder
+          apply_norm: True
+          dense_config_scale:
+            hidden_layers: [128]
+            output_size: 256
+          dense_config_bias:
+            hidden_layers: [128]
+            output_size: 256
         - layer: global
           dense_config_scale:
             output_size: 128
@@ -292,7 +300,7 @@ by [Dumoulin et al](https://distill.pub/2018/feature-wise-transformations/). Her
 ```
 
 Here, two instances of featurewise transformations have been added to the model. For each, you must specify the layer whose features you would
-like to transform (this can currently be either `input`, which applies the transformations to the features before they are passed into the initialisation network, or `global`, which applies them to the global track representations outputted by the encoder). For each instance, you can specify either one or both of `dense_config_scale` or `dense_config_bias`, which configure dense networks whose output scales and biases the features of the chosen layer, respectively. It is important to ensure the `output_size` of these networks matches the number of features in the layer you are transforming. In this case, the transformations are applied to a model with 17 inputs per track, and an encoder that outputs 128 features for each track representation. 
+like to transform (this can currently be either `input`, which applies the transformations to the features before they are passed into the initialisation network, `encoder`, which applies the transformations to the inputs of each layer to the encoder using separate networks, or `global`, which applies them to the global track representations outputted by the encoder). For each instance, you can specify either one or both of `dense_config_scale` or `dense_config_bias`, which configure dense networks whose output scales and biases the features of the chosen layer, respectively. It is important to ensure the `output_size` of these networks matches the number of features in the layer you are transforming. In this case, the transformations are applied to a model with 17 inputs per track, the layers of an encoder with 256 features, and the output of the encoder, which has 128 features for each track representation. You can optionally apply a layer normalisation after applying the transformations by setting `apply_norm: True` for a given network, as shown above.
 
 
 ### Training
diff --git a/docs/index.md b/docs/index.md
index e8317532a1fe02fa6130b8f4873c986e0b8f9d33..f5816ee6f847074480ce5407691799e147eb086e 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -24,7 +24,7 @@ Below are some helpful links to get you started:
 
 !!! question "There is a [channel](https://mattermost.web.cern.ch/aft-algs/channels/gnns) for the framework in the [FTAG Mattermost workspace](https://mattermost.web.cern.ch/signup_user_complete/?id=1wicts5csjd49kg7uymwwt9aho&md=link&sbr=su)"
 
-!!! abstract "A tutorial on how to use Salt is provided at the [FTAG docs page](https://ftag.docs.cern.ch/software/tutorials/tutorial-salt/)"
+!!! abstract "A tutorial on how to use Salt can be found [here](tutorial.md) and [here](tutorial-Xbb.md)"
 
 !!! note "[Contributions](contributing) are welcome! Check out [existing issues](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/salt/-/issues) for inspiration, or open your own"
 
diff --git a/docs/training.md b/docs/training.md
index f89610f59df34410d54cf8f123c0c8bab0136132..2311f8d3738884c15cab62b543b53fae7b3108c5 100644
--- a/docs/training.md
+++ b/docs/training.md
@@ -167,14 +167,25 @@ The job parameters such as memory requirements, number of GPUs and CPUs requeste
 
 #### Slurm Batch
 
-Those at institutions with Slurm managed GPU batch queues can submit training jobs using
+Those at institutions with Slurm managed GPU batch queues can submit training jobs using a very similar script.
+
+All options described above for HTCondor and more (CPUs, GPUs, etc) are available as command line arguments. 
+
+```bash
+python submit/submit_slurm.py --config configs/GN2.yaml --tag test_salt --account MY-ACCOUNT --nodes 1 --gpus_per_node 2
+```
+
+The script submit/submit_slurm.py script itself can be modified if a required configuration is not supported in this way.
+
+Where arguments need to agree between Slurm and Pytorch Lightning, such as ntasks-per-node for Slurm and trainer.devices for Lightning, this is handled by the script.
+
+Lightning has the ability to requeue a job if it is killed by Slurm for exceeding the system walltime. The training state is saved in a checkpoint and loaded when the new job begins. submit_slurm.py creates a single log directory holding the checkpoints for the original and any requeue-d jobs (in the below example GN2_my_requeue_job).
 
 ```bash
-sbatch submit/submit_slurm.sh
+python submit/submit_slurm.py --config configs/GN2.yaml --requeue --salt_log_dir=my_requeue_job --signal=SIGUSR1@90
 ```
 
-The submit script only supports running from a conda environment for now.
-There are several options in the script which need to be tailored to make sure to make a look inside.
+There is also an older submit/submit_slurm.sh bash script that is kept around for compatibility. Users are strongly encouraged to use the python script.
 
 ??? info "Cleaning up after interruption"
 
diff --git a/docs/tutorial-Xbb.md b/docs/tutorial-Xbb.md
new file mode 100644
index 0000000000000000000000000000000000000000..67f9add21ad2fdd30007f0c25a8d94004a540f94
--- /dev/null
+++ b/docs/tutorial-Xbb.md
@@ -0,0 +1,606 @@
+# Salt framework for $X \rightarrow bb$ tagger tutorial
+
+## Introduction
+
+In this tutorial, you will learn to setup and use the [Salt framework](https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/salt/) in the context of $X \rightarrow bb$ tagging.
+Salt is a high-level framework for training state-of-the-art flavour tagging algorithms.
+In addition, plotting scripts are provided to plot the results of the evaluation using the [`puma`](https://github.com/umami-hep/puma) package.
+
+In this tutorial, we cover the following functionalities of Salt:
+
+1. Training of a subjet-based Xbb tagger
+2. Training of a constituent-based Xbb tagger using tracks
+3. Modification of high-level settings and network hyperparameters
+4. Evaluation of results
+
+!!!info "If you are not studying $X \rightarrow bb$ feel free to skip task 2"
+
+The tutorial is meant to be followed in a self-guided manner. You will be prompted to do certain tasks by telling you what the desired outcome will be, without telling you how to do it. Using the [Salt documentation](index.md), you can find out how to achieve your goal. In case you are stuck, you can click on the "hint" toggle box to get a hint. If you tried for more than 10 min at a problem, feel free to toggle also the solution with a working example.
+
+???+ question "What to do if you get stuck"
+    
+    In case you encounter some errors or you are completely stuck, you can reach out to the dedicated [FTAG tutorial mattermost channel](https://mattermost.web.cern.ch/aft-algs/channels/ftag-tutorials) (click [here](https://mattermost.web.cern.ch/signup_user_complete/?id=ektad7hj4fdf5nehdmh1js4zfy) to sign up).
+
+This tutorial has been run a few times, click below for intro slides which give some context on the framework and this tutorial:
+
+- [$X \rightarrow bb$ taskforce meeting](https://indico.cern.ch/event/1248303/).
+- [FTAG Workshop 2023](https://indico.cern.ch/event/1311519/timetable/?view=standard#31-salt-tutorial).
+
+
+## Prerequisites
+
+For this tutorial, you need access to a shell on either CERN's `lxplus` or your local cluster with `/cvmfs` access to retrieve the `singularity` image needed. To set this up, please follow the instructions [here](setup.md) by selecting the "singularity" tab in the ["Create Environment"](setup.md#create-environment).
+
+Efficient training is only possible on resources with GPU access.
+It is highly encouraged to use an institute-managed GPU enabled machine if one is available.
+Otherwise, CERN provides special lxplus nodes with GPU access for interactive computing.
+
+You can log in to a CERN lxplus gpu node with:
+
+```bash
+ssh -Y <username>@lxplus-gpu.cern.ch
+```
+
+You can check that your node is configured with GPU access by running 
+
+```bash
+nvidia-smi
+```
+
+If you see a tabular output with information about one or more GPUs, then you are good to continue.
+
+!!! warning "Check your machine is configured correctly"
+    
+    If you see `No devices were found` your node is badly configured, and you should log in again and hope for a new node.
+    
+
+
+### Training datasets
+
+You should copy the training files before doing the tutorial.
+If you don't the training will be much slower but you can compensate for that by reducing the number of training and validation jets as hinted in the tasks below.
+The train/val/test samples for the tutorial each have 2M jets and are stored on EOS in the following directory
+
+- `/eos/user/u/umami/tutorials/salt/2023/inputs/`
+
+=== "Copy to user EOS space"
+    If you are running on lxplus, copying the training files to your private storage on `/eos/user/${USER:0:1}/$USER/` is recommended to avoid overly high concurrent access:
+    ```
+    rsync -vaP /eos/user/u/umami/tutorials/salt/2023/inputs/ /eos/user/${USER:0:1}/$USER/training-samples
+    ```
+
+=== "Local Cluster"
+
+    If you are running on your local cluster, you can copy the files to a directory with fast access:
+
+    ```
+    rsync -vaP <cern username>@lxplus.cern.ch:/eos/user/u/umami/tutorials/salt/2023/inputs/ /fast/disk/training-samples/
+    ```
+
+??? warning "Access to EOS is slow, copying files before the tutorial is highly recommended!"
+
+    The training files are stored on EOS, which is a distributed file system. Accessing files on EOS is slow, so it is recommended to copy the files to a local directory before starting the tutorial. If you attempt to run the tutorial directly from EOS, you will experience very slow training times.
+
+
+??? error "What to do if you don't have access to the EOS folder"
+
+    The training files stored on EOS are only shared with user subscribed to the egroups/mailing lists
+    
+    - `atlas-cp-flavtag-btagging-algorithms`
+    - `atlas-cp-flavtag-jetetmiss-BoostedXbbTagging`
+
+    If you are not yet subscribed, please consider doing so to get access to the training files.
+    You can subscribe using the [CERN egroups webpage](https://e-groups.cern.ch/e-groups/EgroupsSearch.do).
+
+    If you already are subscribed and try to copy from inside a singularity container, it might fail. In that case, copy the files without using the singularity container.
+
+When training a model, [it is possible to specify a local directory](training.md#fast-disk-access) with fast access, e.g. `/tmp` to which the files will be copied.
+This will speed up the training on e.g. `lxplus` significantly (though you will still incur the initial cost of copying the files).
+
+The total size of the training, validation and test files is 17GB, make sure you have sufficient free space.
+Alongside the input h5 are the `norm_dict.yaml` and `class_dict.yaml` which are also used for training.
+
+
+### Singularity image
+
+The FTAG group provides salt-ready singularity images via `/cvmfs/unpacked.cern.ch` on lxplus (or any cluster which has `/cvmfs` mounted). On the node, you can use `singularity` to launch the container from the image on `/cvmfs/unpacked.cern.ch` with the already prepared `salt` framework.
+We'll use the tagged image for version `0.3` of the code.
+
+=== "lxplus (eos access)"
+
+    If you run on lxplus, it is advantageous to also mount the `/afs`, `/eos`, `/tmp` and `/cvmfs` directories:
+
+    ```bash
+    singularity shell -e --env KRB5CCNAME=$KRB5CCNAME --nv --bind $PWD,/afs,/eos,/tmp,/cvmfs,/run/user \
+        /cvmfs/unpacked.cern.ch/gitlab-registry.cern.ch/atlas-flavor-tagging-tools/algorithms/salt:0-3
+    ```
+
+=== "other (cvmfs only)"
+
+    ```
+    singularity shell -e --env KRB5CCNAME=$KRB5CCNAME --nv --bind $PWD,/cvmfs,/run/user \
+        /cvmfs/unpacked.cern.ch/gitlab-registry.cern.ch/atlas-flavor-tagging-tools/algorithms/salt:0-3
+    ```
+
+If you have issues accessing bound paths, ensure your Kerberos credentials are set with `export KRB5CCNAME=FILE:/run/user/${UID}/krb5cc`
+
+After running the [`singularity shell`](https://docs.sylabs.io/guides/latest/user-guide/cli/singularity_shell.html#singularity-shell) command, you can re-source your `.bashrc` to get some of the features of your normal terminal back by running 
+
+```bash
+source ~/.bashrc
+```
+
+
+## Tutorial tasks
+
+### 1. Fork, clone and install Salt
+
+Although the singularity images come with salt pre-installed, they do not allow for an editable version of the package.
+It's therefore highly recommended to re-install the package from source to give you full control.
+To do so, you need to do the following steps:
+
+1. Create a personal fork of Salt in Gitlab.
+2. Clone the forked repository to your machine using `git`.
+3. Switch to the `0.3` tag which is used for the tutorial.
+4. (Optional) Run the setup to switch to development mode.
+5. Run the test suite
+
+Go to the GitLab project page of Salt to begin with the task: <https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/salt>
+
+??? info "Hint: Create a personal fork of Salt in Gitlab"
+
+    In case you are stuck how to create your personal fork of the project, you can find some general information on git and the forking concept [here in the GitLab documentation](https://docs.gitlab.com/ee/user/project/repository/forking_workflow.html).
+
+??? info "Hint: Clone the forked repository to your machine using `git`"
+
+    The command `git clone` is the one you need. You can look up the use [here](setup.md). You can use the `--branch` argument to checkout a specific branch, e.g. `--branch 0.3` to checkout the `0.3` tag.
+
+
+??? info "Hint: installing the package in development mode"
+
+    By default, the singularity image comes with salt preinstalled, but this not an editable installation. If you want to make code changes, you can install salt in development mode using `pip` with the `-e` flag.
+
+
+??? info "Hint: Run the test suite"
+
+    You can run the suite of unit tests as outlined in the [salt documentation on ](contributing.md#test-suite). Make sure that you enter the `salt` source code directory before you execute the test suite!
+
+    ```bash
+    cd salt/
+    pytest --cov=salt --show-capture=stdout
+    ```
+
+    Note that, depending on your machine, the test suite may take a while to run. To just run a single test, you can instead use
+    
+    ```bash
+    pytest --cov=salt --show-capture=stdout tests/test_pipeline.py::TestModels::test_GN1
+    ```
+
+??? warning "Solution"
+
+    Open the website <https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/salt> in a browser. You may need to authenticate with your CERN login credentials. In the top right corner of the Salt project you see three buttons which show a bell (notifications), a star (to favourite the project) next to a number, and a forking graph (to fork the project) with the text "Fork" next to a number. Click on the word "Fork" to open a new website, allowing you to specify the namespace of your fork. Click on "Select a namespace", choose your CERN username, and create the fork by clicking on "Fork project".
+
+    Next, you need to clone the project using `git`. Open a fresh terminal on the cluster your are working on, create a new folder and proceed with the cloning. To do so, open your forked project in a browser. The address typically is `https://gitlab.cern.ch/<your CERN username>/salt`. When clicking on the blue "Clone" button at the right hand-side of the page, a drop-down mini-page appears with the ssh path to the forked git project. Let's check out your personal fork and add the original project as upstream:
+
+    ```bash
+    git clone ssh://git@gitlab.cern.ch:7999/<your CERN username>/salt.git
+    cd salt
+    git remote add upstream ssh://git@gitlab.cern.ch:7999/atlas-flavor-tagging-tools/algorithms/salt.git
+    git checkout 0.3
+    ```
+
+    You now forked and cloned Salt and should be ready to go!
+
+    Launch the salt singularity container (make sure to bind the directory containing the cloned git project) and change the directory to the top level directory of the project.
+
+
+    ```bash
+    singularity shell -e --nv --bind $PWD,/afs,/eos,/tmp,/cvmfs \
+    /cvmfs/unpacked.cern.ch/gitlab-registry.cern.ch/atlas-flavor-tagging-tools/algorithms/salt:0-3
+    ```
+
+    If you want to modify the salt code and contribute to development, you need to install the salt package to switch to development mode:
+
+    ```bash
+    python -m pip install -e .
+    ```
+
+    Finally, you can run a test to check if everything works fine:
+    
+    ```bash
+    pytest --cov=salt --show-capture=stdout tests/test_pipeline.py::test_GN1
+    ```
+
+    Make sure you are in the directory containing `tests/` when running the test suite.
+    The full test suite would likely take some time to run, but can be invoked with
+
+    ```bash
+    cd salt/
+    pytest --cov=salt --show-capture=stdout
+    ```
+
+
+### 2. (Optional) Set up logging
+
+[Comet.ml](https://www.comet.com/) is an online ML logging service. It's free for academic use. If you get stuck, consult the [comet documentation](https://www.comet.com/docs/v2/guides/getting-started/quickstart/) and the hints below.
+
+1. Create an account and a project.
+2. Generate an API key.
+3. Save the API key and project name in the relevant environment variables, or add them to your `~/bashrc` file.
+
+??? info "Hint: Creating an account and a project"
+
+    To use it, [create an account](https://www.comet.com/signup), and then create a project using the blue `+ New project` button in the GUI interface.
+
+??? info "Hint: Generating an API key"
+
+    You then need to create and save an API key for your project, which you can use to log your training runs. You can find the API key in the [account settings page](https://www.comet.com/account-settings/apiKeys).
+
+??? info "Hint: Saving info to environment variables"
+
+    See the Salt [logging docs](setup.md#setup-logging) for info on which environment variables to use.
+
+??? danger "Warning: If you don't set up logging, you may need to disable it in the training config file"
+
+    Open the `base.yaml` config file and set `logger: False` under the `trainer:` block. Remove the existing sub-blocks under `logger:`. You also need to remove the 
+    ```
+    - class_path: pytorch_lightning.callbacks.LearningRateMonitor
+    ``` 
+    line under the `trainer: callbacks:` block, since this feature requires a logger to work.
+
+### 3. Train subjet-based tagger
+
+In this task, you will train an algorithm based on inputs from variable-radius track sub-jets of the large-radius jet. For the purposes of the tutorial, the training is configured to use the first 5M jets from the training dataset, and to run for 10 epochs.
+
+You can take a look inside the `SubjetXbb.yaml` or `norm_dict.yaml` file to see which variables will be used for training the subjet based model. The r22 GN2 PFlow scores have been included, along with some other kinematic information about the subjet.
+
+1. Modify the `SubjetXbb.yaml` model config file to use the correct paths to your locally downloaded training files.
+2. Run the training for 10 epochs.
+
+??? info "Hint: Modifying the `SubjetXbb.yaml` config file"
+
+    You'll need to specify the `SubjetXbb.yaml` model config to use. This config file needs to be edited with the correct paths to your locally downloaded training files. You'll need to modify the `train_file`, `val_file` and `scale_dict` keys under the `data:` block.
+    To change the number of epochs the training runs for, you can modify the `max_epochs` key under the `trainer:` block.
+
+??? info "Hint: Warning about the number of workers used in the dataloaders"
+
+    By default the `--data.num_workers` flag is set to 10. On your machine, you might see a warning if this is too many, and the suggested number to use instead. Including the `--data.num_workers` flag in the training command will override the default value.
+
+??? info "Hint: Running the training"
+
+    The command to run a training is described in the Salt documentation [here](training.md#training). Make sure you specify `--config configs/SubjetXbb.yaml` to use the correct config file.
+
+??? info "Hint: Speeding up the training"
+
+    Take a look at the [dataloading](training.md#dataloading) section of the training documentation. You can try increasing the worker count, moving the files to fast storage or RAM, or reducing the overall number of training jets using the `--data.num_jets_train` and `--data.num_jets_val` flags. Finally, you could decrease the number of training epochs with `--trainer.max_epochs` flag.
+
+    Make sure that you really have copied the training files from the `/eos` location to a local path, otherwise reading the input via `/eos` can also slow down the training. You could also experiment with the [`--data.move_files_temp`](training.md#fast-disk-access) flag to transfer the data to a high-speed file reading resource before training. On lxplus, you could do so by adding `--data.move_files_temp /tmp` to the training command.
+
+
+??? warning "Solution"
+
+    After modifying the `SubjetXbb.yaml` config file as described in the first hint, you can run a test training with 
+    ```bash
+    salt fit --config configs/SubjetXbb.yaml --trainer.fast_dev_run 2
+    ```
+    Assuming this completes without any errors, you can run a full training by omitting the `--trainer.fast_dev_run` flag.
+    By default, the training uses the first GPU on system. If you want to use a different GPU, you can specify it with the `--trainer.devices` flag as described in the [documentation](training.md#choosing-gpus). To run on the CPU, you can use `--trainer.accelerator cpu`.
+
+    For more training options, including tips for how to speed up the training, take a look at the [documentation](training.md).
+
+
+### 4. Train track-based tagger
+
+In this task, you will train an algorithm based directly on the tracks associated with the large-radius jet as inputs. Again, take a look inside the variable config to get an idea of which variables are being used in the track-based case.
+
+1. Modify the `GN2X.yaml` model config file to use the correct paths to your locally downloaded training files.
+2. Run the training for 10 epochs.
+3. Compare the losses of the subjet-based model and the track-based model.
+
+Note that you may ecounter _carefully planned_ errors as part of this task, please use the hints below to try and resolve them.
+
+??? info "Hint: See hints for the previous task"
+
+    This task is very similar to Task 2, for a different model config: `GN2X.yaml`.
+
+??? info "Hint: What to do about a `MisconfigurationException`" 
+    
+    You might see the following error if you run on a machine with only one accessible GPU.
+
+    ```
+    lightning.fabric.utilities.exceptions.MisconfigurationException: You requested gpu: [0, 1]
+     But your machine only has: [0]
+    ```
+    
+    This is because the default `GN2X.yaml` config asks for 2 GPUs to speed up training.
+    This is a good opportunity to learn about requesting GPUs. 
+    You can read [here](training.md#choosing-gpus) for hints about what to do.
+
+??? info "Hint: What to do about `ValueError: Variables {...} were not found in dataset`" 
+    
+    This kind of error is quite common if the inputs become out of sync with the config.
+    In our case, the config has been updated to use the new truth label names, but the samples are somewhat older and have not been updated.
+
+    You need to modify the `GN2X.yaml` config to revert to the old label names which do not have the `ftag` prefex, e.g. `ftagTruthOriginLabel` -> `truthOriginLabel`.
+    This needs to be done in the task config, i.e. around L150 and L191.
+
+
+??? warning "Solution"
+
+    The training should run in the same way as the previous task, but will take longer to complete, since we are processing up to 100 tracks per jet, rather than just the info about 3 subjets.
+
+    You should take note of the initial and final values of the losses for the two models so that you can compare them. Which loss decreases faster from its initial value? Which is lower after the training has been completed? Why do you think this is? (Remember the default choice of 2M training jets is a small fraction of the total number of jets used to train the GN2 tagger which is used to produce the subjet scores.)
+
+    In order to fix the `MisconfigurationException`, just add `--trainer.devices=1` as a command line flag to your `salt fit` call.
+
+### 5. Modify network parameters and retrain
+
+In this task, you will modify parameters of the models trained in the previous tasks and retrain the networks. You should consider what effect the changes have on the evolution of the loss, the size of the model, and the training speed. This task is open-ended and you are encouraged to experiment with modifying the different config files.
+
+??? info "Hint: Changing the number of training jets"
+
+    Inside the model config file you wish to change, look at the `num_jets_train` config key inside the `data:` block. You can take a look at how the number of jets affects the final performance of the models. You can also configure this from the CLI using `--data.num_jets_train <num>`.
+
+??? info "Hint: Changing the model architecture"
+
+    Inside the model config file you wish to change, look at the `model:` block. The core of the model is the Graph Network/Transformer configured in the `gnn:` block. You can modify the number of layers, the number of attention heads, or the embedding dimension and see what effect his has on the training.
+
+??? info "Hint: Removing auxiliary tasks"
+
+    The `GN2X.yaml` includes the auxiliary track classification task. To remove it, look in the `tasks:` block for the list item which has `name: track_classification`. Removing the associated block will disable that part of the model and remove the associated track classification loss from the overall training loss function when training.
+
+??? warning "Solution"
+
+    This task is really just about tinkering, but you may notice the following things:
+    
+    - Reducing the number of jets is detrimental to the performance of the model. If you study how the lowest value of the loss changes as a function of the number of training jets, you should come to the conclusion that larger training samples would be beneficial to improving performance.
+    - Increasing the model size (number of layers, embedding dimension) leads to slower training times, and more overtraining (especially visible with such a small number of training jets)
+    - The loss for the subjet-based model initially drops much more quickly than for the track-based model. This is because it has outputs from an already trained model as inputs.
+    - Given enough time and enough input jets, the loss for the track-based model will drop below that of the subjet-based model. This reflects the fact constituent-based tagging approach is more powerful than the subjet-based approach in the long run.
+
+
+### 6. Evaluate the models on the test set
+
+After training, the model is evaluated on an independent set of testing jets. The results are used to produce performance plots. The test file you will use for the tutorial is called `pp_output_test.h5`, and contains a mixture of the different jet classes used for training. The jet labels are specified by the `R10TruthLabel_R22v1` jet variable. The classes are specified in the `enum` [here](https://gitlab.cern.ch/atlas/athena/-/blob/master/PhysicsAnalysis/AnalysisCommon/ParticleJetTools/ParticleJetTools/LargeRJetLabelEnum.h#L11).
+
+1. Choose which models you want to evaluate and 
+2. Run the model evaluation command
+
+??? info "Hint: Comparing to a provided pre-trained model"
+
+    If you had problems with the training, or you just want to compare to a benchmark, you can use one of the provided model checkpoints to evaluate.
+    These models are found in the `/eos/user/u/umami/tutorials/salt/2023/trained-models/` directory.
+    Note that they are not claimed to be especially well performing models (they were only trained for 5 epochs) - you may well find a configuration that outperforms them!
+
+??? info "Hint: Running the evaluation command"
+
+    Take a look at the [relevant page](evaluation.md) in the salt documentation.
+    You might want to choose to evaluate on e.g. 1e5 jets, rather than the full 2M.
+
+??? warning "Solution"
+
+    Find the path to the saved config file of the model you want to evaluate.
+    This will be located in a timestamped directory under `logs/`.
+    Also have handy the path to your `/eos/home-u/umami/tutorials/salt/2023/inputs/pp_output_test.h5` (or run directly on this file).
+    To run the evaluation, use
+
+    ```bash
+    salt test --config logs/<timestamp>/config.yaml --data.test_file path/to/pp_output_test.h5
+    ```
+    
+    If you want to evaluate the pre-trained model on EOS, this command will be for example
+    ```bash
+    salt test \ 
+      --config /eos/home-u/umami/tutorials/salt/2023/trained-models/SubjetXbb_20230920-T192350/config.yaml \
+      --data.test_file /eos/home-u/umami/tutorials/salt/2023/inputs/pp_output_test.h5
+    ```
+
+    Salt automatically evaluates the checkpoint with the lowest associated validation loss for evaluation, but you can use `--ckpt_path` to specify this manually.
+    The resulting evaluation file will be saved in `ckpts/` in the training output directory, alongside the checkpoint that was used to run the evaluation.
+    Read the [salt docs](evaluation.md#running-the-test-loop) for more info.
+
+
+### 7. Create plots which quantity the trained algorithms performance
+
+In this task, you will create plots of performance metrics using the [`puma`](https://github.com/umami-hep/puma/) python package.
+You can find more information on how to use `puma` for plotting in the [corresponding plotting tutorial](https://ftag-docs.docs.cern.ch/software/tutorials/tutorial-plotting/).
+
+1. Produce a histogram of the jet scores for each class.
+2. Produce ROC curves as a function of signal efficiency.
+
+??? info "Hint: Installing puma"
+
+    Your Salt installation will install puma as a dependency. 
+    You can also follow the quickstart guide in the [puma docs](https://umami-hep.github.io/puma/main/index.html#) to learn how to install it yourself.
+
+??? info "Hint: Plotting histograms"
+
+    Take a look at the [relevant page](https://umami-hep.github.io/puma/main/examples/histograms.html) in the puma docs.
+
+??? info "Hint: Plotting ROCs"
+
+    Take a look at the [relevant page](https://umami-hep.github.io/puma/main/examples/rocs.html) in the puma docs.
+
+??? info "Hint: What to use as a discriminant?"
+
+    Since we have four classes, calculating a discriminant is more complicated than in the single b-tagging case. 
+    One option is to use the score for the signal class directly as the discriminant, but please note, this may lead to a suboptimal trade off between the different background rejections.
+
+??? info "Hint: What are the truth labels"
+
+    Take a look at their definition [here](https://gitlab.cern.ch/atlas/athena/-/blob/master/PhysicsAnalysis/AnalysisCommon/ParticleJetTools/ParticleJetTools/LargeRJetLabelEnum.h#L34). Hbb is `11`, Hcc is `12`, and top is `1` and QCD is `10`.
+
+??? warning "Solution"
+
+    The prepared script `make_plots.py` provides an implementation example to plot tagger discriminant distributions and ROC curves. 
+    It is located here: `/eos/home-u/umami/tutorials/salt/2023/make_plots.py`.
+    
+    At the beginning of the script, you are invited to fill the `networks` dictionary with one or more trained model paths and dedicated keys. In the current version, the pretrained model paths have been implemented. If you want, you can modify them with your own trained models. 
+    In addition, a `reference` is also requested for the model comparisons. It should correspond to one of the model keys added to `networks`.
+    Finally, `test_path` has to be completed with the path to your `pp_output_test.h5` sample. 
+    
+    Two different tagger discriminants are defined by `Hbb` and `Hcc` signal class probabilities.
+    
+    In a first step, the script extracts the different jet tagging probabilities as well as the needed kinematic information to define the jet selection to be applied. Implemented as a boolean `mask`, the jet selection can be easily modified. Efficiencies and rejections are also computed.
+    
+    For a given tagger discriminant, the distributions corresponding to the different jet flavours and trained models are then plotted on the same figure in order to perform a complete comparison. The Puma's `HistogramPlot` and `Histogram` objects offer a lot of configuration variables which can be modified according to cosmetic tastes and needs. Finally, the plotting of the corresponding ROCs follows similarly in another set of figures. 
+
+    Below, the content of `make_plots.py` is shown:
+
+
+    ```python
+
+    import h5py
+    import numpy as np
+    from puma import Histogram, HistogramPlot, Roc, RocPlot
+    from puma.metrics import calc_rej
+    from puma.utils import get_good_colours, get_good_linestyles, logger
+
+    networks = {
+        "SubjetXbb" : "/eos/home-u/umami/tutorials/salt/2023/trained-models/SubjetXbb_20230920-T192350/ckpts/epoch=004-val_loss=0.59234__test_pp_output_test.h5",
+        "GN2X" : "/eos/home-u/umami/tutorials/salt/2023/trained-models/GN2X_20230920-T193158/ckpts/epoch=004-val_loss=1.15303__test_pp_output_test.h5"
+    }
+
+    reference = "SubjetXbb"
+    test_path = '/eos/home-u/umami/tutorials/salt/2023/inputs/pp_output_test.h5'
+    num_jets = 100_000
+
+    # load test data
+    logger.info("Load data")
+    with h5py.File(test_path, 'r') as test_f:
+        jets = test_f['jets'][:num_jets]
+        jet_pt = jets['pt'] / 1000
+        jet_mass = jets['mass'] / 1000
+        jet_eta = np.abs(jets['eta'])
+        flav = jets['R10TruthLabel_R22v1']
+        mask = (jet_pt < 1000) & (jet_pt > 250) & (jet_mass > 50) & (jet_mass < 300)
+        is_QCD = flav == 10
+        is_Hcc = flav == 12
+        is_Hbb = flav == 11
+        is_Top = flav == 1
+        n_jets_QCD = np.sum(is_QCD & mask)
+        n_jets_Top = np.sum(is_Top & mask)
+
+    results = {}
+    logger.info("Calculate rejections")
+    for key, val in networks.items():
+        with h5py.File(val, 'r') as f:
+            jets = f['jets'][:num_jets]
+            pHbb = jets[f'{key}_phbb']
+            pHcc = jets[f'{key}_phcc']
+            pQCD = jets[f'{key}_pqcd']
+            pTop = jets[f'{key}_ptop']
+            disc_Hbb = pHbb
+            disc_Hcc = pHcc
+
+            sig_eff = np.linspace(0.4, 1, 100)
+            Hbb_rej_QCD = calc_rej(disc_Hbb[is_Hbb & mask], disc_Hbb[is_QCD & mask], sig_eff)
+            Hbb_rej_Top = calc_rej(disc_Hbb[is_Hbb & mask], disc_Hbb[is_Top & mask], sig_eff)
+            Hcc_rej_QCD = calc_rej(disc_Hcc[is_Hcc & mask], disc_Hcc[is_QCD & mask], sig_eff)
+            Hcc_rej_Top = calc_rej(disc_Hcc[is_Hcc & mask], disc_Hcc[is_Top & mask], sig_eff)
+            results[key] = {
+                'sig_eff' : sig_eff,
+                'disc_Hbb' : disc_Hbb,
+                'disc_Hcc' : disc_Hcc,
+                'Hbb_rej_QCD' : Hbb_rej_QCD,
+                'Hbb_rej_Top' : Hbb_rej_Top,
+                'Hcc_rej_QCD' : Hcc_rej_QCD,
+                'Hcc_rej_Top' : Hcc_rej_Top
+            }
+
+    logger.info("Plotting Discriminants.")
+    plot_histo = {
+        key : HistogramPlot(
+            n_ratio_panels=1,
+            ylabel="Normalised number of jets",
+            xlabel=f"{key}-jet discriminant",
+            logy=True,
+            leg_ncol=1,
+            figsize=(6.5, 4.5),
+            bins=np.linspace(0, 1, 50),
+            y_scale=1.5,
+            atlas_second_tag="$\\sqrt{s}=13$ TeV, Xbb jets",
+        ) for key in ['Hbb', 'Hcc']}
+    linestyles = get_good_linestyles()[:len(networks.keys())]
+    colours = get_good_colours()[:3]
+    for key, value in plot_histo.items():
+        for network, linestyle in zip(networks.keys(), linestyles):
+            value.add(
+                Histogram(
+                    results[network][f'disc_{key}'][is_QCD],
+                    label="QCD jets" if network == reference else None,
+                    ratio_group="QCD",
+                    colour=colours[0],
+                    linestyle=linestyle,
+                ),
+                reference=(network == reference),
+                )
+            value.add(
+                Histogram(
+                    results[network][f'disc_{key}'][is_Top],
+                    label="Top jets" if network == reference else None,
+                    ratio_group="Top",
+                    colour=colours[1],
+                    linestyle=linestyle,
+                ),
+                reference=(network == reference),
+                )
+            value.add(
+                Histogram(
+                    results[network][f'disc_{key}'][is_Hbb if key == 'Hbb' else is_Hcc],
+                    label=f"{key} jets" if network == reference else None,
+                    ratio_group=f"{key}",
+                    colour=colours[2],
+                    linestyle=linestyle,
+                ),
+                reference=(network == reference),
+                )
+        value.draw()
+        # The lines below create a legend for the linestyles
+        value.make_linestyle_legend(
+            linestyles=linestyles, labels=networks.keys(), bbox_to_anchor=(0.5, 1)
+        )
+        value.savefig(f"disc_{key}.png", transparent=False)
+
+    # here the plotting of the roc starts
+    logger.info("Plotting ROC curves.")
+    plot_roc = {
+        key : RocPlot(
+            n_ratio_panels=2,
+            ylabel="Background rejection",
+            xlabel=f"{key}-jet efficiency",
+            atlas_second_tag="$\\sqrt{s}=13$ TeV, Xbb jets",
+            figsize=(6.5, 6),
+            y_scale=1.4,
+        ) for key in ['Hbb', 'Hcc']}
+
+    for key, value in plot_roc.items():
+        for network in networks.keys():
+            value.add_roc(
+                Roc(
+                    sig_eff,
+                    results[network][f'{key}_rej_QCD'],
+                    n_test=n_jets_QCD,
+                    rej_class="qcd",
+                    signal_class=f"{key}",
+                    label=f"{network}",
+                ),
+                reference=(reference == network),
+            )
+            value.add_roc(
+                Roc(
+                    sig_eff,
+                    results[network][f'{key}_rej_Top'],
+                    n_test=n_jets_Top,
+                    rej_class="top",
+                    signal_class=f"{key}",
+                    label=f"{network}",
+                ),
+                reference=(reference == network),
+            )
+        # setting which flavour rejection ratio is drawn in which ratio panel
+        value.set_ratio_class(1, "qcd")
+        value.set_ratio_class(2, "top")
+        value.draw()
+        value.savefig(f"roc_{key}.png", transparent=False)
+    ```
diff --git a/docs/tutorial.md b/docs/tutorial.md
index bd1bcf7125517e17c32f242bb2bd37d6462e36f2..3ccaf4e49d0cd63e16c946a294a44a0c38dc288e 100644
--- a/docs/tutorial.md
+++ b/docs/tutorial.md
@@ -29,7 +29,7 @@ Create a directory at a location with sufficient free disk space. The unpacked d
 Execute the following commands to download all files to a directory which you will need to define (replace `<path to directory>` with the path to the directory of your choice).
 
 ```bash
-export DIR_TUTORIAL_DATA=<path to directory>
+export TUTORIAL_DATA=<path to directory>
 mkdir -p $TUTORIAL_DATA
 cd $TUTORIAL_DATA
 curl -o $TUTORIAL_DATA/tutorialdata.zip "https://zenodo.org/api/records/10371998/files-archive"
diff --git a/mkdocs.yaml b/mkdocs.yaml
index 84ac6e084a55672d02110e573645df27e7f7146f..9440617f112e3cc86770742a891457dbaf700594 100644
--- a/mkdocs.yaml
+++ b/mkdocs.yaml
@@ -38,6 +38,7 @@ nav:
   - ONNX Export: export.md
   - Contributing: contributing.md
   - Tutorial: tutorial.md
+  - Tutorial (Xbb): tutorial-Xbb.md
   - API Reference:
       - api/data.md
       - api/initialisation.md
diff --git a/paper/paper.bib b/paper/paper.bib
new file mode 100644
index 0000000000000000000000000000000000000000..2731117f6fd042996390e143d416e900e269e039
--- /dev/null
+++ b/paper/paper.bib
@@ -0,0 +1,470 @@
+% LHC
+@article{Evans:2008,
+  author       = {Evans, Lyndon and Bryant, Philip},
+  title        = {{LHC Machine}},
+  journal      = {JINST},
+  volume       = {3},
+  pages        = {S08001},
+  doi          = {10.1088/1748-0221/3/08/S08001},
+  year         = {2008},
+  slaccitation = {%%CITATION = JINST,3,S08001;%%}
+}
+
+% ATLAS
+@article{ATLAS:2008,
+  author       = {{ATLAS Collaboration}},
+  title        = {{The ATLAS Experiment at the CERN Large Hadron Collider}},
+  journal      = {JINST},
+  volume       = {3},
+  year         = {2008},
+  pages        = {S08003},
+  doi          = {10.1088/1748-0221/3/08/S08003},
+  primaryclass = {hep-ex}
+}
+
+% ATLAS software
+@booklet{ATLAS:2021,
+  author       = {{ATLAS Collaboration}},
+  title        = {{The ATLAS Collaboration Software and Firmware}},
+  howpublished = {{ATL-SOFT-PUB-2021-001}},
+  url          = {https://cds.cern.ch/record/2767187},
+  year         = {2021}
+}
+
+% Python
+@book{Rossum:2009,
+  author    = {Van Rossum, Guido and Drake, Fred L.},
+  title     = {Python 3 Reference Manual},
+  year      = {2009},
+  isbn      = {1441412697},
+  publisher = {CreateSpace},
+  address   = {Scotts Valley, CA}
+}
+
+% PEP8
+@techreport{PEP8:2001,
+  author = {Guido van Rossum and Barry Warsaw and Nick Coghlan},
+  title  = {Style Guide for {Python} Code},
+  year   = {2001},
+  type   = {PEP},
+  number = {8},
+  url    = {https://www.python.org/dev/peps/pep-0008/}
+}
+
+% YAML
+@misc{YAML:2021,
+  title        = {{YAML} Ain’t Markup Language (YAML™) version 1.2},
+  howpublished = {\url{https://yaml.org/spec/1.2.2/}},
+  year         = 2001,
+  note         = {Accessed: 2023-05-11}
+}
+
+% Setuptools
+@misc{setuptools:2023,
+  title        = {{Setuptools}},
+  howpublished = {\url{https://github.com/pypa/setuptools}},
+  year         = 2013,
+  note         = {Accessed: 2023-05-11}
+}
+
+
+% Flake8
+@misc{flake8:2023,
+  title        = {{Flake8}},
+  howpublished = {\url{https://github.com/PyCQA/flake8}},
+  year         = 2010,
+  note         = {Accessed: 2023-05-11}
+}
+
+% Black
+@misc{black:2023,
+  title        = {{Black}},
+  howpublished = {\url{https://github.com/psf/black}},
+  year         = 2018,
+  note         = {Accessed: 2023-05-11}
+}
+
+% Pytest
+@misc{pytest:2004,
+  title  = {pytest 7.3},
+  author = {Krekel, Holger and Oliveira, Bruno and Pfannschmidt, Ronny and Bruynooghe, Floris and Laugher, Brianna and Bruhin, Florian},
+  year   = {2004},
+  url    = {https://github.com/pytest-dev/pytest}
+}
+
+% mkdocs
+@misc{mkdocs:2023,
+  title        = {{MkDocs}},
+  howpublished = {\url{https://github.com/mkdocs/mkdocs}},
+  year         = 2014,
+  note         = {Accessed: 2023-05-11}
+}
+
+% Pytest
+@misc{sphinx:2023,
+  title  = {Sphinx},
+  author = {Brandl, Georg},
+  year   = {2008},
+  url    = {https://www.sphinx-doc.org}
+}
+
+
+% Docker
+@article{Merkel:2014,
+  title   = {Docker: lightweight linux containers for consistent development and deployment},
+  author  = {Merkel, Dirk},
+  journal = {Linux journal},
+  volume  = {2014},
+  number  = {239},
+  pages   = {2},
+  year    = {2014}
+}
+
+% TensorFlow
+@misc{tensorflow:2015,
+  title  = { {TensorFlow}: Large-Scale Machine Learning on Heterogeneous Systems},
+  url    = {https://www.tensorflow.org/},
+  note   = {Software available from tensorflow.org},
+  author = {
+            Mart\'{i}n~Abadi and
+            Ashish~Agarwal and
+            Paul~Barham and
+            Eugene~Brevdo and
+            Zhifeng~Chen and
+            Craig~Citro and
+            Greg~S.~Corrado and
+            Andy~Davis and
+            Jeffrey~Dean and
+            Matthieu~Devin and
+            Sanjay~Ghemawat and
+            Ian~Goodfellow and
+            Andrew~Harp and
+            Geoffrey~Irving and
+            Michael~Isard and
+            Yangqing Jia and
+            Rafal~Jozefowicz and
+            Lukasz~Kaiser and
+            Manjunath~Kudlur and
+            Josh~Levenberg and
+            Dandelion~Man\'{e} and
+            Rajat~Monga and
+            Sherry~Moore and
+            Derek~Murray and
+            Chris~Olah and
+            Mike~Schuster and
+            Jonathon~Shlens and
+            Benoit~Steiner and
+            Ilya~Sutskever and
+            Kunal~Talwar and
+            Paul~Tucker and
+            Vincent~Vanhoucke and
+            Vijay~Vasudevan and
+            Fernanda~Vi\'{e}gas and
+            Oriol~Vinyals and
+            Pete~Warden and
+            Martin~Wattenberg and
+            Martin~Wicke and
+            Yuan~Yu and
+            Xiaoqiang~Zheng},
+  year   = {2015}
+}
+
+% Keras
+@misc{chollet:2015,
+  title        = {Keras},
+  author       = {Chollet, Fran\c{c}ois},
+  year         = {2015},
+  howpublished = {\url{https://keras.io}}
+}
+
+% LWTNN
+@misc{Guest:2022,
+  doi       = {10.5281/ZENODO.6467676},
+  url       = {https://zenodo.org/record/6467676},
+  author    = {Guest,  Daniel Hay and Smith,  Joshua Wyatt and Paganini,  Michela and Kagan,  Michael and Lanfermann,  Marie and Krasznahorkay,  Attila and Marley,  Daniel Edison and Ghosh,  Aishik and Huth,  Benjamin and Feickert,  Matthew},
+  title     = {lwtnn/lwtnn: Version 2.13},
+  publisher = {Zenodo},
+  year      = {2022},
+  copyright = {MIT License}
+}
+
+% HDF5
+@online{hdf5:2023,
+  author = {{The HDF Group}},
+  title  = {{Hierarchical Data Format, version 5}},
+  year   = {1997},
+  note   = {https://www.hdfgroup.org/HDF5/}
+}
+
+% Puma
+@misc{Birk:2023,
+  doi       = {10.5281/ZENODO.7806395},
+  url       = {https://zenodo.org/record/7806395},
+  author    = {Birk,  Joschka and Froch,  Alexander and VS,  Sam and Guth,  Manuel and Gadow,  Philipp and Schr\"oer, Tomke and Kobylianskii,  Dmitrii and Rettie,  Sébastien and Strebler,  Thomas},
+  title     = {umami-hep/puma: v0.2.4},
+  publisher = {Zenodo},
+  year      = {2023},
+  copyright = {Open Access}
+}
+
+% Matplotlib
+@article{Hunter:2007,
+  author    = {Hunter, J. D.},
+  title     = {Matplotlib: A 2D graphics environment},
+  journal   = {Computing in Science \& Engineering},
+  volume    = {9},
+  number    = {3},
+  pages     = {90--95},
+  abstract  = {Matplotlib is a 2D graphics package used for Python for
+               application development, interactive scripting, and publication-quality
+               image generation across user interfaces and operating systems.},
+  publisher = {IEEE COMPUTER SOC},
+  doi       = {10.1109/MCSE.2007.55},
+  year      = 2007
+}
+
+% Machine learning in HEP
+@article{Guest:2018,
+  doi       = {10.1146/annurev-nucl-101917-021019},
+  url       = {https://doi.org/10.1146/annurev-nucl-101917-021019},
+  year      = {2018},
+  month     = oct,
+  publisher = {Annual Reviews},
+  volume    = {68},
+  number    = {1},
+  pages     = {161--181},
+  author    = {Dan Guest and Kyle Cranmer and Daniel Whiteson},
+  title     = {Deep Learning and Its Application to {LHC} Physics},
+  journal   = {Annual Review of Nuclear and Particle Science}
+}
+
+% Machine learning in CMS
+@article{Cagnotta:2022,
+  doi       = {10.3390/app122010574},
+  url       = {https://doi.org/10.3390/app122010574},
+  year      = {2022},
+  month     = oct,
+  publisher = {{MDPI} {AG}},
+  volume    = {12},
+  number    = {20},
+  pages     = {10574},
+  author    = {Antimo Cagnotta and Francesco Carnevali and Agostino De Iorio},
+  title     = {Machine Learning Applications for Jet Tagging in the {CMS} Experiment},
+  journal   = {Applied Sciences}
+}
+
+% DeepJet
+@article{Bols:2020,
+  doi       = {10.1088/1748-0221/15/12/p12012},
+  url       = {https://doi.org/10.1088/1748-0221/15/12/p12012},
+  year      = {2020},
+  month     = dec,
+  publisher = {{IOP} Publishing},
+  volume    = {15},
+  number    = {12},
+  pages     = {P12012--P12012},
+  author    = {E. Bols and J. Kieseler and M. Verzetti and M. Stoye and A. Stakia},
+  title     = {Jet flavour classification using {DeepJet}},
+  journal   = {Journal of Instrumentation}
+}
+
+% ATLAS Flavour Tagging paper
+@article{ATLAS:2019,
+  author        = {{ATLAS Collaboration}},
+  title         = {{ATLAS flavour-tagging algorithms for the LHC Run~2 \(pp\) collision dataset}},
+  year          = {2022},
+  reportnumber  = {CERN-EP-2022-226},
+  eprint        = {2211.16345},
+  archiveprefix = {arXiv},
+  primaryclass  = {physics.data-an}
+}
+
+
+% ParticleNet
+@article{Qu:2020,
+  doi       = {10.1103/physrevd.101.056019},
+  url       = {https://doi.org/10.1103/physrevd.101.056019},
+  year      = {2020},
+  month     = mar,
+  publisher = {American Physical Society ({APS})},
+  volume    = {101},
+  number    = {5},
+  author    = {Huilin Qu and Loukas Gouskos},
+  title     = {Jet tagging via particle clouds},
+  journal   = {Physical Review D}
+}
+
+% ParT
+@inproceedings{Qu:2022,
+  title     = {Particle Transformer for Jet Tagging},
+  author    = {Qu, Huilin and Li, Congqiao and Qian, Sitian},
+  booktitle = {Proceedings of the 39th International Conference on Machine Learning},
+  pages     = {18281--18292},
+  year      = {2022},
+  editor    = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan},
+  volume    = {162},
+  series    = {Proceedings of Machine Learning Research},
+  month     = {17--23 Jul},
+  publisher = {PMLR},
+  pdf       = {https://proceedings.mlr.press/v162/qu22b/qu22b.pdf},
+  url       = {https://proceedings.mlr.press/v162/qu22b.html},
+  abstract  = {Jet tagging is a critical yet challenging classification task in particle physics. While deep learning has transformed jet tagging and significantly improved performance, the lack of a large-scale public dataset impedes further enhancement. In this work, we present JetClass, a new comprehensive dataset for jet tagging. The JetClass dataset consists of 100 M jets, about two orders of magnitude larger than existing public datasets. A total of 10 types of jets are simulated, including several types unexplored for tagging so far. Based on the large dataset, we propose a new Transformer-based architecture for jet tagging, called Particle Transformer (ParT). By incorporating pairwise particle interactions in the attention mechanism, ParT achieves higher tagging performance than a plain Transformer and surpasses the previous state-of-the-art, ParticleNet, by a large margin. The pre-trained ParT models, once fine-tuned, also substantially enhance the performance on two widely adopted jet tagging benchmarks. The dataset, code and models are publicly available at https://github.com/jet-universe/particle_transformer.}
+}
+
+% ADAM
+@inproceedings{Kingma:2015,
+  author    = {Diederik P. Kingma and
+               Jimmy Ba},
+  editor    = {Yoshua Bengio and
+               Yann LeCun},
+  title     = {Adam: {A} Method for Stochastic Optimization},
+  booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015,
+               San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings},
+  year      = {2015},
+  url       = {http://arxiv.org/abs/1412.6980},
+  timestamp = {Thu, 25 Jul 2019 14:25:37 +0200},
+  biburl    = {https://dblp.org/rec/journals/corr/KingmaB14.bib},
+  bibsource = {dblp computer science bibliography, https://dblp.org}
+}
+
+
+% JetClass
+@misc{JetClass:2022,
+  doi       = {10.5281/ZENODO.6619768},
+  url       = {https://zenodo.org/record/6619768},
+  author    = {Qu,  Huilin and Li,  Congqiao and Qian,  Sitian},
+  keywords  = {Particle physics,  Jet,  Jet tagging,  Machine learning},
+  title     = {JetClass: A Large-Scale Dataset for Deep Learning in Jet Physics},
+  publisher = {Zenodo},
+  year      = {2022},
+  copyright = {Creative Commons Attribution 4.0 International}
+}
+
+% Deep Learning
+@article{LeCun:2015,
+  doi       = {10.1038/nature14539},
+  url       = {https://doi.org/10.1038/nature14539},
+  year      = {2015},
+  month     = may,
+  publisher = {Springer Science and Business Media {LLC}},
+  volume    = {521},
+  number    = {7553},
+  pages     = {436--444},
+  author    = {Yann LeCun and Yoshua Bengio and Geoffrey Hinton},
+  title     = {Deep learning},
+  journal   = {Nature}
+}
+
+% Deep Sets
+@inproceedings{Zaheer:2017,
+  author    = {Zaheer, Manzil and Kottur, Satwik and Ravanbhakhsh, Siamak and P\'{o}czos, Barnab\'{a}s and Salakhutdinov, Ruslan and Smola, Alexander J},
+  title     = {Deep Sets},
+  year      = {2017},
+  isbn      = {9781510860964},
+  publisher = {Curran Associates Inc.},
+  address   = {Red Hook, NY, USA},
+  booktitle = {Proceedings of the 31st International Conference on Neural Information Processing Systems},
+  pages     = {3394–3404},
+  numpages  = {11},
+  location  = {Long Beach, California, USA},
+  series    = {NIPS'17}
+}
+
+% SHAP
+@incollection{NIPS:2017,
+  title     = {A Unified Approach to Interpreting Model Predictions},
+  author    = {Lundberg, Scott M and Lee, Su-In},
+  booktitle = {Advances in Neural Information Processing Systems 30},
+  editor    = {I. Guyon and U. V. Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett},
+  pages     = {4765--4774},
+  year      = {2017},
+  publisher = {Curran Associates, Inc.},
+  url       = {http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf}
+}
+
+@incollection{pytorch,
+  title     = {PyTorch: An Imperative Style, High-Performance Deep Learning Library},
+  author    = {Paszke, Adam and Gross, Sam and Massa, Francisco and Lerer, Adam and Bradbury, James and Chanan, Gregory and Killeen, Trevor and Lin, Zeming and Gimelshein, Natalia and Antiga, Luca and Desmaison, Alban and Kopf, Andreas and Yang, Edward and DeVito, Zachary and Raison, Martin and Tejani, Alykhan and Chilamkurthy, Sasank and Steiner, Benoit and Fang, Lu and Bai, Junjie and Chintala, Soumith},
+  booktitle = {Advances in Neural Information Processing Systems 32},
+  pages     = {8024--8035},
+  year      = {2019},
+  publisher = {Curran Associates, Inc.},
+  url       = {http://papers.neurips.cc/paper/9015-pytorch-an-imperative-style-high-performance-deep-learning-library.pdf}
+}
+
+@software{lightning,
+  author  = {Falcon, William and {The PyTorch Lightning team}},
+  doi     = {10.5281/zenodo.3828935},
+  license = {Apache-2.0},
+  month   = mar,
+  title   = {{PyTorch Lightning}},
+  url     = {https://github.com/Lightning-AI/lightning},
+  version = {1.4},
+  year    = {2019}
+}
+
+@techreport{GN1,
+  collaboration = {ATLAS},
+  title         = {{Graph Neural Network Jet Flavour Tagging with the ATLAS
+                   Detector}},
+  institution   = {CERN},
+  reportnumber  = {ATL-PHYS-PUB-2022-027},
+  address       = {Geneva},
+  year          = {2022},
+  url           = {https://cds.cern.ch/record/2811135},
+  note          = {All figures including auxiliary figures are available at
+                   https://atlas.web.cern.ch/Atlas/GROUPS/PHYSICS/PUBNOTES/ATL-PHYS-PUB-2022-027}
+}
+
+@techreport{GN2X,
+  collaboration = {ATLAS},
+  title         = {{Transformer Neural Networks for Identifying Boosted Higgs
+                   Bosons decaying into $b\bar{b}$ and $c\bar{c}$ in ATLAS}},
+  institution   = {CERN},
+  reportnumber  = {ATL-PHYS-PUB-2023-021},
+  address       = {Geneva},
+  year          = {2023},
+  url           = {https://cds.cern.ch/record/2866601},
+  note          = {All figures including auxiliary figures are available at
+                   https://atlas.web.cern.ch/Atlas/GROUPS/PHYSICS/PUBNOTES/ATL-PHYS-PUB-2023-021}
+}
+
+@misc{onnx,
+  author       = {Bai, Junjie and Lu, Fang and Zhang, Ke and others},
+  title        = {ONNX: Open Neural Network Exchange},
+  year         = {2019},
+  publisher    = {GitHub},
+  journal      = {GitHub repository},
+  howpublished = {\url{https://github.com/onnx/onnx}},
+  commit       = {94d238d96e3fb3a7ba34f03c284b9ad3516163be}
+}
+
+
+@misc{umami,
+  author       = {Barr, Jackson and others},
+  title        = {Umami: A Python toolkit for jet flavour tagging},
+  year         = {2024},
+  publisher    = {GitHub},
+  journal      = {GitHub repository},
+  howpublished = {\url{https://github.com/umami-hep/umami-preprocessing}},
+  commit       = {640369546e65937db79f0f7bbc86ea4c3114943c}
+}
+
+
+@article{2017arXiv170603762V,
+  author        = {{Vaswani}, Ashish and {Shazeer}, Noam and {Parmar}, Niki and {Uszkoreit}, Jakob and {Jones}, Llion and {Gomez}, Aidan N. and {Kaiser}, Lukasz and {Polosukhin}, Illia},
+  title         = {{Attention Is All You Need}},
+  journal       = {arXiv e-prints},
+  keywords      = {Computer Science - Computation and Language, Computer Science - Machine Learning},
+  year          = 2017,
+  month         = jun,
+  eid           = {arXiv:1706.03762},
+  pages         = {arXiv:1706.03762},
+  doi           = {10.48550/arXiv.1706.03762},
+  archiveprefix = {arXiv},
+  eprint        = {1706.03762},
+  primaryclass  = {cs.CL},
+  adsurl        = {https://ui.adsabs.harvard.edu/abs/2017arXiv170603762V},
+  adsnote       = {Provided by the SAO/NASA Astrophysics Data System}
+}
+
diff --git a/paper/paper.md b/paper/paper.md
new file mode 100644
index 0000000000000000000000000000000000000000..dc1689b185916ec304e04bf3766572741eb6cc0e
--- /dev/null
+++ b/paper/paper.md
@@ -0,0 +1,145 @@
+---
+title: 'Salt: Multimodal Multitask Machine Learning for High Energy Physics'
+
+tags:
+  - Python
+  - high energy physics
+  - machine learning
+  - jet physics
+  - flavour tagging
+
+authors:
+  - name: Jackson Barr
+    orcid: 0000-0002-9752-9204 
+    affiliation: 1
+  - name: Diptaparna Biswas
+    orcid: 0000-0002-7543-3471
+    affiliation: 2
+  - name: Maxence Draguet
+    orcid: 0000-0003-1530-0519
+    affiliation: 3
+  - name: Philipp Gadow
+    orcid: 0000-0003-4475-6734
+    affiliation: 4
+  - name: Emil Haines
+    orcid: 0000-0002-5417-2081
+    affiliation: 1
+  - name: Osama Karkout
+    orcid: 0000-0002-4907-9499
+    affiliation: 5
+  - name: Dmitrii Kobylianskii
+    orcid: 0009-0002-0070-5900
+    affiliation: 6
+  - name: Wei Sheng Lai
+    orcid: 0009-0001-6726-9851
+    affiliation: 1
+  - name: Matthew Leigh
+    orcid: 0000-0003-1406-1413
+    affiliation: 7
+  - name: Nicholas Luongo
+    orcid: 0000-0001-6527-0253
+    affiliation: 10
+  - name: Ivan Oleksiyuk
+    orcid: 0000-0002-4784-6340
+    affiliation: 7
+  - name: Nikita Pond
+    orcid: 0000-0002-5966-0332
+    affiliation: 1
+  - name: Sébastien Rettie
+    orcid: 0000-0002-7092-3893
+    affiliation: 4
+  - name: Andrius Vaitkus
+    orcid: 0000-0002-0393-666X
+    affiliation: 1
+  - name: Samuel Van Stroud
+    orcid: 0000-0002-7969-0301
+    affiliation: 1
+  - name: Johannes Wagner
+    orcid: 0000-0002-5588-0020
+    affiliation: 9
+
+affiliations:
+ - name: University College London, United Kingdom
+   index: 1
+ - name: University of Siegen
+   index: 2
+ - name: University of Oxford, United Kingdom
+   index: 3
+ - name: European Laboratory for Particle Physics CERN, Switzerland
+   index: 4
+ - name: Nikhef
+   index: 5
+ - name: Department of Particle Physics and Astrophysics, Weizmann Institute of Science, Israel
+   index: 6
+ - name: Université de Genève, Switzerland
+   index: 7
+ - name: Technical University of Munich, Germany
+   index: 8
+ - name: University of California, Berkeley
+   index: 9
+ - name: Argonne National Laboratory
+   index: 10
+
+date: 15 Janurary 2024
+bibliography: paper.bib
+
+---
+
+# Summary
+
+High energy physics studies the fundamental particles and forces that constitute the universe, often through experiments conducted in large particle accelerators such as the Large Hadron Collider (LHC) [@Evans:2008].
+`Salt` is a Python application developed for the high energy physics community that streamlines the training and deployment of advanced machine learning (ML) models, making them more accessible and promoting shared best practices.
+`Salt` features a generic multimodal, multitask model skeleton which, coupled with a strong emphasis on modularity, configurabiltiy, and ease of use, can be used to tackle a wide variety of high energy physics ML applications.
+
+Some key features of `Salt` are listed below:
+
+- Based on established frameworks: `Salt` is built upon PyTorch [@pytorch] and Lightning [@lightning] for maximum performance and scalability with minimal boilerplate code.
+- Multimodal, multitask models: `Salt` models support multimodal inputs and can be configured to perform various tasks such as classification, regression, segmentation, and edge classification tasks. Any combination of these can be used to flexibly define models for multitask learning problems.
+- Customisable and extensible: `Salt` supports full customisation of training and model configuration through YAML config files. Its modular design allows for the easy integration of custom dataloaders, layers, and models.
+- Train at scale: `Salt` can handle large volumes of data with efficient HDF5 [@hdf5:2023] dataloaders. It also includes multi-GPU support from Lightning, enabling distributed training.
+- Deployment ready: `Salt` facilitates ONNX [@onnx] serialization for integrating models into C++ based software environments.
+
+
+# Statement of need
+
+In high energy physics research the reliance on ML for data analysis and object classification is growing [@Guest:2018; @Cagnotta:2022].
+`Salt` meets this growing need by providing a versatile, performant, and user-friendly tool for developing advanced ML models.
+`Salt` was originally developed to train state of the art flavour tagging models at the ATLAS experiment [@ATLAS:2008] at the LHC.
+Flavour tagging, the identification of jets from bottom and charm quarks, plays a crucial role in analysing ATLAS collision data. This process is key for precision Standard Model measurements, particularly in the characterisation of the Higgs boson, and for investigating new phenomena.
+The unique characteristics of hadrons containing bottom and charm quarks – such as their long lifetimes, high mass, and high decay multiplicity – create distinct signatures in particle detectors that can be effectively exploited by ML algorithms.
+The presence of hadrons containing bottom and charm quarks can be inferred via the identification of approximately 3-5 reconstructed charged particle trajectories from the weak decay of the heavy flavour hadron admist several more tracks from the primary proton-proton interaction vertex.
+
+While initially developed for flavour tagging, `Salt` has evolved into a flexible tool that can be used for a wide range of tasks, from object and event classification, regression of object properties, to object reconstruction (via edge classification or input segmentation), demonstrating its broad applicability across various data analysis challenges in high energy physics.
+
+
+# Model Architecture
+
+Salt is designed to be fully modular, but ships with a flexible model architecture that can be configured for a variety of use cases.
+This architecture facilitates the training of multimodal and multitask models as depicted in \autoref{fig:salt-arch}, and is designed to take advantage of multiple input modalities.
+In the context of jet classification, these input modalities might include global features of the jet and varying numbers of jet constituents such as charged particle trajectories, calorimeter energy depositions, reconstructed leptons, or inner detector spacepoints.
+The architecture is described briefly below.
+First, any global input features are concatentated with the features of each constituent.
+Next, an initial embedding to a shared representation space is performed separately for each type of constituent.
+The different types of constituents are then projected into a shared representation space by a series of initialisation networks.
+The embedded constituents are then combined and fed into a encoder network that processes constituents of different modalities in a unified way.
+The encoder then outputs to a set of task-specific modules, each tailored to a specific learning objective.
+This architecture allows the model to leverage all the available detector information, leading to improved performance.
+A concrete example of this architecture is in use at ATLAS [@GN1; @GN2X].
+
+![This diagram illustrates the flow of information within a generic model trained using `Salt`. In this example, global object features are provided alongisde two types of constituents. The model is configured with three training objectives, each of which may relate to the global object or the one of the constituent modalities. Concatenation is denoted by $\oplus$.\label{fig:salt-arch}](salt-arch.png){ width=90% }
+
+
+# Related work
+
+`Umami` [@umami] is a related software package in use at ATLAS. 
+While `Salt` relies on similar preprocessing techniques as those provided by `Umami`, it provides several additional features which make it a more powerful and flexible tool for creating advanced ML models.
+Namely, `Salt` provides support for multimodal and multitask learning, optimised Transformer encoders [@2017arXiv170603762V], and distributed model training.
+
+
+# Acknowledgements
+
+The development of `Salt` is part of the offline software research and development programme of the ATLAS Collaboration, and we thank the collaboration for its support and cooperation.
+This work is funded in part by the UK's Science and Technology Facilities Council via University College London's Centre for Doctoral Training in Data Intensive Science, and the Royal Society.
+
+
+# References
diff --git a/paper/paper.pdf b/paper/paper.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..c69cb401119550c0282de6948e185d69b2ee3d92
Binary files /dev/null and b/paper/paper.pdf differ
diff --git a/paper/salt-arch.png b/paper/salt-arch.png
new file mode 100644
index 0000000000000000000000000000000000000000..1105b651ff860ad1b9a2392f2a6e511b63bec561
Binary files /dev/null and b/paper/salt-arch.png differ
diff --git a/pyproject.toml b/pyproject.toml
index 0a239924bffee6594970094c5ad33903e6679ac6..d46177759465cd6309c9fd06d7db35755b2ed103 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,7 +32,7 @@ build-backend = "setuptools.build_meta"
 [tool.ruff]
 line-length = 100
 preview = true
-lint.select = ["ALL"]
+lint.select = ["ALL", "D212", "D417"]
 lint.ignore = [
     "COM", "D100", "D101", "D102", "D103", "D104", "D105", "D205", "D401", "EM", "FIX", "FBT",
     "S101", "S404", "S602", "PLR2004", "PLR0912", "PLR0913", "PLR0914", "PLR0915", "PLR0917",
diff --git a/requirements.txt b/requirements.txt
index 6dceab4315e87603e76002d79b3174846e8e9da5..768a89c6c5ae8fb7d4d4cc5290222c1a4a67c34d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -14,7 +14,7 @@ torch==2.2.1
 lightning==2.2.0
 jsonargparse[all]==4.27.5
 torchmetrics==1.2.1
-onnx==1.15.0
+onnx==1.16.0
 onnxruntime==1.15.1
 atlas-ftag-tools==0.1.18
 scipy==1.12.0
diff --git a/salt/callbacks/predictionwriter.py b/salt/callbacks/predictionwriter.py
index 4ecf2dec7dc21a1166724bf5d5335d77e9259964..865a161f2c39f0f7b18760578792b0af0b8c74d2 100644
--- a/salt/callbacks/predictionwriter.py
+++ b/salt/callbacks/predictionwriter.py
@@ -15,6 +15,7 @@ from salt.models.task import (
 )
 from salt.stypes import Vars
 from salt.utils.array_utils import join_structured_arrays, maybe_pad
+from salt.utils.mask_utils import indices_from_mask
 
 
 class PredictionWriter(Callback):
@@ -180,6 +181,8 @@ class PredictionWriter(Callback):
             for out in ["object_class_probs", "object_class_targets", "mask_logits", "tgt_masks"]:
                 if out not in self.outputs["objects"]:
                     self.outputs["objects"][out] = []
+            if "mask_index" not in self.outputs["tracks"]:
+                self.outputs["tracks"]["mask_index"] = []
 
             probs_dtype = np.dtype([(n, self.precision) for n in self.object_params["label_map"]])
             self.outputs["objects"]["object_class_probs"].append(
@@ -189,6 +192,14 @@ class PredictionWriter(Callback):
                 labels["objects"][self.object_params["class_label"]].cpu().numpy()
             )
             self.outputs["objects"]["mask_logits"].append(objects["masks"].cpu().float().numpy())
+            mask_indices = indices_from_mask(objects["masks"].cpu().sigmoid() > 0.5)
+            dtype = np.dtype([("MaskIndex", "i8")])
+            mask_indices = mask_indices.int().cpu().numpy()
+            mask_indices = np.where(~this_pad_masks, mask_indices, -1)
+            # Get the mask index with a default mask cut value of 0.5
+            self.outputs["tracks"]["mask_index"].append(
+                u2s(np.expand_dims(mask_indices, -1), dtype)
+            )
             self.outputs["objects"]["tgt_masks"].append(labels["objects"]["masks"].cpu().numpy())
 
     def on_test_end(self, trainer, module):  # noqa: ARG002
diff --git a/salt/callbacks/saveconfig.py b/salt/callbacks/saveconfig.py
index cba39a8af9d8f2bf065c034672aa2de859ad3c36..f9f47fedfdd5972410ccdecbc9daab60473a33d9 100644
--- a/salt/callbacks/saveconfig.py
+++ b/salt/callbacks/saveconfig.py
@@ -13,6 +13,7 @@ import yaml
 from ftag.git_check import get_git_hash
 from lightning import Callback, LightningModule, Trainer
 from lightning.pytorch.cli import LightningArgumentParser, Namespace
+from lightning.pytorch.loggers import CometLogger
 from s3fs import S3FileSystem
 from s3path import S3Path
 
@@ -138,8 +139,8 @@ class SaveConfigCallback(Callback):
         self.write_yaml_file(self.config, config_path)
 
         # log files as assets
-        #  currently cannot save log files as assests on S3
-        if self.plm.logger is not None and not self.use_S3:
+        # currently cannot save log files as assests on S3
+        if isinstance(self.plm.logger, CometLogger) and not self.use_S3:
             self.plm.logger.experiment.log_asset(config_path)
             self.plm.logger.experiment.log_asset(nd_path)
             self.plm.logger.experiment.log_asset(cd_path)
diff --git a/salt/configs/GN2_extended.yaml b/salt/configs/GN2_extended.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dfabb11b7fac57446d3eb13cfc369578160dda84
--- /dev/null
+++ b/salt/configs/GN2_extended.yaml
@@ -0,0 +1,125 @@
+name: GN2
+
+model:
+  lrs_config:
+    initial: 1e-7
+    max: 5e-4
+    end: 1e-5
+    pct_start: 0.01
+    weight_decay: 1e-5
+
+  model:
+    class_path: salt.models.SaltModel
+    init_args:
+      init_nets:
+        - input_name: tracks
+          dense_config:
+            output_size: &embed_dim 256
+            hidden_layers: [256]
+            activation: &activation ReLU
+
+      encoder:
+        class_path: salt.models.TransformerEncoder
+        init_args:
+          embed_dim: *embed_dim
+          num_layers: 4
+          out_dim: &out_dim 128
+          mha_config:
+            num_heads: 8
+            attention: { class_path: salt.models.ScaledDotProductAttention }
+          dense_config:
+            activation: *activation
+
+      pool_net:
+        class_path: salt.models.GlobalAttentionPooling
+        init_args: { input_size: *out_dim }
+
+      tasks:
+        class_path: torch.nn.ModuleList
+        init_args:
+          modules:
+            - class_path: salt.models.ClassificationTask
+              init_args:
+                name: jets_classification
+                input_name: jets
+                label: flavour_label
+                loss:
+                  class_path: torch.nn.CrossEntropyLoss
+                  init_args: { weight: [2.0, 2.0, 2.0, 1.0, 6.25] }
+                dense_config: &task_dense_config
+                  input_size: *out_dim
+                  output_size: 5
+                  hidden_layers: [128, 64, 32]
+                  activation: *activation
+
+            - class_path: salt.models.ClassificationTask
+              init_args:
+                name: track_origin
+                input_name: tracks
+                label: ftagTruthOriginLabel
+                weight: 0.5
+                loss:
+                  class_path: torch.nn.CrossEntropyLoss
+                  init_args:
+                    weight: [3.92, 83.21, 1.0, 10.22, 7.11, 7.88, 62.91, 19.42]
+                dense_config:
+                  <<: *task_dense_config
+                  output_size: 8
+                  context_size: *out_dim
+
+            - class_path: salt.models.VertexingTask
+              init_args:
+                name: track_vertexing
+                input_name: tracks
+                label: ftagTruthVertexIndex
+                weight: 1.5
+                loss:
+                  class_path: torch.nn.BCEWithLogitsLoss
+                  init_args: { reduction: none }
+                dense_config:
+                  <<: *task_dense_config
+                  input_size: 256
+                  output_size: 1
+                  context_size: *out_dim
+
+data:
+  variables:
+    jets:
+      - pt_btagJes
+      - eta_btagJes
+    tracks:
+      - d0
+      - z0SinTheta
+      - dphi
+      - deta
+      - qOverP
+      - IP3D_signed_d0_significance
+      - IP3D_signed_z0_significance
+      - phiUncertainty
+      - thetaUncertainty
+      - qOverPUncertainty
+      - numberOfPixelHits
+      - numberOfSCTHits
+      - numberOfInnermostPixelLayerHits
+      - numberOfNextToInnermostPixelLayerHits
+      - numberOfInnermostPixelLayerSharedHits
+      - numberOfInnermostPixelLayerSplitHits
+      - numberOfPixelSharedHits
+      - numberOfPixelSplitHits
+      - numberOfSCTSharedHits
+      #- numberOfTRTHits
+      #- leptonID
+
+  train_file: /nfs/dust/atlas/user/nkumari/UPP_latest/umami-preprocessing/upp/configs/prep/output/pp_output_train.h5
+  val_file: /nfs/dust/atlas/user/nkumari/UPP_latest/umami-preprocessing/upp/configs/prep/output/pp_output_val.h5
+  norm_dict: /nfs/dust/atlas/user/nkumari/UPP_latest/umami-preprocessing/upp/configs/prep/output/norm_dict.yaml
+  class_dict: /nfs/dust/atlas/user/nkumari/UPP_latest/umami-preprocessing/upp/configs/prep/output/class_dict.yaml
+
+  batch_size: 4000
+  num_workers: 40
+
+trainer:
+  max_epochs: 40
+  accelerator: gpu
+  devices: 2
+  precision: 16-mixed
diff --git a/salt/configs/GN3.yaml b/salt/configs/GN3.yaml
index 307a67f1cdd552e335c393369ef6669b9c96a4c1..1c3fc70603de78496f61abdbd9ba3482f5b9dc0a 100644
--- a/salt/configs/GN3.yaml
+++ b/salt/configs/GN3.yaml
@@ -11,23 +11,31 @@ model:
   model:
     class_path: salt.models.SaltModel
     init_args:
-      num_register_tokens: 10
 
       init_nets:
         - input_name: tracks
           dense_config:
             output_size: &embed_dim 256
             hidden_layers: [256]
-            activation: &activation ReLU
+            activation: &activation SiLU
 
       encoder:
         class_path: salt.models.TransformerV2
         init_args:
+          num_layers: 8
           embed_dim: *embed_dim
-          num_layers: 4
           out_dim: &out_dim 128
+          attn_type: flash-varlen
+          norm: LayerNorm
+          ls_init: 1.0e-2
+          dense_kwargs:
+            activation: *activation
+            dropout: 0
+            gated: True
           attn_kwargs:
             num_heads: 8
+            dropout: 0.1
+          num_registers: 8
 
       pool_net:
         class_path: salt.models.GlobalAttentionPooling
@@ -42,44 +50,41 @@ model:
                 name: jets_classification
                 input_name: jets
                 label: flavour_label
-                loss:
-                  class_path: torch.nn.CrossEntropyLoss
-                  init_args: { weight: [1.0, 2.0, 2.0] }
+                use_class_dict: True
+                loss: torch.nn.CrossEntropyLoss
                 dense_config: &task_dense_config
                   input_size: *out_dim
-                  output_size: 3
+                  output_size: 4
                   hidden_layers: [128, 64, 32]
                   activation: *activation
 
-            - class_path: salt.models.ClassificationTask
-              init_args:
-                name: track_origin
-                input_name: tracks
-                label: ftagTruthOriginLabel
-                weight: 0.5
-                loss:
-                  class_path: torch.nn.CrossEntropyLoss
-                  init_args:
-                    weight: [4.2, 73.7, 1.0, 17.5, 12.3, 12.5, 141.7, 22.3]
-                dense_config:
-                  <<: *task_dense_config
-                  output_size: 8
-                  context_size: *out_dim
-
-            - class_path: salt.models.VertexingTask
-              init_args:
-                name: track_vertexing
-                input_name: tracks
-                label: ftagTruthVertexIndex
-                weight: 1.5
-                loss:
-                  class_path: torch.nn.BCEWithLogitsLoss
-                  init_args: { reduction: none }
-                dense_config:
-                  <<: *task_dense_config
-                  input_size: 256
-                  output_size: 1
-                  context_size: *out_dim
+# Disabling aux tasks during R&D phase to speed up training!
+#            - class_path: salt.models.ClassificationTask
+#              init_args:
+#                name: track_origin
+#                input_name: tracks
+#                label: ftagTruthOriginLabel
+#                weight: 0.5
+#                loss: torch.nn.CrossEntropyLoss
+#                dense_config:
+#                  <<: *task_dense_config
+#                  output_size: 8
+#                  context_size: *out_dim
+#
+#            - class_path: salt.models.VertexingTask
+#              init_args:
+#                name: track_vertexing
+#                input_name: tracks
+#                label: ftagTruthVertexIndex
+#                weight: 1.5
+#                loss:
+#                  class_path: torch.nn.BCEWithLogitsLoss
+#                  init_args: { reduction: none }
+#                dense_config:
+#                  <<: *task_dense_config
+#                  input_size: 256
+#                  output_size: 1
+#                  context_size: *out_dim
 
 data:
   variables:
@@ -106,8 +111,8 @@ data:
       - numberOfPixelSharedHits
       - numberOfPixelSplitHits
       - numberOfSCTSharedHits
+      - leptonID
       #- numberOfTRTHits
-      #- leptonID
 
   train_file: /unix/atlastracking/samples/ftag_dumps/vertexing/output/pp_output_train.h5
   val_file: /unix/atlastracking/samples/ftag_dumps/vertexing/output/pp_output_val.h5
diff --git a/salt/configs/MaskFormer.yaml b/salt/configs/MaskFormer.yaml
index bd8da4996157733e3ccebc8e88d9fcd521f72cef..532887b59f4cdfa032703d7c239bb44944044d7d 100644
--- a/salt/configs/MaskFormer.yaml
+++ b/salt/configs/MaskFormer.yaml
@@ -28,6 +28,7 @@ model:
             num_heads: 8
           dense_kwargs:
             activation: *activation
+          drop_registers: true
 
 
       mask_decoder:
diff --git a/salt/models/featurewise.py b/salt/models/featurewise.py
index 7695d1456d2fa16650f2c9e4f46581f4f4896211..6ecc8652aa534d9466a4ab94b9023ee3936e6901 100644
--- a/salt/models/featurewise.py
+++ b/salt/models/featurewise.py
@@ -12,6 +12,7 @@ class FeaturewiseTransformation(nn.Module):
         variables: Vars,
         dense_config_scale: dict | None = None,
         dense_config_bias: dict | None = None,
+        apply_norm: bool = False,
     ):
         """Perform feature wise transformations on the features of a layer.
         https://distill.pub/2018/feature-wise-transformations/.
@@ -19,7 +20,7 @@ class FeaturewiseTransformation(nn.Module):
         Parameters
         ----------
         layer : str
-            layer to scale/bias (either "input", or "global")
+            layer to scale/bias (either "input", "encoder", or "global")
         variables : Vars
             Input variables used in the forward pass, set automatically by the framework
         dense_config_scale : dict
@@ -28,30 +29,39 @@ class FeaturewiseTransformation(nn.Module):
         dense_config_bias : dict
             Keyword arguments for [salt.models.Dense][salt.models.Dense],
             the dense network performing the biasing.
+        apply_norm : bool
+            Apply layer normalisation to the transformed features. By default false.
         """
         super().__init__()
 
         self.layer = layer
-        if layer not in {"input", "global"}:
+        if layer not in {"input", "encoder", "global"}:
             raise ValueError(
-                "Featurewise transformations must be applied to either 'input' or 'global' layers."
+                "Select either 'input', 'encoder' or 'global' layers for featurewise nets."
             )
 
         self.scale_net = None
         self.bias_net = None
+        self.num_features = None
+        self.norm = None
 
         if dense_config_scale:
             dense_config_scale["input_size"] = len(variables.get("PARAMETERS", []))
             self.scale_net = Dense(**dense_config_scale)
+            self.num_features = self.scale_net.output_size
         if dense_config_bias:
             dense_config_bias["input_size"] = len(variables.get("PARAMETERS", []))
             self.bias_net = Dense(**dense_config_bias)
+            self.num_features = self.bias_net.output_size
 
         if not self.bias_net and not self.scale_net:
             raise ValueError(
                 "Need to specify at least one dense_config_scale or dense_config_bias."
             )
 
+        if apply_norm:
+            self.norm = nn.LayerNorm(self.num_features)
+
     def forward(self, inputs: dict, features: Tensor):
         if "PARAMETERS" not in inputs:
             raise ValueError("Featurewise transformations require 'PARAMETERS'.")
@@ -60,4 +70,6 @@ class FeaturewiseTransformation(nn.Module):
             features = self.scale_net(x).unsqueeze(1) * features
         if self.bias_net:
             features = torch.add(features, self.bias_net(x).unsqueeze(1))
+        if self.norm:
+            features = self.norm(features)
         return features
diff --git a/salt/models/maskformer.py b/salt/models/maskformer.py
index e505f4915c58845aaf886d35b3b4da67c885a41f..7aaa6fb8aee8e29ddd93e9187bb2f27a4443291f 100644
--- a/salt/models/maskformer.py
+++ b/salt/models/maskformer.py
@@ -4,7 +4,7 @@ import torch
 from torch import Tensor, nn
 
 from salt.models import MaskFormerLoss
-from salt.models.transformer_v2 import GLU, CrossAttention, SelfAttention
+from salt.models.transformer_v2 import GLU, Attention
 from salt.stypes import Tensors
 
 
@@ -108,6 +108,15 @@ class MaskDecoder(nn.Module):
         # apply norm
         q = self.norm1(self.inital_q.expand(x.shape[0], -1, -1))
         x = self.norm2(x)
+        xpad = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype)
+
+        if pad_mask is not None:
+            padpad_mask = torch.zeros(
+                (pad_mask.shape[0], 1), device=pad_mask.device, dtype=pad_mask.dtype
+            )
+            pad_mask = torch.cat([pad_mask, padpad_mask], dim=1)
+
+        x = torch.cat([x, xpad], dim=1)
 
         intermediate_outputs: list | None = [] if self.aux_loss else None
         for layer in self.layers:
@@ -115,8 +124,10 @@ class MaskDecoder(nn.Module):
                 assert intermediate_outputs is not None
                 intermediate_outputs.append({"embed": q, **self.get_preds(q, x, pad_mask)})
             q, x = layer(q, x, kv_mask=pad_mask)
+        mf_preds = self.get_preds(q, x, pad_mask)
 
-        preds["objects"] = {"embed": q, "x": x, **self.get_preds(q, x, pad_mask)}
+        preds["objects"] = {"embed": q, "x": x[:, :-1, :], **mf_preds}
+        preds["objects"]["masks"] = preds["objects"]["masks"][:, :, :-1]
         if self.aux_loss:
             preds["intermediate_outputs"] = intermediate_outputs
 
@@ -134,6 +145,7 @@ def get_masks(x: Tensor, q: Tensor, mask_net: nn.Module, input_pad_mask: Tensor
         pred_masks[input_pad_mask.unsqueeze(1).expand_as(pred_masks)] = torch.finfo(
             pred_masks.dtype
         ).min
+
     return pred_masks
 
 
@@ -151,28 +163,34 @@ class MaskDecoderLayer(nn.Module):
         self.mask_attention = mask_attention
         self.bidirectional_ca = bidirectional_ca
 
-        self.q_ca = CrossAttention(embed_dim=embed_dim, num_heads=n_heads)
-        self.q_sa = SelfAttention(embed_dim=embed_dim, num_heads=n_heads)
+        self.q_ca = Attention(embed_dim=embed_dim, num_heads=n_heads)
+        self.q_sa = Attention(embed_dim=embed_dim, num_heads=n_heads)
         self.q_dense = GLU(embed_dim)
         if bidirectional_ca:
-            self.kv_ca = CrossAttention(embed_dim=embed_dim, num_heads=n_heads)
+            self.kv_ca = Attention(embed_dim=embed_dim, num_heads=n_heads)
             self.kv_dense = GLU(embed_dim)
         self.mask_net = mask_net
 
     def forward(self, q: Tensor, kv: Tensor, kv_mask: Tensor | None = None) -> Tensor:
         attn_mask = None
-
+        # return q, kv
         # if we want to do mask attention
         if self.mask_attention:
-            # If a BoolTensor is provided, positions with ``True`` are not allowed
-            # to attend while ``False`` values will be unchanged.
-            attn_mask = (get_masks(kv, q, self.mask_net, kv_mask).sigmoid() < 0.1).detach()
+            # New attention masking convention with transformers 2
+            # Positions with True are allowed while False are masked
+            # Compute masks and apply sigmoid
+            attn_mask = get_masks(kv, q, self.mask_net, kv_mask).sigmoid()
 
-            # if the attn mask is invalid for a given query, allow it to attend everywhere
-            attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+            # Threshold and detach
+            attn_mask = (attn_mask > 0.9).detach()
+            newmask = torch.all(attn_mask == 0, dim=-1, keepdim=True).expand(attn_mask.shape)
+            # Check if all values along the last dimension are 0 (equivalent to `False` in boolean)
+            # If so, set them to 1 (equivalent to `True` in boolean)
+
+            attn_mask = attn_mask | newmask.bool()
 
         # update queries with cross attention from nodes
-        q = q + self.q_ca(q, kv, kv_mask=kv_mask, attn_mask=attn_mask)
+        q = q + self.q_ca(q, kv=kv, kv_mask=kv_mask, attn_mask=attn_mask)
 
         # update queries with self attention
         q = q + self.q_sa(q)
@@ -184,7 +202,9 @@ class MaskDecoderLayer(nn.Module):
         if self.bidirectional_ca:
             if attn_mask is not None:
                 attn_mask = attn_mask.transpose(1, 2)
-            kv = kv + self.kv_ca(kv, q, q_mask=kv_mask, attn_mask=attn_mask)
-            kv = kv + self.kv_dense(kv)
+                newmask = torch.all(attn_mask == 1, dim=-1, keepdim=True).expand(attn_mask.shape)
+                attn_mask = attn_mask | ~newmask.bool()
 
+            kv = kv + self.kv_ca(kv, q, attn_mask=attn_mask)
+            kv = kv + self.kv_dense(kv)
         return q, kv
diff --git a/salt/models/saltmodel.py b/salt/models/saltmodel.py
index 06dd583b43683f6bd07ec6ba123bd17c0c40a6ba..0fbdda5d3b77ad0d32e9467f8e421caa63f259ab 100644
--- a/salt/models/saltmodel.py
+++ b/salt/models/saltmodel.py
@@ -14,7 +14,6 @@ class SaltModel(nn.Module):
         encoder: nn.Module = None,
         mask_decoder: nn.Module = None,
         pool_net: Pooling = None,
-        num_register_tokens: int = 0,
         merge_dict: dict[str, list[str]] | None = None,
         featurewise_nets: list[dict] | None = None,
     ):
@@ -48,10 +47,6 @@ class SaltModel(nn.Module):
             Pooling network which computes a global representation of the object
             by aggregating over the constituents. If not provided, assume that
             the only inputs are global features (i.e. no constituents).
-        num_register_tokens : int
-            Number of randomly initialised register tokens of the same length as
-            any other input sequences after initialiser networks (e.g. tracks).
-            See https://arxiv.org/abs/2309.16588.
         merge_dict : dict[str, list[str]] | None
             A dictionary that lets the salt concatenate all the input
             representations of the inputs in list[str] and act on them
@@ -63,20 +58,9 @@ class SaltModel(nn.Module):
         """
         super().__init__()
 
-        self.featurewise_nets = None
+        # init featurewise networks
         if featurewise_nets:
-            self.featurewise_nets = nn.ModuleList([
-                FeaturewiseTransformation(**featurewise_net) for featurewise_net in featurewise_nets
-            ])
-        self.featurewise_nets_map = (
-            {featurewise_net.layer: featurewise_net for featurewise_net in self.featurewise_nets}
-            if self.featurewise_nets
-            else {}
-        )
-        # if available, add featurewise net to init net config
-        if "input" in self.featurewise_nets_map:
-            for init_net in init_nets:
-                init_net["featurewise"] = self.featurewise_nets_map["input"]
+            self.init_featurewise(featurewise_nets, init_nets, encoder)
 
         self.init_nets = nn.ModuleList([InitNet(**init_net) for init_net in init_nets])
         self.tasks = tasks
@@ -85,22 +69,6 @@ class SaltModel(nn.Module):
 
         self.pool_net = pool_net
         self.merge_dict = merge_dict
-        self.num_register_tokens = num_register_tokens
-
-        # init register tokens
-        if self.num_register_tokens and not self.encoder:
-            raise ValueError("encoder must be set if num_register_tokens is set")
-        if self.num_register_tokens and self.encoder:
-            self.registers = torch.nn.Parameter(
-                torch.normal(
-                    torch.zeros((self.num_register_tokens, self.encoder.embed_dim)), std=1e-4
-                )
-            )
-            self.register_mask = torch.zeros(self.num_register_tokens, dtype=torch.bool)
-            self.register_buffer("register_mask_buffer", self.register_mask)
-        else:
-            self.registers = None
-            self.register_mask = None
 
         # checks for the global object only setup
         if self.pool_net is None:
@@ -146,12 +114,6 @@ class SaltModel(nn.Module):
         for init_net in self.init_nets:
             xs[init_net.input_name] = init_net(inputs)
 
-        if self.num_register_tokens:
-            batch_size = xs[next(iter(xs))].shape[0]
-            xs["REGISTERS"] = self.registers.expand(batch_size, -1, -1)
-            if pad_masks:
-                pad_masks["REGISTERS"] = self.register_mask_buffer.expand(batch_size, -1)
-
         # handle edge features if present
         edge_x = xs.pop("EDGE", None)
         kwargs = {} if edge_x is None else {"edge_x": edge_x}
@@ -171,8 +133,12 @@ class SaltModel(nn.Module):
                         })
 
         # Generate embedding from encoder, or by concatenating the init net outputs
+        # We should change this such that all encoders return (x, mask)
         if self.encoder:
-            preds = {"embed_xs": self.encoder(xs, pad_mask=pad_masks, **kwargs)}
+            embed_xs = self.encoder(xs, pad_mask=pad_masks, inputs=inputs, **kwargs)
+            if isinstance(embed_xs, tuple):
+                embed_xs, pad_masks = embed_xs
+            preds = {"embed_xs": embed_xs}
         else:
             preds = {"embed_xs": flatten_tensor_dict(xs)}
 
@@ -182,9 +148,9 @@ class SaltModel(nn.Module):
             else (preds, labels, {})
         )
 
-        # apply featurewise transformation to global track representations if configured
-        if "global" in self.featurewise_nets_map:
-            preds["embed_xs"] = self.featurewise_nets_map["global"](inputs, preds["embed_xs"])
+        # apply featurewise transformation to global track embeddings if configured
+        if hasattr(self, "featurewise_global") and self.featurewise_global:
+            preds["embed_xs"] = self.featurewise_global(inputs, preds["embed_xs"])
 
         # pooling
         if self.pool_net:
@@ -229,3 +195,25 @@ class SaltModel(nn.Module):
             loss[task.name] = task_loss
 
         return preds, loss
+
+    def init_featurewise(
+        self, featurewise_nets: list[dict], init_nets: list[dict], encoder: nn.Module
+    ):
+        for featurewise_net in featurewise_nets:
+            if featurewise_net.get("layer") == "input":
+                for init_net in init_nets:
+                    init_net["featurewise"] = FeaturewiseTransformation(**featurewise_net)
+            elif featurewise_net.get("layer") == "encoder":
+                if encoder:
+                    for _layer in range(encoder.num_layers):
+                        encoder.featurewise.append(FeaturewiseTransformation(**featurewise_net))
+                else:
+                    raise ValueError(
+                        "Requested featurewise transforms for encoder, no encoder configured"
+                    )
+            elif featurewise_net.get("layer") == "global":
+                self.featurewise_global = FeaturewiseTransformation(**featurewise_net)
+            else:
+                raise ValueError(
+                    "Select either 'input', 'encoder' or 'global' layers for featurewise nets."
+                )
diff --git a/salt/models/task.py b/salt/models/task.py
index 8b2645c9136dd1424df0dc2dbd25a809f4a55e12..0aae93243e87d01dbe9ce50c6c14caf258a38409 100644
--- a/salt/models/task.py
+++ b/salt/models/task.py
@@ -11,7 +11,7 @@ from salt.utils.array_utils import listify
 from salt.utils.class_names import CLASS_NAMES
 from salt.utils.scalers import RegressionTargetScaler
 from salt.utils.tensor_utils import masked_softmax
-from salt.utils.union_find import get_node_assignment
+from salt.utils.union_find import get_node_assignment_jit
 
 
 class TaskBase(nn.Module, ABC):
@@ -296,6 +296,8 @@ class RegressionTask(RegressionTaskBase):
 
         Parameters
         ----------
+        scaler
+            dummy text
         **kwargs
             Keyword arguments for
             [`salt.models.RegressionTaskBase`][salt.models.RegressionTaskBase].
@@ -329,6 +331,7 @@ class RegressionTask(RegressionTaskBase):
         loss = None
         if targets is not None:
             loss = self.nan_loss(preds, targets) * self.weight
+
         return preds, loss
 
     def run_inference(self, preds: Tensor, targets_dict: Mapping, precision: str = "f4"):
@@ -511,7 +514,7 @@ class VertexingTask(TaskBase):
         return 1 + weights
 
     def run_inference(self, preds: Tensor, pad_mask: Tensor | None = None):
-        preds = get_node_assignment(preds, pad_mask)
+        preds = get_node_assignment_jit(preds, pad_mask)
         preds = mask_fill_flattened(preds, pad_mask)
         dtype = np.dtype([("VertexIndex", "i8")])
         return u2s(preds.int().cpu().numpy(), dtype)
diff --git a/salt/models/transformer.py b/salt/models/transformer.py
index 212a51c592cecae7f2ef8d93dafc8273b83a3ae0..e92508e87eb43af18909feb9e494c5c89854425b 100644
--- a/salt/models/transformer.py
+++ b/salt/models/transformer.py
@@ -6,6 +6,7 @@ from torch import BoolTensor, Tensor, cat, nn
 
 from salt.models.attention import MultiheadAttention
 from salt.models.dense import Dense
+from salt.stypes import Tensors
 
 
 class TransformerEncoderLayer(nn.Module):
@@ -171,6 +172,7 @@ class TransformerEncoder(nn.Module):
         self.out_dim = out_dim
         self.update_edges = update_edges
         self.muP = muP
+        self.featurewise = nn.ModuleList()
 
         self.layers = nn.ModuleList([
             TransformerEncoderLayer(
@@ -203,6 +205,7 @@ class TransformerEncoder(nn.Module):
         x: Tensor | dict,
         edge_x: Tensor = None,
         pad_mask: Tensor | dict | None = None,
+        inputs: Tensors = None,
         **kwargs,
     ) -> Tensor:
         """Pass the input through all layers sequentially."""
@@ -212,7 +215,9 @@ class TransformerEncoder(nn.Module):
         if isinstance(pad_mask, dict):
             pad_mask = cat(list(pad_mask.values()), dim=1)
 
-        for layer in self.layers:
+        for i, layer in enumerate(self.layers):
+            if len(self.featurewise) > 0:
+                x = self.featurewise[i](inputs, x)
             if edge_x is not None:
                 x, edge_x = layer(x, edge_x, pad_mask=pad_mask, **kwargs)
             else:
diff --git a/salt/models/transformer_v2.py b/salt/models/transformer_v2.py
index e9f25b53a580f9ea3ede0578750081a62d0ab233..6c1887e0fe371ea2a21a3575704681d691f010f8 100644
--- a/salt/models/transformer_v2.py
+++ b/salt/models/transformer_v2.py
@@ -9,56 +9,61 @@ Features:
 - RMSNorm https://arxiv.org/abs/1910.07467
 """
 
-from abc import ABC
+import warnings
+from functools import partial
 
 import torch
-from torch import BoolTensor, Size, Tensor, nn
+import torch.nn.functional as F
+from torch import BoolTensor, Tensor, nn
 
 import salt.models.layernorm as layernorms
+from salt.stypes import Tensors
+from salt.utils.tensor_utils import redo_padding, undo_padding
 
 
 def merge_masks(
-    q_mask: BoolTensor | None,
     kv_mask: BoolTensor | None,
     attn_mask: BoolTensor | None,
-    q_shape: Size,
-    k_shape: Size,
-) -> BoolTensor:
+    q_shape: Tensor,
+) -> BoolTensor | None:
     """Create a full attention mask which incorporates the padding information.
 
-    Using pytorch transformer convention:
+    Using pytorch transformer convention for padding
         False: Real node
         True:  Zero padded
 
+    Using pytorch transformer convention for attention mask
+        False:  Not allowed in attention mechanism
+        True:   Allowed in attention mechanism
+
+    Designing attention mask such that padded tokens can't send information.
+    But they can receive them.
+    This prevents Nans in the attention scores caused by the softmax
+
     Parameters
     ----------
-    q_mask : BoolTensor | None
-        Mask for the queries, of shape (batch, q_len).
     kv_mask : BoolTensor | None
         Mask for the keys and values, of shape (batch, kv_len).
     attn_mask : BoolTensor | None
         Full attention mask, of shape (batch, q_len, kv_len).
     q_shape : Size
         Shape of the queries tensor, (batch, q_len, dim).
-    k_shape : Size
-        Shape of the keys tensor, (batch, kv_len, dim).
     """
     # Create the full mask which combines the attention and padding masks
     mask = None
 
-    # if both masks exist, combine them
-    if q_mask is not None and kv_mask is not None:
-        mask = q_mask.unsqueeze(-1) | kv_mask.unsqueeze(-2)
-
-    # if only one mask exists, expand it to the other dimension
-    if q_mask is None and kv_mask is not None:
+    # if the kv_mask mask exists, ensure that padded tokens never send information
+    if kv_mask is not None:
         mask = kv_mask.unsqueeze(-2).expand(-1, q_shape[-2], -1)
-    if kv_mask is None and q_mask is not None:
-        mask = q_mask.unsqueeze(-1).expand(-1, -1, k_shape[-2])
+        mask = ~mask  # convert the mask such that True is a valid token
 
     # include the attention mask
     if attn_mask is not None:
-        mask = attn_mask if mask is None else attn_mask | mask
+        mask = attn_mask if mask is None else attn_mask & mask
+
+    # Unsqueeze the mask to give it a dimension for num_head broadcasting
+    if mask is not None:
+        mask = mask.unsqueeze(1)
 
     return mask
 
@@ -69,50 +74,88 @@ def repeat_kv(keys: Tensor, values: Tensor, repeats: int, dim: int):
     return keys, values
 
 
-def torch_meff_attn(q: Tensor, k: Tensor, v: Tensor, mask: BoolTensor, dropout: float) -> Tensor:
-    # masking can lead to nans, see
-    # - https://github.com/pytorch/pytorch/issues/110213
-    # - https://github.com/pytorch/pytorch/issues/103749
-    # to get round this, can transform the mask from a bool to float
-    # mask = (1.0 - mask.to(q.dtype)) * torch.finfo(q.dtype).min
-    # but don't need this if add_zero_attn is True
+def change_attn_backends(module: nn.Module, backend: str) -> None:
+    """Recursively change the attention backend of a module and all its children.
 
-    # TODO: change mask convention
-    # https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/salt/-/issues/47
-    if mask is not None:
-        mask = ~mask.contiguous()
+    Used primarily for switching back to torch-math for ONNX exports.
+    """
+    for child in module.children():
+        change_attn_backends(child, backend)
+        if isinstance(child, Attention):
+            child.set_backend(backend)
 
-    return nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout)
 
+def projection_packed(
+    q: Tensor,
+    kv: Tensor | None,
+    weight: Tensor,
+    bias: Tensor | None = None,
+) -> tuple:
+    """Efficient input projection for MHA when using a single linear layer.
 
-def torch_flash_attn(q: Tensor, k: Tensor, v: Tensor, mask: BoolTensor, dropout: float) -> Tensor:
-    assert mask is None, "Flash attention does not support attention masks"
+    Essentially the same as torch.nn.functional._in_projection_packed
+    But here we use chunk which is 40x faster than unflatten
+    Not sure why they don't use chunk in the original implementation...
+
+    Parameters
+    ----------
+    q : Tensor
+        The queries tensor of shape (batch, q_len, dim).
+    kv : Tensor | None
+        The keys and values tensor of shape (batch, kv_len, dim).
+    weight : Tensor
+        The packed weight tensor of the input lienar projection with shape (3 * dim, dim).
+    bias : Tensor | None
+        The optional packed bias tensor of the input linear projection with shape (3 * dim).
+
+    Returns
+    -------
+    q_proj, k_proj, v_proj : tuple
+        The projected queries, keys, and values tensors.
+    """
+    # If the q tensor is the only input, then we assume we are doing self-attention.
+    # This is made (slightly) faster by using a single linear layer, then chunking rather than
+    # three seperate linear layers processed one at a time.
+    if kv is None:
+        return F.linear(q, weight, bias).chunk(3, dim=-1)
+
+    # If the kv tensor is present, then we are doing cross-attention.
+    # This means we must project the q and kv tensors seperately.
+    # The kv linear layer can remain packed, allowing us to project together then chunk,
+    # using the same trick as above. We must however first seperate weights (and biases if present)
+    # of the linear layers for the q and kv parts. We use torch.split which returns a veiw of the
+    # original tensor so this step doesnt required any extra memory or much time.
+    dim = q.size(-1)
+    w_q, w_kv = weight.split([dim, dim * 2])
+    b_q, b_kv = bias.split([dim, dim * 2]) if bias is not None else (None, None)
+
+    # Now we can do the seperate projections
+    q_proj = F.linear(q, w_q, b_q)
+    k_proj, v_proj = F.linear(kv, w_kv, b_kv).chunk(2, dim=-1)
+    return q_proj, k_proj, v_proj
+
+
+def torch_attn(
+    q: Tensor, k: Tensor, v: Tensor, mask: BoolTensor, dropout: float, backend: str
+) -> Tensor:
+    """Torch dot product attention with a switchable backend."""
     with torch.backends.cuda.sdp_kernel(
-        enable_flash=True, enable_math=False, enable_mem_efficient=False
+        enable_math=True,  # always enabled as a fallback
+        enable_mem_efficient=(backend == "torch-meff"),
+        enable_flash=(backend == "torch-flash"),
     ):
-        return nn.functional.scaled_dot_product_attention(
-            q, k, v, attn_mask=mask, dropout_p=dropout
-        )
+        return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=dropout)
 
 
-ATTN_BACKENDS = {
-    "torch-meff": torch_meff_attn,
-    "torch-flash": torch_flash_attn,
-}
-
-
-class Attention(nn.Module, ABC):
+class Attention(nn.Module):
     def __init__(
         self,
         embed_dim: int,
-        num_heads: int,
+        num_heads: int = 1,
         attn_type: str = "torch-meff",
-        n_kv_heads: int | None = None,
-        window_size: int | None = None,
         dropout: float = 0.0,
         bias: bool = True,
-        add_zero_attn: bool = True,
-    ):
+    ) -> None:
         """Multihead attention module.
 
         Parameters
@@ -122,153 +165,154 @@ class Attention(nn.Module, ABC):
         num_heads : int
             Number of attention heads.
         attn_type : str, optional
-            Type of backend kernel to use.
-        n_kv_heads : int | None, optional
-            Number of heads for the keys and values. If None, defaults to num_heads.
-        window_size : int | None, optional
-            Window size for flash attention kernel. If None, defaults to global attention.
+            Name of backend kernel to use.
         dropout : float, optional
             Dropout rate.
         bias : bool, optional
             Whether to include bias terms.
-        add_zero_attn : bool, optional
-            Whether to add a dummy token to attend to. This avoids nan when all tokens are padded.
         """
         super().__init__()
-
+        assert embed_dim % num_heads == 0, "Dim not div by the number of heads!"
+        assert attn_type in {
+            "torch-flash",
+            "torch-math",
+            "torch-meff",
+            "flash-varlen",
+        }, "Invalid attention type!"
+
+        # Attributes
         self.embed_dim = embed_dim
         self.num_heads = num_heads
         self.head_dim = embed_dim // num_heads
-
-        self.n_kv_heads = num_heads if n_kv_heads is None else n_kv_heads
-        assert self.n_kv_heads is not None
-        self.repeats = self.num_heads // self.n_kv_heads
-        self.scale = self.head_dim**-0.5
         self.dropout = dropout
         self.bias = bias
-        self.add_zero_attn = add_zero_attn
 
+        # Better parallelism for self-attention when using parameters directly
+        self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
+        self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) if bias else None
+        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.reset_parameters()
+        self.set_backend(attn_type)
+
+    def set_backend(self, attn_type: str) -> str:
+        # Check the attention backend
         self.attn_type = attn_type
-        self.attn_func = ATTN_BACKENDS[self.attn_type]
-        self.backend = self._flash_backend if self.attn_type == "flash" else self._torch_backend
-        if window_size is None:
-            self.window_size = (-1, -1)
+        if self.attn_type == "flash-varlen":
+            why_not_varlen = ""
+
+            # Try importing the flash-varlen backend
+            try:
+                from flash_attn import flash_attn_varlen_qkvpacked_func
+
+                self.attn_fn = flash_attn_varlen_qkvpacked_func
+            except ImportError:
+                why_not_varlen = (
+                    "Requires the flash_attn package and CUDA 12+ which must be installed "
+                    "separately. See salt/setup/install_flash.sh for installation instructions."
+                )
+
+            # Check if a GPU is available
+            if not torch.cuda.is_available():
+                why_not_varlen = "No GPU available."
+
+            if why_not_varlen:
+                warnings.warn(
+                    f"Cannot use flash-varlen backend. {why_not_varlen} Reverting to torch-math.",
+                    stacklevel=2,
+                )
+                self.attn_type = "torch-math"
+                self.attn_fn = torch_attn
         else:
-            assert attn_type == "flash"
-            assert window_size % 2 == 0
-            self.window_size = (window_size // 2, window_size // 2)
+            self.attn_fn = torch_attn
+
+        return self.attn_type
+
+    def reset_parameters(self):
+        """Initialize the parameters."""
+        nn.init.xavier_uniform_(self.in_proj_weight)
+        if self.bias:
+            nn.init.constant_(self.in_proj_bias, 0.0)
+        self.out_proj.reset_parameters()
+
+    def _varlen_attention(self, x: Tensor, culens: Tensor, maxlen: int) -> Tensor:
+        """Attention forward pass for the flash-varlen backend."""
+        # Perform the packed input projection
+        qkv = F.linear(x, self.in_proj_weight, self.in_proj_bias)
+        qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
+
+        # Run the flash-varlen backend
+        dropout = self.dropout if self.training else 0.0
+        a_out = self.attn_fn(qkv, culens, maxlen, dropout)
+        a_out = a_out.reshape(-1, self.embed_dim)
 
-        self.wq = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=self.bias)
-        self.wk = nn.Linear(self.embed_dim, self.n_kv_heads * self.head_dim, bias=self.bias)
-        self.wv = nn.Linear(self.embed_dim, self.n_kv_heads * self.head_dim, bias=self.bias)
-        self.wo = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=self.bias)
+        # Mix with final linear layer
+        return self.out_proj(a_out)
 
     def forward(
         self,
-        q: Tensor,
-        k: Tensor,
-        v: Tensor,
-        q_mask: BoolTensor | None = None,
+        x: Tensor,
+        kv: Tensor | None = None,
+        mask: BoolTensor | None = None,
         kv_mask: BoolTensor | None = None,
         attn_mask: BoolTensor | None = None,
+        culens: Tensor | None = None,
+        maxlen: int | None = None,
     ) -> Tensor:
         """Attention forward pass.
 
         Parameters
         ----------
-        q : Tensor
-            Queries of shape (batch, q_len, dim).
-        k : Tensor
-            Keys of shape (batch, kv_len, dim).
-        v : Tensor
-            Values of shape (batch, kv_len, dim).
-        q_mask : BoolTensor, optional
-            Mask for the queries, by default None.
+        x : Tensor
+            The pointcloud of shape (batch, x_len, dim).
+        kv : Tensor
+            Optional second pointcloud for cross-attn with shape (batch, kv_len, dim).
+        mask : BoolTensor, optional
+            Mask for the pointcloud x, by default None.
         kv_mask : BoolTensor, optional
-            Mask for the keys and values, by default None.
+            Mask the kv pointcloud, by default None.
         attn_mask : BoolTensor, optional
             Full attention mask, by default None.
+        culens : Tensor, optional
+            Cumulative lengths of the sequences in x, by default None.
+            Only used for the flash-varlen backend.
+        maxlen : int, optional
+            Maximum length of a sequence in the x, by default None.
+            Only used for the flash-varlen backend.
 
         Returns
         -------
         Tensor
-            Output of shape (batch, q_len, dim).
+            Output of shape (batch, x_len, dim).
         """
-        # combine masks
-        attn_mask = merge_masks(q_mask, kv_mask, attn_mask, q.shape, k.shape)
-
-        # input projections
-        q, k, v = self.wq(q), self.wk(k), self.wv(v)
-
-        # add a dummy token to attend to - avoids nan when all tokens are padded
-        if self.add_zero_attn:
-            batch = q.shape[0]
-            zero_attn_shape = (batch, 1, self.embed_dim)
-            k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
-            v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
-            if attn_mask is not None:
-                attn_mask = nn.functional.pad(attn_mask, (0, 1), value=False)
-            if kv_mask is not None:
-                kv_mask = nn.functional.pad(kv_mask, (0, 1), value=False)
+        # the varlen attention backend is called at the begining (different args)
+        if self.attn_type == "flash-varlen":
+            assert kv is None, "flash-varlen only supports self attention!"
+            assert attn_mask is None, "flash-varlen does not support attention masks!"
+            assert culens is not None, "flash-varlen requires culens!"
+            assert maxlen is not None, "flash-varlen requires maxlen!"
+            return self._varlen_attention(x, culens, maxlen)
 
-        # run attention
-        output = self.backend(q, k, v, attn_mask)
-
-        # return output projection
-        return self.wo(output)
-
-    def _torch_backend(self, q: Tensor, k: Tensor, v: Tensor, attn_mask: BoolTensor | None = None):
-        batch, q_len, _ = q.shape
-        _, kv_len, _ = k.shape
-
-        # transform tensors to (batch, num_heads, seq_len, head_dim)
-        q = q.view(batch, q_len, self.num_heads, self.head_dim).transpose(1, 2)
-        k = k.view(batch, kv_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
-        v = v.view(batch, kv_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
+        # Otherwise perform standard attention
+        B, S, D = x.shape
 
-        # repeat keys and values to match number of query heads
-        if self.repeats > 1:
-            k, v = repeat_kv(k, v, self.repeats, dim=-2)
+        # input projections -> B, S, D
+        q, k, v = projection_packed(x, kv, self.in_proj_weight, self.in_proj_bias)
 
-        # expand mask to (batch, num_heads, q_len, kv_len)
-        if attn_mask is not None:
-            attn_mask = attn_mask.view(batch, 1, q_len, kv_len).expand(-1, self.num_heads, -1, -1)
+        # transform tensors to (B, Nh, S, Hd)
+        shape = (B, -1, self.num_heads, self.head_dim)  # Dont use S for cross attn
+        q, k, v = (t.view(shape).transpose(1, 2).contiguous() for t in (q, k, v))
 
         # run attention
-        output = self.attn_func(q, k, v, mask=attn_mask, dropout=self.dropout)
-
-        # recombine heads and return
-        return output.transpose(1, 2).contiguous().view(batch, -1, self.embed_dim)
-
-
-class SelfAttention(nn.Module):
-    def __init__(self, embed_dim: int, **kwargs):
-        """Self attention module.
+        s_mask = mask if kv is None else kv_mask  # Who is sending, x or kv
+        mask = merge_masks(s_mask, attn_mask, q.shape)
+        dropout = self.dropout if self.training else 0.0
+        a_out = torch_attn(q, k, v, mask, dropout, self.attn_type)
 
-        Parameters
-        ----------
-        embed_dim : int
-            Dimension of the input.
-        kwargs : dict
-            Keyword arguments for
-            [salt.models.transformer_v2.Attention][salt.models.transformer_v2.Attention].
-        """
-        super().__init__()
-        self.embed_dim = embed_dim
-        self.attention = Attention(embed_dim=embed_dim, **kwargs)
-
-    def forward(self, x: Tensor, **kwargs) -> Tensor:
-        return self.attention(x, x, x, **kwargs)
-
-
-class CrossAttention(nn.Module):
-    def __init__(self, embed_dim: int, **kwargs):
-        super().__init__()
-        self.embed_dim = embed_dim
-        self.attention = Attention(embed_dim=embed_dim, **kwargs)
+        # recombine heads
+        a_out = a_out.transpose(1, 2).contiguous().view(B, S, D)
 
-    def forward(self, q: Tensor, kv: Tensor, **kwargs) -> Tensor:
-        return self.attention(q, kv, kv, **kwargs)
+        # mix with final linear layer
+        return self.out_proj(a_out)
 
 
 class GLU(nn.Module):
@@ -276,7 +320,8 @@ class GLU(nn.Module):
         self,
         embed_dim: int,
         hidden_dim: int | None = None,
-        activation: str = "ReLU",
+        activation: str = "SiLU",
+        dropout: float = 0.0,
         bias: bool = True,
         gated: bool = False,
     ):
@@ -292,6 +337,8 @@ class GLU(nn.Module):
             Dimension of the hidden layer. If None, defaults to embed_dim * 2.
         activation : str, optional
             Activation function.
+        dropout : float, optional
+            Dropout rate.
         bias : bool, optional
             Whether to include bias in the linear layers.
         gated : bool, optional
@@ -302,18 +349,104 @@ class GLU(nn.Module):
         if hidden_dim is None:
             hidden_dim = embed_dim * 2
 
-        self.in_proj = nn.Linear(embed_dim, hidden_dim, bias=bias)
+        self.gated = gated
+        self.embed_dim = embed_dim
+        self.in_proj = nn.Linear(embed_dim, hidden_dim + hidden_dim * gated, bias=bias)
         self.out_proj = nn.Linear(hidden_dim, embed_dim, bias=bias)
-        self.gate = None
-        if gated:
-            self.gate = nn.Linear(embed_dim, hidden_dim, bias=bias)
+        self.drop = nn.Dropout(dropout)
         self.activation = getattr(nn, activation)()
 
     def forward(self, x: Tensor) -> Tensor:
-        out = self.activation(self.in_proj(x))
-        if self.gate:
-            out = out * self.gate(x)
-        return self.out_proj(out)
+        x = self.in_proj(x)
+        if self.gated:
+            x1, x2 = x.chunk(2, dim=-1)
+            x = self.activation(x1) * x2
+        else:
+            x = self.activation(x)
+        x = self.drop(x)
+        return self.out_proj(x)
+
+
+class LayerScale(nn.Module):
+    """Applies the LayerScale operation from the Cait vision transformer.
+
+    Effective at improving stability and speed of deep transformers.
+    Now the standard for vision transformers
+    https://arxiv.org/abs/2103.17239
+    """
+
+    def __init__(self, dim: int, init_value: float = 1e-3) -> None:
+        super().__init__()
+        self.gamma = nn.Parameter(init_value * torch.ones(dim))
+
+    def forward(self, x: Tensor) -> Tensor:
+        return x * self.gamma
+
+
+class DropPath(nn.Module):
+    """Drop paths for a stochastic depth neural network.
+
+    Used for regularisation when applied to the main path of a residual block.
+    """
+
+    def __init__(self, drop_prob: float = 0.0):
+        super().__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x: Tensor) -> Tensor:
+        if self.drop_prob == 0.0 or not self.training:
+            return x
+        keep_prob = 1 - self.drop_prob
+        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+        random_tensor.floor_()  # binarize
+        return x.div(keep_prob) * random_tensor
+
+
+class PreNormResidual(nn.Module):
+    """Wraps a module with pre-norm with a residual connection.
+
+    Optionally also applies:
+    - LayerScale
+    - DropPath (Stochastic Depth)
+
+    Neat way of doing the most common transformer pattern:
+    - x = x + drop(scale * fn(norm(x)))
+    """
+
+    def __init__(
+        self,
+        fn: nn.Module,
+        norm: str = "LayerNorm",
+        ls_init: float | None = None,
+        drop_path: float = 0.0,
+        embed_dim: int = 0,
+    ) -> None:
+        """Parameters
+        ----------
+        fn : nn.Module
+            The module to wrap. Must be non-resizing.
+        norm : str, optional
+            The normalization method, by default "LayerNorm".
+        ls_init : float | None, optional
+            The initial value for the layerscale, by default 1e-3.
+            If None, then no layerscale is applied.
+        drop_path : float, optional
+            The drop path rate, by default 0.0.
+        embed_dim : int
+            The dimension of the input and output.
+            If zero we will try get it from the fn's own embed_dim attribute.
+        """
+        super().__init__()
+        dim = embed_dim or fn.embed_dim
+        assert dim > 0, "Could not determine embed_dim from fn"
+        self.fn = fn
+        self.norm = getattr(layernorms, norm)(dim)
+        self.ls = LayerScale(dim, ls_init) if ls_init is not None else nn.Identity()
+        self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
+
+    def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
+        return x + self.drop_path(self.ls(self.fn(self.norm(x), *args, **kwargs)))
 
 
 class EncoderLayer(nn.Module):
@@ -321,9 +454,11 @@ class EncoderLayer(nn.Module):
         self,
         embed_dim: int,
         norm: str = "LayerNorm",
+        ls_init: float | None = None,
+        drop_path: float = 0.0,
         dense_kwargs: dict | None = None,
         attn_kwargs: dict | None = None,
-    ):
+    ) -> None:
         """Encoder layer consisting of a self-attention and a feed-forward layer.
 
         Parameters
@@ -332,26 +467,34 @@ class EncoderLayer(nn.Module):
             Dimension of the embeddings at each layer.
         norm : str, optional
             Normalization style, by default "LayerNorm".
+        drop_path : float, optional
+            Drop path rate, by default 0.0.
+        ls_init : float | None, optional
+            Initial value for the layerscale, by default 1e-3.
         dense_kwargs : dict | None, optional
             Keyword arguments for [salt.models.transformer_v2.GLU][salt.models.transformer_v2.GLU].
         attn_kwargs : dict | None, optional
             Keyword arguments for
-            [salt.models.transformer_v2.SelfAttention][salt.models.transformer_v2.SelfAttention].
+            [salt.models.transformer_v2.Attention][salt.models.transformer_v2.Attention].
         """
         super().__init__()
+
+        # Safe defaults
         if attn_kwargs is None:
             attn_kwargs = {}
         if dense_kwargs is None:
             dense_kwargs = {}
+
+        # Attributes
         self.embed_dim = embed_dim
-        self.attn = SelfAttention(embed_dim=embed_dim, **attn_kwargs)
-        self.attn_norm = getattr(layernorms, norm)(embed_dim)
-        self.dense = GLU(embed_dim, **dense_kwargs)
-        self.dense_norm = getattr(layernorms, norm)(embed_dim)
 
-    def forward(self, x: Tensor, pad_mask: BoolTensor) -> Tensor:
-        x = x + self.attn(self.attn_norm(x), kv_mask=pad_mask)
-        return x + self.dense(self.dense_norm(x))
+        # Submodules
+        residual = partial(PreNormResidual, norm=norm, ls_init=ls_init, drop_path=drop_path)
+        self.attn = residual(Attention(embed_dim, **attn_kwargs))
+        self.dense = residual(GLU(embed_dim, **dense_kwargs))
+
+    def forward(self, x: Tensor, **kwargs) -> Tensor:
+        return self.dense(self.attn(x, **kwargs))
 
 
 class DecoderLayer(nn.Module):
@@ -359,24 +502,39 @@ class DecoderLayer(nn.Module):
         self,
         embed_dim: int,
         norm: str = "LayerNorm",
+        ls_init: float | None = 1e-3,
+        drop_path: float = 0.0,
         dense_kwargs: dict | None = None,
         attn_kwargs: dict | None = None,
     ):
         super().__init__()
+
+        # Safe defaults
         if attn_kwargs is None:
             attn_kwargs = {}
         if dense_kwargs is None:
             dense_kwargs = {}
+
+        # Attributes
         self.embed_dim = embed_dim
-        self.attn = CrossAttention(embed_dim=embed_dim, **attn_kwargs)
-        self.q_norm = getattr(layernorms, norm)(embed_dim)
-        self.kv_norm = getattr(layernorms, norm)(embed_dim)
-        self.dense = GLU(embed_dim, **dense_kwargs)
-        self.dense_norm = getattr(layernorms, norm)(embed_dim)
 
-    def forward(self, x: Tensor, kv: Tensor, pad_mask: BoolTensor) -> Tensor:
-        x = x + self.attn(self.q_norm(x), self.kv_norm(kv), kv_mask=pad_mask)
-        return x + self.dense(self.dense_norm(x))
+        # Submodules
+        residual = partial(PreNormResidual, norm=norm, ls_init=ls_init, drop_path=drop_path)
+        self.self_attn = residual(Attention(embed_dim=embed_dim, **attn_kwargs))
+        self.cross_attn = residual(Attention(embed_dim=embed_dim, **attn_kwargs))
+        self.dense = residual(GLU(embed_dim, **dense_kwargs))
+
+    def forward(
+        self,
+        x: Tensor,
+        *,  # Indicates that kv is required
+        kv: Tensor,
+        mask: Tensor | None = None,
+        kv_mask: Tensor | None = None,
+    ) -> Tensor:
+        x = self.self_attn(x, kv_mask=mask)
+        x = self.cross_attn(x, kv=kv, kv_mask=kv_mask)
+        return self.dense(x)
 
 
 class TransformerV2(nn.Module):
@@ -386,9 +544,13 @@ class TransformerV2(nn.Module):
         embed_dim: int,
         out_dim: int | None = None,
         norm: str = "LayerNorm",
+        attn_type: str = "torch-math",
+        do_final_norm: bool = True,
+        num_registers: int = 1,
+        drop_registers: bool = False,
         **kwargs,
-    ):
-        """Transformer model consisting of a series of stacked Transformer encoder layers.
+    ) -> None:
+        """Transformer model consisting of a stack of Transformer encoder layers.
 
         Parameters
         ----------
@@ -400,29 +562,126 @@ class TransformerV2(nn.Module):
             Optionally project the output to a different dimension.
         norm : str, optional
             Normalization style, by default "LayerNorm".
+        attn_type : str, optional
+            The backend for the attention mechanism, by default "torch-flash".
+            Provided here because the varlen backend requires pre/post processing.
+        do_final_norm : bool, optional
+            Whether to apply a final normalization layer, by default True.
+        num_registers : int, optional
+            The number of registers to add to the END of the input sequence.
+            Registers are randomly initialised tokens of the same dimension as
+            any other inputs after initialiser networks. See 2309.16588.
+        drop_registers : bool, optional
+            If to drop the registers from the outputs
         kwargs : dict
             Keyword arguments for [salt.models.transformer_v2.EncoderLayer].
         """
         super().__init__()
+
+        # Check the inputs
+        if num_registers < 1:
+            raise ValueError(
+                "Some jets have no tracks, which causes NaNs in the attention scores. ",
+                "To avoid this, set num_registers to at least 1",
+            )
+
+        # Attributes
         self.num_layers = num_layers
         self.embed_dim = embed_dim
+        self.out_dim = out_dim or embed_dim
+        self.do_final_norm = do_final_norm
+        self.do_out_proj = out_dim is not None
+        self.attn_type = attn_type
+        self.num_registers = num_registers
+        self.drop_registers = drop_registers
 
+        # Submodules
         self.layers = torch.nn.ModuleList([
             EncoderLayer(embed_dim=embed_dim, norm=norm, **kwargs) for _ in range(num_layers)
         ])
-        self.out_norm = getattr(layernorms, norm)(embed_dim if out_dim is None else out_dim)
-        self.out_proj = None
-        if out_dim is not None:
+        self.attn_type = self.set_backend(attn_type)
+
+        # Optional submodules
+        if self.do_out_proj:
             self.out_proj = nn.Linear(self.embed_dim, out_dim)
+        if self.do_final_norm:
+            self.out_norm = getattr(layernorms, norm)(self.out_dim)
+        if self.num_registers:
+            self.registers = nn.Parameter(
+                torch.normal(torch.zeros((self.num_registers, self.embed_dim)), std=1e-4)
+            )
+            self.register_buffer("register_mask", torch.zeros(num_registers, dtype=torch.bool))
+        self.featurewise = nn.ModuleList()
+
+    def set_backend(self, attn_type: str) -> str:
+        for layer in self.layers:
+            attn_type = layer.attn.fn.set_backend(attn_type)
+        return attn_type  # Might change due to library availibility
 
-    def forward(self, x: Tensor, pad_mask: BoolTensor) -> Tensor:
+    def forward(
+        self,
+        x: Tensor,
+        pad_mask: BoolTensor,
+        inputs: Tensors | None = None,
+        **kwargs,
+    ) -> Tensor:
+        # Add the registers to the sequence and the mask
+        if self.num_registers:
+            x, pad_mask = self._add_registers(x, pad_mask)
+
+        # Combine the input sequences if they are dictionaries (don't overwrite pad_mask)
         if isinstance(x, dict):
             x = torch.cat(list(x.values()), dim=1)
-        if isinstance(pad_mask, dict):
-            pad_mask = torch.cat(list(pad_mask.values()), dim=1)
+        mask = torch.cat(list(pad_mask.values()), dim=1) if isinstance(pad_mask, dict) else pad_mask
 
-        for layer in self.layers:
-            x = layer(x, pad_mask)
-        if self.out_proj is not None:
+        # If using the varlen backend, pack the sequence and store the cumulative lengths
+        if self.attn_type == "flash-varlen":
+            x, kwargs["culens"], kwargs["maxlen"] = undo_padding(x, mask)
+
+        # Run through the main transformer encoder layers
+        for i, layer in enumerate(self.layers):
+            if len(self.featurewise) > 0:
+                x = self.featurewise[i](inputs, x)
+            x = layer(x, mask=mask, **kwargs)
+
+        # Run through the optional layers
+        if self.do_out_proj:
             x = self.out_proj(x)
-        return self.out_norm(x)
+        if self.do_final_norm:
+            x = self.out_norm(x)
+
+        # If using the varlen backend, unpack the sequence
+        if self.attn_type == "flash-varlen":
+            x = redo_padding(x, mask)
+
+        # Optionally drop the registers from the output
+        if self.drop_registers:
+            x = x[:, : -self.num_registers]
+            if isinstance(pad_mask, dict):
+                del pad_mask["REGISTERS"]
+            elif isinstance(pad_mask, Tensor):
+                pad_mask = pad_mask[:, : -self.num_registers]
+
+        return x, pad_mask
+
+    def _add_registers(self, x: Tensor | dict, pad_mask: BoolTensor | dict | None) -> tuple:
+        """Add the learnable registers to the end of the input sequence."""
+        # Get the batch size and expand the registers to match
+        B = next(iter(x.values())).size(0) if isinstance(x, dict) else x.size(0)
+
+        # Add as a key or concatenate at the end
+        reg = self.registers.expand(B, -1, -1)
+        if isinstance(x, dict):
+            x["REGISTERS"] = reg
+        else:
+            x = torch.cat([x, reg], dim=1)
+
+        # Also include a mask for the registers
+        if pad_mask is not None:
+            reg_mask = self.register_mask.expand(B, -1)
+            if isinstance(pad_mask, dict):
+                pad_mask["REGISTERS"] = reg_mask
+            else:
+                pad_mask = torch.cat([pad_mask, reg_mask], dim=-1)
+
+        return x, pad_mask
diff --git a/salt/submit/slurm_handler.py b/salt/submit/slurm_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a119ca1a28f9675ff7c70d81e9b5d97fbbe3399
--- /dev/null
+++ b/salt/submit/slurm_handler.py
@@ -0,0 +1,73 @@
+import logging
+import subprocess
+from pathlib import Path
+from typing import Any
+
+logging.basicConfig(level=logging.INFO)
+
+
+class SlurmHandler:
+    """A class to submit batch jobs to a Slurm scheduler.
+
+    Attributes
+    ----------
+    batch_path : Path
+        Path where the batch file which is created will be stored.
+    log_path : Path
+        Path where the batch log files will be stored.
+    base_dir : Path
+        Directory in which batch job will execute its command.
+
+    Methods
+    -------
+    activate_testmode():
+        Activate test mode: check config files in dry runs, no jobs submitted.
+    deactivate_testmode():
+        Deactivate test mode, enable submitting jobs.
+    send_job(command: str, tag: str = "slurm_job"):
+        Submit job by creating and executing Slurm batch file
+    """
+
+    def __init__(self, batch_path: str, log_path: str, basedir: str) -> None:
+        self.batch_path = Path(batch_path)
+        self.log_path = Path(log_path)
+        self.base_dir = Path(basedir) if basedir else Path.cwd()
+        self._tag = "salt_job"
+        # Keywords to be used in Slurm configuration
+        self._slurm_options_dict: dict[str, Any] = {}
+        self._test_mode = False
+
+    def activate_testmode(self) -> None:
+        logging.debug("Activated test mode: not submitting any jobs.")
+        self._test_mode = True
+
+    def deactivate_testmode(self) -> None:
+        logging.debug("Deactivated test mode: submitting jobs.")
+        self._test_mode = False
+
+    def send_job(self, command: str, tag: str = "salt_job") -> None:
+        self._tag = tag
+        batchfile = self._make_batch_file(command)
+        if self._test_mode:
+            logging.debug(f"Created batch file {batchfile}")
+        else:
+            subprocess.call(f"sbatch {batchfile}", shell=True)
+
+    def __setitem__(self, key: str, value: Any) -> None:  # noqa: ANN401
+        self._slurm_options_dict[key] = value
+
+    def _make_batch_file(self, command: str) -> Path:
+        batch_file = self.batch_path / f"sbatch_{self._tag}.sh"
+        with batch_file.open("w") as bf:
+            bf.write(f"""#!/bin/sh
+# {self._tag} batch run script\n""")
+            for key, value in self._slurm_options_dict.items():
+                if value is None:
+                    bf.write(f"#SBATCH --{key}\n")
+                else:
+                    bf.write(f"#SBATCH --{key}={value}\n")
+            bf.write(f"""BASEDIR={self.base_dir};pwd; ls -l\n""")
+            bf.write(f"""{command}""")
+        batch_file.chmod(0o755)
+        logging.debug(f"Made batch file {batch_file}")
+        return batch_file
diff --git a/salt/submit/submit_slurm.py b/salt/submit/submit_slurm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c4021a67aa05e3dd2ff3b1d9cf9607cfb9cd917
--- /dev/null
+++ b/salt/submit/submit_slurm.py
@@ -0,0 +1,145 @@
+import argparse
+from datetime import datetime
+from pathlib import Path
+
+from slurm_handler import SlurmHandler
+
+# Set up argument parser
+parser = argparse.ArgumentParser(description="Submit batch jobs to Slurm.")
+parser.add_argument("-c", "--config", required=True, type=Path, help="Configuration file for job.")
+parser.add_argument("-t", "--tag", default="salt_job", help="Tag for job to be submitted.")
+parser.add_argument("-p", "--partition", default=None, type=str, help="Partition to submit job.")
+parser.add_argument(
+    "-cn", "--constraint", default=None, type=str, help="Constraint on requested resources."
+)
+parser.add_argument("-a", "--account", default=None, type=str, help="Slurm account name.")
+parser.add_argument(
+    "-e",
+    "--environment",
+    default="conda",
+    choices=["conda", "singularity", "local"],
+    help="Environment for job to be submitted.",
+)
+parser.add_argument("-q", "--qos", default=None, type=str, help="Quality Of Service for job")
+parser.add_argument("-n", "--nodes", default=1, type=int, help="Nodes to split training across")
+parser.add_argument("-g", "--gpus_per_node", default=1, type=int, help="GPUs for each node")
+parser.add_argument(
+    "-gt",
+    "--gpu_type",
+    default="",
+    type=str,
+    help="GPU type e.g. v100, leave empty for no preference",
+)
+parser.add_argument("-cpt", "--cpus_per_task", default=10, type=int, help="CPUs for each task")
+parser.add_argument("-m", "--memory", default="100G", type=str, help="Memory per node")
+parser.add_argument("-ex", "--exclusive", action="store_true")
+parser.add_argument("-ti", "--time", default=None, type=str, help="Job time limit e.g. '24:00:00'")
+parser.add_argument("-f", "--force", action="store_true")
+parser.add_argument(
+    "-b",
+    "--bind",
+    nargs="+",
+    help="List of binds for singularity (e.g. /path/to/upp/output:/inputs)",
+)
+parser.add_argument("-r", "--requeue", action="store_true")
+parser.add_argument(
+    "-s",
+    "--signal",
+    default="SIGUSR1@90",
+    type=str,
+    help="Signal from Slurm to trigger Lightning to prepare for requeue",
+)
+parser.add_argument(
+    "-sls",
+    "--salt_log_suffix",
+    default=None,
+    help="Appended to model name to create Salt log directory",
+)
+args = parser.parse_args()
+
+if args.bind and args.environment != "singularity":
+    parser.error("--bind option is only allowed with --environment singularity")
+
+# Define directories
+batch_dir = Path.cwd() / "slurm"
+batch_path = batch_dir / "batch"
+log_path = batch_dir / "batch_logs"
+for directory in [batch_path, log_path]:
+    directory.mkdir(parents=True, exist_ok=True)
+
+# Variables that need to be harmonized between Slurm and salt
+nodes = args.nodes
+gpus_per_node = args.gpus_per_node
+cpus_per_task = args.cpus_per_task
+
+gpu_type = args.gpu_type
+gres = f"gpu:{gpu_type}:{gpus_per_node}" if gpu_type else f"gpu:{gpus_per_node}"
+
+# Set up Slurm options
+job_basedir = Path(__file__).resolve().parent.parent.parent
+handler = SlurmHandler(str(batch_path), str(log_path), str(job_basedir))
+handler["job-name"] = args.tag
+if args.partition is not None:
+    handler["partition"] = args.partition
+if args.constraint is not None:
+    handler["constraint"] = args.constraint
+if args.account is not None:
+    handler["account"] = args.account
+if args.qos is not None:
+    handler["qos"] = args.qos
+handler["nodes"] = nodes
+handler["gres"] = gres
+handler["ntasks-per-node"] = gpus_per_node
+handler["mem"] = args.memory  # memory, 100 GiB - in MiB
+if args.exclusive:
+    handler["exclusive"] = None  # Exclusive access to nodes
+handler["cpus-per-task"] = cpus_per_task  # Don't use this if you have exclusive access to the node
+handler["export"] = "ALL"
+handler["output"] = f"{log_path}/slurm-%j.out"
+handler["error"] = f"{log_path}/slurm-%j.err"
+if args.time is not None:
+    handler["time"] = args.time  # Time limit of job, default is system specified
+if args.requeue:
+    handler["requeue"] = None
+    handler["signal"] = args.signal
+
+log_suffix = args.salt_log_suffix
+if args.requeue and not log_suffix:
+    log_suffix = datetime.now().strftime("%Y%m%d-T%H%M%S")
+
+# Construct and submit the job command
+command = "cd ${BASEDIR} && " "export OMP_NUM_THREADS=1\n"
+if args.environment == "conda":
+    command += (
+        "source conda/bin/activate && conda activate salt\n"
+        'echo "Activated environment ${CONDA_DEFAULT_ENV}"\n'
+    )
+elif args.environment == "singularity":
+    command += "srun singularity exec -e --nv \\\n"
+    command += " \\\n".join([f"--bind {b}" for b in args.bind]) + " \\\n"
+    command += (
+        "--home ${BASEDIR} \\\n"
+        "/cvmfs/unpacked.cern.ch/gitlab-registry.cern.ch/atlas-flavor-tagging-tools/algorithms/salt:latest/ \\\n"  # noqa: E501
+        'sh -c "'
+    )
+command += (
+    "echo 'CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}' &&\n"
+    "cat /proc/cpuinfo | awk '/^processor/{print $3}' | tail -1 &&\n"
+    "cd ${BASEDIR}/salt && pwd &&\n"
+    + ("srun " if args.environment == "conda" else "")
+    + f"salt fit --config {args.config.resolve()} "
+    f"--trainer.devices={gpus_per_node} "
+    f"--trainer.num_nodes={nodes} "
+    f"--data.num_workers={cpus_per_task} "
+)
+
+if args.requeue:
+    command += f"--overwrite_config --log_suffix={log_suffix} "
+
+if args.force:
+    command += "--force "
+if args.environment == "singularity":
+    command += '"'
+
+# handler.activate_testmode() # To inspect batch script before running
+handler.send_job(command, args.tag)
diff --git a/salt/configs/parameterisation_concatenation.yaml b/salt/tests/configs/param_concat.yaml
similarity index 87%
rename from salt/configs/parameterisation_concatenation.yaml
rename to salt/tests/configs/param_concat.yaml
index 14a1f72cc47084236be001588ca30c5fbcfe78cb..f15d2861654a4d363216a3fd6c7f31262d49fa29 100644
--- a/salt/configs/parameterisation_concatenation.yaml
+++ b/salt/tests/configs/param_concat.yaml
@@ -1,4 +1,5 @@
-name: parameterisation_concatenation
+# test config for parameterisation using input concatenation
+name: param_concat
 
 model:
   lrs_config:
@@ -47,10 +48,10 @@ model:
                 label: flavour_label
                 loss:
                   class_path: torch.nn.CrossEntropyLoss
-                  init_args: { weight: [1.0, 2.0, 2.0] }
+                  init_args: { weight: [1.0, 2.0, 2.0, 16.8] }
                 dense_config: &task_dense_config
                   input_size: *out_dim
-                  output_size: 3
+                  output_size: 4
                   hidden_layers: [128, 64, 32]
                   activation: *activation
 
@@ -112,11 +113,6 @@ data:
     PARAMETERS:
       - mass
 
-  train_file: /share/lustre/ehaines/umami-preprocessing/train_test/pp_output_train.h5
-  val_file: /share/lustre/ehaines/umami-preprocessing/val_test/pp_output_val.h5
-  norm_dict: /share/lustre/ehaines/umami-preprocessing/train_test/norm_dict.yaml
-  class_dict: /share/lustre/ehaines/umami-preprocessing/train_test/class_dict.yaml
-
   PARAMETERS:
     mass:
       train: [5, 40, 55]
@@ -127,6 +123,5 @@ data:
 
 trainer:
   max_epochs: 2
-  accelerator: gpu
   devices: 1
   precision: 32
diff --git a/salt/configs/parameterisation_featurewise.yaml b/salt/tests/configs/param_featurewise.yaml
similarity index 88%
rename from salt/configs/parameterisation_featurewise.yaml
rename to salt/tests/configs/param_featurewise.yaml
index 5f113a116b323a53c54299bf687c04d47250d6fa..a653c7bcaf1ccd8d4afedfe43c7e3a708eadd4f7 100644
--- a/salt/configs/parameterisation_featurewise.yaml
+++ b/salt/tests/configs/param_featurewise.yaml
@@ -1,4 +1,5 @@
-name: parameterisation_featurewise
+# test config for parameterisation using featurewise transformations
+name: param_featurewise
 
 model:
   lrs_config:
@@ -25,6 +26,13 @@ model:
           dense_config_bias:
             hidden_layers: [4]
             output_size: 21
+        - layer: encoder
+          dense_config_scale:
+            hidden_layers: [128]
+            output_size: 256
+          dense_config_bias:
+            hidden_layers: [128]
+            output_size: 256
         - layer: global
           dense_config_scale:
             output_size: 128
@@ -60,10 +68,10 @@ model:
                 label: flavour_label
                 loss:
                   class_path: torch.nn.CrossEntropyLoss
-                  init_args: { weight: [1.0, 2.0, 2.0] }
+                  init_args: { weight: [1.0, 2.0, 2.0, 16.8] }
                 dense_config: &task_dense_config
                   input_size: *out_dim
-                  output_size: 3
+                  output_size: 4
                   hidden_layers: [128, 64, 32]
                   activation: *activation
 
@@ -125,11 +133,6 @@ data:
     PARAMETERS:
       - mass
 
-  train_file: /share/lustre/ehaines/umami-preprocessing/train_test/pp_output_train.h5
-  val_file: /share/lustre/ehaines/umami-preprocessing/val_test/pp_output_val.h5
-  norm_dict: /share/lustre/ehaines/umami-preprocessing/train_test/norm_dict.yaml
-  class_dict: /share/lustre/ehaines/umami-preprocessing/train_test/class_dict.yaml
-
   PARAMETERS:
     mass:
       train: [5, 40, 55]
@@ -140,6 +143,5 @@ data:
 
 trainer:
   max_epochs: 2
-  accelerator: gpu
   devices: 1
   precision: 32
diff --git a/salt/tests/test_masks.py b/salt/tests/test_masks.py
index 0ec7603ae60b41743369aa28dfa09fef08831ade..f2bfcce0a0b297d841cba79c60f8797c5e212d66 100644
--- a/salt/tests/test_masks.py
+++ b/salt/tests/test_masks.py
@@ -67,13 +67,20 @@ def test_indices_from_mask_3d(mask_3d, indices_2d):
 
 
 def test_indices_from_mask_empty():
-    mask = torch.tensor([[False, False], [False, False], [False, True]])
+    mask = torch.tensor([
+        [
+            False,
+            False,
+        ],
+        [False, False],
+        [False, False],
+    ])
     indices = indices_from_mask(mask)
-    assert torch.all(indices == torch.tensor([-1, 0]))
+    assert torch.all(indices == torch.tensor([1, 2]))
 
     mask = torch.tensor([
-        [[False, False], [False, False], [False, True]],
+        [[False, False], [False, False], [False, False]],
         [[False, False], [False, False], [False, False]],
     ])
     indices = indices_from_mask(mask)
-    assert torch.all(indices == torch.tensor([[-1, 0], [-1, -1]]))
+    assert torch.all(indices == torch.tensor([[1, 2], [1, 2]]))
diff --git a/salt/tests/test_models.py b/salt/tests/test_models.py
index c391efea3296227c36c1101f8fcabb111da7c9fd..4abddaa81c0f7952aaf07858d8e8dcc58d27f93c 100644
--- a/salt/tests/test_models.py
+++ b/salt/tests/test_models.py
@@ -139,7 +139,7 @@ def test_transformer_cross_attention_encoder() -> None:
     mask["type1"] = torch.zeros(extended_x["type1"].shape[:-1]).bool()
     mask["type1"][:, -1] = True
     out_with_pad = net(extended_x, mask)["type1"][:, :-1]
-    assert torch.all(out["type1"] == out_with_pad)
+    torch.testing.assert_allclose(out["type1"], out_with_pad)
 
 
 def test_mha_allvalid_mask() -> None:
diff --git a/salt/tests/test_pipeline.py b/salt/tests/test_pipeline.py
index 1b75b2a3c33c11c8d778a201fc320e817f246163..a2c873ff169abfda9e0b1dfd405737133b636494 100644
--- a/salt/tests/test_pipeline.py
+++ b/salt/tests/test_pipeline.py
@@ -11,10 +11,11 @@ from salt.utils.inputs import write_dummy_file, write_dummy_norm_dict
 
 w = "ignore::lightning.fabric.utilities.warnings.PossibleUserWarning:"
 CONFIG = "GN2.yaml"
+TAU_CONFIGS = {"GN2.yaml", "GN3.yaml"}
 
 
 def run_train(tmp_path, config_path, train_args, do_xbb=False, do_muP=False, inc_params=False):
-    incl_taus = config_path.name == CONFIG
+    incl_taus = config_path.name in TAU_CONFIGS
     tmp_path = Path(tmp_path)
     train_h5_path = tmp_path / "dummy_train_inputs.h5"
     nd_path = tmp_path / "dummy_norm_dict.yaml"
@@ -28,9 +29,9 @@ def run_train(tmp_path, config_path, train_args, do_xbb=False, do_muP=False, inc
     args += [f"--data.class_dict={cd_path}"]
     args += [f"--data.train_file={train_h5_path}"]
     args += [f"--data.val_file={train_h5_path}"]
-    args += ["--data.num_train=500"]
-    args += ["--data.num_val=200"]
-    args += ["--data.batch_size=100"]
+    args += ["--data.num_train=50"]
+    args += ["--data.num_val=20"]
+    args += ["--data.batch_size=10"]
     args += ["--data.num_workers=0"]
     args += ["--trainer.max_epochs=1"]
     args += ["--trainer.accelerator=cpu"]
@@ -61,8 +62,10 @@ def run_eval(tmp_path, train_config_path, nd_path, do_xbb=False):
     write_dummy_file(test_h5_path, nd_path, do_xbb)
 
     args = ["test"]
+
     args += [f"--config={train_config_path}"]
     args += [f"--data.test_file={test_h5_path}"]
+    args += ["--data.batch_size=100"]
     args += ["--data.num_test=1000"]
     main(args)
 
@@ -108,7 +111,12 @@ def run_onnx(train_dir, args=None):
         args = []
     args += [f"--ckpt_path={ckpt_path}"]
     args += ["--track_selection=dipsLoose202102"]
-    args += args
+
+    if "MaskFormer" in str(train_dir):
+        args += ["-mf=vertexing"]
+    print("ONNX" * 100)
+    print(train_dir)
+    # args += args
     to_onnx(args)
     get_onnx_metadata([str(train_dir / "network.onnx")])
 
@@ -262,10 +270,14 @@ def test_maskformer(tmp_path) -> None:
 
 
 @pytest.mark.filterwarnings(w)
-def test_parameterisation_concatenation(tmp_path) -> None:
-    run_combined(tmp_path, "parameterisation_concatenation.yaml", do_onnx=False, inc_params=True)
+def test_param_concat(tmp_path) -> None:
+    args = [f"--config={Path(__file__).parent.parent / 'tests' / 'configs' / 'param_concat.yaml'}"]
+    run_combined(tmp_path, CONFIG, do_onnx=False, inc_params=True, train_args=args)
 
 
 @pytest.mark.filterwarnings(w)
-def test_parameterisation_featurewise(tmp_path) -> None:
-    run_combined(tmp_path, "parameterisation_featurewise.yaml", do_onnx=False, inc_params=True)
+def test_param_featurewise(tmp_path) -> None:
+    args = [
+        f"--config={Path(__file__).parent.parent / 'tests' / 'configs' / 'param_featurewise.yaml'}"
+    ]
+    run_combined(tmp_path, CONFIG, do_onnx=False, inc_params=True, train_args=args)
diff --git a/salt/tests/test_transformerv2.py b/salt/tests/test_transformerv2.py
index 22df9d499c048e54a91fafbfcb578862e0f7809b..0cd769f261c3294b0587eafba91dc2d6a77b7ba3 100644
--- a/salt/tests/test_transformerv2.py
+++ b/salt/tests/test_transformerv2.py
@@ -1,13 +1,20 @@
-import time
+import importlib.util
 
 import pytest
 import torch
 from torch import nn
+from torch.utils.benchmark import Timer
 
 from salt.models.attention import MultiheadAttention
 from salt.models.layernorm import RMSNorm
-from salt.models.transformer import TransformerEncoderLayer
-from salt.models.transformer_v2 import Attention, DecoderLayer, EncoderLayer, merge_masks
+from salt.models.transformer_v2 import (
+    Attention,
+    DecoderLayer,
+    TransformerV2,
+    merge_masks,
+    redo_padding,
+    undo_padding,
+)
 
 N_BATCH = 10
 Q_SEQ = 20
@@ -21,54 +28,35 @@ def create_bool_tensor(shape, value):
 
 def test_merge_masks_none_inputs():
     q_shape = (N_BATCH, Q_SEQ, DIM)
-    k_shape = (N_BATCH, KV_SEQ, DIM)
-    mask = merge_masks(None, None, None, q_shape, k_shape)
+    mask = merge_masks(None, None, q_shape)
     assert mask is None
 
 
-def test_merge_masks_only_q_mask():
+def test_merge_masks_only_attn_mask():
     q_shape = (N_BATCH, Q_SEQ, DIM)
-    k_shape = (N_BATCH, KV_SEQ, DIM)
-    q_mask = create_bool_tensor(q_shape[:-1], False)
-    mask = merge_masks(q_mask, None, None, q_shape, k_shape)
-    assert mask.shape == (N_BATCH, Q_SEQ, KV_SEQ)
+    attn_shape = (N_BATCH, Q_SEQ, KV_SEQ)
+    attn_mask = create_bool_tensor(attn_shape, False)
+    mask = merge_masks(None, attn_mask, q_shape)
+    assert mask.shape == (N_BATCH, 1, Q_SEQ, KV_SEQ)
 
 
 def test_merge_masks_only_kv_mask():
     q_shape = (N_BATCH, Q_SEQ, DIM)
     k_shape = (N_BATCH, KV_SEQ, DIM)
     kv_mask = create_bool_tensor(k_shape[:-1], False)
-    mask = merge_masks(None, kv_mask, None, q_shape, k_shape)
-    assert mask.shape == (N_BATCH, Q_SEQ, KV_SEQ)
-
-
-def test_merge_masks_q_and_kv_masks():
-    q_shape = (N_BATCH, Q_SEQ, DIM)
-    k_shape = (N_BATCH, KV_SEQ, DIM)
-    q_mask = create_bool_tensor(q_shape[:-1], False)
-    kv_mask = create_bool_tensor(k_shape[:-1], True)
-    mask = merge_masks(q_mask, kv_mask, None, q_shape, k_shape)
-    assert mask.shape == (N_BATCH, Q_SEQ, KV_SEQ)
-    assert torch.all(mask)
+    mask = merge_masks(kv_mask, None, q_shape)
+    assert mask.shape == (N_BATCH, 1, Q_SEQ, KV_SEQ)
 
 
-def test_merge_masks_with_attn_mask():
+def test_merge_masks_attn_and_kv_masks():
     q_shape = (N_BATCH, Q_SEQ, DIM)
     k_shape = (N_BATCH, KV_SEQ, DIM)
-    attn_mask = create_bool_tensor((3, 4, 5), False)
-    mask = merge_masks(None, None, attn_mask, q_shape, k_shape)
-    assert mask.shape == attn_mask.shape
-    assert torch.equal(mask, attn_mask)
-
-
-def test_merge_masks_different_shapes():
-    q_shape = (2, 3, 10)
-    k_shape = (2, 4, 10)
-    q_mask = create_bool_tensor(q_shape[:-1], False)
+    attn_shape = (N_BATCH, Q_SEQ, KV_SEQ)
     kv_mask = create_bool_tensor(k_shape[:-1], False)
-    attn_mask = create_bool_tensor((2, 3, 4), False)
-    mask = merge_masks(q_mask, kv_mask, attn_mask, q_shape, k_shape)
-    assert mask.shape == attn_mask.shape
+    attn_mask = create_bool_tensor(attn_shape, True)
+    mask = merge_masks(kv_mask, attn_mask, q_shape)
+    assert mask.shape == (N_BATCH, 1, Q_SEQ, KV_SEQ)
+    assert torch.all(mask)
 
 
 def test_padding_mask():
@@ -96,132 +84,222 @@ def test_padding_mask():
     # ]])
 
 
-def compare_attention_outputs(custom_attn, torch_attn, q, k, v, kv_mask=None):
-    """Helper function to compare outputs of custom and torch attention modules."""
-    custom_output = custom_attn(q, k, v, kv_mask=kv_mask)
-    torch_output, _ = torch_attn(q, k, v, key_padding_mask=kv_mask)
-    torch.testing.assert_close(custom_output, torch_output)
-    assert not torch.isnan(custom_output).any()
-
+def get_models(dim, num_heads) -> tuple:
+    salt_attn = Attention(dim, num_heads=num_heads)
+    torch_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
+    salt_attn.in_proj_weight = torch_attn.in_proj_weight
+    salt_attn.in_proj_bias = torch_attn.in_proj_bias
+    salt_attn.out_proj.weight = torch_attn.out_proj.weight
+    salt_attn.out_proj.bias = torch_attn.out_proj.bias
+    return salt_attn, torch_attn
 
-def get_models(dim, num_heads, add_zero_attn):
-    salt_attn = Attention(dim, num_heads=num_heads, add_zero_attn=add_zero_attn)
-    torch_attn = nn.MultiheadAttention(
-        dim, num_heads, batch_first=True, add_zero_attn=add_zero_attn
-    )
 
-    # Set the weights of the custom attention module to be the same as the torch module
-    weights = torch.rand((3 * dim, dim))
-    bias = torch.rand(3 * dim)
-    torch_attn.in_proj_weight = nn.Parameter(weights)
-    torch_attn.in_proj_bias = nn.Parameter(bias)
-
-    wq, wk, wv = weights.chunk(3)
-    bq, bk, bv = bias.chunk(3)
-    salt_attn.wq.weight = nn.Parameter(wq)
-    salt_attn.wk.weight = nn.Parameter(wk)
-    salt_attn.wv.weight = nn.Parameter(wv)
-    salt_attn.wq.bias = nn.Parameter(bq)
-    salt_attn.wk.bias = nn.Parameter(bk)
-    salt_attn.wv.bias = nn.Parameter(bv)
-    salt_attn.wo.weight = torch_attn.out_proj.weight
-    salt_attn.wo.bias = torch_attn.out_proj.bias
-    return salt_attn, torch_attn
+def get_cross_attn_inputs(batch_size, q_len, kv_len, dim, frac_pad=0.0) -> tuple:
+    torch.manual_seed(0)
+    q = torch.randn(batch_size, q_len, dim)
+    kv = torch.randn(batch_size, kv_len, dim)
+    kv_mask = torch.rand(batch_size, kv_len) > frac_pad
+    kv_mask[:, 0] = False  # Make sure something can send
+    return q, kv, kv_mask
 
 
-def get_test_inputs(batch_size, seq_len, dim, frac_pad=0.0):
+def get_self_attn_inputs(batch_size, seq_len, dim, frac_pad=0.0) -> tuple:
     torch.manual_seed(0)
-    q = torch.randn(batch_size, seq_len, dim)
-    k = torch.randn(batch_size, seq_len, dim)
-    v = torch.randn(batch_size, seq_len, dim)
-    kv_mask = torch.rand(batch_size, seq_len) < frac_pad
-    q[kv_mask] = 0
-    k[kv_mask] = 0
-    v[kv_mask] = 0
-    return q, k, v, kv_mask
+    x = torch.randn(batch_size, seq_len, dim)
+    mask = torch.rand(batch_size, seq_len) > frac_pad
+    mask[:, 0] = False  # Make sure something can send
+    return x, mask
 
 
 @pytest.mark.parametrize("batch_size", [1, 10])
-@pytest.mark.parametrize("seq_len", [0, 1, 2, 10])
+@pytest.mark.parametrize("q_len", [1, 10])
+@pytest.mark.parametrize("kv_len", [1, 10])
 @pytest.mark.parametrize("dim", [32])
-@pytest.mark.parametrize("num_heads", [1, 2])
-@pytest.mark.parametrize("add_zero_attn", [False, True])
-@pytest.mark.parametrize("frac_pad", [0.0, 0.5, 1.0])
-def test_attention_output(batch_size, seq_len, dim, num_heads, add_zero_attn, frac_pad):
-    salt_attn, torch_attn = get_models(dim, num_heads, add_zero_attn=add_zero_attn)
-    q, k, v, kv_mask = get_test_inputs(batch_size, seq_len, dim, frac_pad=frac_pad)
+@pytest.mark.parametrize("frac_pad", [0.0, 0.5, 0.9])
+def test_cross_attention(
+    batch_size,
+    q_len,
+    kv_len,
+    dim,
+    frac_pad,
+) -> None:
+    salt_attn, torch_attn = get_models(dim, 2)
+    q, kv, kv_mask = get_cross_attn_inputs(batch_size, q_len, kv_len, dim, frac_pad)
+    custom_output = salt_attn(q, kv, kv_mask=kv_mask)
+    torch_output, _ = torch_attn(q, kv, kv, key_padding_mask=kv_mask)
+    torch.testing.assert_close(custom_output, torch_output)
+    assert not torch.isnan(custom_output).any()
 
-    # if not adding a dummy token to attend to, ensure at least one element is not masked
-    if not add_zero_attn and kv_mask.shape[-1] != 0:
-        kv_mask[..., 0] = False
 
-    compare_attention_outputs(salt_attn, torch_attn, q, k, v, kv_mask)
+@pytest.mark.parametrize("batch_size", [1, 10])
+@pytest.mark.parametrize("seq_len", [1, 2, 10])
+@pytest.mark.parametrize("dim", [32])
+@pytest.mark.parametrize("num_heads", [1, 2])
+@pytest.mark.parametrize("frac_pad", [0.0, 0.5, 0.9])
+def test_self_attention(
+    batch_size,
+    seq_len,
+    dim,
+    num_heads,
+    frac_pad,
+) -> None:
+    salt_attn, torch_attn = get_models(dim, num_heads)
+    x, mask = get_self_attn_inputs(batch_size, seq_len, dim, frac_pad)
+    custom_output = salt_attn(x, mask=mask)
+    torch_output, _ = torch_attn(x, x, x, key_padding_mask=mask)
+    torch.testing.assert_close(custom_output, torch_output)
+    assert not torch.isnan(custom_output).any()
 
 
+@pytest.mark.parametrize("batch_size", [1, 10])
+@pytest.mark.parametrize("seq_len", [1, 2, 10])
 @pytest.mark.parametrize("dim", [32])
 @pytest.mark.parametrize("num_heads", [1, 2])
 @pytest.mark.parametrize("frac_pad", [0.0, 0.5])
+@pytest.mark.parametrize("attn_type", ["torch-flash", "torch-meff", "flash-varlen"])
+def test_attention_backends(
+    batch_size,
+    seq_len,
+    dim,
+    num_heads,
+    frac_pad,
+    attn_type,
+) -> None:
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA not available")
+    if importlib.util.find_spec("flash_attn") is None:
+        pytest.skip("flash_attn not available")
+
+    # FlashVarlenAttention requires half precision
+    with torch.autocast("cuda", enabled=True):
+        # Get the inputs and move to device
+        x, mask = get_self_attn_inputs(batch_size, seq_len, dim, frac_pad)
+        x = x.cuda()
+        mask = mask.cuda()
+
+        # Change the masking to None for the torch backends as they dont support it
+        if "torch" in attn_type:
+            mask = None
+
+        # Perform the standard attention (math)
+        attn = Attention(dim, num_heads=num_heads).to("cuda")
+        output = attn(x, mask=mask)
+
+        # ensure zero padded
+        if mask is not None:
+            output *= ~mask.unsqueeze(-1)
+
+        # Switch to the attention backend
+        attn.set_backend(attn_type)
+        if attn_type == "flash-varlen":
+            x_p, culens, maxlen = undo_padding(x, mask)
+            output_2 = attn(x_p, mask=mask, culens=culens, maxlen=maxlen)
+            output_2 = redo_padding(output_2, mask)
+        else:
+            output_2 = attn(x, mask=mask)
+
+        # Test all close with less strict due to half precision
+        torch.testing.assert_close(output, output_2, atol=1e-3, rtol=1e-3)
+        assert not torch.isnan(output_2).any()
+
+
+def sync_v1v2_attn(v1_attn, v2_attn):
+    wq, wk, wv = v2_attn.in_proj_weight.chunk(3)
+    bq, bk, bv = v2_attn.in_proj_bias.chunk(3)
+    v1_attn.linear_q.weight.data = wq
+    v1_attn.linear_k.weight.data = wk
+    v1_attn.linear_v.weight.data = wv
+    v1_attn.linear_q.bias.data = bq
+    v1_attn.linear_k.bias.data = bk
+    v1_attn.linear_v.bias.data = bv
+
+
+@pytest.mark.parametrize("dim", [32])
+@pytest.mark.parametrize("num_heads", [1, 2])
+@pytest.mark.parametrize("frac_pad", [0.0, 0.5, 0.9])
 def test_v1_v2_attention_output(dim, num_heads, frac_pad):
     v1_attn = MultiheadAttention(
         dim, num_heads, {"class_path": "salt.models.ScaledDotProductAttention"}
     )
-    v2_attn = Attention(dim, num_heads=num_heads, add_zero_attn=False)
-    v1_attn.linear_q = v2_attn.wq
-    v1_attn.linear_k = v2_attn.wk
-    v1_attn.linear_v = v2_attn.wv
-    v1_attn.linear_out = v2_attn.wo
-    q, k, v, kv_mask = get_test_inputs(10, 20, dim, frac_pad=frac_pad)
-    v1_out = v1_attn(q, k, v, kv_mask=kv_mask)
-    v2_out = v2_attn(q, k, v, kv_mask=kv_mask)
+    v2_attn = Attention(dim, num_heads=num_heads)
+    sync_v1v2_attn(v1_attn, v2_attn)
+    v1_attn.linear_out = v2_attn.out_proj
+    q, kv, kv_mask = get_cross_attn_inputs(10, 20, 20, dim, frac_pad=frac_pad)
+    v1_out = v1_attn(q, kv, kv_mask=kv_mask)
+    v2_out = v2_attn(q, kv, kv_mask=kv_mask)
     torch.testing.assert_close(v1_out, v2_out)
 
 
-@pytest.mark.parametrize("dim", [32])
-@pytest.mark.parametrize("num_heads", [1, 2])
-@pytest.mark.parametrize("frac_pad", [0])  # note that this fails for frac_pad > 0
-def test_v1_v2_encoder_output(dim, num_heads, frac_pad):
-    v1_enc = TransformerEncoderLayer(
-        dim,
-        {
-            "num_heads": num_heads,
-            "attention": {"class_path": "salt.models.ScaledDotProductAttention"},
-        },
-        {"activation": "ReLU"},
+@pytest.mark.parametrize("num_registers", [1, 4])
+@pytest.mark.parametrize("num_layers", [1, 3])
+@pytest.mark.parametrize("ls_init", [None, 0.1])
+@pytest.mark.parametrize("drop_path", [0, 0.1])
+def test_transformerv2_tensor_input(num_registers, num_layers, ls_init, drop_path):
+    x, mask = get_self_attn_inputs(5, 10, 32, 0.5)
+    trans = TransformerV2(
+        num_layers=num_layers,
+        embed_dim=32,
+        attn_type="torch-math",
+        dense_kwargs={"activation": "SiLU"},
+        attn_kwargs={"num_heads": 2},
+        num_registers=num_registers,
+        ls_init=ls_init,
+        drop_path=drop_path,
     )
-    v2_enc = EncoderLayer(
-        dim,
-        attn_kwargs={"num_heads": num_heads, "add_zero_attn": False},
-        dense_kwargs={"gated": False},
+    x, mask = trans(x, pad_mask=mask)
+    assert x.shape == (5, 10 + num_registers, 32)
+    assert not x.isnan().any()
+
+
+@pytest.mark.parametrize("ls_init", [None, 0.1])
+@pytest.mark.parametrize("drop_path", [0, 0.1])
+def test_decoder_layer(ls_init, drop_path):
+    q, kv, kv_mask = get_cross_attn_inputs(5, 10, 5, 32, 0.5)
+    decoder = DecoderLayer(
+        embed_dim=32,
+        dense_kwargs={"activation": "SiLU"},
+        attn_kwargs={"num_heads": 2},
+        ls_init=ls_init,
+        drop_path=drop_path,
     )
-
-    v1_enc.mha.linear_q = v2_enc.attn.attention.wq
-    v1_enc.mha.linear_k = v2_enc.attn.attention.wk
-    v1_enc.mha.linear_v = v2_enc.attn.attention.wv
-    v1_enc.mha.linear_out = v2_enc.attn.attention.wo
-
-    v1_enc.dense.net[0] = v2_enc.dense.in_proj
-    v1_enc.dense.net[2] = v2_enc.dense.out_proj
-    v1_enc.norm1 = v2_enc.attn_norm
-    v1_enc.norm2 = v2_enc.dense_norm
-
-    q, _, _, kv_mask = get_test_inputs(10, 20, dim, frac_pad=frac_pad)
-
-    v1_out = v1_enc(q, pad_mask=kv_mask)
-    v2_out = v2_enc(q, pad_mask=kv_mask)
-
-    torch.testing.assert_close(v1_out, v2_out)
+    x = decoder(q, kv=kv, kv_mask=kv_mask)
+    assert x.shape == q.shape
+    assert not x.isnan().any()
+
+
+@pytest.mark.parametrize("num_registers", [1, 4])
+def test_transformerv2_dict_input(num_registers):
+    x1, m1 = get_self_attn_inputs(5, 10, 32, 0.5)
+    x2, m2 = get_self_attn_inputs(5, 3, 32, 0.5)
+    x3, m3 = get_self_attn_inputs(5, 2, 32, 0.5)
+    x = {"m1": x1, "m2": x2, "m3": x3}  # Multimodal inputs
+    mask = {"m1": m1, "m2": m2, "m3": m3}
+    trans = TransformerV2(
+        num_layers=3,
+        embed_dim=32,
+        attn_type="torch-math",
+        dense_kwargs={"activation": "SiLU"},
+        attn_kwargs={"num_heads": 2},
+        num_registers=num_registers,
+    )
+    x, mask = trans(x, pad_mask=mask)
+    assert x.shape == (5, 10 + 3 + 2 + num_registers, 32)
+    assert all(k in mask for k in ["m1", "m2", "m3", "REGISTERS"])
 
 
-def test_times_torch_vs_salt():  # pragma: no cover
+def test_times_torch_vs_salt() -> None:
     # skip if cuda is not available
     if not torch.cuda.is_available():
         pytest.skip("CUDA not available")
-    batch_size, seq_len, dim, num_heads = 1000, 40, 128, 8
-    salt_attn, torch_attn = get_models(dim, num_heads, add_zero_attn=True)
-    q, k, v, kv_mask = get_test_inputs(batch_size, seq_len, dim, frac_pad=0.5)
+
+    # Define the input parameters for the timings
+    batch_size, seq_len, dim, num_heads = 1000, 64, 128, 8
+    salt_attn, torch_attn = get_models(dim, num_heads)
+    x, mask = get_self_attn_inputs(batch_size, seq_len, dim, frac_pad=0.5)
 
     # move tensors and models to cuda
-    q, k, v, kv_mask = q.cuda(), k.cuda(), v.cuda(), kv_mask.cuda()
+    x = x.cuda()
+    mask = mask.cuda()
     salt_attn.cuda()
     torch_attn.cuda()
 
@@ -229,32 +307,81 @@ def test_times_torch_vs_salt():  # pragma: no cover
     salt_attn.training = True
     torch_attn.training = True
 
-    # warm up
-    for _ in range(10):
-        salt_attn(q, k, v, kv_mask=kv_mask)
-        torch_attn(q, k, v, key_padding_mask=kv_mask)
+    # Using timers also performs warm up
+    salt_timer = Timer(
+        stmt="salt_attn(x, kv_mask=mask)",
+        globals={"salt_attn": salt_attn, "x": x, "mask": mask},
+        label="salt",
+        num_threads=1,
+    )
 
-    salt_times = []
-    for _ in range(50):
-        start = time.time()
-        salt_attn(q, k, v, kv_mask=kv_mask)
-        end = time.time()
-        salt_times.append(end - start)
+    torch_timer = Timer(
+        stmt="torch_attn(x, x, x, key_padding_mask=mask)",
+        globals={"torch_attn": torch_attn, "x": x, "mask": mask},
+        label="torch",
+        num_threads=1,
+    )
 
-    torch_times = []
-    for _ in range(50):
-        start = time.time()
-        torch_attn(q, k, v, key_padding_mask=kv_mask)
-        end = time.time()
-        torch_times.append(end - start)
+    salt_time = salt_timer.timeit(300).mean
+    torch_time = torch_timer.timeit(300).mean
+    assert salt_time < torch_time, f"mean: {salt_time} vs {torch_time}"
 
-    salt_mean = sum(salt_times) / len(salt_times)
-    torch_mean = sum(torch_times) / len(torch_times)
-    salt_median = sorted(salt_times)[len(salt_times) // 2]
-    torch_median = sorted(torch_times)[len(torch_times) // 2]
 
-    assert salt_mean < torch_mean, f"mean: {salt_mean} vs {torch_mean}"
-    assert salt_median < torch_median, f"median: {salt_median} vs {torch_median}"
+def test_times_varlen_vs_default() -> None:
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA not available")
+    if importlib.util.find_spec("flash_attn") is None:
+        pytest.skip("flash_attn not available")
+
+    # FlashVarlenAttention requires half precision
+    with torch.autocast("cuda", enabled=True):
+        # Define the input parameters for the timings
+        num_layers = 4
+        num_heads = 4
+        batch_size = 256
+        seq_len = 64
+        dim = 128
+        x, mask = get_self_attn_inputs(batch_size, seq_len, dim, frac_pad=0.5)
+
+        # Create the transformers
+        standard_attn = TransformerV2(
+            num_layers=num_layers,
+            embed_dim=dim,
+            attn_type="torch-math",
+            dense_kwargs={"activation": "SiLU"},
+            attn_kwargs={"num_heads": num_heads},
+        )
+
+        varlen_attn = TransformerV2(
+            num_layers=num_layers,
+            embed_dim=dim,
+            attn_type="flash-varlen",
+            dense_kwargs={"activation": "SiLU"},
+            attn_kwargs={"num_heads": num_heads},
+        )
+
+        # move tensors and models to cuda
+        x = x.cuda()
+        mask = mask.cuda()
+        standard_attn.cuda()
+        varlen_attn.cuda()
+
+        # Time the models
+        s_timer = Timer(
+            stmt="standard_attn(x, pad_mask=mask)",
+            globals={"standard_attn": standard_attn, "x": x, "mask": mask},
+            label="salt",
+            num_threads=1,
+        )
+        v_timer = Timer(
+            stmt="varlen_attn(x, pad_mask=mask)",
+            globals={"varlen_attn": varlen_attn, "x": x, "mask": mask},
+            label="salt",
+            num_threads=1,
+        )
+        st = s_timer.timeit(20).mean
+        vt = v_timer.timeit(20).mean
+        assert vt < st, f"mean: {vt} vs {st}"
 
 
 def test_RMSNorm():
@@ -266,5 +393,5 @@ def test_RMSNorm():
 def test_DecoderLayer():
     layer = DecoderLayer(embed_dim=32, attn_kwargs={"num_heads": 2})
     x = torch.randn(5, 10, 32)
-    y = torch.randn(5, 10, 32)
-    layer(x, y, pad_mask=None)
+    kv = torch.randn(5, 10, 32)
+    layer(x, kv=kv)
diff --git a/salt/to_onnx.py b/salt/to_onnx.py
index 78b0170c97c401f57d7819347f5678a039fc4be1..a65522833bd279699085433c3ed2e60c2436bf91 100644
--- a/salt/to_onnx.py
+++ b/salt/to_onnx.py
@@ -14,9 +14,12 @@ from torch.nn.functional import softmax
 from tqdm import tqdm
 
 from salt.models.task import mask_fill_flattened
+from salt.models.transformer_v2 import change_attn_backends
 from salt.modelwrapper import ModelWrapper
+from salt.utils.configs import MaskformerConfig
 from salt.utils.inputs import inputs_sep_no_pad, inputs_sep_with_pad
-from salt.utils.union_find import get_node_assignment
+from salt.utils.mask_utils import indices_from_mask
+from salt.utils.union_find import get_node_assignment_jit
 
 torch.manual_seed(42)
 # https://gitlab.cern.ch/atlas/athena/-/blob/master/PhysicsAnalysis/JetTagging/FlavorTagDiscriminants/Root/DataPrepUtilities.cxx
@@ -77,6 +80,10 @@ def parse_args(args):
         help="Include auxiliary task outputs (if available)",
         action="store_true",
     )
+    parser.add_argument(
+        "-mf",
+        "--object_name",
+    )
     parser.add_argument(
         "-f",
         "--force",
@@ -92,8 +99,70 @@ def get_probs(outputs: Tensor):
     return tuple(output.squeeze() for output in torch.split(outputs, 1, -1))
 
 
+def get_maskformer_outputs(objects):
+    # Convert the (N,M) -> (M,) mask indices
+    masks = objects["masks"]
+    class_probs = objects["class_probs"]
+    regression = objects["regression"]
+    object_leading = objects["regression"]
+    n_tracks = masks.shape[-1]
+    n_obj = masks.shape[1]
+    n_reg = regression.shape[-1]
+
+    # If we have a jet with no tracks,
+    if n_tracks == 0:
+        return (
+            torch.ones((1, n_obj)) * torch.nan,
+            None,
+            class_probs,
+            torch.ones((1, n_obj, n_reg)) * torch.nan,
+        )
+    # For testing purposes - this will likely blow up our fake rate
+    null_preds = class_probs[:, :, -1] > 0.9
+    if not null_preds.any():
+        # If we have no predicted objects, we return dummy values
+        return (
+            torch.ones((1, n_obj)) * torch.nan,
+            torch.zeros((1, n_obj, n_tracks), dtype=torch.bool),
+            class_probs,
+            torch.ones((1, n_obj, n_reg)) * torch.nan,
+        )
+
+    masks = masks.sigmoid() > 0.5
+    object_leading[null_preds] = -999
+    regression[null_preds] = np.nan
+
+    # Define the leading object as the one with the highest regression[0] value
+    # in vertexing case, this is the pT
+    order = torch.argsort(object_leading[:, :, 0], descending=True)
+    order_expanded = order.unsqueeze(-1).expand(-1, -1, masks.size(-1))
+
+    # Use gather to reorder tensors along a specific dimension
+    masks = torch.gather(masks, 1, order_expanded)
+    class_probs = torch.gather(
+        class_probs, 1, order.unsqueeze(-1).expand(-1, -1, class_probs.size(-1))
+    )
+    regression = torch.gather(
+        regression, 1, order.unsqueeze(-1).expand(-1, -1, regression.size(-1))
+    )
+    # Define the leading object as that with the highest [0] (pt for vertexing)
+    leading_regression = regression[:, 0]
+
+    # Convert our masks (N,M), now in pT order, to be (M,) indices
+    obj_indices = indices_from_mask(masks)
+
+    return leading_regression, obj_indices, class_probs, regression
+
+
 class ONNXModel(ModelWrapper):
-    def __init__(self, name: str | None = None, include_aux: bool = False, **kwargs) -> None:
+    def __init__(
+        self,
+        name: str | None = None,
+        include_aux: bool = False,
+        object_name: str | None = None,
+        mf_config: dict | None = None,
+        **kwargs,
+    ) -> None:
         super().__init__(**kwargs)
         self.name = name if name else self.name
         assert len(self.model.init_nets) == 1, "Multi input ONNX models are not yet supported."
@@ -101,10 +170,23 @@ class ONNXModel(ModelWrapper):
         assert "-" not in self.name, "Model name cannot contain dashes."
         self.include_aux = include_aux
         self.const = "tracks"
+        if sum([bool(object_name), bool(mf_config)]) not in {0, 2}:
+            raise ValueError("If one of object name or mf config is defined, so must the other.")
+        self.object = object_name
+        self.mf_config = MaskformerConfig(**mf_config) if mf_config else None
+        if self.object and self.mf_config:
+            self.object_params = {
+                "class_label": self.mf_config.object.class_label,
+                "label_map": [f"p{name}" for name in self.mf_config.object.class_names],
+            }
+            print("OBJECT PARAMS", self.object_params)
         self.input_names = ["jet_features", "track_features"]
         jets, tracks = inputs_sep_no_pad(
             1, 40, self.input_dims[self.global_object], self.input_dims[self.const]
         )
+        self.has_global_task = (
+            len([t for t in self.model.tasks if t.input_name == self.global_object]) > 0
+        )
         self.example_input_array = jets, tracks.squeeze(0)  # used for the tracing during export
 
     @property
@@ -117,10 +199,12 @@ class ONNXModel(ModelWrapper):
         """The output names are a list of strings, one for each output of the model."""
         # get the global task output
         global_tasks = [t for t in self.model.tasks if t.input_name == self.global_object]
-        assert len(global_tasks) == 1, "Multi global task ONNX models are not yet supported."
-        object_classes = global_tasks[0].class_names
-        outputs = [f"{self.model_name}_p{flav.rstrip('jets')}" for flav in object_classes]
-
+        assert len(global_tasks) <= 1, "Multi global task ONNX models are not yet supported."
+        if self.has_global_task:
+            object_classes = global_tasks[0].class_names
+            outputs = [f"{self.model_name}_p{flav.rstrip('jets')}" for flav in object_classes]
+        else:
+            outputs = []
         # aux task output names
         if self.include_aux:
             if "track_origin" in [t.name for t in self.model.tasks]:
@@ -130,6 +214,16 @@ class ONNXModel(ModelWrapper):
             if "track_vertexing" in [t.name for t in self.model.tasks]:
                 out_name = f"{self.model_name}_VertexIndex"
                 outputs.append(out_name)
+        if self.object:
+            regression_task = [
+                t for t in self.model.tasks if t.input_name == "objects" and t.name == "regression"
+            ]
+            assert len(regression_task) == 1, "Object outputs require a regression task"
+            # First we append the leading jet regression variables
+            outputs += [
+                f"{self.model_name}_leading_{self.object}_{v}" for v in regression_task[0].targets
+            ]
+            outputs += [f"{self.model_name}_{self.object}Index"]
 
         return outputs
 
@@ -147,6 +241,9 @@ class ONNXModel(ModelWrapper):
             if "track_vertexing" in [t.name for t in self.model.tasks]:
                 out_name = f"{self.model_name}_VertexIndex"
                 dynamic_axes[out_name] = {0: "n_tracks"}
+        if self.object:
+            out_name = f"{self.model_name}_{self.object}"
+            dynamic_axes[out_name] = {0: "n_tracks"}
         return dynamic_axes
 
     def forward(self, jets: Tensor, tracks: Tensor, labels=None):  # type: ignore[override]
@@ -158,9 +255,10 @@ class ONNXModel(ModelWrapper):
         # forward pass
         outputs = super().forward({self.global_object: jets, self.const: tracks}, None)[0]
 
-        # get class probabilities
-        onnx_outputs = get_probs(
-            outputs[self.global_object][f"{self.global_object}_classification"]
+        onnx_outputs = (
+            get_probs(outputs[self.global_object][f"{self.global_object}_classification"])
+            if self.has_global_task
+            else ()
         )
 
         # add aux outputs
@@ -174,10 +272,36 @@ class ONNXModel(ModelWrapper):
             if "track_vertexing" in track_outs:
                 pad_mask = torch.zeros(tracks.shape[:-1], dtype=torch.bool)
                 edge_scores = track_outs["track_vertexing"]
-                vertex_indices = get_node_assignment(edge_scores, pad_mask)
+                vertex_indices = get_node_assignment_jit(edge_scores, pad_mask)
                 vertex_list = mask_fill_flattened(vertex_indices, pad_mask)
                 onnx_outputs += (vertex_list.reshape(-1).char(),)
 
+        if self.object:
+            assert "objects" in outputs, "No MF objects in outputs"
+            regression_tasks = [
+                t for t in self.model.tasks if t.input_name == "objects" and t.name == "regression"
+            ]
+            assert len(regression_tasks) == 1, "Object outputs require a regression task"
+            regression_task = regression_tasks[0]
+
+            # Get the (hopefully) correctly (un)scaled regression predictions
+            for i, t in enumerate(regression_task.targets):
+                unscaled_preds = regression_task.scaler.inverse(
+                    t, outputs["objects"]["regression"][:, :, i]
+                )
+                outputs["objects"]["regression"][:, :, i] = unscaled_preds
+
+            # Extract the mf outputs.
+            # TODO: write all regression values, this will require work on the athena end as well
+            # https://gitlab.cern.ch/atlas-flavor-tagging-tools/algorithms/salt/-/issues/53
+            leading_reg, indices, class_probs, regression = get_maskformer_outputs(  # noqa: F841
+                outputs["objects"]
+            )
+
+            for r in leading_reg[0]:
+                onnx_outputs += (r,)
+            onnx_outputs += (indices.reshape(-1).char(),)
+
         return onnx_outputs
 
 
@@ -190,7 +314,11 @@ def compare_output(pt_model, onnx_session, include_aux, n_track=40):
 
     inputs_pt = {"jets": jets, "tracks": tracks}
     outputs_pt = pt_model(inputs_pt, {"tracks": pad_mask})[0]
-    pred_pt_jc = [p.detach().numpy() for p in get_probs(outputs_pt["jets"]["jets_classification"])]
+    pred_pt_jc = (
+        [p.detach().numpy() for p in get_probs(outputs_pt["jets"]["jets_classification"])]
+        if "jets" in outputs_pt
+        else []
+    )
 
     inputs_onnx = {
         "jet_features": jets.numpy(),
@@ -229,9 +357,9 @@ def compare_output(pt_model, onnx_session, include_aux, n_track=40):
         )
 
     # test vertexing
-    if include_aux:
+    if include_aux and "track_vertexing" in outputs_pt["tracks"]:
         pred_pt_scores = outputs_pt["tracks"]["track_vertexing"].detach()
-        pred_pt_indices = get_node_assignment(pred_pt_scores, pad_mask)
+        pred_pt_indices = get_node_assignment_jit(pred_pt_scores, pad_mask)
         pred_pt_vtx = mask_fill_flattened(pred_pt_indices, pad_mask)
 
         pred_onnx_vtx = outputs_onnx[-1]
@@ -276,23 +404,43 @@ def main(args=None):
         config_path = args.ckpt_path.parents[1] / "config.yaml"
         assert config_path.is_file(), f"Could not find config file at {config_path}"
 
+    with open(config_path) as f:
+        config = yaml.safe_load(f)
+
     # instantiate pytorch and wrapper models
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
 
         pt_model = ModelWrapper.load_from_checkpoint(
-            args.ckpt_path, map_location=torch.device("cpu")
+            args.ckpt_path,
+            map_location=torch.device("cpu"),
+            norm_config=config["model"]["norm_config"],
         )
         pt_model.eval()
         pt_model.float()
 
+        if args.object_name:
+            with open(config_path) as f:
+                config = yaml.safe_load(f)
+            mf_config = config["data"].get("mf_config")
+            if not mf_config:
+                raise ValueError("No mf_config in config")
+        else:
+            mf_config = {}
         onnx_model = ONNXModel.load_from_checkpoint(
             args.ckpt_path,
             name=args.name,
             include_aux=args.include_aux,
+            object_name=args.object_name,
+            mf_config=mf_config,
             map_location=torch.device("cpu"),
+            norm_config=config["model"]["norm_config"],
         )
+        print("OUTPUTS", onnx_model.output_names)
         onnx_model.eval()
+        change_attn_backends(
+            onnx_model.model, "torch-math"
+        )  # Only applies to transformer_v2 layers
 
     print("\n" + "-" * 100)
     print("Converting model to ONNX...")
@@ -376,6 +524,7 @@ def add_metadata(
 
     # write metadata as json string
     metadata = {"gnn_config": json.dumps(metadata)}
+
     for k, v in metadata.items():
         meta = onnx_model.metadata_props.add()
         meta.key = k
diff --git a/salt/utils/benchmarking.py b/salt/utils/benchmarking.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e0ae07d2fece210e59a004e45a5256104a1afa7
--- /dev/null
+++ b/salt/utils/benchmarking.py
@@ -0,0 +1,171 @@
+"""Benchmarking utilities for Pytorch models."""
+
+from collections.abc import Callable
+
+import torch
+from torch import Tensor, dtype
+from torch.utils import benchmark
+
+
+def time_forward(
+    fn: Callable,
+    *args,
+    repeats: int = 10,
+    block_time: float = 0.0,
+    desc: str = "",
+    verbose: bool = True,
+    amp: bool = False,
+    amp_dtype: dtype = torch.float16,
+    **kwargs,
+) -> tuple:
+    """Use Pytorch Benchmark on the forward pass of an arbitrary function.
+
+    Parameters
+    ----------
+    fn : function
+        The function to benchmark.
+    args : list
+        The args to the function.
+    repeats : int
+        Number of times to repeat the benchmark.
+    block_time : float
+        Instead of repeats, run the benchmark for a fixed amount of time.
+    desc : str
+        Description of the benchmark.
+    verbose : bool
+        Whether to print the benchmark results.
+    amp : bool
+        Whether to use automatic mixed precision.
+    amp_dtype : torch.dtype
+        The dtype to use for automatic mixed precision.
+    kwargs : dict
+        Additional keyword arguments to pass to the function.
+    """
+    if verbose:
+        print(desc, " - Foward pass")
+
+    # Define the automatic mixed precision wrapper
+    def fn_with_amp(*args, **kwargs):
+        with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
+            fn(*args, **kwargs)
+
+    # Create the benchmark timer
+    t = benchmark.Timer(
+        stmt="fn_with_amp(*args, **kwargs)",
+        globals={"fn_with_amp": fn_with_amp, "args": args, "kwargs": kwargs},
+        num_threads=torch.get_num_threads(),
+    )
+
+    # Run the benchmark
+    m = t.blocked_autorange(min_run_time=block_time) if block_time > 0 else t.timeit(repeats)
+
+    if verbose:
+        print(m)
+    return t, m
+
+
+def time_backward(
+    fn: Callable,
+    *args,
+    repeats: int = 10,
+    block_time: float = 0.0,
+    desc: str = "",
+    verbose: bool = True,
+    amp: bool = False,
+    amp_dtype: dtype = torch.float16,
+    **kwargs,
+) -> tuple:
+    """Use Pytorch Benchmark on the backward pass of an arbitrary function.
+
+    Parameters
+    ----------
+    fn : function
+        The function to benchmark.
+    args : list
+        The args to the function.
+    repeats : int
+        Number of times to repeat the benchmark.
+    block_time : float
+        Instead of repeats, run the benchmark for a fixed amount of time.
+    desc : str
+        Description of the benchmark.
+    verbose : bool
+        Whether to print the benchmark results.
+    amp : bool
+        Whether to use automatic mixed precision.
+    amp_dtype : torch.dtype
+        The dtype to use for automatic mixed precision.
+    kwargs : dict
+        Additional keyword arguments to pass to the function.
+    """
+    if verbose:
+        print(desc, " - Backward pass")
+
+    # Run in forward to get the output so we can backpropagate
+    with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
+        y = fn(*args, **kwargs)
+        if type(y) is tuple:
+            y = y[0]
+        elif type(y) is dict:
+            y = next(iter(y.values()))
+
+    # Generate a random gradient
+    grad = torch.randn_like(y)
+
+    # Define the backward function
+    def bwd(*args, y, grad):
+        for x in args:  # Turn off gradients for all args
+            if isinstance(x, Tensor):
+                x.grad = None
+        y.backward(grad, retain_graph=True)
+
+    # Create the benchmark timer
+    t = benchmark.Timer(
+        stmt="f(*args, y=y, grad=grad)",
+        globals={"f": bwd, "args": args, "y": y, "grad": grad},
+        num_threads=torch.get_num_threads(),
+    )
+
+    # Run the benchmark
+    m = t.blocked_autorange(min_run_time=block_time) if block_time > 0 else t.timeit(repeats)
+    if verbose:
+        print(m)
+    return t, m
+
+
+def benchmark_gpu_memory(
+    fn: Callable,
+    *args,
+    amp: bool = False,
+    amp_dtype: dtype = torch.float16,
+    **kwargs,
+) -> tuple:
+    """Calculate the maximum GPU memory used by a function.
+
+    Parameters
+    ----------
+    fn : function
+        The function to benchmark.
+    args : list
+        The args to the function.
+    amp : bool
+        Whether to use automatic mixed precision.
+    amp_dtype : torch.dtype
+        The dtype to use for automatic mixed precision.
+    kwargs : dict
+        Additional keyword arguments to pass to the function.
+    """
+    # Clear the cache and reset memory stats
+    torch.cuda.empty_cache()
+    torch.cuda.reset_peak_memory_stats()
+    torch.cuda.synchronize()
+
+    # Run the function
+    with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
+        fn(*args, **kwargs)
+    torch.cuda.synchronize()
+
+    # Calculate the max memory used in GB
+    mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
+    torch.cuda.empty_cache()
+    return mem
diff --git a/salt/utils/cli.py b/salt/utils/cli.py
index 6853dfb9fdecbb9b76ec9c70f2b7944183b01bde..ab711db96726723941510963d3cd7a7224bef895 100644
--- a/salt/utils/cli.py
+++ b/salt/utils/cli.py
@@ -66,6 +66,16 @@ class SaltCLI(LightningCLI):
         parser.add_argument(
             "--compile", action="store_true", help="Compile the model to speed up training."
         )
+        parser.add_argument(
+            "-oc", "--overwrite_config", action="store_true", help="Overwrite config file."
+        )
+        parser.add_argument(
+            "-ls",
+            "--log_suffix",
+            default=None,
+            type=str,
+            help="Appended to model name to create the log directory.",
+        )
         self.apply_links(parser)
 
     def fit(self, model, **kwargs):
@@ -178,6 +188,7 @@ class SaltCLI(LightningCLI):
                 input_name = task["init_args"]["input_name"]
                 if task["init_args"]["label"] in class_dict[input_name]:
                     class_weights = class_dict[input_name][task["init_args"]["label"]]
+                    class_weights = torch.Tensor(class_weights)
                     task["init_args"]["loss"]["init_args"]["weight"] = class_weights
                 else:
                     raise ValueError(
@@ -203,7 +214,11 @@ class SaltCLI(LightningCLI):
                 pass
 
             # set the timestampped dir
-            dirname = f"{name}_{timestamp}"
+            if sc["log_suffix"]:
+                log_suffix = sc["log_suffix"]
+                dirname = f"{name}_{log_suffix}"
+            else:
+                dirname = f"{name}_{timestamp}"
             if "s3:/" not in sc["trainer.default_root_dir"]:
                 log_dir_timestamp = str(Path(log_dir / dirname).resolve())
             else:
@@ -225,6 +240,9 @@ class SaltCLI(LightningCLI):
                         "automated salt tag",
                     )
 
+            if sc["overwrite_config"]:
+                self.save_config_kwargs["overwrite"] = True
+
         if self.subcommand == "test":
             print("\n" + "-" * 100)
 
diff --git a/salt/utils/configs.py b/salt/utils/configs.py
index 5580bfc81049b2c870ce09db0464803f44f0ca80..43e073001cee446b4716cab4255b16a70cfe2c0a 100644
--- a/salt/utils/configs.py
+++ b/salt/utils/configs.py
@@ -85,3 +85,9 @@ class MaskformerConfig:
 
     object: MaskformerObjectConfig
     constituent: MaskformerObjectConfig
+
+    def __post_init__(self):
+        if isinstance(self.object, dict):
+            self.object = MaskformerObjectConfig(**self.object)
+        if isinstance(self.constituent, dict):
+            self.constituent = MaskformerObjectConfig(**self.constituent)
diff --git a/salt/utils/inputs.py b/salt/utils/inputs.py
index 1125c67ab678a54ba004d3e343a592e60f121345..15073e8170c39ba89e7024d977a0b7a0af0f0161 100644
--- a/salt/utils/inputs.py
+++ b/salt/utils/inputs.py
@@ -76,6 +76,7 @@ TRACK_VARS = [
     "eta",
     "phi",
     "subjetIndex",
+    "leptonID",
 ]
 
 ELECTRON_VARS = [
@@ -151,8 +152,13 @@ def write_dummy_norm_dict(nd_path: Path, cd_path: Path):
     sd["flow"] = {n: {"std": 1.0, "mean": 1.0} for n in TRACK_VARS}
     with open(nd_path, "w") as file:
         yaml.dump(sd, file, sort_keys=False)
+
+    cd: dict = {}
+    cd["jets"] = {"HadronConeExclTruthLabelID": [1.0, 2.0, 2.0, 2.0]}
+    cd["jets"]["flavour_label"] = cd["jets"]["HadronConeExclTruthLabelID"]
+    cd["tracks"] = {"ftagTruthOriginLabel": [4.2, 73.7, 1.0, 17.5, 12.3, 12.5, 141.7, 22.3]}
     with open(cd_path, "w") as file:
-        yaml.dump(sd, file, sort_keys=False)
+        yaml.dump(cd, file, sort_keys=False)
 
 
 def get_dummy_inputs(n_jets=1000, n_jet_features=2, n_track_features=21, n_tracks_per_jet=40):
diff --git a/salt/utils/mask_utils.py b/salt/utils/mask_utils.py
index beb197d67199c83fa6815730d8e0223444c6edb7..82e2b125804a52e91a144901eda8385dad81976f 100644
--- a/salt/utils/mask_utils.py
+++ b/salt/utils/mask_utils.py
@@ -15,6 +15,8 @@ def build_target_masks(object_ids, input_ids, shuffle=False):
         The unqiue ids of the truth object labels
     input_ids : Tensor
         The ids of the per-input labels
+    shuffle: bool
+        Shuffle object ids
 
     Returns
     -------
@@ -74,26 +76,24 @@ def mask_from_indices(indices: Tensor, num_masks: int | None = None) -> BoolTens
     return mask
 
 
-def indices_from_mask(mask: BoolTensor, noindex: int = -1) -> Tensor:
-    """Convert a sparse bool mask to a dense index tensor.
-
-    Indices are arbitrary and start from 0.
+def indices_from_mask(mask: BoolTensor, noindex: int = -1):
+    """Converts a spares bool mask to a dense index tensor, where any
+    index NOT part of a mask is given an increasing index value.
 
     Examples
     --------
-    [[True, False, False], [False, True, True]] -> [0, 1, 1]
+    [
+        [True, True, False, False, False, False],
+        [False, False, True, False, False, True]
+    ] -> [0, 0, 1, 2, 3, 1]
 
     Parameters
     ----------
     mask : BoolTensor
         The sparse mask
     noindex : int
-        The value to use for no index
+        The value to insert for padding in the mask
 
-    Returns
-    -------
-    Tensor
-        The dense indices
     """
     mask = torch.as_tensor(mask)
     kwargs = {"dtype": torch.long, "device": mask.device}
@@ -101,18 +101,33 @@ def indices_from_mask(mask: BoolTensor, noindex: int = -1) -> Tensor:
         indices = torch.ones(mask.shape[-1], **kwargs) * noindex
         nonzero_idx = torch.where(mask)
         indices[nonzero_idx[1]] = nonzero_idx[0]
-    elif mask.ndim == 3:
-        indices = torch.ones((mask.shape[0], mask.shape[-1]), **kwargs) * noindex
-        nonzero_idx = torch.where(mask)
-        indices[nonzero_idx[0], nonzero_idx[2]] = nonzero_idx[1]
-    else:
-        raise ValueError("mask must be 2D for single sample or 3D for batch")
-
-    # ensure indices start from 0
-    indices -= indices[indices >= 0].min()
-    indices[indices < 0] = noindex
-
-    return indices
+        # The idx of all indices that are part of a mask
+        if mask.shape[-1] == 0:
+            return torch.arange(mask.shape[-1], **kwargs)
+
+        idx_exist = indices >= 0
+        if idx_exist.any():
+            min_val = torch.min(indices[idx_exist]).item()
+            indices[idx_exist] = indices[idx_exist] - min_val
+            max_val = torch.max(indices[idx_exist]).item()
+        else:
+            min_val = 0  # Default value if the tensor is empty
+            max_val = 0
+
+        neg_ind = torch.where(indices < 0)[0]
+        if len(neg_ind) == 0:
+            return indices
+        replacement_vals = torch.arange(max_val + 1, max_val + 1 + neg_ind.shape[0])
+        indices[neg_ind] = replacement_vals
+        return indices
+    if mask.ndim == 3:
+        # Not a fan, but CBA to do this properly for now as its only used
+        # by the onnx model, so speed isn't an issue
+        indices = torch.full((mask.shape[0], mask.shape[-1]), noindex, **kwargs)
+        for i in range(mask.shape[0]):
+            indices[i] = indices_from_mask(mask[i])
+        return indices
+    raise ValueError("mask must be 2D for single sample or 3D for batch")
 
 
 def sanitise_mask(
diff --git a/salt/utils/tensor_utils.py b/salt/utils/tensor_utils.py
index 10fcafb2ba415739ab10cfe2beba58d41725261d..288ac888032f812b9259d2b7a7a11b626bca07e3 100644
--- a/salt/utils/tensor_utils.py
+++ b/salt/utils/tensor_utils.py
@@ -1,6 +1,6 @@
 import torch
 from torch import BoolTensor, Tensor
-from torch.nn.functional import softmax
+from torch.nn.functional import pad, softmax
 
 from salt.stypes import Tensors
 
@@ -22,11 +22,14 @@ def flatten_tensor_dict(
 
     Parameters
     ----------
-        x: Dictionary of tensors to flatten.
-        include: List of keys defining the tensors to be concatenated. If None, all tensors will be
-            concatenated unless defined by 'exclude'. Cannot be used with 'exclude'.
-        exclude: List of keys to exclude from the concatenation. If None, all tensors will be
-            concatenated unless defined by 'include'. Cannot be used with 'include'.
+    x: dict[str, Tensor]
+        Dictionary of tensors to flatten.
+    include: list[str] | None, optional
+        List of keys defining the tensors to be concatenated. If None, all tensors will be
+        concatenated unless defined by 'exclude'. Cannot be used with 'exclude'.
+    exclude: list[str] | None, optional
+        List of keys to exclude from the concatenation. If None, all tensors will be
+        concatenated unless defined by 'include'. Cannot be used with 'include'.
 
     Returns
     -------
@@ -56,6 +59,24 @@ def masked_softmax(x: Tensor, mask: BoolTensor, dim: int = -1) -> Tensor:
     return x
 
 
+def undo_padding(seq: Tensor, mask: BoolTensor) -> tuple:
+    """Remove all padded elements from a tensor and return the sequence lengths."""
+    mask = ~mask  # convert the mask such that True is a valid token
+    seqlens = mask.sum(dim=-1)
+    maxlen = seqlens.max().item()
+    culens = pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
+    return seq[mask], culens, maxlen
+
+
+def redo_padding(unpadded_seq: Tensor, mask: BoolTensor) -> Tensor:
+    """Redo the padding and return a zero-padded tensor."""
+    mask = ~mask  # convert the mask such that True is a valid token
+    shape = (*mask.shape, unpadded_seq.shape[-1])
+    out = torch.zeros(shape, dtype=unpadded_seq.dtype, device=unpadded_seq.device)
+    out[mask] = unpadded_seq
+    return out
+
+
 def add_dims(x: Tensor, ndim: int):
     """Adds dimensions to a tensor to match the shape of another tensor."""
     if (dim_diff := ndim - x.dim()) < 0:
diff --git a/salt/utils/union_find.py b/salt/utils/union_find.py
index 8522b612b9dd595ed967c5bdf853f4389ce211a3..7a11676495744a1562025524e421a8ecdfec5b74 100644
--- a/salt/utils/union_find.py
+++ b/salt/utils/union_find.py
@@ -2,7 +2,6 @@ import torch
 from torch import Tensor
 
 
-@torch.jit.script
 def symmetrize_edge_scores(scores: Tensor, node_numbers: Tensor):
     """Function to make edge scores symmetric.
 
@@ -36,7 +35,6 @@ def symmetrize_edge_scores(scores: Tensor, node_numbers: Tensor):
     return torch.sigmoid(edge_scores.float())
 
 
-@torch.jit.script
 def update_node_indices(
     scores: Tensor, node_indices: Tensor, update_indices: Tensor, node_numbers: Tensor
 ):
@@ -92,7 +90,6 @@ def update_node_indices(
     return node_indices, update_indices
 
 
-@torch.jit.script
 def get_node_assignment(output: Tensor, mask: Tensor):
     """Run edge score symmetrization and union find.
 
@@ -117,3 +114,8 @@ def get_node_assignment(output: Tensor, mask: Tensor):
         )
 
     return node_indices.unsqueeze(-1)
+
+
+@torch.jit.script
+def get_node_assignment_jit(output: Tensor, mask: Tensor):
+    return get_node_assignment(output, mask)
diff --git a/setup/Dockerfile b/setup/Dockerfile
index 9a8cc07f686003cde1a7837952dfa79035203451..963f303d4c711db7abdd4ff0e1a4563ee6655b43 100644
--- a/setup/Dockerfile
+++ b/setup/Dockerfile
@@ -1,5 +1,5 @@
 # base image
-FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-runtime
+FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-devel
 
 # local and envs
 ENV LANG C.UTF-8
@@ -24,6 +24,10 @@ RUN python -m pip install -r requirements.txt
 # add some other packages to the image, instead of as a package dependency
 RUN python -m pip install puma-hep umami-preprocessing
 
+# Flash attention sometimes has issues in a requirements file
+RUN python -m pip install wheel packaging ninja
+RUN python -m pip install flash-attn==2.5.7
+
 # copy and install package
 COPY . .
 RUN python -m pip install -e .
diff --git a/setup/install_flash.sh b/setup/install_flash.sh
new file mode 100644
index 0000000000000000000000000000000000000000..71141a08ade0ff5a52739e78903e53ff2596b7e0
--- /dev/null
+++ b/setup/install_flash.sh
@@ -0,0 +1,5 @@
+# To optionally install the flash-attn package for transformer-v2 models
+python -m pip install wheel packaging
+python -m pip install ninja==1.11.1.1
+python -m pip install flash-attn==2.5.7
+