Adding cross compatability between onnx and lwtnn as a library for NNs
Objectives
These changes will enable parameter files to store the net as an ONNX bytestream, or as an LWTNN json file.
Up to this point all parameter files have been LWTNN. We hope that by storing parameter files as ONNX files we may reduce the memory footprint of the ATLAS FastSim3. In order to do this, we need to have the ability to write and read a net in ONNX format. With this update;
- We can write "Big" param files for the voxel GAN and the longitudinal weight prediction network using ONNX.
- We can read and run "Big" param files for the voxel GAN and the longitudinal weight prediction network using ONNX.
- We can still write, read and run "Big" param files with LWTNN.
- We can still read and run "Big" param files created before these changes were introduced.
- We have the flexibility to easily introduce a new format (such as SOPHIE), if ONNX doesn't solve out issues.
Points 1. and 2. are the main objectives. Point 3. allows us to go back to LWTNN without delay if ONNX doesn't improve memory. Point 4. lets us compare new and old fast simulation versions in the same codebase (and didn't take much to implement anyway). Point 5. is likely to be relevant soon, as David Attard has demonstrated that SOPHIE is a very effective way to run these networks too.
Implementation
A small hierarchy has been introduced to manage the possible network formats;
-
VNetworkBase
; This is the abstract base class (hence theV
) from which all network formats must inherit. It defines the interface for a network format. This is the only class that should be directly referenced in code that runs a network, this way the file format of the network is not hard coded.-
TFCSONNXHandler
; Inherits fromVNetworkBase
. Is the concrete class for any ONNX network. -
VNetworkLWTNN
; Inherits fromVNetworkBase
. Is an abstract base class that handles some common functionality for the different variants of LWTNN networks.-
TFCSSimpleLWTNNHandler
; Inherits fromVNetworkLWTNN
. Is the concrete class for networks of formatlwt::LightweightNeuralNetwork
. This is what is used for the weight prediction network. -
TFCSGANLWTNNHandler
; Inherits fromVNetworkLWTNN
. Is the concrete class for networks of formatlwt::LightweightGraph
. This is what is used for the GAN.
-
-
Then there is also a TFCSNetworkFactory
. This factory is allows us to use only pointers to VNetworkBase
in all other parts of the code by returning a VNetworkBase
smart pointer that points to the appropriate handler for the input. For example, if TFCSNetworkFactory.Create(file_path)
is called, and file_path
is a string that points to a .onnx
file, then a TFCSONNXHandler
will be created and returned in a std::unique_ptr<VNetworkBase>
. On the other hand, if TFCSNetworkFactory.Create(file_path)
is called, and file_path
is a string that points to a .json
file, then either a TFCSSimpleLWTNNHandler
or a TFCSGANLWTNNHandler
will be created (depending on the content of the .json
) and returned in a std::unique_ptr<VNetworkBase>
.
Unit tests
There are 3 unit tests with these extensions.
-
GenericNetwork_test.cxx
simply loads each handler class using a toy network. It checks they can run, write and read. -
TFCSEnergyAndHitsGANV2_test.cxx
calls the unit test inTFCSEnergyAndHitsGANV2.cxx
on both ONNX and LWTNN files. This is the longest running test, taking about 90s. -
TFCSPredictExtrapWeights_test.cxx
is for the unit test inTFCSPredictExtrapWeights.cxx
, also on both ONNX and LWTNN. It's shorter.
Manual tests
Before taking this MR out of draft I have done a few more intensive experiments to check things work ok.
- Old param files in the new code. I have run pion particle gun events using an old GAN param file using this codebase. Results are numerically identical.
- Writing and running LWTNN param files in the new code. I have done this using David Attard's choice of LWTNN files, and the output looks physically correct.
- Writing and running ONNX param files in the new code. I have done this using David Attard's generated ONNX files, and the output looks physically correct and has the same physics as the LWTNN output. It's not identical, but the distributions are the same.
Running an old LWTNN param file with existing code (blue line) and new code (red line) produces numerically identical output.
Running an old and new LWTNN params file in the new code also produces numerically identical output.
Interdependencies
To actually make ONNX param files we will also need https://gitlab.cern.ch/atlas-simulation-fastcalosim/FCSParametrization/-/merge_requests/50 and https://gitlab.cern.ch/atlas-simulation-fastcalosim/FastCaloSimCommon/-/merge_requests/54