From 5c0a329418c7a00007ddd99f6dc4ac46bf2c7422 Mon Sep 17 00:00:00 2001 From: Sam Van Stroud <sam.van.stroud@cern.ch> Date: Wed, 16 Aug 2023 11:03:29 +0100 Subject: [PATCH 01/10] fix in_dims --- salt/lightning_module.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/salt/lightning_module.py b/salt/lightning_module.py index 6fa59160..cae8c669 100644 --- a/salt/lightning_module.py +++ b/salt/lightning_module.py @@ -28,8 +28,7 @@ class LightningTagger(L.LightningModule): self.model = model self.lrs_config = lrs_config self.name = name - - self.in_dims = [list(net.parameters())[0].shape[1] for net in self.model.init_nets] + self.in_dims = [self.model.init_nets[0].net.input_size for net in self.model.init_nets] def forward(self, x, mask, labels=None): """Forward pass through the model. -- GitLab From 8f4632a606af138735fad01fe3dfe49626e1a9d0 Mon Sep 17 00:00:00 2001 From: Sam Van Stroud <sam.van.stroud@cern.ch> Date: Wed, 16 Aug 2023 11:03:44 +0100 Subject: [PATCH 02/10] add short args --- salt/to_onnx.py | 8 +++++++- salt/utils/cli.py | 8 ++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/salt/to_onnx.py b/salt/to_onnx.py index a9469fcb..588c99c6 100644 --- a/salt/to_onnx.py +++ b/salt/to_onnx.py @@ -72,11 +72,17 @@ def parse_args(args): action="store_true", ) parser.add_argument( + "-a", "--include_aux", help="Include auxiliary task outputs (if available)", action="store_true", ) - parser.add_argument("--force", help="Run with uncomitted changes.", action="store_true") + parser.add_argument( + "-f", + "--force", + help="Run with uncomitted changes.", + action="store_true", + ) return parser.parse_args(args) diff --git a/salt/utils/cli.py b/salt/utils/cli.py index 02ad08f4..808c63c8 100644 --- a/salt/utils/cli.py +++ b/salt/utils/cli.py @@ -53,8 +53,12 @@ class SaltCLI(LightningCLI): parser.link_arguments("name", "trainer.logger.init_args.experiment_name") parser.link_arguments("name", "model.name") parser.link_arguments("trainer.default_root_dir", "trainer.logger.init_args.save_dir") - parser.add_argument("--force", action="store_true", help="Run with uncomitted changes.") - parser.add_argument("--tag", action="store_true", help="Push a tag for the current code.") + parser.add_argument( + "-f", "--force", action="store_true", help="Run with uncomitted changes." + ) + parser.add_argument( + "-t", "--tag", action="store_true", help="Push a tag for the current code." + ) def before_instantiate_classes(self) -> None: sc = self.config[self.subcommand] -- GitLab From 78a03bf4aa96ae9de9809c643b5796fe0c693116 Mon Sep 17 00:00:00 2001 From: Sam Van Stroud <sam.van.stroud@cern.ch> Date: Wed, 16 Aug 2023 11:04:10 +0100 Subject: [PATCH 03/10] normalise onnx output names --- salt/to_onnx.py | 1 + 1 file changed, 1 insertion(+) diff --git a/salt/to_onnx.py b/salt/to_onnx.py index 588c99c6..d3755083 100644 --- a/salt/to_onnx.py +++ b/salt/to_onnx.py @@ -199,6 +199,7 @@ def main(args=None): config = yaml.safe_load(config_path.read_text()) model_name = args.name if args.name else config["name"] + model_name = model_name.replace("-", "_") # dashes not allowed for AuxVars with warnings.catch_warnings(): warnings.simplefilter("ignore") -- GitLab From d99a31fe964824ae8ea36363aec0d048a83528dd Mon Sep 17 00:00:00 2001 From: Sam Van Stroud <sam.van.stroud@cern.ch> Date: Wed, 16 Aug 2023 11:04:24 +0100 Subject: [PATCH 04/10] don't tweak layernorm --- salt/models/dense.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/salt/models/dense.py b/salt/models/dense.py index f507cb66..e1822b4d 100644 --- a/salt/models/dense.py +++ b/salt/models/dense.py @@ -59,7 +59,7 @@ class Dense(nn.Module): # normalisation first if norm_layer and (norm_final_layer or not is_final_layer): - layers.append(getattr(nn, norm_layer)(node_list[i], elementwise_affine=False)) + layers.append(getattr(nn, norm_layer)(node_list[i])) # then dropout if dropout and (norm_final_layer or not is_final_layer): -- GitLab From 42de145459433438df799fa98e293279e8f4c1c8 Mon Sep 17 00:00:00 2001 From: Sam Van Stroud <sam.van.stroud@cern.ch> Date: Wed, 16 Aug 2023 11:05:09 +0100 Subject: [PATCH 05/10] pre-ln style in transfomer --- salt/models/transformer.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/salt/models/transformer.py b/salt/models/transformer.py index 2de69388..68c50fc4 100644 --- a/salt/models/transformer.py +++ b/salt/models/transformer.py @@ -102,11 +102,11 @@ class TransformerEncoderLayer(nn.Module): attn_bias=attn_bias, ) - x = x + self.norm2(xi) + x = x + xi if self.update_edges: edge_x = edge_x + self.enorm2(edge_xi) if self.dense: - x = x + self.dense(x, context) + x = x + self.dense(self.norm2(x), context) if edge_x is not None: return x, edge_x @@ -218,14 +218,12 @@ class TransformerCrossAttentionLayer(TransformerEncoderLayer): key_value_mask: BoolTensor | None = None, context: Tensor | None = None, ) -> Tensor: - query = query + self.norm2( - self.mha( - self.norm1(query), - self.norm0(key_value), - q_mask=query_mask, - kv_mask=key_value_mask, - ) + query = query + self.mha( + self.norm1(query), + self.norm0(key_value), + q_mask=query_mask, + kv_mask=key_value_mask, ) if self.dense: - query = query + self.dense(query, context) + query = query + self.dense(self.norm2(query), context) return query -- GitLab From 520413c36c76694530313958989e0d7f9678b31f Mon Sep 17 00:00:00 2001 From: Sam Van Stroud <sam.van.stroud@cern.ch> Date: Wed, 16 Aug 2023 11:07:21 +0100 Subject: [PATCH 06/10] wide, shallow, ReLU, nodrop, nonorm --- salt/configs/GN2.yaml | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/salt/configs/GN2.yaml b/salt/configs/GN2.yaml index b5f1a5db..c6ff0876 100644 --- a/salt/configs/GN2.yaml +++ b/salt/configs/GN2.yaml @@ -15,27 +15,26 @@ model: class_path: salt.models.Dense init_args: input_size: 21 - output_size: &embed_dim 192 + output_size: &embed_dim 256 hidden_layers: [256] - activation: &activation SiLU - norm_layer: &norm_layer LayerNorm + activation: &activation ReLU + #norm_layer: &norm_layer LayerNorm gnn: class_path: salt.models.TransformerEncoder init_args: embed_dim: *embed_dim - num_layers: 6 + num_layers: 4 out_dim: &out_dim 128 mha_config: num_heads: 8 attention: class_path: salt.models.ScaledDotProductAttention - out_proj: False dense_config: - norm_layer: *norm_layer + #norm_layer: *norm_layer activation: *activation - hidden_layers: [256] - dropout: &dropout 0.1 + hidden_layers: [512] + #dropout: &dropout 0.1 pool_net: class_path: salt.models.GlobalAttentionPooling @@ -62,8 +61,8 @@ model: output_size: 3 hidden_layers: [128, 64, 32] activation: *activation - norm_layer: *norm_layer - dropout: *dropout + #norm_layer: *norm_layer + #dropout: *dropout - class_path: salt.models.ClassificationTask init_args: @@ -83,8 +82,8 @@ model: output_size: 8 hidden_layers: [128, 64, 32] activation: *activation - norm_layer: *norm_layer - dropout: *dropout + #norm_layer: *norm_layer + #dropout: *dropout - class_path: salt.models.VertexingTask init_args: -- GitLab From 074b3d29fd91c8dad9d3c22adad1c55e72e9d7e7 Mon Sep 17 00:00:00 2001 From: Sam Van Stroud <sam.van.stroud@cern.ch> Date: Wed, 16 Aug 2023 11:08:03 +0100 Subject: [PATCH 07/10] boost lrs --- salt/configs/GN2.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/salt/configs/GN2.yaml b/salt/configs/GN2.yaml index c6ff0876..bbb68fe2 100644 --- a/salt/configs/GN2.yaml +++ b/salt/configs/GN2.yaml @@ -1,6 +1,11 @@ name: GN2 model: + lrs_config: + initial: 1e-7 + max: 1e-3 + end: 1e-5 + pct_start: 0.01 model: class_path: salt.models.JetTagger init_args: -- GitLab From 95491fe2bf049e926287db9707d0f084633f8720 Mon Sep 17 00:00:00 2001 From: Sam Van Stroud <sam.van.stroud@cern.ch> Date: Thu, 17 Aug 2023 11:46:12 +0100 Subject: [PATCH 08/10] bugfix for final learning rate --- salt/lightning_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/salt/lightning_module.py b/salt/lightning_module.py index cae8c669..bb56ada0 100644 --- a/salt/lightning_module.py +++ b/salt/lightning_module.py @@ -110,7 +110,7 @@ class LightningTagger(L.LightningModule): max_lr=self.lrs_config["max"], total_steps=self.trainer.estimated_stepping_batches, div_factor=self.lrs_config["max"] / self.lrs_config["initial"], - final_div_factor=self.lrs_config["max"] / self.lrs_config["end"], + final_div_factor=self.lrs_config["initial"] / self.lrs_config["end"], pct_start=float(self.lrs_config["pct_start"]), ) sch = {"scheduler": sch, "interval": "step"} -- GitLab From 7745a0ab0f77a782112ac179fd255b71e1a25645 Mon Sep 17 00:00:00 2001 From: Sam Van Stroud <sam.van.stroud@cern.ch> Date: Fri, 18 Aug 2023 10:03:46 +0100 Subject: [PATCH 09/10] revert lr change --- salt/configs/GN2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/salt/configs/GN2.yaml b/salt/configs/GN2.yaml index bbb68fe2..9eddd548 100644 --- a/salt/configs/GN2.yaml +++ b/salt/configs/GN2.yaml @@ -3,7 +3,7 @@ name: GN2 model: lrs_config: initial: 1e-7 - max: 1e-3 + max: 5e-4 end: 1e-5 pct_start: 0.01 model: -- GitLab From bbc1368bd85aeb721e4e30558c4d4dc77f8f0d88 Mon Sep 17 00:00:00 2001 From: Nikita I Pond <zcappon@ucl.ac.uk> Date: Fri, 1 Sep 2023 06:40:47 +0100 Subject: [PATCH 10/10] pin memory option --- salt/data/datamodules.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/salt/data/datamodules.py b/salt/data/datamodules.py index ba979ee1..d56d1cfe 100644 --- a/salt/data/datamodules.py +++ b/salt/data/datamodules.py @@ -20,6 +20,7 @@ class JetDataModule(L.LightningDataModule): class_dict: str | None = None, test_file: str | None = None, test_suff: str | None = None, + pin_memory: bool = True, **kwargs, ): """h5 jet datamodule. @@ -49,6 +50,8 @@ class JetDataModule(L.LightningDataModule): Test file path, default is None test_suff : str Test file suffix, default is None + pin_memory: bool + Pin memory for faster GPU transfer, default is True **kwargs Additional arguments to pass to the Dataset class """ @@ -65,6 +68,7 @@ class JetDataModule(L.LightningDataModule): self.num_jets_test = num_jets_test self.class_dict = class_dict self.move_files_temp = move_files_temp + self.pin_memory = pin_memory self.kwargs = kwargs def prepare_data(self): @@ -122,7 +126,7 @@ class JetDataModule(L.LightningDataModule): sampler=RandomBatchSampler(dataset, self.batch_size, shuffle, drop_last=drop_last), num_workers=self.num_workers, shuffle=False, - pin_memory=True, + pin_memory=self.pin_memory, ) def train_dataloader(self): -- GitLab