Skip to content

Always cast inputs to full precision before scaling

Jackson Barr requested to merge jabarr/salt:dtype into main

Currently we only cast floats to half precision so if a 32-bit integer is given as input, things don't work. This MR casts 32-bit ints to 16-bit floats. A specific exclude flag is also added to the get_dtype function as some of our labels (e.g. truthOriginLabel) are 32-bit ints and the loss functions don't support half precision so I specifically exclude them from the conversion.

p.s. I'm not sure about the regression case, and the labels are floats, maybe those labels shouldn't be excluded?

Edited by Jackson Barr

Merge request reports