Commit 44de4839 authored by Alexander Froch's avatar Alexander Froch
Browse files

Merge branch maggiechen_label_binarize_fix with refs/heads/master into...

Merge branch maggiechen_label_binarize_fix with refs/heads/master into refs/merge-requests/409/train
parents e72a7650 a8d8dabc
Pipeline #3636061 passed with stages
in 14 minutes and 51 seconds
......@@ -1545,11 +1545,12 @@ class PDFSampling(Resampling): # pylint: disable=too-many-public-methods
load_chunk = load_more
# One hot encode the loaded labels
label_classes.append(-1)
labels = label_binarize(
(np.ones(len(indices)) * label),
classes=label_classes,
)
)[:, :-1]
label_classes = label_classes[:-1]
# Open the input file and read the jets and tracks
# in a fancy way which allows double index loading
with h5py.File(in_file, "r") as file_df:
......
......@@ -88,10 +88,12 @@ def SamplingGenerator(
tupled_indices = rng.choice(tupled_indices, len(tupled_indices), replace=False)
for index_tuple in tupled_indices:
loading_indices = indices[index_tuple[0] : index_tuple[1]]
label_classes.append(-1)
labels = label_binarize(
(np.ones(index_tuple[1] - index_tuple[0]) * label),
classes=label_classes,
)
)[:, :-1]
label_classes = label_classes[:-1]
if duplicate and quick_check_duplicates(loading_indices):
# Duplicate indices, fancy indexing of H5 not working, manual approach.
list_loading_indices = []
......@@ -604,13 +606,13 @@ class Resampling:
# Iterate over the indicies
for index_tuple in tupled_indices:
loading_indices = indices[index_tuple[0] : index_tuple[1]]
label_classes.append(-1)
# One hot encode the labels
labels = label_binarize(
(np.ones(index_tuple[1] - index_tuple[0]) * label),
classes=label_classes,
)
)[:, :-1]
label_classes = label_classes[:-1]
# Yield the jets and labels
# If tracks are used, also yield the tracks
if use_tracks:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment