diff --git a/salt/data/datamodules.py b/salt/data/datamodules.py index ba979ee1f6408d922d641dbb83c6fc3c0d995140..d56d1cfe7ba36775512aaef4c1a3718c777e6af3 100644 --- a/salt/data/datamodules.py +++ b/salt/data/datamodules.py @@ -20,6 +20,7 @@ class JetDataModule(L.LightningDataModule): class_dict: str | None = None, test_file: str | None = None, test_suff: str | None = None, + pin_memory: bool = True, **kwargs, ): """h5 jet datamodule. @@ -49,6 +50,8 @@ class JetDataModule(L.LightningDataModule): Test file path, default is None test_suff : str Test file suffix, default is None + pin_memory: bool + Pin memory for faster GPU transfer, default is True **kwargs Additional arguments to pass to the Dataset class """ @@ -65,6 +68,7 @@ class JetDataModule(L.LightningDataModule): self.num_jets_test = num_jets_test self.class_dict = class_dict self.move_files_temp = move_files_temp + self.pin_memory = pin_memory self.kwargs = kwargs def prepare_data(self): @@ -122,7 +126,7 @@ class JetDataModule(L.LightningDataModule): sampler=RandomBatchSampler(dataset, self.batch_size, shuffle, drop_last=drop_last), num_workers=self.num_workers, shuffle=False, - pin_memory=True, + pin_memory=self.pin_memory, ) def train_dataloader(self):