Commit fc08c38b authored by Nikita I Pond's avatar Nikita I Pond
Browse files

.

parent 21af1f88
Pipeline #3310030 failed with stages
in 20 seconds
......@@ -230,6 +230,8 @@ sampling:
# If set to -1: max out to target numbers (limited by fractions ratio)
njets: 25e6
save_tracks: True
# If none, 40 will be used
n_tracks:
tracks_name: "tracks"
# this stores the indices per sample into an intermediate file
intermediate_index_file: indices.h5
......
......@@ -74,6 +74,10 @@ class PrepareSamples:
# Check if tracks are used
self.save_tracks = self.config.sampling["options"]["save_tracks"]
self.n_tracks = self.config.sampling["options"]["n_tracks"]
if(self.n_tracks is None):
self.n_tracks = 40
# Check for tracks name. If not there, use default
if (
"tracks_name" in self.config.sampling["options"]
......@@ -149,7 +153,7 @@ class PrepareSamples:
tracks = data_set[self.tracks_name][
batch[0] : batch[1]
]
tracks = np.delete(tracks, indices_to_remove, axis=0)
tracks = np.delete(tracks, indices_to_remove, axis=0)[:, :self.n_tracks]
else:
tracks = None
yield (jets, tracks)
......
......@@ -81,21 +81,31 @@ class TrainSampleWriter:
yield jets, labels, flavour
elif self.bool_use_tracks is True:
logger.info(f"File name: {input_file}")
logger.info(len(indices_selected))
# temp = h5py.File(input_file, "r")
# logger.info(temp.keys())
# logger.info(temp['/tracks'])
# tracks_test = temp['/tracks'][100:200]
# logger.info(tracks_test)
# Load tracks
trks = np.asarray(
h5py.File(input_file, "r")["/tracks"][
indices_selected
],
trks_h5 = f["/tracks"][indices_selected]
trks = np.array(trks_h5,
dtype=self.precision,
)
trks = trks[rng_index]
if "track_labels" in f.keys():
track_labels = np.asarray(
h5py.File(input_file, "r")["/track_labels"][
indices_selected
]
)
track_labels = track_labels[rng_index]
track_labels_h5 = f["/track_labels"][rng_index]
track_labels = np.array(track_labels_h5,
dtype=self.precision)
# trks = np.array(trks_h5,
# dtype=self.precision,
# )
# track_labels = np.asarray(
# f["/track_labels"][
# indices_selected
# ]
# )
else:
track_labels = None
......@@ -120,7 +130,7 @@ class TrainSampleWriter:
self,
input_file: str = None,
output_file: str = None,
chunkSize: int = int(1e6),
chunkSize: int = int(5e5),
):
"""
Input:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment