From 5c0a329418c7a00007ddd99f6dc4ac46bf2c7422 Mon Sep 17 00:00:00 2001
From: Sam Van Stroud <sam.van.stroud@cern.ch>
Date: Wed, 16 Aug 2023 11:03:29 +0100
Subject: [PATCH 01/10] fix in_dims

---
 salt/lightning_module.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/salt/lightning_module.py b/salt/lightning_module.py
index 6fa59160..cae8c669 100644
--- a/salt/lightning_module.py
+++ b/salt/lightning_module.py
@@ -28,8 +28,7 @@ class LightningTagger(L.LightningModule):
         self.model = model
         self.lrs_config = lrs_config
         self.name = name
-
-        self.in_dims = [list(net.parameters())[0].shape[1] for net in self.model.init_nets]
+        self.in_dims = [self.model.init_nets[0].net.input_size for net in self.model.init_nets]
 
     def forward(self, x, mask, labels=None):
         """Forward pass through the model.
-- 
GitLab


From 8f4632a606af138735fad01fe3dfe49626e1a9d0 Mon Sep 17 00:00:00 2001
From: Sam Van Stroud <sam.van.stroud@cern.ch>
Date: Wed, 16 Aug 2023 11:03:44 +0100
Subject: [PATCH 02/10] add short args

---
 salt/to_onnx.py   | 8 +++++++-
 salt/utils/cli.py | 8 ++++++--
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/salt/to_onnx.py b/salt/to_onnx.py
index a9469fcb..588c99c6 100644
--- a/salt/to_onnx.py
+++ b/salt/to_onnx.py
@@ -72,11 +72,17 @@ def parse_args(args):
         action="store_true",
     )
     parser.add_argument(
+        "-a",
         "--include_aux",
         help="Include auxiliary task outputs (if available)",
         action="store_true",
     )
-    parser.add_argument("--force", help="Run with uncomitted changes.", action="store_true")
+    parser.add_argument(
+        "-f",
+        "--force",
+        help="Run with uncomitted changes.",
+        action="store_true",
+    )
 
     return parser.parse_args(args)
 
diff --git a/salt/utils/cli.py b/salt/utils/cli.py
index 02ad08f4..808c63c8 100644
--- a/salt/utils/cli.py
+++ b/salt/utils/cli.py
@@ -53,8 +53,12 @@ class SaltCLI(LightningCLI):
         parser.link_arguments("name", "trainer.logger.init_args.experiment_name")
         parser.link_arguments("name", "model.name")
         parser.link_arguments("trainer.default_root_dir", "trainer.logger.init_args.save_dir")
-        parser.add_argument("--force", action="store_true", help="Run with uncomitted changes.")
-        parser.add_argument("--tag", action="store_true", help="Push a tag for the current code.")
+        parser.add_argument(
+            "-f", "--force", action="store_true", help="Run with uncomitted changes."
+        )
+        parser.add_argument(
+            "-t", "--tag", action="store_true", help="Push a tag for the current code."
+        )
 
     def before_instantiate_classes(self) -> None:
         sc = self.config[self.subcommand]
-- 
GitLab


From 78a03bf4aa96ae9de9809c643b5796fe0c693116 Mon Sep 17 00:00:00 2001
From: Sam Van Stroud <sam.van.stroud@cern.ch>
Date: Wed, 16 Aug 2023 11:04:10 +0100
Subject: [PATCH 03/10] normalise onnx output names

---
 salt/to_onnx.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/salt/to_onnx.py b/salt/to_onnx.py
index 588c99c6..d3755083 100644
--- a/salt/to_onnx.py
+++ b/salt/to_onnx.py
@@ -199,6 +199,7 @@ def main(args=None):
 
     config = yaml.safe_load(config_path.read_text())
     model_name = args.name if args.name else config["name"]
+    model_name = model_name.replace("-", "_")  # dashes not allowed for AuxVars
 
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
-- 
GitLab


From d99a31fe964824ae8ea36363aec0d048a83528dd Mon Sep 17 00:00:00 2001
From: Sam Van Stroud <sam.van.stroud@cern.ch>
Date: Wed, 16 Aug 2023 11:04:24 +0100
Subject: [PATCH 04/10] don't tweak layernorm

---
 salt/models/dense.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/salt/models/dense.py b/salt/models/dense.py
index f507cb66..e1822b4d 100644
--- a/salt/models/dense.py
+++ b/salt/models/dense.py
@@ -59,7 +59,7 @@ class Dense(nn.Module):
 
             # normalisation first
             if norm_layer and (norm_final_layer or not is_final_layer):
-                layers.append(getattr(nn, norm_layer)(node_list[i], elementwise_affine=False))
+                layers.append(getattr(nn, norm_layer)(node_list[i]))
 
             # then dropout
             if dropout and (norm_final_layer or not is_final_layer):
-- 
GitLab


From 42de145459433438df799fa98e293279e8f4c1c8 Mon Sep 17 00:00:00 2001
From: Sam Van Stroud <sam.van.stroud@cern.ch>
Date: Wed, 16 Aug 2023 11:05:09 +0100
Subject: [PATCH 05/10] pre-ln style in transfomer

---
 salt/models/transformer.py | 18 ++++++++----------
 1 file changed, 8 insertions(+), 10 deletions(-)

diff --git a/salt/models/transformer.py b/salt/models/transformer.py
index 2de69388..68c50fc4 100644
--- a/salt/models/transformer.py
+++ b/salt/models/transformer.py
@@ -102,11 +102,11 @@ class TransformerEncoderLayer(nn.Module):
                 attn_bias=attn_bias,
             )
 
-        x = x + self.norm2(xi)
+        x = x + xi
         if self.update_edges:
             edge_x = edge_x + self.enorm2(edge_xi)
         if self.dense:
-            x = x + self.dense(x, context)
+            x = x + self.dense(self.norm2(x), context)
 
         if edge_x is not None:
             return x, edge_x
@@ -218,14 +218,12 @@ class TransformerCrossAttentionLayer(TransformerEncoderLayer):
         key_value_mask: BoolTensor | None = None,
         context: Tensor | None = None,
     ) -> Tensor:
-        query = query + self.norm2(
-            self.mha(
-                self.norm1(query),
-                self.norm0(key_value),
-                q_mask=query_mask,
-                kv_mask=key_value_mask,
-            )
+        query = query + self.mha(
+            self.norm1(query),
+            self.norm0(key_value),
+            q_mask=query_mask,
+            kv_mask=key_value_mask,
         )
         if self.dense:
-            query = query + self.dense(query, context)
+            query = query + self.dense(self.norm2(query), context)
         return query
-- 
GitLab


From 520413c36c76694530313958989e0d7f9678b31f Mon Sep 17 00:00:00 2001
From: Sam Van Stroud <sam.van.stroud@cern.ch>
Date: Wed, 16 Aug 2023 11:07:21 +0100
Subject: [PATCH 06/10] wide, shallow, ReLU, nodrop, nonorm

---
 salt/configs/GN2.yaml | 23 +++++++++++------------
 1 file changed, 11 insertions(+), 12 deletions(-)

diff --git a/salt/configs/GN2.yaml b/salt/configs/GN2.yaml
index b5f1a5db..c6ff0876 100644
--- a/salt/configs/GN2.yaml
+++ b/salt/configs/GN2.yaml
@@ -15,27 +15,26 @@ model:
                   class_path: salt.models.Dense
                   init_args:
                     input_size: 21
-                    output_size: &embed_dim 192
+                    output_size: &embed_dim 256
                     hidden_layers: [256]
-                    activation: &activation SiLU
-                    norm_layer: &norm_layer LayerNorm
+                    activation: &activation ReLU
+                    #norm_layer: &norm_layer LayerNorm
 
       gnn:
         class_path: salt.models.TransformerEncoder
         init_args:
           embed_dim: *embed_dim
-          num_layers: 6
+          num_layers: 4
           out_dim: &out_dim 128
           mha_config:
             num_heads: 8
             attention:
               class_path: salt.models.ScaledDotProductAttention
-            out_proj: False
           dense_config:
-            norm_layer: *norm_layer
+            #norm_layer: *norm_layer
             activation: *activation
-            hidden_layers: [256]
-            dropout: &dropout 0.1
+            hidden_layers: [512]
+            #dropout: &dropout 0.1
 
       pool_net:
         class_path: salt.models.GlobalAttentionPooling
@@ -62,8 +61,8 @@ model:
                     output_size: 3
                     hidden_layers: [128, 64, 32]
                     activation: *activation
-                    norm_layer: *norm_layer
-                    dropout: *dropout
+                    #norm_layer: *norm_layer
+                    #dropout: *dropout
 
             - class_path: salt.models.ClassificationTask
               init_args:
@@ -83,8 +82,8 @@ model:
                     output_size: 8
                     hidden_layers: [128, 64, 32]
                     activation: *activation
-                    norm_layer: *norm_layer
-                    dropout: *dropout
+                    #norm_layer: *norm_layer
+                    #dropout: *dropout
 
             - class_path: salt.models.VertexingTask
               init_args:
-- 
GitLab


From 074b3d29fd91c8dad9d3c22adad1c55e72e9d7e7 Mon Sep 17 00:00:00 2001
From: Sam Van Stroud <sam.van.stroud@cern.ch>
Date: Wed, 16 Aug 2023 11:08:03 +0100
Subject: [PATCH 07/10] boost lrs

---
 salt/configs/GN2.yaml | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/salt/configs/GN2.yaml b/salt/configs/GN2.yaml
index c6ff0876..bbb68fe2 100644
--- a/salt/configs/GN2.yaml
+++ b/salt/configs/GN2.yaml
@@ -1,6 +1,11 @@
 name: GN2
 
 model:
+  lrs_config:
+    initial: 1e-7
+    max: 1e-3
+    end: 1e-5
+    pct_start: 0.01
   model:
     class_path: salt.models.JetTagger
     init_args:
-- 
GitLab


From 95491fe2bf049e926287db9707d0f084633f8720 Mon Sep 17 00:00:00 2001
From: Sam Van Stroud <sam.van.stroud@cern.ch>
Date: Thu, 17 Aug 2023 11:46:12 +0100
Subject: [PATCH 08/10] bugfix for final learning rate

---
 salt/lightning_module.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/salt/lightning_module.py b/salt/lightning_module.py
index cae8c669..bb56ada0 100644
--- a/salt/lightning_module.py
+++ b/salt/lightning_module.py
@@ -110,7 +110,7 @@ class LightningTagger(L.LightningModule):
             max_lr=self.lrs_config["max"],
             total_steps=self.trainer.estimated_stepping_batches,
             div_factor=self.lrs_config["max"] / self.lrs_config["initial"],
-            final_div_factor=self.lrs_config["max"] / self.lrs_config["end"],
+            final_div_factor=self.lrs_config["initial"] / self.lrs_config["end"],
             pct_start=float(self.lrs_config["pct_start"]),
         )
         sch = {"scheduler": sch, "interval": "step"}
-- 
GitLab


From 7745a0ab0f77a782112ac179fd255b71e1a25645 Mon Sep 17 00:00:00 2001
From: Sam Van Stroud <sam.van.stroud@cern.ch>
Date: Fri, 18 Aug 2023 10:03:46 +0100
Subject: [PATCH 09/10] revert lr change

---
 salt/configs/GN2.yaml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/salt/configs/GN2.yaml b/salt/configs/GN2.yaml
index bbb68fe2..9eddd548 100644
--- a/salt/configs/GN2.yaml
+++ b/salt/configs/GN2.yaml
@@ -3,7 +3,7 @@ name: GN2
 model:
   lrs_config:
     initial: 1e-7
-    max: 1e-3
+    max: 5e-4
     end: 1e-5
     pct_start: 0.01
   model:
-- 
GitLab


From bbc1368bd85aeb721e4e30558c4d4dc77f8f0d88 Mon Sep 17 00:00:00 2001
From: Nikita I Pond <zcappon@ucl.ac.uk>
Date: Fri, 1 Sep 2023 06:40:47 +0100
Subject: [PATCH 10/10] pin memory option

---
 salt/data/datamodules.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/salt/data/datamodules.py b/salt/data/datamodules.py
index ba979ee1..d56d1cfe 100644
--- a/salt/data/datamodules.py
+++ b/salt/data/datamodules.py
@@ -20,6 +20,7 @@ class JetDataModule(L.LightningDataModule):
         class_dict: str | None = None,
         test_file: str | None = None,
         test_suff: str | None = None,
+        pin_memory: bool = True,
         **kwargs,
     ):
         """h5 jet datamodule.
@@ -49,6 +50,8 @@ class JetDataModule(L.LightningDataModule):
             Test file path, default is None
         test_suff : str
             Test file suffix, default is None
+        pin_memory: bool
+            Pin memory for faster GPU transfer, default is True
         **kwargs
             Additional arguments to pass to the Dataset class
         """
@@ -65,6 +68,7 @@ class JetDataModule(L.LightningDataModule):
         self.num_jets_test = num_jets_test
         self.class_dict = class_dict
         self.move_files_temp = move_files_temp
+        self.pin_memory = pin_memory
         self.kwargs = kwargs
 
     def prepare_data(self):
@@ -122,7 +126,7 @@ class JetDataModule(L.LightningDataModule):
             sampler=RandomBatchSampler(dataset, self.batch_size, shuffle, drop_last=drop_last),
             num_workers=self.num_workers,
             shuffle=False,
-            pin_memory=True,
+            pin_memory=self.pin_memory,
         )
 
     def train_dataloader(self):
-- 
GitLab