Always cast inputs to full precision before scaling
Currently we only cast floats to half precision so if a 32-bit integer is given as input, things don't work. This MR casts 32-bit ints to 16-bit floats. A specific exclude flag is also added to the get_dtype
function as some of our labels (e.g. truthOriginLabel
) are 32-bit ints and the loss functions don't support half precision so I specifically exclude them from the conversion.
p.s. I'm not sure about the regression case, and the labels are floats, maybe those labels shouldn't be excluded?
Edited by Jackson Barr