Skip to content

Fixing Gaussian Regression

Mathias Josef Backes requested to merge backes/salt2:main into main

There are several aspects which needed to be changed:

  • predictionwriter.py L131: needed a shape correction, because for gaussian regression there are twice as many outputs than inputs
  • task.py L273: scale the standard deviation given by the network correctly and ensure positiveness (see lines below)
  • task.py L277: small correction to not run in shape errors
  • task.py L353: torch.nn.GaussianNLLLoss needs variances, not standard deviations. Therefore calculate the square of the network output, the softplus function is hence not necessary anymore to ensure positiveness. From what I tested this did not have a negative impact on the training.
  • hitz.yaml and Dipz.yaml are config files for Gaussian regression
  • base.yaml: changed precision to 16-mixed
Edited by Mathias Josef Backes

Merge request reports