Read loss weights from class_dict if available
When ClassificationTask
loss weights are not set in the config file, automatically read them from class_dict.yaml
, if provided. If label is not found in the file, leave weights as 1.
ClassificationTask
has a new attribute, use_class_dict_weights
- within
task.py
it does nothing, just saves it as a class attribute. All the code is withincli.py
. - if the flag isn't set in the config file, it is assumed to be
False
and class loss weights are processed as before/ - if the flag is
True
:- finds and sets loss weights from
class_dict
based on thelabel
- raises
ValueError
ifclass_dict
isn't specified - raises
ValueError
if thelabel
isn't found in the dict - raises
ValueError
if loss weights are already specified and are about to be overwritten
- finds and sets loss weights from
Also, found a bug: currently in GN2X.yaml the hard-coded weights are written in the wrong order, the are saved as top-qcd-hbb-hcc (like R10TruthLabel_R22v1
), while the new flavour_label
order is hbb-hcc-top-qcd. Left it as it is for now.
Closes #36 (closed)
Edited by Andrius Vaitkus