Skip to content

Fix auxiliary task outputs

Nikita Ivvan Pond requested to merge npond/salt:fix_aux_output into main

I was getting the following error when running auxiliary task outputs:

  File "/home/xzcappon/phd/tools/salt/salt/salt/utils/union_find.py", line 112, in get_node_assignment

    # symmetrize edge scores
    scores = symmetrize_edge_scores(output, node_numbers) if output.shape[0] > 0 else output
             ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

    # update node assignments until no more changes occur
  File "/home/xzcappon/phd/tools/salt/salt/salt/utils/union_find.py", line 40, in symmetrize_edge_scores
    edge_scores = (scores + scores[sym_ind]) / 2.0

    return torch.sigmoid(edge_scores)
           ~~~~~~~~~~~~~ <--- HERE
RuntimeError: "sigmoid_cpu" not implemented for 'Half'

Not sure why this isn't getting picked up on tests. This fix simply casts the edge scores to 'float' before the sigmoid call, allowing this to work fine.

Merge request reports