Skip to content

Does data.num_test have to be a multiple of the batch_size?

If I try to predict pu, pb, pc with a trained GN2 model and a batch size of 512, while setting data.num_test in the config to 1_000, I am getting this error:

Traceback (most recent call last):
  File "/usr/local/bin/salt", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/workspace/salt/salt/main.py", line 17, in main
    SaltCLI(
  File "/usr/local/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 388, in __init__
    self._run_subcommand(self.subcommand)
  File "/usr/local/lib/python3.11/site-packages/lightning/pytorch/cli.py", line 679, in _run_subcommand
    fn(**fn_kwargs)
  File "/usr/local/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 753, in test
    return call._call_and_handle_interrupt(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 793, in _test_impl
    results = self._run(model, ckpt_path=ckpt_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
    return self._evaluation_loop.run()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/usr/local/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 410, in _evaluation_step
    call._call_callback_hooks(trainer, hook_name, output, *hook_kwargs.values())
  File "/usr/local/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 208, in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
  File "/workspace/salt/salt/callbacks/predictionwriter.py", line 252, in on_test_batch_end
    self._write_batch_outputs(to_write, out_pads, batch_idx)
  File "/workspace/salt/salt/callbacks/predictionwriter.py", line 166, in _write_batch_outputs
    to_write[name] = join_structured_arrays(this_outputs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/salt/salt/utils/array_utils.py", line 23, in join_structured_arrays
    newrecarray[name] = a[name]
    ~~~~~~~~~~~^^^^^^
ValueError: could not broadcast input array from shape (488,) into shape (512,)

The error seems to be occuring in the prediction writer here, which is called twice. Once with batch_outputs containing a structured array of shape (512,) and once with batch_outputs containing a structured array of shape (488,). The first time everything works, but the second time the prediction writer loads data from the test h5 file in the shape (512,) due to blow, bhigh being set to 512, 1024 here which mismatches with the shape of the batch_outputs array, such that we get an error when we try to join them here.

My question is, what is the fix here? Is this intended behavior and I am supposed to set data.num_test to e.g. 1024 as a multiple of the batch_size? Or am I missing something and this only happens due to some other misconfiguration?