masked_fill and nan_to_num discussion
We recently switched from torch.masked_fill
to torch.nan_to_num
as it was reported by @mleigh to be significantly faster.
@dguest was interested to follow up on this as it introduces some nan-safety concerns.
We should check if we can use the mask to fill the nans in a reasonable time, rather than relying on torch.nan_to_num
. Failing that, we should assert there are no nans before the softmax operation that is expected to introduce them.
Edited by Samuel Van Stroud