Skip to content

Read loss weights from class_dict if available

Andrius Vaitkus requested to merge avaitkus/salt:class-dict-improvement into main

When ClassificationTask loss weights are not set in the config file, automatically read them from class_dict.yaml, if provided. If label is not found in the file, leave weights as 1.

ClassificationTask has a new attribute, use_class_dict_weights

  • within task.py it does nothing, just saves it as a class attribute. All the code is within cli.py.
  • if the flag isn't set in the config file, it is assumed to be False and class loss weights are processed as before/
  • if the flag is True:
    • finds and sets loss weights from class_dict based on the label
    • raises ValueError if class_dict isn't specified
    • raises ValueError if the label isn't found in the dict
    • raises ValueError if loss weights are already specified and are about to be overwritten

Also, found a bug: currently in GN2X.yaml the hard-coded weights are written in the wrong order, the are saved as top-qcd-hbb-hcc (like R10TruthLabel_R22v1), while the new flavour_label order is hbb-hcc-top-qcd. Left it as it is for now.

Closes #36 (closed)

Edited by Andrius Vaitkus

Merge request reports