diff --git a/salt/data/datamodules.py b/salt/data/datamodules.py
index ba979ee1f6408d922d641dbb83c6fc3c0d995140..d56d1cfe7ba36775512aaef4c1a3718c777e6af3 100644
--- a/salt/data/datamodules.py
+++ b/salt/data/datamodules.py
@@ -20,6 +20,7 @@ class JetDataModule(L.LightningDataModule):
         class_dict: str | None = None,
         test_file: str | None = None,
         test_suff: str | None = None,
+        pin_memory: bool = True,
         **kwargs,
     ):
         """h5 jet datamodule.
@@ -49,6 +50,8 @@ class JetDataModule(L.LightningDataModule):
             Test file path, default is None
         test_suff : str
             Test file suffix, default is None
+        pin_memory: bool
+            Pin memory for faster GPU transfer, default is True
         **kwargs
             Additional arguments to pass to the Dataset class
         """
@@ -65,6 +68,7 @@ class JetDataModule(L.LightningDataModule):
         self.num_jets_test = num_jets_test
         self.class_dict = class_dict
         self.move_files_temp = move_files_temp
+        self.pin_memory = pin_memory
         self.kwargs = kwargs
 
     def prepare_data(self):
@@ -122,7 +126,7 @@ class JetDataModule(L.LightningDataModule):
             sampler=RandomBatchSampler(dataset, self.batch_size, shuffle, drop_last=drop_last),
             num_workers=self.num_workers,
             shuffle=False,
-            pin_memory=True,
+            pin_memory=self.pin_memory,
         )
 
     def train_dataloader(self):