diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/CMakeLists.txt b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/CMakeLists.txt index 022092e075cc629c101d38927971c3597361bda2..9f323813771b92a615642083ebd3dd9e7efbccda 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/CMakeLists.txt +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/CMakeLists.txt @@ -8,7 +8,8 @@ find_package( Boost ) find_package( CLHEP ) find_package( HepPDT ) find_package( ROOT COMPONENTS Core Tree MathCore MathMore Hist RIO Matrix Physics ) -find_package( lwtnn ) +find_package( lwtnn REQUIRED ) +find_package( onnxruntime REQUIRED ) find_package( LibXml2 ) option(USE_GPU "whether to run FCS on GPU or not" OFF) @@ -17,13 +18,14 @@ if(USE_GPU AND CMAKE_CUDA_COMPILER) add_definitions( -DUSE_GPU) endif() -#Remove the-- as - needed linker flags: +#Remove the --as-needed linker flags: atlas_disable_as_needed() #Component(s) in the package: atlas_add_root_dictionary( ISF_FastCaloSimEvent _dictSource ROOT_HEADERS ISF_FastCaloSimEvent/IntArray.h ISF_FastCaloSimEvent/DoubleArray.h + ISF_FastCaloSimEvent/MLogging.h ISF_FastCaloSimEvent/TFCSFunction.h ISF_FastCaloSimEvent/TFCS1DFunction.h ISF_FastCaloSimEvent/TFCS1DFunctionHistogram.h @@ -90,26 +92,37 @@ atlas_add_root_dictionary( ISF_FastCaloSimEvent _dictSource ISF_FastCaloSimEvent/TFCSSimulationState.h ISF_FastCaloSimEvent/TFCSTruthState.h ISF_FastCaloSimEvent/TFCSVoxelHistoLateralCovarianceFluctuations.h + ISF_FastCaloSimEvent/VNetworkBase.h + ISF_FastCaloSimEvent/VNetworkLWTNN.h + ISF_FastCaloSimEvent/TFCSSimpleLWTNNHandler.h + ISF_FastCaloSimEvent/TFCSONNXHandler.h + ISF_FastCaloSimEvent/TFCSNetworkFactory.h ISF_FastCaloSimEvent/LinkDef.h EXTERNAL_PACKAGES HepPDT) +set(ALL_ONNX_LIBS ${onnxruntime_LIBRARY} ${onnxruntime_LIBRARIES} ${ONNXRUNTIME_LIBRARIES}) +message(NOTICE "-- ISF_FastCaloSimEvent/CMakeLists: Using onnx libs=${ALL_ONNX_LIBS}") +set(ALL_ONNX_INCS ${onnxruntime_INCLUDE_DIR} ${ONNXRUNTIME_INCLUDE_DIRS}) +message(NOTICE "-- ISF_FastCaloSimEvent/CMakeLists: Using onnx incs=${ALL_ONNX_INCS}") + + if(USE_GPU AND CMAKE_CUDA_COMPILER) message("compiling ISF_FastCaoSimEvent using cuda") atlas_add_library( ISF_FastCaloSimEvent ISF_FastCaloSimEvent/*.h ${_dictSource} src/*.cxx PUBLIC_HEADERS ISF_FastCaloSimEvent - INCLUDE_DIRS ${CLHEP_INCLUDE_DIRS} ${HEPPDT_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} ${LIBXML2_INCLUDE_DIRS} ${LWTNN_INCLUDE_DIRS} + INCLUDE_DIRS ${CLHEP_INCLUDE_DIRS} ${HEPPDT_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} ${LIBXML2_INCLUDE_DIRS} ${LWTNN_INCLUDE_DIRS} ${ALL_ONNX_INCS} DEFINITIONS ${CLHEP_DEFINITIONS} - LINK_LIBRARIES ${CLHEP_LIBRARIES} ${HEPPDT_LIBRARIES} ${ROOT_LIBRARIES} AthContainers AthenaKernel AthenaBaseComps CaloDetDescrLib ${LWTNN_LIBRARIES} ${LIBXML2_LIBRARIES} + LINK_LIBRARIES ${CLHEP_LIBRARIES} ${HEPPDT_LIBRARIES} ${ROOT_LIBRARIES} AthContainers AthenaKernel AthenaBaseComps CaloDetDescrLib ${LWTNN_LIBRARIES} ${LIBXML2_LIBRARIES} ${ALL_ONNX_LIBS} CaloGeoHelpers CxxUtils TileSimEvent ISF_FastCaloGpuLib PRIVATE_LINK_LIBRARIES GaudiKernel ) else() atlas_add_library( ISF_FastCaloSimEvent ISF_FastCaloSimEvent/*.h ${_dictSource} src/*.cxx PUBLIC_HEADERS ISF_FastCaloSimEvent - INCLUDE_DIRS ${Boost_INCLUDE_DIRS} ${CLHEP_INCLUDE_DIRS} ${HEPPDT_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} ${LIBXML2_INCLUDE_DIRS} ${LWTNN_INCLUDE_DIRS} + INCLUDE_DIRS ${Boost_INCLUDE_DIRS} ${CLHEP_INCLUDE_DIRS} ${HEPPDT_INCLUDE_DIRS} ${ROOT_INCLUDE_DIRS} ${LIBXML2_INCLUDE_DIRS} ${LWTNN_INCLUDE_DIRS} ${ALL_ONNX_INCS} DEFINITIONS ${CLHEP_DEFINITIONS} - LINK_LIBRARIES ${Boost_LIBRARIES} ${CLHEP_LIBRARIES} ${HEPPDT_LIBRARIES} ${ROOT_LIBRARIES} AthContainers AthenaKernel AthenaBaseComps CaloDetDescrLib ${LWTNN_LIBRARIES} ${LIBXML2_LIBRARIES} + LINK_LIBRARIES ${Boost_LIBRARIES} ${CLHEP_LIBRARIES} ${HEPPDT_LIBRARIES} ${ROOT_LIBRARIES} AthContainers AthenaKernel AthenaBaseComps CaloDetDescrLib ${LWTNN_LIBRARIES} ${LIBXML2_LIBRARIES} ${ALL_ONNX_LIBS} CaloGeoHelpers CxxUtils TileSimEvent PRIVATE_LINK_LIBRARIES GaudiKernel ) endif() @@ -118,3 +131,25 @@ atlas_add_dictionary( ISF_FastCaloSimEventDict ISF_FastCaloSimEvent/ISF_FastCaloSimEventDict.h ISF_FastCaloSimEvent/selection.xml LINK_LIBRARIES ISF_FastCaloSimEvent ) + +# This really should only take 10s, but something strange is happening in CI +# TODO resolve +atlas_add_test( GenericNetwork_test + SOURCES test/GenericNetwork_test.cxx + INCLUDE_DIRS ${LWTNN_INCLUDE_DIRS} ${ALL_ONNX_INCS} + LINK_LIBRARIES ${LWTNN_LIBRARIES} ${ALL_ONNX_LIBS} ISF_FastCaloSimEvent + PROPERTIES TIMEOUT 600 ) + +# Takes a bit longer, could be converted to a atlas_add_citest +atlas_add_test( TFCSEnergyAndHitGANV2_test + SOURCES test/TFCSEnergyAndHitGANV2_test.cxx + INCLUDE_DIRS ${LWTNN_INCLUDE_DIRS} ${ALL_ONNX_INCS} + LINK_LIBRARIES ${LWTNN_LIBRARIES} ${ALL_ONNX_LIBS} ISF_FastCaloSimEvent + LOG_IGNORE_PATTERN "*(TFCSCenterPositionCalculation*)0x*" # it's a pointer + PROPERTIES TIMEOUT 1200 ) + +atlas_add_test( TFCSPredictExtrapWeights_test + SOURCES test/TFCSPredictExtrapWeights_test.cxx + INCLUDE_DIRS ${LWTNN_INCLUDE_DIRS} ${ALL_ONNX_INCS} + LINK_LIBRARIES ${LWTNN_LIBRARIES} ${ALL_ONNX_LIBS} ISF_FastCaloSimEvent ) + diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/LinkDef.h b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/LinkDef.h index 8b8a68a210015ea0a98c0dee3eba6c07bbdc831a..48900330199292696342401d60f6b6ab841c3c49 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/LinkDef.h +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/LinkDef.h @@ -50,6 +50,11 @@ #include "ISF_FastCaloSimEvent/TFCSEnergyAndHitGAN.h" #include "ISF_FastCaloSimEvent/TFCSEnergyAndHitGANV2.h" #include "ISF_FastCaloSimEvent/TFCSPredictExtrapWeights.h" +#include "ISF_FastCaloSimEvent/VNetworkBase.h" +#include "ISF_FastCaloSimEvent/VNetworkLWTNN.h" +#include "ISF_FastCaloSimEvent/TFCSSimpleLWTNNHandler.h" +#include "ISF_FastCaloSimEvent/TFCSONNXHandler.h" +#include "ISF_FastCaloSimEvent/TFCSNetworkFactory.h" #endif #include "ISF_FastCaloSimEvent/TFCSLateralShapeParametrization.h" @@ -606,6 +611,11 @@ #pragma link C++ class TFCSEnergyAndHitGAN - ; #pragma link C++ class TFCSEnergyAndHitGANV2 + ; #pragma link C++ class TFCSPredictExtrapWeights - ; + +#pragma link C++ class VNetworkBase + ; +#pragma link C++ class VNetworkLWTNN + ; +#pragma link C++ class TFCSSimpleLWTNNHandler - ; +#pragma link C++ class TFCSONNXHandler - ; #endif #pragma link C++ class TFCSLateralShapeParametrization + ; diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/MLogging.h b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/MLogging.h index fce96ca23c31dfd00a799000670900b08e124976..b0a214309889399ca4a1864d02b4570a674f338a 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/MLogging.h +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/MLogging.h @@ -8,14 +8,6 @@ #include <TNamed.h> //for ClassDef #include "CxxUtils/checker_macros.h" -// One macro for use outside classes. -// Use this in standalone functions or static methods. -#define ATH_MSG_NOCLASS(logger_name, x) \ - do { \ - logger_name.msg() << logger_name.startMsg(MSG::ALWAYS, __FILE__, __LINE__) \ - << x << std::endl; \ - } while (0) - #if defined(__FastCaloSimStandAlone__) #include <iomanip> #include <iostream> @@ -32,6 +24,14 @@ enum Level { NUM_LEVELS }; // enum Level } // end namespace MSG + +// Macro for use outside classes. +// Use this in standalone functions or static methods. +#define ATH_MSG_NOCLASS(logger_name, x) \ + do { \ + logger_name.msg() << logger_name.startMsg(MSG::ALWAYS, __FILE__, __LINE__) \ + << x << std::endl; \ + } while (0) #else // not __FastCaloSimStandAlone__ We get some things from AthenaKernal. // STL includes #include <iosfwd> @@ -45,6 +45,15 @@ enum Level { #include "AthenaKernel/getMessageSvc.h" #include <boost/thread/tss.hpp> + +// Macro for use outside classes. +// Use this in standalone functions or static methods. +// Differs, becuase it must call doOutput +#define ATH_MSG_NOCLASS(logger_name, x) \ + do { \ + logger_name.msg(MSG::ALWAYS) << x << std::endl; \ + logger_name.msg().doOutput(); \ + } while (0) #endif // end not __FastCaloSimStandAlone__ // Declare the class accessories in a namespace @@ -202,8 +211,8 @@ private: std::string m_nm; //! Do not persistify! /// MsgStream instance (a std::cout like with print-out levels) - inline static boost::thread_specific_ptr<MsgStream> - m_msg_tls ATLAS_THREAD_SAFE; //! Do not persistify! + inline static boost::thread_specific_ptr<MsgStream> m_msg_tls + ATLAS_THREAD_SAFE; //! Do not persistify! ClassDef(MLogging, 0) }; diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSEnergyAndHitGANV2.h b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSEnergyAndHitGANV2.h index b72ad406bcaa16bb38aa7eb03e0dd031f9a35d17..1e428a7cfa13396a5dcb5c8177bad2f76906dc51 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSEnergyAndHitGANV2.h +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSEnergyAndHitGANV2.h @@ -70,7 +70,7 @@ public: return m_slice->GetExtrapolatorWeights(); }; - bool initializeNetwork(int pid, int etaMin, + bool initializeNetwork(int const &pid, int const &etaMin, const std::string &FastCaloGANInputFolderName); bool fillEnergy(TFCSSimulationState &simulstate, const TFCSTruthState *truth, @@ -81,6 +81,12 @@ public: virtual void Print(Option_t *option = "") const override; + static void test_path(std::string path, + TFCSSimulationState *simulstate = nullptr, + const TFCSTruthState *truth = nullptr, + const TFCSExtrapolationState *extrapol = nullptr, + std::string outputname = "unnamed", int pid = 211); + static void unit_test(TFCSSimulationState *simulstate = nullptr, const TFCSTruthState *truth = nullptr, const TFCSExtrapolationState *extrapol = nullptr); @@ -90,7 +96,7 @@ protected: std::string FastCaloGANInputFolderName); private: - static int GetBinsInFours(double bins); + static int GetBinsInFours(double const &bins); int GetAlphaBinsForRBin(const TAxis *x, int ix, int yBinNum) const; std::vector<int> m_bin_ninit; @@ -103,7 +109,7 @@ private: TFCSGANEtaSlice *m_slice = nullptr; TFCSGANXMLParameters m_param; - ClassDefOverride(TFCSEnergyAndHitGANV2, 1) // TFCSEnergyAndHitGANV2 + ClassDefOverride(TFCSEnergyAndHitGANV2, 2) // TFCSEnergyAndHitGANV2 }; #endif diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSGANEtaSlice.h b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSGANEtaSlice.h index ebb063ed8555cf050fed5466fbd5be28233b6036..9315057f7fad88b01d3131e23844e25448fa7f59 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSGANEtaSlice.h +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSGANEtaSlice.h @@ -17,11 +17,12 @@ #include "ISF_FastCaloSimEvent/TFCSSimulationState.h" #include "ISF_FastCaloSimEvent/TFCSExtrapolationState.h" #include "ISF_FastCaloSimEvent/TFCSGANXMLParameters.h" -#include "ISF_FastCaloSimEvent/TFCSGANLWTNNHandler.h" #include "ISF_FastCaloSimEvent/MLogging.h" -#include "lwtnn/LightweightGraph.hh" -#include "lwtnn/parse_json.hh" +// generic network class +#include "ISF_FastCaloSimEvent/VNetworkBase.h" +// net class for legacy loading +#include "ISF_FastCaloSimEvent/TFCSGANLWTNNHandler.h" #include <fstream> @@ -62,16 +63,26 @@ private: FitResultsPerLayer m_allFitResults; ExtrapolatorWeights m_extrapolatorWeights; + // legacy - keep or streamers are confused by + // old classes that didn't inherit TFCSGANLWTNNHandler *m_gan_all = nullptr; TFCSGANLWTNNHandler *m_gan_low = nullptr; TFCSGANLWTNNHandler *m_gan_high = nullptr; + // updated - can take an old or new class + std::unique_ptr<VNetworkBase> m_net_all = nullptr; + std::unique_ptr<VNetworkBase> m_net_low = nullptr; + std::unique_ptr<VNetworkBase> m_net_high = nullptr; + // getters so that we are insensitive to where the data actually is + VNetworkBase *GetNetAll() const; + VNetworkBase *GetNetLow() const; + VNetworkBase *GetNetHigh() const; bool LoadGANNoRange(std::string inputFileName); bool LoadGANFromRange(std::string inputFileName, std::string energyRange); TFCSGANXMLParameters m_param; - ClassDef(TFCSGANEtaSlice, 4) // TFCSGANEtaSlice + ClassDef(TFCSGANEtaSlice, 5) // TFCSGANEtaSlice }; #endif //> !ISF_TFCSGANETASLICE_H diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSGANLWTNNHandler.h b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSGANLWTNNHandler.h index f2c1a1fe75ccab934c27f3a3cb88e52dc5ef3a9f..2266eebc184c1c827b45810f60ea4cb91922cf5f 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSGANLWTNNHandler.h +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSGANLWTNNHandler.h @@ -1,36 +1,120 @@ -/* - Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration -*/ +/** + * Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration + * + * Class for a neural network read in the LWTNN format. + * Derived from the abstract base class VNetworkBase + * such that it can be used interchangably with it's + * sibling class, TFCSONNXHandler, TFCSGANLWTNNHandler, + * TFCSSimpleLWTNNHandler. + * + * Frustratingly, LightweightNeuralNetwork and LightweightGraph + * from lwtnn do not have a common ancestor, + * they could be connected with the bridge pattern, + * but that is more complex that currently required. + * This one handles the graph case, TFCSSimpleLWTNNHandler + * is for the non-graph case. + * + * The LoadNetwork function has VNetworkBase as it's return type + * so that it can make a run-time decision about which derived class + * to use, based on the file name presented. + **/ -////////////////////////////////////////////////////////////////// -// TFCSGANLWTNNHandler.h, (c) ATLAS Detector software -/////////////////////////////////////////////////////////////////// +// Hopefully documentation gets inherited from VNetworkBase -#ifndef ISF_TFCSGANLWTNNHANDLER_H -#define ISF_TFCSGANLWTNNHANDLER_H 1 +#ifndef TFCSGANLWTNNHANDLER_H +#define TFCSGANLWTNNHANDLER_H -#include "ISF_FastCaloSimEvent/TFCSTruthState.h" -#include "ISF_FastCaloSimEvent/TFCSSimulationState.h" -#include "ISF_FastCaloSimEvent/TFCSExtrapolationState.h" +#include "ISF_FastCaloSimEvent/VNetworkLWTNN.h" +#include <iostream> +// Becuase we have a field of type LightweightGraph #include "lwtnn/LightweightGraph.hh" -#include "lwtnn/parse_json.hh" -#include <string> -class TFCSGANLWTNNHandler { +// For writing to a tree +#include "TTree.h" + +class TFCSGANLWTNNHandler : public VNetworkLWTNN { public: - TFCSGANLWTNNHandler(); - virtual ~TFCSGANLWTNNHandler(); + // Don't lose default constructors + using VNetworkLWTNN::VNetworkLWTNN; + + /** + * @brief TFCSGANLWTNNHandler constructor. + * + * Calls setupPersistedVariables and setupNet. + * + * @param inputFile file-path on disk (with file name) of a readable + * lwtnn file containing a json format description + * of the network to be constructed, or the json + * itself as a string. + **/ + explicit TFCSGANLWTNNHandler(const std::string &inputFile); + + /** + * @brief TFCSGANLWTNNHandler copy constructor. + * + * Will copy the variables that would be generated by + * setupPersistedVariables and setupNet. + * + * @param copy_from existing network that we are copying + **/ + TFCSGANLWTNNHandler(const TFCSGANLWTNNHandler ©_from); - const lwt::LightweightGraph *GetGraph() const { return m_graph; } + /** + * @brief Function to pass values to the network. + * + * This function hides variations in the formated needed + * by different network libraries, providing a uniform input + * and output type. + * + * @param inputs values to be evaluated by the network + * @return the output of the network + * @see VNetworkBase::NetworkInputs + * @see VNetworkBase::NetworkOutputs + **/ + NetworkOutputs compute(NetworkInputs const &inputs) const override; - bool LoadGAN(const std::string &inputFile); + /** + * @brief List the names of the outputs. + * + * Outputs are stored in an NetworkOutputs object + * which is indexed by strings. This function + * returns the list of all strings that will index the outputs. + * + **/ + std::vector<std::string> getOutputLayers() const override; + +protected: + /** + * @brief Perform actions that prepare network for use. + * + * Will be called in the streamer or class constructor + * after the inputs have been set (either automaically by the + * streamer or by setupPersistedVariables in the constructor). + * Does not delete any resources used. + * + **/ + void setupNet() override; private: - const lwt::LightweightGraph *m_graph; //! Do not persistify + // unique ptr deletes the object when it goes out of scope + /** + * @brief The network that we are wrapping here. + **/ + std::unique_ptr<lwt::LightweightGraph> m_lwtnn_graph; //! Do not persistify + + /** + * @brief List of names that index the output layer. + **/ + std::vector<std::string> m_outputLayers; //! Do not persistify + + /** + * @brief Just for backcompatability. + **/ std::string *m_input = nullptr; - ClassDef(TFCSGANLWTNNHandler, 5) // TFCSGANLWTNNHandler + // Suppling a ClassDef for writing to file. + ClassDefOverride(TFCSGANLWTNNHandler, 6); }; -#endif //> !ISF_TFCSGANLWTNNHANDLER_H +#endif // TFCSGANLWTNNHANDLER_H diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSNetworkFactory.h b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSNetworkFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..97159bbaa584e614e2d0c42421d40772666de9f0 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSNetworkFactory.h @@ -0,0 +1,131 @@ +/** + * Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration + * + * Class to perform runtime selection from the derived + * classes of VNetworkBase given inout for a network. + * + * Has only static functions becuase no statelike + * information is needed to make this decision. + * + * Information about the which network would be + * apropreate can be specified, or left entirely + * to the factory to determine. + * + **/ +#ifndef TFCSNETWORKFACTORY_H +#define TFCSNETWORKFACTORY_H + +#include "ISF_FastCaloSimEvent/VNetworkBase.h" +#include <memory> +#include <string> +#include <vector> +#include <filesystem> + +class TFCSNetworkFactory { +public: + // Unspecified TFCSGANLWTNNHandler or TFCSSimpleLWTNNHandler, take a guess + /** + * @brief Given a string, make a network. + * + * This function will first check if the string is the path of a readable + * file on the disk. If so, the file suffix is used to decide if it's an + * onnx (.onnx) or lwtnn file (.json). If the filepath ends in .* then + * first an onnx then an lwtnn file will be tried. The file is read and + * parsed into a network. + * + * If the string is not a filepath, it is assumed to be the content of a + * json file to make an lwtnn network. + * + * When an lwtnn network is created, first the TFCSSimpleLWTNNHandler + * format is tried, and if this raises an error, the TFCSGANLWTNNHandler + * is applied. The former is simpler than the latter, so it will always + * fail to parse the more complex graph format. + * + * @param input Either a file path, or the content of a file. + * + **/ + static std::unique_ptr<VNetworkBase> create(std::string input); + // Specified TFCSGANLWTNNHandler or TFCSSimpleLWTNNHandler + /** + * @brief Given a string, and information about format, make a network. + * + * This function will first check if the string is the path of a readable + * file on the disk. If so, the file suffix is used to decide if it's an + * onnx (.onnx) or lwtnn file (.json). If the filepath ends in .* then + * first an onnx then an lwtnn file will be tried. The file is read and + * parsed into a network. + * + * If the string is not a filepath, it is assumed to be the content of a + * json file to make an lwtnn network. + * + * When an lwtnn network is created, if graph_form is true + * the network will be a TFCSSimpleLWTNNHandler otherwise it is + * a TFCSGANLWTNNHandler. + * + * @param input Either a file path, or the content of a file. + * @param graph_form Is the network the more complex graph form? + * + **/ + static std::unique_ptr<VNetworkBase> create(std::string input, + bool graph_form); + + /** + * @brief Given a vector of chars (bytes), make a network. + * + * This function will always create a TFCSONNXHandler. + * Caution: this function is designed to modify its input. + * + * @param input The content of an onnx proto file. + * + **/ + static std::unique_ptr<VNetworkBase> create(std::vector<char> const &input); + + /** + * @brief Create a network from whichever input isn't empty. + * + * If the vector_input is not empty, construct a network from that, + * otherwise, use the string_input to construct a network. + * + * @param vector_input The content of an onnx proto file. + * @param string_input Either a file path, or the content of a file. + **/ + static std::unique_ptr<VNetworkBase> + create(std::vector<char> const &vector_input, std::string string_input); + /** + * @brief Create a network from whichever input isn't empty. + * + * If the vector_input is not empty, construct a network from that, + * otherwise, use the string_input to construct a network. + * Whether the network is in graph form is specifed for LWTNN networks. + * + * @param vector_input The content of an onnx proto file. + * @param string_input Either a file path, or the content of a file. + * @param graph_form Is the network the more compelx graph form? + **/ + static std::unique_ptr<VNetworkBase> + create(std::vector<char> const &vector_input, std::string string_input, + bool graph_form); + +private: + /** + * @brief If the filepath ends in .* change it to .onnx or .json + * + * If the filepath doesn't end in .*, no change is made. + * Will check first for a .onnx file, then look for a .json. + * Throws an exception if niether are found. + * + * @param filename Path to check. + **/ + static void resolveGlobs(std::string &filename); + + /** + * @brief Check if a filename seems to be an onnx file. + * + * Really just checks if the input ends in ".onnx" + * + * @param filename Path to check. + **/ + static bool isOnnxFile(std::string const &filename); +}; + +#endif // TFCSNETWORKFACTORY_H diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSONNXHandler.h b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSONNXHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..c01d02a2f2fd51cbf61ec73300846c694a258498 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSONNXHandler.h @@ -0,0 +1,324 @@ +/** + * Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration + * + * Class for a neural network read in the ONNX format. + * Derived from the abstract base class VNetworkBase + * such that it can be used interchangably with it's + * sibling classes, TFCSSimpleLWTNNHandler and TFCSGANLWTNNHandler. + * + * The TFCSNetworkFactory::Create function has VNetworkBase as its return + * type so that it can make a run-time decision about which derived + * class to use, based on the data or file presented. As such it's + * best not to create this directly, instead allow TFCSNetworkFactory::Create + * to create the appropreat network object so that new network formats can + * be accomidated by writing new subclasses of VNetworkBase. + **/ + +#ifndef TFCSONNXHANDLER_H +#define TFCSONNXHANDLER_H + +// inherits from +#include "ISF_FastCaloSimEvent/VNetworkBase.h" + +#include <iostream> + +// ONNX Runtime include(s). +#include <core/session/onnxruntime_cxx_api.h> + +// For reading and writing to root +#include "TFile.h" +#include "TTree.h" + +// For storing the lambda function +#include <functional> + + +/** + * @brief A handler specific for an ONNX network + * + * Inherits from the generic interface VNetworkBase, + * such that it cna be used interchangably with other network + * formats and libraries. + **/ +class TFCSONNXHandler : public VNetworkBase { +public: + // Don't lose the default constructor + using VNetworkBase::VNetworkBase; + + /** + * @brief TFCSONNXHandler constructor. + * + * Calls setupPersistedVariables and setupNet. + * + * @param inputFile file-path on disk (with file name) of a readable + * onnx file containing a proto format description + * of the network to be constructed. + **/ + explicit TFCSONNXHandler(const std::string &inputFile); + + /** + * @brief TFCSONNXHandler constructor. + * + * As this passes nothing to the super constructor + * the setupPersistedVariables will not be called. + * + * @param bytes byte content of a .onnx file, (which are a subset + * if proto files). Allows TFCSONNXHandler objects to be + * created from data in memory, retrived rom any source. + * The bytes are not copied interally, and must remain + * in memory while the net is in use. + * (TODO check that assertion) + * + **/ + explicit TFCSONNXHandler(const std::vector<char> &bytes); + + /** + * @brief TFCSONNXHandler copy constructor. + * + * Will copy the variables taht would be generated by + * setupPersistedVariables and setupNet. + * + * @param copy_from existing network that we are copying + **/ + TFCSONNXHandler(const TFCSONNXHandler ©_from); + + /** + * @brief Function to pass values to the network. + * + * This function is used to actually run data through the loaded + * network and obtain results. + * + * @param inputs values to be evaluated by the network + * @return the output of the network + * @see VNetworkBase::NetworkInputs + * @see VNetworkBase::NetworkOutputs + **/ + NetworkOutputs compute(NetworkInputs const &inputs) const override; + + // Output to a ttree file + using VNetworkBase::writeNetToTTree; + + /** + * @brief Save the network to a TTree. + * + * All data required to recreate the network object is saved + * into a TTree. The format is not specified. + * Will still work even if deleteAllButNet has already + * been called. + * + * @param tree The tree to save inside. + **/ + void writeNetToTTree(TTree &tree) override; + + /** + * @brief List the names of the outputs. + * + * Outputs are stored in an NetworkOutputs object + * which is indexed by strings. This function + * returns the list of all strings that will index the outputs. + * + **/ + std::vector<std::string> getOutputLayers() const override; + + /** + * @brief Get rid of any memory objects that arn't needed to run the net. + * + * Minimise memory usage by deleting nay inputs that are + * no longer required to run the compute function. + * Doesn't actually do anything for this network type. + * + **/ + void deleteAllButNet() override; + +protected: + /** + * @brief Write a short description of this net to the string stream. + * + * Specialised for ONNX to print the input and output nodes with their + * dimensions. + * + * @param strm output parameter, to which the description will be written. + **/ + virtual void print(std::ostream &strm) const override; + + /** + * @brief Perform actions that prep data to create the net + * + * Will be called in the class constructor + * before calling setupNet, but not in the streamer. + * It sets any variables that the sreamer would persist + * when saving or loading to file. + * + **/ + void setupPersistedVariables() override; + + /** + * @brief Perform actions that prepare network for use. + * + * Will be called in the streamer or class constructor + * after the inputs have been set (either automaically by the + * streamer or by setupPersistedVariables in the constructor). + * Does not delete any resources used. + * + **/ + void setupNet() override; + +private: + /** + * @brief Content of the proto file. + **/ + std::vector<char> m_bytes; + /** + * @brief Return content of the proto (.onnx) file in memory. + * + * Get the session as a stream of bytes + * It's a vector<char> rather than a string becuase we need the guarantee + * that &bytes[0]+n == bytes[n] (string has this only after c++11). + * Also bytes may not be terminated by a null byte + * (which early strings required). + * + **/ + std::vector<char> + getSerializedSession(std::string tree_name = m_defaultTreeName); + /** + * @brief Retrieve the content of the proto file from a TTree + * + * If the ONNX file has been saved as a loose variable in a TTree + * this method will read it back into m_bytes. + * + **/ + std::vector<char> readBytesFromTTree(TTree &tree); + /** + * @brief Write the content of the proto file to a TTree as a branch + * + * The ONNX proto file is saved as a simple branch (no streamers involved). + * + **/ + void writeBytesToTTree(TTree &tree, const std::vector<char> &bytes); + + // unique ptr deletes the object when it goes out of scope + /** + * @brief The network session itself + * + * This is the object created by onnxruntime_cxx_api which + * contains information about the network and can run inputs + * through it. + * + * Held as a unique pointer to prevent the need for manual + * memory management + **/ + std::unique_ptr<Ort::Session> m_session; //! Do not persistify + /** + * @brief Using content of the proto (.onnx) file make a session. + * + * The m_session variable is initialised from the m_bytes variable + * so that the net can be run. + * Requires that the m_bytes variable is retained while the net is + * used. + * + **/ + void readSerializedSession(); + + /** + * @brief names that index the input nodes + * + * An ONNX network is capable of having two layers of + * labels, input node names, then labels within each node, + * but it's twin, LWTNN is not. LWTNN supports one list of + * nodes indexed by strings, and each input node may have + * more than one value, indexed by positive integers (list + * like), so this interfae only supports that more limited + * format. + * + **/ + std::vector<const char *> m_inputNodeNames; //! Do not persistify + +#if ORT_API_VERSION > 11 + /** + * @brief Manage memory of the input node names + * + * In newer versions of ONNX the caller is responsible for + * managing the memory required to store node names + **/ + std::vector<Ort::AllocatedStringPtr> m_storeInputNodeNames; //! Do not persistify +#endif + + /** + * @brief the names that index the output nodes + * + * An ONNX network is capable of having two layers of + * labels, input node names, then labels within each node, + * but it's twin, LWTNN is not. LWTNN supports one list of + * nodes indexed by strings, and each input node may have + * more than one value, indexed by positive integers (list + * like), so this interfae only supports that more limited + * format. + * + **/ + std::vector<const char *> m_outputNodeNames; //! Do not persistify + +#if ORT_API_VERSION > 11 + /** + * @brief Manage memory of the output node names + * + * In newer versions of ONNX the caller is responsible for + * managing the memory required to store node names + **/ + std::vector<Ort::AllocatedStringPtr> m_storeOutputNodeNames; //! Do not persistify +#endif + + /** + * @brief dimension lengths in each named input node + * + * Describes the shape of the input nodes. + * @see TFCSONNXHandler::m_inputNodeNames + **/ + std::vector<std::vector<int64_t>> m_inputNodeDims; //! Do not persistify + /** + * @brief dimension lengths in each named output node + * + * As the final output must be flat in each output node, + * this is for internal manipulations only. + * @see TFCSONNXHandler::m_inputNodeDims + **/ + std::vector<std::vector<int64_t>> m_outputNodeDims; //! Do not persistify + /** + * @brief total elements in each named output node + * + * For internal use only, gives the total number of elements in + * the output nodes. + * @see TFCSONNXHandler::m_inputNodeDims + **/ + std::vector<int64_t> m_outputNodeSize; //! Do not persistify + + /** + * @brief Computation template with adjustable types for input. + * + * A lambda function will be used to make the correct type choice + * for the session/net used as a member variable during setupNet. + **/ + template <typename Tin, typename Tout> + NetworkOutputs computeTemplate(NetworkInputs const &input); + + /** + * @brief computeTemplate with apropreate types selected. + **/ + std::function<NetworkOutputs(NetworkInputs)> + m_computeLambda; //! Do not persistify + + /** + * @brief Specifies memory behavior for vectors in ONNX. + **/ + Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu( + OrtArenaAllocator, OrtMemTypeDefault); //! Do not persistify + + /** + * @brief Externally visible names that index the output. + **/ + std::vector<std::string> m_outputLayers; //! Do not persistify + + // For the streamer + ClassDefOverride(TFCSONNXHandler, 1); +}; + +#endif // TFCSONNXHANDLER_H diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSPredictExtrapWeights.h b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSPredictExtrapWeights.h index 3e52f07b614612bb0efc377e633b35f31df60e9f..495d9ad9fd17ce9e9aac03b94f81d22ecc6529dd 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSPredictExtrapWeights.h +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSPredictExtrapWeights.h @@ -51,6 +51,10 @@ public: const std::string &FastCaloTXTInputFolderName); // Test function + static void test_path(std::string &net_path, std::string const &norm_path, + TFCSSimulationState *simulstate = nullptr, + const TFCSTruthState *truth = nullptr, + const TFCSExtrapolationState *extrapol = nullptr); static void unit_test(TFCSSimulationState *simulstate = nullptr, const TFCSTruthState *truth = nullptr, const TFCSExtrapolationState *extrapol = nullptr); diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSSimpleLWTNNHandler.h b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSSimpleLWTNNHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..217ae3790bf6b288178eac0c4b78e36630f6493a --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/TFCSSimpleLWTNNHandler.h @@ -0,0 +1,113 @@ +/** + * Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration + * + * Class for a neural network read in the LWTNN format. + * Derived from the abstract base class VNetworkBase + * such that it can be used interchangably with it's + * sibling classes, TFCSSimpleLWTNNHandler, TFCSGANLWTNNHandler, + * TFCSONNXHandler. + * + * Frustratingly, LightweightNeuralNetwork and LightweightGraph + * from lwtnn do not have a common ancestor, + * they could be connected with the bridge pattern, + * but that is more complex that currently required. + * + * The LoadNetwork function has VNetworkBase as it's return type + * so that it can make a run-time decision about which derived class + * to use, based on the file name presented. + **/ + +// Hopefully documentation gets inherited from VNetworkBase + +#ifndef TFCSSIMPLELWTNNHANDLER_H +#define TFCSSIMPLELWTNNHANDLER_H + +#include "ISF_FastCaloSimEvent/VNetworkLWTNN.h" +#include <iostream> + +// Becuase we have a field of type LightweightNeuralNetwork +#include "lwtnn/LightweightNeuralNetwork.hh" + +// For writing to a tree +#include "TTree.h" + +class TFCSSimpleLWTNNHandler : public VNetworkLWTNN { +public: + // Don't lose the default constructor + using VNetworkLWTNN::VNetworkLWTNN; + + /** + * @brief TFCSSimpleLWTNNHandler constructor. + * + * Calls setupPersistedVariables and setupNet. + * + * @param inputFile file-path on disk (with file name) of a readable + * lwtnn file containing a json format description + * of the network to be constructed, or the json + * itself as a string. + **/ + explicit TFCSSimpleLWTNNHandler(const std::string &inputFile); + + /** + * @brief TFCSSimpleLWTNNHandler copy constructor. + * + * Will copy the variables that would be generated by + * setupPersistedVariables and setupNet. + * + * @param copy_from existing network that we are copying + **/ + TFCSSimpleLWTNNHandler(const TFCSSimpleLWTNNHandler ©_from); + + /** + * @brief Function to pass values to the network. + * + * This function, hides variations in the formated needed + * by different network libraries, providing a uniform input + * and output type. + * + * @param inputs values to be evaluated by the network + * @return the output of the network + * @see VNetworkBase::NetworkInputs + * @see VNetworkBase::NetworkOutputs + **/ + NetworkOutputs compute(NetworkInputs const &inputs) const override; + + /** + * @brief List the names of the outputs. + * + * Outputs are stored in an NetworkOutputs object + * which is indexed by strings. This function + * returns the list of all strings that will index the outputs. + * + **/ + std::vector<std::string> getOutputLayers() const override; + +protected: + /** + * @brief Perform actions that prepare network for use. + * + * Will be called in the streamer or class constructor + * after the inputs have been set (either automaically by the + * streamer or by setupPersistedVariables in the constructor). + * Does not delete any resources used. + * + **/ + void setupNet() override; + +private: + // unique ptr deletes the object when it goes out of scope + /** + * @brief The network that we are wrapping here. + **/ + std::unique_ptr<lwt::LightweightNeuralNetwork> + m_lwtnn_neural; //! Do not persistify + /** + * @brief List of names that index the output layer. + **/ + std::vector<std::string> m_outputLayers; //! Do not persistify + + // Suppling a ClassDef for writing to file. + ClassDefOverride(TFCSSimpleLWTNNHandler, 1); +}; + +#endif // TFCSSIMPLELWTNNHANDLER_H diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/VNetworkBase.h b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/VNetworkBase.h new file mode 100644 index 0000000000000000000000000000000000000000..b389e76e6c90e39d0d7a8b791bc96e228d64c734 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/VNetworkBase.h @@ -0,0 +1,311 @@ +/** + * Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration + * + * Abstract base class for Neural networks. + * Intially aimed at replacing instances of an lwtnn network + * with a network that could be either lwtnn or ONNX, + * so it is an interface which mirrors that of lwtnn graphs. + * At least 3 derived classes are avaliable; + * + * - TFCSSimpleLWTNNHandler; Designed to wrap a lwtnn neural network + * - TFCSGANLWTNNHandler; Designed to wrap a lwtnn graph network + * - TFCSONNXHandler; Designed to wrap an ONNX network + * + * The TFCSNetworkFactory::Create function has this class as its return type + * so that it can make a run-time decision about which derived class + * to use, based on the file or data presented. + **/ +#ifndef VNETWORKBASE_H +#define VNETWORKBASE_H + +// For conversion to ostream +#include <iostream> +#include <map> + +// For reading and writing +#include "TFile.h" +#include "TTree.h" + +// For messaging +#include "ISF_FastCaloSimEvent/MLogging.h" +using ISF_FCS::MLogging; + +/** + * @brief A template defining the interface to a neural network. + * + * Has various subclasses to cover differing network + * libraries and save formats. + **/ +class VNetworkBase : public MLogging { +public: + /** + * @brief VNetworkBase default constructor. + * + * For use in streamers. + **/ + VNetworkBase(); + + // explicit = Don't let this do implicit type conversion + /** + * @brief VNetworkBase constructor. + * + * Only saves inputFile to m_inputFile; + * Inherting classes should call setupPersistedVariables + * and setupNet in constructor; + * + * @param inputFile file-path on disk (with file name) of a readable + * file containing a description of the network to + * be constructed or the content of the file. + **/ + explicit VNetworkBase(const std::string &inputFile); + + /** + * @brief VNetworkBase copy constructor. + * + * Does not call setupPersistedVariables or setupNet + * but will pass on m_inputFile. + * Inherting classes should do whatever they need to move the variables + * created in the setup functions. + * + * @param copy_from existing network that we are copying + **/ + VNetworkBase(const VNetworkBase ©_from); + + // virtual destructor, to ensure that it is always called, even + // when a base class is deleted via a pointer to a derived class + virtual ~VNetworkBase(); + + // same as for lwtnn + /** + * @brief Format for network inputs. + * + * The doubles are the values to be passed into the network. + * Strings in the outer map identify the input node, which + * must corrispond to the names of the nodes as read from the + * description of the network found by the constructor. + * Strings in the inner map identify the part of the input node, + * for some networks these must be simple integers, in string form, + * as parts of nodes do not always have the ability to carry + * real string labels. + **/ + typedef std::map<std::string, std::map<std::string, double>> NetworkInputs; + /** + * @brief Format for network outputs. + * + * The doubles are the values generated by the network. + * Strings identify which node this value came from, + * and when nodes have multiple values, are suffixed with + * a number to indicate which part of the node they came from. + * So in multi-value nodes the format becomes "<node_name>_<part_n>" + **/ + typedef std::map<std::string, double> NetworkOutputs; + + /** + * @brief String representation of network inputs + * + * Create a string that summarises a set of network inputs. + * Gives basic dimensions plus a few values, up to the maxValues + * + * @param inputs values to be evaluated by the network + * @param maxValues maximum number of values to include in the representaiton + * @return string represetning the inputs + **/ + static std::string representNetworkInputs(NetworkInputs const &inputs, + int maxValues = 3); + + /** + * @brief String representation of network outputs + * + * Create a string that summarises a set of network outputs. + * Gives basic dimensions plus a few values, up to the maxValues + * + * @param outputs output of the network + * @param maxValues maximum number of values to include in the representaiton + * @return string represetning the outputs + **/ + static std::string representNetworkOutputs(NetworkOutputs const &outputs, + int maxValues = 3); + + // pure virtual, derived classes must impement this + /** + * @brief Function to pass values to the network. + * + * This function hides variations in the formated needed + * by different network libraries, providing a uniform input + * and output type. + * + * @param inputs values to be evaluated by the network + * @return the output of the network + * @see VNetworkBase::NetworkInputs + * @see VNetworkBase::NetworkOutputs + **/ + virtual NetworkOutputs compute(NetworkInputs const &inputs) const = 0; + + // Conversion to ostream + // It's not possible to have a virtual friend function + // so instead, have a friend function that calls a virtual protected method + /** + * @brief Put-to operator to facilitate printing. + * + * It is useful to be able to display a reasonable representation of + * a network for debugging. + * This can be altered by subclasses by changing the protected + * print function of this class. + **/ + friend std::ostream &operator<<(std::ostream &strm, + const VNetworkBase &vNetworkBase) { + vNetworkBase.print(strm); + return strm; + } + + /** + * @brief Save the network to a TTree. + * + * All data required to recreate the network object is saved + * into a TTree. The format is not specified. + * + * @param tree The tree to save inside. + **/ + virtual void writeNetToTTree(TTree &tree) = 0; + + /** + * @brief Default name for the TTree to save in. + **/ + inline static const std::string m_defaultTreeName = "onnxruntime_session"; + + /** + * @brief Save the network to a TTree. + * + * All data required to recreate the network object is saved + * into a TTree. The format is not specified. + * + * @param root_file The file to save inside. + * @param tree_name The name of the TTree to save inside. + **/ + void writeNetToTTree(TFile &root_file, + std::string const &tree_name = m_defaultTreeName); + + /** + * @brief Save the network to a TTree. + * + * All data required to recreate the network object is saved + * into a TTree. The format is not specified. + * + * @param root_name The path of the file to save inside. + * @param tree_name The name of the TTree to save inside. + **/ + void writeNetToTTree(std::string const &root_name, + std::string const &tree_name = m_defaultTreeName); + + /** + * @brief List the names of the outputs. + * + * Outputs are stored in an NetworkOutputs object + * which is indexed by strings. This function + * returns the list of all strings that will index the outputs. + * + **/ + virtual std::vector<std::string> getOutputLayers() const = 0; + + /** + * @brief Check if a string is the path of a file on disk. + * + * Determines if a string corrisponds to tha path of a file + * that can be read on the disk. + * + * @param inputFile name of the pottential file + * @return is it a readable file on disk + **/ + static bool isFile(std::string const &inputFile); + + /** + * @brief Check if the argument inputFile is the path of a file on disk. + * + * Determines if the string that was passed to the constructor as + * inputFile corrisponds to tha path of a file + * that can be read on the disk. + * + * @return is it a readable file on disk + **/ + bool isFile() const; + + /** + * @brief Get rid of any memory objects that arn't needed to run the net. + * + * Minimise memory usage by deleting any inputs that are + * no longer required to run the compute function. + * May prevent the net from being saved. + * + **/ + virtual void deleteAllButNet() = 0; + +protected: + /** + * @brief Path to the file describing the network, including filename. + **/ + std::string m_inputFile; + + /** + * @brief Perform actions that prep data to create the net + * + * Will be called in the class constructor + * before calling setupNet, but not in the streamer. + * It sets any variables that the sreamer would persist + * when saving or loading to file. + * + **/ + virtual void setupPersistedVariables() = 0; + + /** + * @brief Perform actions that prepare network for use. + * + * Will be called in the streamer or class constructor + * after the inputs have been set (either automaically by the + * streamer or by setupPersistedVariables in the constructor). + * Does not delete any resources used. + * + **/ + virtual void setupNet() = 0; + + /** + * @brief Write a short description of this net to the string stream. + * + * Intended to facilitate the put-to operator, allowing subclasses + * to change how this object is displayed. + * + * @param strm output parameter, to which the description will be written. + **/ + virtual void print(std::ostream &strm) const; + + /** + * @brief Check if a string is possibly a root file path. + * + * Just checks if the string ends in .root + * as there are almost no reliable rules for file paths. + * + * @param inputFile name of the pottential file + * if blank, m_inputFile is used. + * @return is it the path of a root file + **/ + bool isRootFile(std::string const &filename = "") const; + + /** + * @brief Remove any common prefix from the outputs. + * + * @param outputs The outputs, changed in place. + **/ + void removePrefixes(NetworkOutputs &outputs) const; + + /** + * @brief Remove any common prefix from the outputs. + * + * @param outputs The output names, changed in place. + **/ + void removePrefixes(std::vector<std::string> &output_names) const; + +private: + // Suppling a ClassDef for writing to file. + ClassDef(VNetworkBase, 1); +}; + +#endif diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/VNetworkLWTNN.h b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/VNetworkLWTNN.h new file mode 100644 index 0000000000000000000000000000000000000000..0b29f72face21f623529ea2fb27ddd11e4a9f13e --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/ISF_FastCaloSimEvent/VNetworkLWTNN.h @@ -0,0 +1,150 @@ +/** + * Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration + * + * Abstract base class for LWTNN Neural networks. + * Inherits from a generic abstract network class, + * and defined aspects common to all LWTNN networks. + * + * In particular, this derived virtual class handles + * saving, memory managment and printing. + * + * To classes derive from this class; + * - TFCSSimpleLWTNNHandler; Designed to wrap a lwtnn neural network + * - TFCSGANLWTNNHandler; Designed to wrap a lwtnn graph network + * + **/ +#ifndef VNETWORKLWTNN_H +#define VNETWORKLWTNN_H +// inherits from +#include "VNetworkBase.h" + +// For reading and writing +#include "TTree.h" +#include <sstream> + +/** + * @brief A template defining the interface to a lwtnn network. + * + * Has various subclasses to cover the various formats of lwtnn + * networks. + **/ +class VNetworkLWTNN : public VNetworkBase { +public: + // Not sure if this is needed + using VNetworkBase::VNetworkBase; + + /** + * @brief VNetworkLWTNN copy constructor. + * + * Will copy the variables that would be generated by + * setupPersistedVariables and setupNet. + * Will fail if deleteAllButNet has already been called. + * + * @param copy_from existing network that we are copying + **/ + VNetworkLWTNN(const VNetworkLWTNN ©_from); + + // Ensure we inherit methods of the same name with different signatures + using VNetworkBase::writeNetToTTree; + + /** + * @brief Save the network to a TTree. + * + * All data required to recreate the network object is saved + * into a TTree. The format is not specified. + * + * @param tree The tree to save inside. + **/ + void writeNetToTTree(TTree &tree) override; + + // virtual destructor, to ensure that it is always called, even + // when a base class is deleted via a pointer to a derived class + virtual ~VNetworkLWTNN(); + + /** + * @brief Get rid of any memory objects that arn't needed to run the net. + * + * Minimise memory usage by deleting nay inputs that are + * no longer required to run the compute function. + * Will prevent the net from being saved, if you need + * to call writeNetToTTree that must happen before this is called. + * + **/ + void deleteAllButNet() override; + +protected: + /** + * @brief String containing json input file + * + * Is needed to save the network with writeNetToTTree + * but not needed to run the network with compute. + * Is eraised by deleteAllButNet + * Should be persisted. + **/ + std::string m_json; + + /** + * @brief Write a short description of this net to the string stream. + * + * Outputs a printable name, which maybe a file name, or + * a note specifying that the file has been provided from memory. + * + * @param strm output parameter, to which the description will be written. + **/ + virtual void print(std::ostream &strm) const override; + + /** + * @brief Perform actions that prep data to create the net + * + * Will be called in the base class constructor + * before calling setupNet, but not in the streamer. + * It sets any variables that the sreamer would persist + * when saving or loading to file. + * + **/ + void setupPersistedVariables() override; + +private: + /** + * @brief Fill out m_json from a file provided to the constructor + * + * Provided the string provided as inputFile to the constructor + * is a known file type (root or json) this function retreives + * the json string itself and puts it into m_json. + * + * @param tree_name TTree name to check in when reading root files. + **/ + void fillJson(std::string const &tree_name = m_defaultTreeName); + + /** + * @brief Get json string from TTree. + * + * Given a TTree object, retrive the json string from the + * standard branch. This is used to retrive a network previously + * saved using writeNetToTTree. + * + * @param tree TTree with the json saved inside. + **/ + std::string readStringFromTTree(TTree &tree); + + /** + * @brief Get json string from TTree. + * + * Given a TTree object, retrive the json string from the + * standard branch. This is used to retrive a network previously + * saved using writeNetToTTree. + * + * @param tree TTree with the json saved inside. + **/ + void writeStringToTTree(TTree &tree, std::string json_string); + + /** + * @brief Stores a printable identifyer for the net. Not unique. + **/ + std::string m_printable_name; + + // Suppling a ClassDef for writing to file. + ClassDefOverride(VNetworkLWTNN, 1); +}; + +#endif diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/GenericNetwork_test.ref b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/GenericNetwork_test.ref new file mode 100644 index 0000000000000000000000000000000000000000..62a94481bee6e34f5361df9588a0edbba64b7a68 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/GenericNetwork_test.ref @@ -0,0 +1,75 @@ +Athena::getMessageSvc: WARNING MessageSvc not found, will use std::cout +ISF_FastCaloSim...SUCCESS Testing fastCaloGAN format. + +ISF_FastCaloSim...SUCCESS NetworkInputs, outer size 1 + key->inputs; node_1=1.000000, node_2=2.000000, + + +ISF_FastCaloSim...SUCCESS NetworkOutputs, size 2; +0=3.000000, 1=0.500000, + + +ISF_FastCaloSim...SUCCESS Writing to a root file; with_lwtnn_network.root + +ISF_FastCaloSim...SUCCESS Reading copy written to root + +ISF_FastCaloSim...SUCCESS Running copy from root file + +ISF_FastCaloSim...SUCCESS NetworkOutputs, size 2; +0=3.000000, 1=0.500000, + + +ISF_FastCaloSim...SUCCESS Outputs should before and after writing shoud be identical + +ISF_FastCaloSim...SUCCESS Writing with a streamer to; with_lwtnn_network.root + +ISF_FastCaloSim...SUCCESS Reading streamer copy written to root + +ISF_FastCaloSim...SUCCESS Running copy streamed from root file + +ISF_FastCaloSim...SUCCESS NetworkOutputs, size 2; +0=3.000000, 1=0.500000, + + +ISF_FastCaloSim...SUCCESS Outputs should before and after writing shoud be identical + +ISF_FastCaloSim...SUCCESS Testing fastCaloSim format. + +ISF_FastCaloSim...SUCCESS Made the net. + +ISF_FastCaloSim...SUCCESS Made the inputs. + +ISF_FastCaloSim...SUCCESS NetworkInputs, outer size 2 + key->node_0; 0=0.000000, + key->node_1; 0=0.000000, 1=1.000000, + + +ISF_FastCaloSim...SUCCESS NetworkOutputs, size 3; +0=0.400000, 1=0.800000, 2=1.200000, + + +ISF_FastCaloSim...SUCCESS Writing to a root file; with_lwtnn_graph.root + +ISF_FastCaloSim...SUCCESS Reading copy written to root + +ISF_FastCaloSim...SUCCESS Running copy from root file + +ISF_FastCaloSim...SUCCESS NetworkOutputs, size 3; +0=0.400000, 1=0.800000, 2=1.200000, + + +ISF_FastCaloSim...SUCCESS Outputs should before and after writing shoud be identical + +ISF_FastCaloSim...SUCCESS Writing with a streamer to; with_lwtnn_graph.root + +ISF_FastCaloSim...SUCCESS Reading streamer copy written to root + +ISF_FastCaloSim...SUCCESS Running copy streamed from root file + +ISF_FastCaloSim...SUCCESS NetworkOutputs, size 3; +0=0.400000, 1=0.800000, 2=1.200000, + + +ISF_FastCaloSim...SUCCESS Outputs should before and after writing shoud be identical + +ISF_FastCaloSim...SUCCESS Program ends diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/NormPredExtrapSample/MeanStdDevEnergyFractions_eta_0_5.txt b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/NormPredExtrapSample/MeanStdDevEnergyFractions_eta_0_5.txt new file mode 100644 index 0000000000000000000000000000000000000000..c17c48f439a089a838c03d1a28bc7fd36cbafeae --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/NormPredExtrapSample/MeanStdDevEnergyFractions_eta_0_5.txt @@ -0,0 +1,6 @@ +ef_12 0.0156819433269 0.100666099728 +etrue 138024.69032 456755.044259 +ef_0 0.105855443693 0.211264804792 +ef_1 0.30872837623 0.236349313614 +ef_2 0.564853918311 0.309581277596 +ef_3 0.00488031833227 0.0127889361749 diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/TFCSEnergyAndHitGANV2_test.ref b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/TFCSEnergyAndHitGANV2_test.ref new file mode 100644 index 0000000000000000000000000000000000000000..f69d84e6154bd8e96eb04099030e70ddf3e07b99 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/TFCSEnergyAndHitGANV2_test.ref @@ -0,0 +1,193 @@ +Athena::getMessageSvc: WARNING MessageSvc not found, will use std::cout +ISF_FastCaloSim...SUCCESS Running TFCSEnergyAndHitGANV2 on LWTNN + +ISF_FastCaloSim...SUCCESS Running test on /cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/FastCaloSim/LWTNNsample + + +ISF_FastCaloSim...SUCCESS New particle + +ISF_FastCaloSim...SUCCESS Initialize Networks + +ISF_FastCaloSim... INFO Using FastCaloGANInputFolderName: /cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/FastCaloSim/LWTNNsample +ISF_FastCaloSim... INFO Parameters taken from XML +ISF_FastCaloSim... INFO symmetrisedAlpha: 0 +ISF_FastCaloSim... INFO ganVersion:2 +ISF_FastCaloSim... INFO latentDim: 100 +ISF_FastCaloSim... INFO relevantlayers: 0 1 2 3 12 13 14 +ISF_FastCaloSim... INFO Binning along r for layer 0 +ISF_FastCaloSim... INFO 0,5,10,30,50,100,200,400,600, +ISF_FastCaloSim... INFO Binning along r for layer 1 +ISF_FastCaloSim... INFO 0,1,4,7,10,15,30,50,90,150,200, +ISF_FastCaloSim... INFO Binning along r for layer 2 +ISF_FastCaloSim... INFO 0,5,10,20,30,50,80,130,200,300,400, +ISF_FastCaloSim... INFO Binning along r for layer 3 +ISF_FastCaloSim... INFO 0,50,100,200,400,600, +ISF_FastCaloSim... INFO layer 4 not used +ISF_FastCaloSim... INFO layer 5 not used +ISF_FastCaloSim... INFO layer 6 not used +ISF_FastCaloSim... INFO layer 7 not used +ISF_FastCaloSim... INFO layer 8 not used +ISF_FastCaloSim... INFO layer 9 not used +ISF_FastCaloSim... INFO layer 10 not used +ISF_FastCaloSim... INFO layer 11 not used +ISF_FastCaloSim... INFO Binning along r for layer 12 +ISF_FastCaloSim... INFO 0,10,20,30,50,80,100,130,160,200,250,300,350,400,1000,2000, +ISF_FastCaloSim... INFO Binning along r for layer 13 +ISF_FastCaloSim... INFO 0,10,20,30,50,80,100,130,160,200,250,300,350,400,600,1000,2000, +ISF_FastCaloSim... INFO Binning along r for layer 14 +ISF_FastCaloSim... INFO 0,50,100,150,200,250,300,400,600,1000,2000, +ISF_FastCaloSim... INFO layer 15 not used +ISF_FastCaloSim... INFO layer 16 not used +ISF_FastCaloSim... INFO layer 17 not used +ISF_FastCaloSim... INFO layer 18 not used +ISF_FastCaloSim... INFO layer 19 not used +ISF_FastCaloSim... INFO layer 20 not used +ISF_FastCaloSim... INFO layer 21 not used +ISF_FastCaloSim... INFO layer 22 not used +ISF_FastCaloSim... INFO layer 23 not used +ISF_FastCaloSim... INFO LWTNN Handler parameters +ISF_FastCaloSim... INFO pid: 211 +ISF_FastCaloSim... INFO etaMin:20 +ISF_FastCaloSim... INFO etaMax: 25 +ISF_FastCaloSim... INFO Parameters taken from XML +ISF_FastCaloSim... INFO symmetrisedAlpha: 0 +ISF_FastCaloSim... INFO ganVersion:2 +ISF_FastCaloSim... INFO latentDim: 100 +ISF_FastCaloSim... INFO relevantlayers: 0 1 2 3 12 13 14 +ISF_FastCaloSim... INFO Binning along r for layer 0 +ISF_FastCaloSim... INFO 0,5,10,30,50,100,200,400,600, +ISF_FastCaloSim... INFO Binning along r for layer 1 +ISF_FastCaloSim... INFO 0,1,4,7,10,15,30,50,90,150,200, +ISF_FastCaloSim... INFO Binning along r for layer 2 +ISF_FastCaloSim... INFO 0,5,10,20,30,50,80,130,200,300,400, +ISF_FastCaloSim... INFO Binning along r for layer 3 +ISF_FastCaloSim... INFO 0,50,100,200,400,600, +ISF_FastCaloSim... INFO layer 4 not used +ISF_FastCaloSim... INFO layer 5 not used +ISF_FastCaloSim... INFO layer 6 not used +ISF_FastCaloSim... INFO layer 7 not used +ISF_FastCaloSim... INFO layer 8 not used +ISF_FastCaloSim... INFO layer 9 not used +ISF_FastCaloSim... INFO layer 10 not used +ISF_FastCaloSim... INFO layer 11 not used +ISF_FastCaloSim... INFO Binning along r for layer 12 +ISF_FastCaloSim... INFO 0,10,20,30,50,80,100,130,160,200,250,300,350,400,1000,2000, +ISF_FastCaloSim... INFO Binning along r for layer 13 +ISF_FastCaloSim... INFO 0,10,20,30,50,80,100,130,160,200,250,300,350,400,600,1000,2000, +ISF_FastCaloSim... INFO Binning along r for layer 14 +ISF_FastCaloSim... INFO 0,50,100,150,200,250,300,400,600,1000,2000, +ISF_FastCaloSim... INFO layer 15 not used +ISF_FastCaloSim... INFO layer 16 not used +ISF_FastCaloSim... INFO layer 17 not used +ISF_FastCaloSim... INFO layer 18 not used +ISF_FastCaloSim... INFO layer 19 not used +ISF_FastCaloSim... INFO layer 20 not used +ISF_FastCaloSim... INFO layer 21 not used +ISF_FastCaloSim... INFO layer 22 not used +ISF_FastCaloSim... INFO layer 23 not used +Info in <TCanvas::MakeDefCanvas>: created default TCanvas with name c1 +Warning in <TFile::Append>: Replacing existing TH1: h (Potential memory leak). +Warning in <TFile::Append>: Replacing existing TH1: h (Potential memory leak). +Warning in <TFile::Append>: Replacing existing TH1: h (Potential memory leak). +Warning in <TFile::Append>: Replacing existing TH1: h (Potential memory leak). +Warning in <TFile::Append>: Replacing existing TH1: h (Potential memory leak). +Warning in <TFile::Append>: Replacing existing TH1: h (Potential memory leak). +ISF_FastCaloSim...SUCCESS Filename ends in glob. + +ISF_FastCaloSim...SUCCESS Succedeed in creating LWTNN graph from string + +ISF_FastCaloSim... INFO GAN TFCSEnergyAndHitGANV2 +ISF_FastCaloSim... INFO PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 0 Run for bin 0 +ISF_FastCaloSim... INFO 0 center layer 0 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 0 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 0 Ekin_bin=all ; calosample=0 +ISF_FastCaloSim... INFO 0 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 1 Run for bin 1 +ISF_FastCaloSim... INFO 1 center layer 1 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 1 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 1 Ekin_bin=all ; calosample=1 +ISF_FastCaloSim... INFO 1 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 2 Run for bin 2 +ISF_FastCaloSim... INFO 2 center layer 2 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 2 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 2 Ekin_bin=all ; calosample=2 +ISF_FastCaloSim... INFO 2 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 3 Run for bin 3 +ISF_FastCaloSim... INFO 3 center layer 3 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 3 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 3 Ekin_bin=all ; calosample=3 +ISF_FastCaloSim... INFO 3 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 12Run for bin 12 +ISF_FastCaloSim... INFO 12center layer 12 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 12 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 12 Ekin_bin=all ; calosample=12 +ISF_FastCaloSim... INFO 12 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 13Run for bin 13 +ISF_FastCaloSim... INFO 13center layer 13 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 13 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 13 Ekin_bin=all ; calosample=13 +ISF_FastCaloSim... INFO 13 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 14Run for bin 14 +ISF_FastCaloSim... INFO 14center layer 14 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 14 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 14 Ekin_bin=all ; calosample=14 +ISF_FastCaloSim... INFO 14 Weight for extrapolated position: 0.5 +ISF_FastCaloSim...SUCCESS Writing GAN to unnamed + +TFile** FCSGANtest_unnamed.root + TFile* FCSGANtest_unnamed.root + KEY: TFCSEnergyAndHitGANV2 GAN;1 object title +ISF_FastCaloSim...SUCCESS Open FCSGANtest_unnamed.root + +ISF_FastCaloSim... INFO GAN TFCSEnergyAndHitGANV2 +ISF_FastCaloSim... INFO PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 0 Run for bin 0 +ISF_FastCaloSim... INFO 0 center layer 0 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 0 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 0 Ekin_bin=all ; calosample=0 +ISF_FastCaloSim... INFO 0 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 1 Run for bin 1 +ISF_FastCaloSim... INFO 1 center layer 1 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 1 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 1 Ekin_bin=all ; calosample=1 +ISF_FastCaloSim... INFO 1 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 2 Run for bin 2 +ISF_FastCaloSim... INFO 2 center layer 2 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 2 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 2 Ekin_bin=all ; calosample=2 +ISF_FastCaloSim... INFO 2 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 3 Run for bin 3 +ISF_FastCaloSim... INFO 3 center layer 3 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 3 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 3 Ekin_bin=all ; calosample=3 +ISF_FastCaloSim... INFO 3 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 12Run for bin 12 +ISF_FastCaloSim... INFO 12center layer 12 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 12 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 12 Ekin_bin=all ; calosample=12 +ISF_FastCaloSim... INFO 12 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 13Run for bin 13 +ISF_FastCaloSim... INFO 13center layer 13 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 13 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 13 Ekin_bin=all ; calosample=13 +ISF_FastCaloSim... INFO 13 Weight for extrapolated position: 0.5 +ISF_FastCaloSim... INFO 14Run for bin 14 +ISF_FastCaloSim... INFO 14center layer 14 TFCSCenterPositionCalculation +ISF_FastCaloSim... INFO 14 PDGID: -211, 211 ; Ekin=all ; eta=0.225 [0.2 , 0.25) +ISF_FastCaloSim... INFO 14 Ekin_bin=all ; calosample=14 +ISF_FastCaloSim... INFO 14 Weight for extrapolated position: 0.5 +ISF_FastCaloSim...SUCCESS Before running GAN2->simulate() + +ISF_FastCaloSim... INFO Ebin=-1 E=49599.3 #cells=0 +ISF_FastCaloSim... INFO E0(PreSamplerB)=12.5874 E0/E=0.000253781 +ISF_FastCaloSim... INFO E2(EMB2)=16662.3 E2/E=0.335937 +ISF_FastCaloSim... INFO E3(EMB3)=3954.1 E3/E=0.0797209 +ISF_FastCaloSim... INFO E12(TileBar0)=12609.2 E12/E=0.254221 +ISF_FastCaloSim... INFO E13(TileBar1)=13976.6 E13/E=0.28179 +ISF_FastCaloSim... INFO E14(TileBar2)=2384.6 E14/E=0.0480772 +ISF_FastCaloSim... INFO AuxInfo has 2 elements +ISF_FastCaloSim... INFO 776773022 : bool=23 char= int=23 float=3.22299e-44 double=1.13635e-322 void*=0x17 +ISF_FastCaloSim... INFO 2331614212 : bool=0 char=� int=0 float=0 double=1 void*=0x3ff0000000000000 +ISF_FastCaloSim...SUCCESS Program ends + diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/TFCSPredictExtrapWeights_test.ref b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/TFCSPredictExtrapWeights_test.ref new file mode 100644 index 0000000000000000000000000000000000000000..c5b61cba7ba8cea8d18e5a9903222c82e0bb4e08 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/TFCSPredictExtrapWeights_test.ref @@ -0,0 +1,18 @@ +Athena::getMessageSvc: WARNING MessageSvc not found, will use std::cout +ISF_FastCaloSim...SUCCESS Running TFCSPredictExtrapWeights + +ISF_FastCaloSim...SUCCESS Testing net path ...TNNPredExtrapSample/ and norm path ...ormPredExtrapSample/ + +ISF_FastCaloSim...SUCCESS True energy 5.24288e+08 pdgId 22 eta -0 + +ISF_FastCaloSim...SUCCESS etaBin = 0_5 + +ISF_FastCaloSim... INFO Using FastCaloNNInputFolderName: /cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/FastCaloSim/LWTNNPredExtrapSample/ +ISF_FastCaloSim... INFO For pid: 22 and etaBin0_5, loading json file /cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/FastCaloSim/LWTNNPredExtrapSample/NN_0_5.json +ISF_FastCaloSim...SUCCESS computing with m_nn + +TFile** FCSNNtest.root + TFile* FCSNNtest.root + KEY: TFCSPredictExtrapWeights NN;1 NN +ISF_FastCaloSim...SUCCESS Program ends + diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/toy_network.onnx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/toy_network.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3afff3b68052e7ddbaeab0b649dcb125d1514db5 Binary files /dev/null and b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/share/toy_network.onnx differ diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/MLogging.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/MLogging.cxx index 8c0425f2f0a650329ba8558d579d9dc08d5f5633..df5b572ca946132c626406cfeba534ee1c4a3a1b 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/MLogging.cxx +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/MLogging.cxx @@ -104,8 +104,8 @@ MLogging &MLogging::operator=(const MLogging &rhs) { void MLogging::setLevel(MSG::Level lvl) { lvl = (lvl >= MSG::NUM_LEVELS) ? MSG::ALWAYS - : (lvl < MSG::NIL) ? MSG::NIL - : lvl; + : (lvl < MSG::NIL) ? MSG::NIL + : lvl; msg().setLevel(lvl); } diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSEnergyAndHitGANV2.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSEnergyAndHitGANV2.cxx index 9afa13a3513dc316dfda552c367c970c7b21716e..713e25500034d5f993ab4ac1ce5da827dae00462 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSEnergyAndHitGANV2.cxx +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSEnergyAndHitGANV2.cxx @@ -69,7 +69,8 @@ void TFCSEnergyAndHitGANV2::set_nr_of_init(unsigned int bin, // initialize lwtnn network bool TFCSEnergyAndHitGANV2::initializeNetwork( - int pid, int etaMin, const std::string &FastCaloGANInputFolderName) { + int const &pid, int const &etaMin, + const std::string &FastCaloGANInputFolderName) { // initialize all necessary constants // FIXME eventually all these could be stored in the .json file @@ -77,7 +78,7 @@ bool TFCSEnergyAndHitGANV2::initializeNetwork( ATH_MSG_INFO( "Using FastCaloGANInputFolderName: " << FastCaloGANInputFolderName); // get neural net JSON file as an std::istream object - int etaMax = etaMin + 5; + const int etaMax = etaMin + 5; reset_match_all_pdgid(); set_pdgid(pid); @@ -94,7 +95,7 @@ bool TFCSEnergyAndHitGANV2::initializeNetwork( pidForXml = 211; } - int etaMid = (etaMin + etaMax) / 2; + const int etaMid = (etaMin + etaMax) / 2; m_param.InitialiseFromXML(pidForXml, etaMid, FastCaloGANInputFolderName); m_param.Print(); m_slice = new TFCSGANEtaSlice(pid, etaMin, etaMax, m_param); @@ -139,16 +140,15 @@ bool TFCSEnergyAndHitGANV2::fillEnergy( const TFCSGANEtaSlice::FitResultsPerLayer &fitResults = m_slice->GetFitResults(); // used only if GAN version > 1 - unsigned int energyBins = outputs.size(); - ATH_MSG_VERBOSE("energy voxels size = " << energyBins); + ATH_MSG_DEBUG("energy voxels size = " << outputs.size()); double totalEnergy = 0; - for (unsigned int i = 0; i < energyBins; ++i){ - totalEnergy += outputs.at("out_" + std::to_string(i)); + for (auto output : outputs) { + totalEnergy += output.second; } - if (totalEnergy < 0){ + if (totalEnergy < 0) { ATH_MSG_WARNING("Energy from GAN is negative, skipping particle"); - return false; + return false; } ATH_MSG_VERBOSE("Get binning"); @@ -157,11 +157,11 @@ bool TFCSEnergyAndHitGANV2::fillEnergy( int vox = 0; for (const auto &element : binsInLayers) { - int layer = element.first; + const int layer = element.first; const TH2D *h = &element.second; - int xBinNum = h->GetNbinsX(); - int yBinNum = h->GetNbinsY(); + const int xBinNum = h->GetNbinsX(); + const int yBinNum = h->GetNbinsY(); const TAxis *x = h->GetXaxis(); // If only one bin in r means layer is empty, no value should be added @@ -181,7 +181,7 @@ bool TFCSEnergyAndHitGANV2::fillEnergy( for (int ix = 1; ix <= xBinNum; ++ix) { double binsInAlphaInRBin = GetAlphaBinsForRBin(x, ix, yBinNum); for (int iy = 1; iy <= binsInAlphaInRBin; ++iy) { - double energyInVoxel = outputs.at("out_" + std::to_string(vox)); + const double energyInVoxel = outputs.at(std::to_string(vox)); ATH_MSG_VERBOSE(" Vox " << vox << " energy " << energyInVoxel << " binx " << ix << " biny " << iy); @@ -206,10 +206,10 @@ bool TFCSEnergyAndHitGANV2::fillEnergy( vox = 0; for (const auto &element : binsInLayers) { - int layer = element.first; + const int layer = element.first; const TH2D *h = &element.second; - int xBinNum = h->GetNbinsX(); - int yBinNum = h->GetNbinsY(); + const int xBinNum = h->GetNbinsX(); + const int yBinNum = h->GetNbinsY(); const TAxis *x = h->GetXaxis(); const TAxis *y = h->GetYaxis(); @@ -289,13 +289,13 @@ bool TFCSEnergyAndHitGANV2::fillEnergy( // Now create hits for (int ix = 1; ix <= xBinNum; ++ix) { - int binsInAlphaInRBin = GetAlphaBinsForRBin(x, ix, yBinNum); + const int binsInAlphaInRBin = GetAlphaBinsForRBin(x, ix, yBinNum); // Horrible work around for variable # of bins along alpha direction - int binsToMerge = yBinNum == 32 ? 32 / binsInAlphaInRBin : 1; + const int binsToMerge = yBinNum == 32 ? 32 / binsInAlphaInRBin : 1; for (int iy = 1; iy <= binsInAlphaInRBin; ++iy) { - double energyInVoxel = outputs.at("out_" + std::to_string(vox)); - int lowEdgeIndex = (iy - 1) * binsToMerge + 1; + const double energyInVoxel = outputs.at(std::to_string(vox)); + const int lowEdgeIndex = (iy - 1) * binsToMerge + 1; ATH_MSG_VERBOSE(" Vox " << vox << " energy " << energyInVoxel << " binx " << ix << " biny " << iy); @@ -317,15 +317,15 @@ bool TFCSEnergyAndHitGANV2::fillEnergy( nHitsR = x->GetBinUpEdge(ix) - x->GetBinLowEdge(ix); if (yBinNum == 1) { // nbins in alpha depend on circumference lenght - double r = x->GetBinUpEdge(ix); + const double r = x->GetBinUpEdge(ix); nHitsAlpha = ceil(2 * TMath::Pi() * r / binResolution); } else { // d = 2*r*sin (a/2r) this distance at the upper r must be 1mm for // layer 1 or 5, 5mm otherwise. const TAxis *y = h->GetYaxis(); - double angle = y->GetBinUpEdge(iy) - y->GetBinLowEdge(iy); - double r = x->GetBinUpEdge(ix); - double d = 2 * r * sin(angle / 2 * r); + const double angle = y->GetBinUpEdge(iy) - y->GetBinLowEdge(iy); + const double r = x->GetBinUpEdge(ix); + const double d = 2 * r * sin(angle / 2 * r); nHitsAlpha = ceil(d / binResolution); } @@ -333,7 +333,7 @@ bool TFCSEnergyAndHitGANV2::fillEnergy( // For layers that are not EMB1 or EMEC1 use a maximum of 10 hits // per direction, a higher granularity is needed for the other // layers - int maxNhits = 10; + const int maxNhits = 10; nHitsAlpha = std::min(maxNhits, std::max(1, nHitsAlpha)); nHitsR = std::min(maxNhits, std::max(1, nHitsR)); } @@ -528,8 +528,9 @@ TFCSEnergyAndHitGANV2::simulate(TFCSSimulationState &simulstate, void TFCSEnergyAndHitGANV2::Print(Option_t *option) const { TFCSParametrization::Print(option); TString opt(option); - bool shortprint = opt.Index("short") >= 0; - bool longprint = msgLvl(MSG::DEBUG) || (msgLvl(MSG::INFO) && !shortprint); + const bool shortprint = opt.Index("short") >= 0; + const bool longprint = + msgLvl(MSG::DEBUG) || (msgLvl(MSG::INFO) && !shortprint); TString optprint = opt; optprint.ReplaceAll("short", ""); @@ -563,6 +564,25 @@ void TFCSEnergyAndHitGANV2::unit_test(TFCSSimulationState *simulstate, const TFCSTruthState *truth, const TFCSExtrapolationState *extrapol) { ISF_FCS::MLogging logger; + ATH_MSG_NOCLASS(logger, "Start lwtnn test" << std::endl); + std::string path = "/eos/atlas/atlascerngroupdisk/proj-simul/AF3_Run3/" + "InputsToBigParamFiles/FastCaloGANWeightsVer02/"; + test_path(path, simulstate, truth, extrapol, "lwtnn"); + + ATH_MSG_NOCLASS(logger, "Start onnx test" << std::endl); + path = "/eos/atlas/atlascerngroupdisk/proj-simul/AF3_Run3/" + "InputsToBigParamFiles/FastCaloGANWeightsONNXVer08/"; + test_path(path, simulstate, truth, extrapol, "onnx"); + ATH_MSG_NOCLASS(logger, "Finish all tests" << std::endl); +} + +void TFCSEnergyAndHitGANV2::test_path(std::string path, + TFCSSimulationState *simulstate, + const TFCSTruthState *truth, + const TFCSExtrapolationState *extrapol, + std::string outputname, int pid) { + ISF_FCS::MLogging logger; + ATH_MSG_NOCLASS(logger, "Running test on " << path << std::endl); if (!simulstate) { simulstate = new TFCSSimulationState(); #if defined(__FastCaloSimStandAlone__) @@ -575,7 +595,7 @@ void TFCSEnergyAndHitGANV2::unit_test(TFCSSimulationState *simulstate, ATH_MSG_NOCLASS(logger, "New particle"); TFCSTruthState *t = new TFCSTruthState(); t->SetPtEtaPhiM(65536, 0, 0, 139.6); - t->set_pdgid(211); + t->set_pdgid(pid); truth = t; } if (!extrapol) { @@ -599,21 +619,18 @@ void TFCSEnergyAndHitGANV2::unit_test(TFCSSimulationState *simulstate, } TFCSEnergyAndHitGANV2 GAN("GAN", "GAN"); - GAN.setLevel(MSG::VERBOSE); - int pid = 211; - int etaMin = 20; - int etaMax = etaMin + 5; + GAN.setLevel(MSG::INFO); + const int etaMin = 20; + const int etaMax = etaMin + 5; ATH_MSG_NOCLASS(logger, "Initialize Networks"); - GAN.initializeNetwork(pid, etaMin, - "/eos/atlas/atlascerngroupdisk/proj-simul/AF3_Run3/" - "InputsToBigParamFiles/FastCaloGANWeightsVer02"); + GAN.initializeNetwork(pid, etaMin, path); for (int i = 0; i < 24; ++i) if (GAN.is_match_calosample(i)) { TFCSCenterPositionCalculation *c = new TFCSCenterPositionCalculation( Form("center%d", i), Form("center layer %d", i)); c->set_calosample(i); c->setExtrapWeight(0.5); - c->setLevel(MSG::VERBOSE); + c->setLevel(MSG::INFO); c->set_pdgid(pid); if (pid == 11) c->add_pdgid(-pid); @@ -629,24 +646,28 @@ void TFCSEnergyAndHitGANV2::unit_test(TFCSSimulationState *simulstate, GAN.Print(); - ATH_MSG_NOCLASS(logger, "Writing GAN to FCSGANtest.root"); - TFile *fGAN = TFile::Open("FCSGANtest.root", "recreate"); - GAN.Write(); + ATH_MSG_NOCLASS(logger, "Writing GAN to " << outputname); + const std::string outname = "FCSGANtest_" + outputname + ".root"; + TFile *fGAN = TFile::Open(outname.c_str(), "recreate"); + fGAN->cd(); + // GAN.Write(); + fGAN->WriteObjectAny(&GAN, "TFCSEnergyAndHitGANV2", "GAN"); + fGAN->ls(); fGAN->Close(); - ATH_MSG_NOCLASS(logger, "Open FCSGANtest.root"); - fGAN = TFile::Open("FCSGANtest.root"); + ATH_MSG_NOCLASS(logger, "Open " << outname); + fGAN = TFile::Open(outname.c_str()); TFCSEnergyAndHitGANV2 *GAN2 = (TFCSEnergyAndHitGANV2 *)(fGAN->Get("GAN")); + GAN2->setLevel(MSG::INFO); GAN2->Print(); - GAN2->setLevel(MSG::DEBUG); ATH_MSG_NOCLASS(logger, "Before running GAN2->simulate()"); GAN2->simulate(*simulstate, truth, extrapol); simulstate->Print(); } -int TFCSEnergyAndHitGANV2::GetBinsInFours(double bins) { +int TFCSEnergyAndHitGANV2::GetBinsInFours(double const &bins) { if (bins < 4) return 4; else if (bins < 8) @@ -661,14 +682,15 @@ int TFCSEnergyAndHitGANV2::GetAlphaBinsForRBin(const TAxis *x, int ix, int yBinNum) const { double binsInAlphaInRBin = yBinNum; if (yBinNum == 32) { - double widthX = x->GetBinWidth(ix); - double radious = x->GetBinCenter(ix); + ATH_MSG_DEBUG("yBinNum is special value 32"); + const double widthX = x->GetBinWidth(ix); + const double radious = x->GetBinCenter(ix); double circumference = radious * 2 * TMath::Pi(); if (m_param.IsSymmetrisedAlpha()) { circumference = radious * TMath::Pi(); } - double bins = circumference / widthX; + const double bins = circumference / widthX; binsInAlphaInRBin = GetBinsInFours(bins); ATH_MSG_DEBUG("Bin in alpha: " << binsInAlphaInRBin << " for r bin: " << ix << " (" << x->GetBinLowEdge(ix) << "-" diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSGANEtaSlice.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSGANEtaSlice.cxx index d18c6e892d27037d197d64a512db96dbcdd07174..4d826166f1898d47bca9eecbfa827db998d1a4bc 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSGANEtaSlice.cxx +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSGANEtaSlice.cxx @@ -9,6 +9,8 @@ // class header include #include "ISF_FastCaloSimEvent/TFCSGANEtaSlice.h" +#include "ISF_FastCaloSimEvent/TFCSNetworkFactory.h" + #include "CLHEP/Random/RandGauss.h" #include "TFitResult.h" @@ -28,32 +30,38 @@ TFCSGANEtaSlice::TFCSGANEtaSlice() {} TFCSGANEtaSlice::TFCSGANEtaSlice(int pid, int etaMin, int etaMax, const TFCSGANXMLParameters ¶m) - : m_pid (pid), - m_etaMin (etaMin), - m_etaMax (etaMax), - m_param (param) -{ -} + : m_pid(pid), m_etaMin(etaMin), m_etaMax(etaMax), m_param(param) {} TFCSGANEtaSlice::~TFCSGANEtaSlice() { - if (m_gan_all != nullptr) { - delete m_gan_all; - } - if (m_gan_low != nullptr) { - delete m_gan_low; - } - if (m_gan_high != nullptr) { - delete m_gan_high; - } + // Deleting a nullptr is a noop + delete m_gan_all; + delete m_gan_low; + delete m_gan_high; +} + +VNetworkBase *TFCSGANEtaSlice::GetNetAll() const { + if (m_net_all != nullptr) + return m_net_all.get(); + return m_gan_all; +} +VNetworkBase *TFCSGANEtaSlice::GetNetLow() const { + if (m_net_low != nullptr) + return m_net_low.get(); + return m_gan_low; +} +VNetworkBase *TFCSGANEtaSlice::GetNetHigh() const { + if (m_net_high != nullptr) + return m_net_high.get(); + return m_gan_high; } bool TFCSGANEtaSlice::IsGanCorrectlyLoaded() const { if (m_pid == 211 || m_pid == 2212) { - if (m_gan_all == nullptr) { + if (GetNetAll() == nullptr) { return false; } } else { - if (m_gan_high == nullptr || m_gan_low == nullptr) { + if (GetNetHigh() == nullptr || GetNetLow() == nullptr) { return false; } } @@ -61,43 +69,47 @@ bool TFCSGANEtaSlice::IsGanCorrectlyLoaded() const { } bool TFCSGANEtaSlice::LoadGAN() { + // Now load new data std::string inputFileName; CalculateMeanPointFromDistributionOfR(); ExtractExtrapolatorMeansFromInputs(); + bool success = true; + if (m_pid == 211) { inputFileName = m_param.GetInputFolder() + "/neural_net_" + std::to_string(m_pid) + "_eta_" + std::to_string(m_etaMin) + - "_" + std::to_string(m_etaMax) + "_All.json"; + "_" + std::to_string(m_etaMax) + "_All.*"; ATH_MSG_DEBUG("Gan input file name " << inputFileName); - m_gan_all = new TFCSGANLWTNNHandler(); - return m_gan_all->LoadGAN(inputFileName); + m_net_all = TFCSNetworkFactory::create(inputFileName); + if (m_net_all == nullptr) + success = false; } else if (m_pid == 2212) { inputFileName = m_param.GetInputFolder() + "/neural_net_" + std::to_string(m_pid) + "_eta_" + std::to_string(m_etaMin) + - "_" + std::to_string(m_etaMax) + "_High10.json"; + "_" + std::to_string(m_etaMax) + "_High10.*"; ATH_MSG_DEBUG("Gan input file name " << inputFileName); - m_gan_all = new TFCSGANLWTNNHandler(); - return m_gan_all->LoadGAN(inputFileName); + m_net_all = TFCSNetworkFactory::create(inputFileName); + if (m_net_all == nullptr) + success = false; } else { - bool returnValue; inputFileName = m_param.GetInputFolder() + "/neural_net_" + std::to_string(m_pid) + "_eta_" + std::to_string(m_etaMin) + - "_" + std::to_string(m_etaMax) + "_High12.json"; - m_gan_high = new TFCSGANLWTNNHandler(); - returnValue = m_gan_high->LoadGAN(inputFileName); - if (!returnValue) { - return returnValue; - } + "_" + std::to_string(m_etaMax) + "_High12.*"; + ATH_MSG_DEBUG("Gan input file name " << inputFileName); + m_net_high = TFCSNetworkFactory::create(inputFileName); + if (m_net_high == nullptr) + success = false; inputFileName = m_param.GetInputFolder() + "/neural_net_" + std::to_string(m_pid) + "_eta_" + std::to_string(m_etaMin) + - "_" + std::to_string(m_etaMax) + "_UltraLow12.json"; - m_gan_low = new TFCSGANLWTNNHandler(); - return m_gan_low->LoadGAN(inputFileName); - return true; + "_" + std::to_string(m_etaMax) + "_UltraLow12.*"; + m_net_low = TFCSNetworkFactory::create(inputFileName); + if (m_net_low == nullptr) + success = false; } + return success; } void TFCSGANEtaSlice::CalculateMeanPointFromDistributionOfR() { @@ -158,7 +170,7 @@ void TFCSGANEtaSlice::ExtractExtrapolatorMeansFromInputs() { } } -TFCSGANEtaSlice::NetworkOutputs +VNetworkBase::NetworkOutputs TFCSGANEtaSlice::GetNetworkOutputs(const TFCSTruthState *truth, const TFCSExtrapolationState *extrapol, TFCSSimulationState simulstate) const { @@ -194,7 +206,7 @@ TFCSGANEtaSlice::GetNetworkOutputs(const TFCSTruthState *truth, for (int i = 0; i < m_param.GetLatentSpaceSize(); i++) { randUniformZ = CLHEP::RandGauss::shoot(simulstate.randomEngine(), 0.5, 0.5); - inputs["node_0"].insert(std::pair<std::string, double>( + inputs["Noise"].insert(std::pair<std::string, double>( "variable_" + std::to_string(i), randUniformZ)); } @@ -204,35 +216,43 @@ TFCSGANEtaSlice::GetNetworkOutputs(const TFCSTruthState *truth, // truth->P() <<" mass:" << truth->M() <<" Ekin_off:" << // truth->Ekin_off() << " Ekin_min:"<<Ekin_min<<" // Ekin_max:"<<Ekin_max); - // inputs["node_1"].insert ( std::pair<std::string,double>("variable_0", + // inputs["mycond"].insert ( std::pair<std::string,double>("variable_0", // truth->Ekin()/(std::pow(2,maxExp))) ); //Old conditioning using linear // interpolation, now use logaritminc interpolation - inputs["node_1"].insert(std::pair<std::string, double>( + inputs["mycond"].insert(std::pair<std::string, double>( "variable_0", log(truth->Ekin() / Ekin_min) / log(Ekin_max / Ekin_min))); if (m_param.GetGANVersion() >= 2) { if (false) { // conditioning on eta, should only be needed in transition // regions and added only to the GANs that use it, for now all // GANs have 3 conditioning inputs so filling zeros - inputs["node_1"].insert(std::pair<std::string, double>( + inputs["mycond"].insert(std::pair<std::string, double>( "variable_1", fabs(extrapol->IDCaloBoundary_eta()))); } else { - inputs["node_1"].insert(std::pair<std::string, double>("variable_1", 0)); + inputs["mycond"].insert(std::pair<std::string, double>("variable_1", 0)); } } + VNetworkBase::NetworkOutputs outputs; if (m_param.GetGANVersion() == 1 || m_pid == 211 || m_pid == 2212) { - return m_gan_all->GetGraph()->compute(inputs); + outputs = GetNetAll()->compute(inputs); } else { if (truth->P() > 4096) { // This is the momentum, not the energy, because the split is // based on the samples which are produced with the momentum ATH_MSG_DEBUG("Computing outputs given inputs for high"); - return m_gan_high->GetGraph()->compute(inputs); + outputs = GetNetHigh()->compute(inputs); } else { - return m_gan_low->GetGraph()->compute(inputs); + outputs = GetNetLow()->compute(inputs); } } + ATH_MSG_DEBUG("Start Network inputs ~~~~~~~~"); + ATH_MSG_DEBUG(VNetworkBase::representNetworkInputs(inputs, 10000)); + ATH_MSG_DEBUG("End Network inputs ~~~~~~~~"); + ATH_MSG_DEBUG("Start Network outputs ~~~~~~~~"); + ATH_MSG_DEBUG(VNetworkBase::representNetworkOutputs(outputs, 10000)); + ATH_MSG_DEBUG("End Network outputs ~~~~~~~~"); + return outputs; } void TFCSGANEtaSlice::Print() const { diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSGANLWTNNHandler.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSGANLWTNNHandler.cxx index 082c27be7f2b10bde0cdb29b24863597a19bdcff..9b5a71587b72a5f97d9df808230b3965ac1fbcd2 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSGANLWTNNHandler.cxx +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSGANLWTNNHandler.cxx @@ -1,70 +1,114 @@ -/* - Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration -*/ - -/////////////////////////////////////////////////////////////////// -// TFCSGANLWTNNHandler.cxx, (c) ATLAS Detector software // -/////////////////////////////////////////////////////////////////// - -// class header include #include "ISF_FastCaloSimEvent/TFCSGANLWTNNHandler.h" -#include "TFile.h" //Needed for TBuffer +// For writing to a tree +#include "TBranch.h" +#include "TTree.h" -#include <iostream> -#include <fstream> -#include <string> -#include <sstream> +// LWTNN +#include "lwtnn/LightweightGraph.hh" +#include "lwtnn/parse_json.hh" -TFCSGANLWTNNHandler::TFCSGANLWTNNHandler() { m_graph = nullptr; } +TFCSGANLWTNNHandler::TFCSGANLWTNNHandler(const std::string &inputFile) + : VNetworkLWTNN(inputFile) { + ATH_MSG_DEBUG("Setting up from inputFile."); + setupPersistedVariables(); + setupNet(); +}; -TFCSGANLWTNNHandler::~TFCSGANLWTNNHandler() { - if (m_input != nullptr) { +TFCSGANLWTNNHandler::TFCSGANLWTNNHandler(const TFCSGANLWTNNHandler ©_from) + : VNetworkLWTNN(copy_from) { + // Cannot take copies of lwt::LightweightGraph + // (copy constructor disabled) + ATH_MSG_DEBUG("Making a new m_lwtnn_graph for copied network"); + std::stringstream json_stream(m_json); + const lwt::GraphConfig config = lwt::parse_json_graph(json_stream); + m_lwtnn_graph = std::make_unique<lwt::LightweightGraph>(config); + m_outputLayers = copy_from.m_outputLayers; +}; + +void TFCSGANLWTNNHandler::setupNet() { + // Backcompatability, previous versions stored this in m_input + if (m_json.length() == 0 && m_input != nullptr) { + m_json = *m_input; delete m_input; + m_input = nullptr; } - if (m_graph != nullptr) { - delete m_graph; - } -} - -bool TFCSGANLWTNNHandler::LoadGAN(const std::string &inputFile) { - std::ifstream input(inputFile); - std::stringstream sin; - sin << input.rdbuf(); - input.close(); // build the graph - auto config = lwt::parse_json_graph(sin); - m_graph = new lwt::LightweightGraph(config); - if (m_graph == nullptr) { - return false; - } - if (m_input != nullptr) { - delete m_input; + ATH_MSG_VERBOSE("m_json has size " << m_json.length()); + ATH_MSG_DEBUG("m_json starts with " << m_json.substr(0, 10)); + ATH_MSG_VERBOSE("Reading the m_json string stream into a graph network"); + std::stringstream json_stream(m_json); + const lwt::GraphConfig config = lwt::parse_json_graph(json_stream); + m_lwtnn_graph = std::make_unique<lwt::LightweightGraph>(config); + // Get the output layers + ATH_MSG_VERBOSE("Getting output layers for neural network"); + for (auto node : config.outputs) { + const std::string node_name = node.first; + const lwt::OutputNodeConfig node_config = node.second; + for (std::string label : node_config.labels) { + ATH_MSG_VERBOSE("Found output layer called " << node_name << "_" + << label); + m_outputLayers.push_back(node_name + "_" + label); + } + }; + ATH_MSG_VERBOSE("Removing prefix from stored layers."); + removePrefixes(m_outputLayers); + ATH_MSG_VERBOSE("Finished output nodes."); +}; + +std::vector<std::string> TFCSGANLWTNNHandler::getOutputLayers() const { + return m_outputLayers; +}; + +// This is implement the specific compute, and ensure the output is returned in +// regular format. For LWTNN, that's easy. +TFCSGANLWTNNHandler::NetworkOutputs TFCSGANLWTNNHandler::compute( + TFCSGANLWTNNHandler::NetworkInputs const &inputs) const { + ATH_MSG_DEBUG("Running computation on LWTNN graph network"); + NetworkInputs local_copy = inputs; + if (inputs.find("Noise") != inputs.end()) { + // Graphs from EnergyAndHitsGANV2 have the local_copy encoded as Noise = + // node_0 and mycond = node_1 + auto noiseNode = local_copy.extract("Noise"); + noiseNode.key() = "node_0"; + local_copy.insert(std::move(noiseNode)); + auto mycondNode = local_copy.extract("mycond"); + mycondNode.key() = "node_1"; + local_copy.insert(std::move(mycondNode)); } - m_input = new std::string(sin.str()); - return true; -} + // now we can compute + TFCSGANLWTNNHandler::NetworkOutputs outputs = + m_lwtnn_graph->compute(local_copy); + removePrefixes(outputs); + ATH_MSG_DEBUG("Computation on LWTNN graph network done, returning."); + return outputs; +}; -void TFCSGANLWTNNHandler::Streamer(TBuffer &R__b) { - // Stream an object of class TFCSGANLWTNNHandler - if (R__b.IsReading()) { - R__b.ReadClassBuffer(TFCSGANLWTNNHandler::Class(), this); - if (m_graph != nullptr) { - delete m_graph; - m_graph = nullptr; - } - if (m_input != nullptr) { - std::stringstream sin; - sin.str(*m_input); - auto config = lwt::parse_json_graph(sin); - m_graph = new lwt::LightweightGraph(config); - } +// Giving this it's own streamer to call setupNet +void TFCSGANLWTNNHandler::Streamer(TBuffer &buf) { + ATH_MSG_DEBUG("In streamer of " << __FILE__); + if (buf.IsReading()) { + ATH_MSG_DEBUG("Reading buffer in TFCSGANLWTNNHandler "); + // Get the persisted variables filled in + TFCSGANLWTNNHandler::Class()->ReadBuffer(buf, this); + ATH_MSG_DEBUG("m_json has size " << m_json.length()); + ATH_MSG_DEBUG("m_json starts with " << m_json.substr(0, 10)); + // Setup the net, creating the non persisted variables + // exactly as in the constructor + this->setupNet(); #ifndef __FastCaloSimStandAlone__ - // When running inside Athena, delete config to free the memory - delete m_input; - m_input = nullptr; + // When running inside Athena, delete persisted information + // to conserve memory + this->deleteAllButNet(); #endif } else { - R__b.WriteClassBuffer(TFCSGANLWTNNHandler::Class(), this); - } -} + if (!m_json.empty()) { + ATH_MSG_DEBUG("Writing buffer in TFCSGANLWTNNHandler "); + } else { + ATH_MSG_WARNING( + "Writing buffer in TFCSGANLWTNNHandler, but m_json is empty"); + }; + // Persist variables + TFCSGANLWTNNHandler::Class()->WriteBuffer(buf, this); + }; +}; diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSNetworkFactory.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSNetworkFactory.cxx new file mode 100644 index 0000000000000000000000000000000000000000..3c5be96b7c11e450d53df6cea26595597e7e7259 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSNetworkFactory.cxx @@ -0,0 +1,148 @@ +#include "ISF_FastCaloSimEvent/TFCSNetworkFactory.h" +#include "ISF_FastCaloSimEvent/TFCSGANLWTNNHandler.h" +#include "ISF_FastCaloSimEvent/TFCSONNXHandler.h" +#include "ISF_FastCaloSimEvent/TFCSSimpleLWTNNHandler.h" +#include "ISF_FastCaloSimEvent/VNetworkBase.h" + +#include <boost/property_tree/ptree.hpp> +#include <fstream> // For checking if files exist +#include <stdexcept> + +// For messaging +#include "ISF_FastCaloSimEvent/MLogging.h" +using ISF_FCS::MLogging; + +void TFCSNetworkFactory::resolveGlobs(std::string &filename) { + ISF_FCS::MLogging logger; + const std::string ending = ".*"; + const int ending_len = ending.length(); + const int filename_len = filename.length(); + if (filename_len < ending_len) { + ATH_MSG_NOCLASS(logger, "Filename is implausably short."); + } else if (0 == + filename.compare(filename_len - ending_len, ending_len, ending)) { + ATH_MSG_NOCLASS(logger, "Filename ends in glob."); + // Remove the glob + filename.pop_back(); + if (std::filesystem::exists(filename + "onnx")) { + filename += "onnx"; + } else if (std::filesystem::exists(filename + "json")) { + filename += std::string("json"); + } else { + throw std::invalid_argument("No file found matching globbed filename " + + filename); + }; + }; +}; + +bool TFCSNetworkFactory::isOnnxFile(std::string const &filename) { + ISF_FCS::MLogging logger; + const std::string ending = ".onnx"; + const int ending_len = ending.length(); + const int filename_len = filename.length(); + bool is_onnx; + if (filename_len < ending_len) { + is_onnx = false; + } else { + is_onnx = + (0 == filename.compare(filename_len - ending_len, ending_len, ending)); + }; + return is_onnx; +}; + +std::unique_ptr<VNetworkBase> +TFCSNetworkFactory::create(std::vector<char> const &input) { + ISF_FCS::MLogging logger; + ATH_MSG_NOCLASS(logger, "Directly creating ONNX network from bytes length " + << input.size()); + std::unique_ptr<VNetworkBase> created(new TFCSONNXHandler(input)); + return created; +}; + +std::unique_ptr<VNetworkBase> TFCSNetworkFactory::create(std::string input) { + ISF_FCS::MLogging logger; + resolveGlobs(input); + if (VNetworkBase::isFile(input) && isOnnxFile(input)) { + ATH_MSG_NOCLASS(logger, "Creating ONNX network from file ..." + << input.substr(input.length() - 10)); + std::unique_ptr<VNetworkBase> created(new TFCSONNXHandler(input)); + return created; + } else { + try { + std::unique_ptr<VNetworkBase> created(new TFCSSimpleLWTNNHandler(input)); + ATH_MSG_NOCLASS(logger, + "Succedeed in creating LWTNN nn from string starting " + << input.substr(0, 10)); + return created; + } catch (const boost::property_tree::ptree_bad_path &e) { + // If we get this error, it was actually a graph, not a NeuralNetwork + std::unique_ptr<VNetworkBase> created(new TFCSGANLWTNNHandler(input)); + ATH_MSG_NOCLASS(logger, "Succedeed in creating LWTNN graph from string"); + return created; + }; + }; +}; + +std::unique_ptr<VNetworkBase> TFCSNetworkFactory::create(std::string input, + bool graph_form) { + ISF_FCS::MLogging logger; + resolveGlobs(input); + if (VNetworkBase::isFile(input) && isOnnxFile(input)) { + ATH_MSG_NOCLASS(logger, "Creating ONNX network from file ..." + << input.substr(input.length() - 10)); + std::unique_ptr<VNetworkBase> created(new TFCSONNXHandler(input)); + return created; + } else if (graph_form) { + ATH_MSG_NOCLASS(logger, "Creating LWTNN graph from string"); + std::unique_ptr<VNetworkBase> created(new TFCSGANLWTNNHandler(input)); + return created; + } else { + std::unique_ptr<VNetworkBase> created(new TFCSSimpleLWTNNHandler(input)); + ATH_MSG_NOCLASS(logger, "Creating LWTNN nn from string"); + return created; + }; +}; + +std::unique_ptr<VNetworkBase> +TFCSNetworkFactory::create(std::vector<char> const &vector_input, + std::string string_input) { + ISF_FCS::MLogging logger; + ATH_MSG_NOCLASS(logger, "Given both bytes and a string to create an nn."); + resolveGlobs(string_input); + if (vector_input.size() > 0) { + ATH_MSG_NOCLASS(logger, + "Bytes contains data, size=" << vector_input.size() + << ", creating from bytes."); + return create(vector_input); + } else if (string_input.length() > 0) { + ATH_MSG_NOCLASS(logger, "No data in bytes, string contains data, " + << "creating from string."); + return create(string_input); + } else { + throw std::invalid_argument( + "Neither vector_input nor string_input contain data"); + }; +}; + +std::unique_ptr<VNetworkBase> +TFCSNetworkFactory::create(std::vector<char> const &vector_input, + std::string string_input, bool graph_form) { + ISF_FCS::MLogging logger; + ATH_MSG_NOCLASS( + logger, + "Given both bytes, a string and graph form sepcified to create an nn."); + resolveGlobs(string_input); + if (vector_input.size() > 0) { + ATH_MSG_NOCLASS(logger, + "Bytes contains data, size=" << vector_input.size() + << ", creating from bytes."); + return create(vector_input); + } else if (string_input.length() > 0) { + ATH_MSG_NOCLASS(logger, "No data in bytes, string contains data, " + << "creating from string."); + return create(string_input, graph_form); + } else { + throw std::invalid_argument( + "Neither vector_input nor string_input contain data"); + }; +}; diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSONNXHandler.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSONNXHandler.cxx new file mode 100644 index 0000000000000000000000000000000000000000..d12b7d15f5daf27ad0dbcaea978e64efda70bde6 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSONNXHandler.cxx @@ -0,0 +1,463 @@ +// See headder file for documentation. +#include "ISF_FastCaloSimEvent/TFCSONNXHandler.h" + +// For reading the binary onnx files +#include <fstream> +#include <iterator> +#include <vector> + +// ONNX Runtime include(s). +#include <core/session/onnxruntime_cxx_api.h> + +// For reading and writing to root +#include "TBranch.h" +#include "TFile.h" +#include "TTree.h" + +// For throwing exceptions +#include <stdexcept> + +TFCSONNXHandler::TFCSONNXHandler(const std::string &inputFile) + : VNetworkBase(inputFile) { + ATH_MSG_INFO("Setting up from inputFile."); + setupPersistedVariables(); + setupNet(); + ATH_MSG_DEBUG("Setup from file complete"); +}; + +TFCSONNXHandler::TFCSONNXHandler(const std::vector<char> &bytes) + : m_bytes(bytes) { + ATH_MSG_INFO("Given onnx session bytes as input."); + // The super constructor got no inputFile, + // so it won't call setupNet itself + setupNet(); + ATH_MSG_DEBUG("Setup from session complete"); +}; + +TFCSONNXHandler::TFCSONNXHandler(const TFCSONNXHandler ©_from) + : VNetworkBase(copy_from) { + ATH_MSG_DEBUG("TFCSONNXHandler copy construtor called"); + m_bytes = copy_from.m_bytes; + // Cannot copy a session + // m_session = copy_from.m_session; + // But can read it from bytes + readSerializedSession(); + m_inputNodeNames = copy_from.m_inputNodeNames; + m_inputNodeDims = copy_from.m_inputNodeDims; + m_outputNodeNames = copy_from.m_outputNodeNames; + m_outputNodeDims = copy_from.m_outputNodeDims; + m_outputLayers = copy_from.m_outputLayers; +}; + +TFCSONNXHandler::NetworkOutputs +TFCSONNXHandler::compute(TFCSONNXHandler::NetworkInputs const &inputs) const { + return m_computeLambda(inputs); +}; + +// Writing out to ttrees +void TFCSONNXHandler::writeNetToTTree(TTree &tree) { + ATH_MSG_DEBUG("TFCSONNXHandler writing net to tree."); + this->writeBytesToTTree(tree, m_bytes); +}; + +std::vector<std::string> TFCSONNXHandler::getOutputLayers() const { + ATH_MSG_DEBUG("TFCSONNXHandler output layers requested."); + return m_outputLayers; +}; + +void TFCSONNXHandler::deleteAllButNet() { + // As we don't copy the bytes, and the inputFile + // is at most a name, nothing is needed here. + ATH_MSG_DEBUG("Deleted nothing for ONNX."); +}; + +void TFCSONNXHandler::print(std::ostream &strm) const { + if (m_inputFile.empty()) { + strm << "Unknown network"; + } else { + strm << m_inputFile; + }; + strm << "\nHas input nodes (name:dimensions);\n"; + for (size_t inp_n = 0; inp_n < m_inputNodeNames.size(); inp_n++) { + strm << "\t" << m_inputNodeNames[inp_n] << ":["; + for (int dim : m_inputNodeDims[inp_n]) { + strm << " " << dim << ","; + }; + strm << "]\n"; + }; + strm << "\nHas output nodes (name:dimensions);\n"; + for (size_t out_n = 0; out_n < m_outputNodeNames.size(); out_n++) { + strm << "\t" << m_outputNodeNames[out_n] << ":["; + for (int dim : m_outputNodeDims[out_n]) { + strm << " " << dim << ","; + }; + strm << "]\n"; + }; +}; + +void TFCSONNXHandler::setupPersistedVariables() { + ATH_MSG_DEBUG("Setting up persisted variables for ONNX network."); + // depending which constructor was called, + // bytes may already be filled + if (m_bytes.empty()) { + m_bytes = getSerializedSession(); + }; + ATH_MSG_DEBUG("Setup persisted variables for ONNX network."); +}; + +void TFCSONNXHandler::setupNet() { + // From + // https://gitlab.cern.ch/atlas/athena/-/blob/master/Control/AthOnnxruntimeUtils/AthOnnxruntimeUtils/OnnxUtils.h + // m_session = AthONNX::CreateORTSession(inputFile); + // This segfaults. + + // TODO; should I be using m_session_options? see + // https://github.com/microsoft/onnxruntime-inference-examples/blob/2b42b442526b9454d1e2d08caeb403e28a71da5f/c_cxx/squeezenet/main.cpp#L71 + ATH_MSG_INFO("Setting up ONNX session."); + this->readSerializedSession(); + + // Need the type from the first node (which will be used to set + // just set it to undefined to avoid not initialised warnings + ONNXTensorElementDataType first_input_type = + ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + // iterate over all input nodes + ATH_MSG_DEBUG("Getting input nodes."); + const int num_input_nodes = m_session->GetInputCount(); + Ort::AllocatorWithDefaultOptions allocator; + for (int i = 0; i < num_input_nodes; i++) { + +#if ORT_API_VERSION > 11 + Ort::AllocatedStringPtr node_names = m_session->GetInputNameAllocated(i, allocator); + m_storeInputNodeNames.push_back(std::move(node_names)); + const char *input_name = m_storeInputNodeNames.back().get(); +#else + const char *input_name = m_session->GetInputName(i, allocator); +#endif + m_inputNodeNames.push_back(input_name); + ATH_MSG_VERBOSE("Found input node named " << input_name); + + Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i); + + // For some reason unless auto is used as the return type + // this causes a segfault once the loop ends.... + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + if (i == 0) + first_input_type = tensor_info.GetElementType(); + // Check the type has not changed + if (tensor_info.GetElementType() != first_input_type) { + ATH_MSG_ERROR("First type was " << first_input_type << ". In node " << i + << " found type " + << tensor_info.GetElementType()); + throw std::runtime_error("Networks with varying input types not " + "yet impelmented in TFCSONNXHandler."); + }; + + std::vector<int64_t> recieved_dimension = tensor_info.GetShape(); + ATH_MSG_VERBOSE("There are " << recieved_dimension.size() + << " dimensions."); + // This vector sometimes includes a symbolic dimension + // which is represented by -1 + // A symbolic dimension is usually a conversion error, + // from a numpy array with a shape like (None, 7), + // in which case it's safe to treat it as having + // dimension 1. + std::vector<int64_t> dimension_of_node; + for (int64_t node_dim : recieved_dimension) { + if (node_dim < 1) { + ATH_MSG_WARNING("Found symbolic dimension " + << node_dim << " in node named " << input_name + << ". Will treat this as dimension 1."); + dimension_of_node.push_back(1); + } else { + dimension_of_node.push_back(node_dim); + }; + }; + m_inputNodeDims.push_back(dimension_of_node); + }; + ATH_MSG_DEBUG("Finished looping on inputs."); + + // Outputs + // Store the type from the first node (which will be used to set + // m_computeLambda) + ONNXTensorElementDataType first_output_type = + ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + // iterate over all output nodes + int num_output_nodes = m_session->GetOutputCount(); + ATH_MSG_DEBUG("Getting " << num_output_nodes << " output nodes."); + for (int i = 0; i < num_output_nodes; i++) { +#if ORT_API_VERSION > 11 + Ort::AllocatedStringPtr node_names = m_session->GetOutputNameAllocated(i, allocator); + m_storeOutputNodeNames.push_back(std::move(node_names)); + const char *output_name = m_storeOutputNodeNames.back().get(); +#else + const char *output_name = m_session->GetOutputName(i, allocator); +#endif + m_outputNodeNames.push_back(output_name); + ATH_MSG_VERBOSE("Found output node named " << output_name); + + const Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + if (i == 0) + first_output_type = tensor_info.GetElementType(); + + // Check the type has not changed + if (tensor_info.GetElementType() != first_output_type) { + ATH_MSG_ERROR("First type was " << first_output_type << ". In node " << i + << " found type " + << tensor_info.GetElementType()); + throw std::runtime_error("Networks with varying output types not " + "yet impelmented in TFCSONNXHandler."); + }; + + const std::vector<int64_t> recieved_dimension = tensor_info.GetShape(); + ATH_MSG_VERBOSE("There are " << recieved_dimension.size() + << " dimensions."); + // Again, check for sybolic dimensions + std::vector<int64_t> dimension_of_node; + int node_size = 1; + for (int64_t node_dim : recieved_dimension) { + if (node_dim < 1) { + ATH_MSG_WARNING("Found symbolic dimension " + << node_dim << " in node named " << output_name + << ". Will treat this as dimension 1."); + dimension_of_node.push_back(1); + } else { + dimension_of_node.push_back(node_dim); + node_size *= node_dim; + }; + }; + m_outputNodeDims.push_back(dimension_of_node); + m_outputNodeSize.push_back(node_size); + + // The outputs are treated as a flat vector + for (int part_n = 0; part_n < node_size; part_n++) { + // compose the output name + std::string layer_name = + std::string(output_name) + "_" + std::to_string(part_n); + ATH_MSG_VERBOSE("Found output layer named " << layer_name); + m_outputLayers.push_back(layer_name); + } + } + ATH_MSG_DEBUG("Removing prefix from stored layers."); + removePrefixes(m_outputLayers); + ATH_MSG_DEBUG("Finished output nodes."); + + ATH_MSG_DEBUG("Setting up m_computeLambda with input type " + << first_input_type << " and output type " + << first_output_type); + if (first_input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT && + first_output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + // gotta capture this in the lambda so it can access class methods + m_computeLambda = [this](NetworkInputs const &inputs) { + return computeTemplate<float, float>(inputs); + }; + } else if (first_input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE && + first_output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + m_computeLambda = [this](NetworkInputs const &inputs) { + return computeTemplate<double, double>(inputs); + }; + } else { + throw std::runtime_error("Haven't yet implemented that combination of " + "input and output types as a subclass of VState."); + }; + ATH_MSG_DEBUG("Finished setting lambda function."); +}; + +// Needs to also work if the input file is a root file +std::vector<char> TFCSONNXHandler::getSerializedSession(std::string tree_name) { + ATH_MSG_DEBUG("Getting serialized session for ONNX network."); + + if (this->isRootFile()) { + ATH_MSG_INFO("Reading bytes from root file."); + TFile tfile(this->m_inputFile.c_str(), "READ"); + TTree *tree = (TTree *)tfile.Get(tree_name.c_str()); + std::vector<char> bytes = this->readBytesFromTTree(*tree); + ATH_MSG_DEBUG("Found bytes size " << bytes.size()); + return bytes; + } else { + ATH_MSG_INFO("Reading bytes from text file."); + // see https://stackoverflow.com/a/50317432 + std::ifstream input(this->m_inputFile, std::ios::binary); + + std::vector<char> bytes((std::istreambuf_iterator<char>(input)), + (std::istreambuf_iterator<char>())); + + input.close(); + ATH_MSG_DEBUG("Found bytes size " << bytes.size()); + return bytes; + } +}; + +std::vector<char> TFCSONNXHandler::readBytesFromTTree(TTree &tree) { + ATH_MSG_DEBUG("TFCSONNXHandler reading bytes from tree."); + std::vector<char> bytes; + char data; + tree.SetBranchAddress("serialized_m_session", &data); + for (int i = 0; tree.LoadTree(i) >= 0; i++) { + tree.GetEntry(i); + bytes.push_back(data); + }; + ATH_MSG_DEBUG("TFCSONNXHandler read bytes from tree."); + return bytes; +}; + +void TFCSONNXHandler::writeBytesToTTree(TTree &tree, + const std::vector<char> &bytes) { + ATH_MSG_DEBUG("TFCSONNXHandler writing bytes to tree."); + char m_session_data; + tree.Branch("serialized_m_session", &m_session_data, + "serialized_m_session/B"); + for (Char_t here : bytes) { + m_session_data = here; + tree.Fill(); + }; + tree.Write(); + ATH_MSG_DEBUG("TFCSONNXHandler written bytes to tree."); +}; + +void TFCSONNXHandler::readSerializedSession() { + ATH_MSG_DEBUG("Transforming bytes to session."); + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); + Ort::SessionOptions opts({nullptr}); + m_session = + std::make_unique<Ort::Session>(env, m_bytes.data(), m_bytes.size(), opts); + ATH_MSG_DEBUG("Transformed bytes to session."); +}; + +template <typename Tin, typename Tout> +VNetworkBase::NetworkOutputs +TFCSONNXHandler::computeTemplate(VNetworkBase::NetworkInputs const &inputs) { + // working from + // https://github.com/microsoft/onnxruntime-inference-examples/blob/main/c_cxx/squeezenet/main.cpp#L71 + // and + // https://github.com/microsoft/onnxruntime-inference-examples/blob/main/c_cxx/MNIST/MNIST.cpp + ATH_MSG_DEBUG("Setting up inputs for computation on ONNX network."); + ATH_MSG_DEBUG("Input type " << typeid(Tin).name() << " output type " + << typeid(Tout).name()); + + // The inputs must be reformatted to the correct data structure. + const size_t num_input_nodes = m_inputNodeNames.size(); + // A pointer to all the nodes we will make + // Gonna keep the data in each node flat, becuase that's easier + std::vector<std::vector<Tin>> input_values(num_input_nodes); + std::vector<Ort::Value> node_values; + // Non const values that will be needed at each step. + std::string node_name; + int n_dimensions, elements_in_node, key_number; + size_t first_digit; + // Move along the list of node names gathered in the constructor + // we need both the node name, and the dimension + // so we cannot itterate directly on the vector. + ATH_MSG_DEBUG("Looping over " << num_input_nodes + << " input nodes of ONNX network."); + for (size_t node_n = 0; node_n < m_inputNodeNames.size(); node_n++) { + ATH_MSG_DEBUG("Node n = " << node_n); + node_name = m_inputNodeNames[node_n]; + ATH_MSG_DEBUG("Node name " << node_name); + // Get the shape of this node + n_dimensions = m_inputNodeDims[node_n].size(); + ATH_MSG_DEBUG("Node dimensions " << n_dimensions); + elements_in_node = 1; + for (int dimension_len : m_inputNodeDims[node_n]) { + elements_in_node *= dimension_len; + }; + ATH_MSG_DEBUG("Elements in node " << elements_in_node); + for (auto inp : inputs) { + ATH_MSG_DEBUG("Have input named " << inp.first); + }; + // Get the node content and remove any common prefix from the elements + const std::map<std::string, double> node_inputs = inputs.at(node_name); + std::vector<Tin> node_elements(elements_in_node); + + ATH_MSG_DEBUG("Found node named " << node_name << " with " + << elements_in_node << " elements."); + // Then the rest should be numbers from 0 up + for (auto element : node_inputs){ + first_digit = element.first.find_first_of("0123456789"); + // if there is no digit, it's not an element + if (first_digit < element.first.length()){ + key_number = std::stoi(element.first.substr(first_digit)); + node_elements[key_number] = element.second; + } + } + input_values[node_n] = node_elements; + + ATH_MSG_DEBUG("Creating ort tensor n_dimensions = " + << n_dimensions + << ", elements_in_node = " << elements_in_node); + // Doesn't copy data internally, so vector arguments need to stay alive + Ort::Value node = Ort::Value::CreateTensor<Tin>( + m_memoryInfo, input_values[node_n].data(), elements_in_node, + m_inputNodeDims[node_n].data(), n_dimensions); + // Problems with the string steam when compiling seperatly. + // ATH_MSG_DEBUG("Created input node " << node << " from values " << + // input_values[node_n]); + + node_values.push_back(std::move(node)); + } + + ATH_MSG_DEBUG("Running computation on ONNX network."); + // All inputs have been correctly formatted and the net can be run. + auto output_tensors = m_session->Run( + Ort::RunOptions{nullptr}, m_inputNodeNames.data(), &node_values[0], + num_input_nodes, m_outputNodeNames.data(), m_outputNodeNames.size()); + + ATH_MSG_DEBUG("Sorting outputs from computation on ONNX network."); + // Finaly, the output must be rearanged in the expected format. + TFCSONNXHandler::NetworkOutputs outputs; + // as the output format is just a string to double map + // the outputs will be keyed like "<node_name>_<part_n>" + std::string output_name; + const Tout *output_node; + for (size_t node_n = 0; node_n < m_outputNodeNames.size(); node_n++) { + // get a pointer to the data + output_node = output_tensors[node_n].GetTensorMutableData<Tout>(); + ATH_MSG_VERBOSE("output node " << output_node); + elements_in_node = m_outputNodeSize[node_n]; + node_name = m_outputNodeNames[node_n]; + // Does the GetTensorMutableData really always return a + // flat array? + // Likely yes, see use of memcopy on line 301 of + // onnxruntime/core/languge_interop_ops/pyop/pyop.cc + for (int part_n = 0; part_n < elements_in_node; part_n++) { + ATH_MSG_VERBOSE("Node part " << part_n << " contains " + << output_node[part_n]); + // compose the output name + output_name = node_name + "_" + std::to_string(part_n); + outputs[output_name] = static_cast<double>(output_node[part_n]); + } + } + removePrefixes(outputs); + ATH_MSG_DEBUG("Returning outputs from computation on ONNX network."); + return outputs; +}; + +// Possible to avoid copy? +// https://github.com/microsoft/onnxruntime/issues/8328 +// https://github.com/microsoft/onnxruntime/pull/11789 +// https://github.com/microsoft/onnxruntime/pull/8502 + +// Giving this its own streamer to call setupNet +void TFCSONNXHandler::Streamer(TBuffer &buf) { + ATH_MSG_DEBUG("In TFCSONNXHandler streamer."); + if (buf.IsReading()) { + ATH_MSG_INFO("Reading buffer in TFCSONNXHandler "); + // Get the persisted variables filled in + TFCSONNXHandler::Class()->ReadBuffer(buf, this); + // Setup the net, creating the non persisted variables + // exactly as in the constructor + this->setupNet(); +#ifndef __FastCaloSimStandAlone__ + // When running inside Athena, delete persisted information + // to conserve memory + this->deleteAllButNet(); +#endif + } else { + ATH_MSG_INFO("Writing buffer in TFCSONNXHandler "); + // Persist variables + TFCSONNXHandler::Class()->WriteBuffer(buf, this); + }; + ATH_MSG_DEBUG("Finished TFCSONNXHandler streamer."); +}; diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSParametrizationBase.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSParametrizationBase.cxx index 180de694a45915746ed9778ccd7b40cb2ada769e..97ca53f28f7963939e1d516733c16c0c5127877f 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSParametrizationBase.cxx +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSParametrizationBase.cxx @@ -57,8 +57,7 @@ void TFCSParametrizationBase::Print(Option_t *option) const { optprint.ReplaceAll("short", ""); if (longprint) { - ATH_MSG_INFO(optprint << GetTitle() << " (" << IsA()->GetName() << "*)" - << this); + ATH_MSG_INFO(optprint << GetTitle() << " " << IsA()->GetName() ); ATH_MSG(INFO) << optprint << " PDGID: "; if (is_match_all_pdgid()) { ATH_MSG(INFO) << "all"; diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSPredictExtrapWeights.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSPredictExtrapWeights.cxx index 7495f0c8ba74e7eb7165375e979da7e4e3b5378c..627106bd1d68cee3f11a259c63cb958313e7ebb7 100644 --- a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSPredictExtrapWeights.cxx +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSPredictExtrapWeights.cxx @@ -185,7 +185,7 @@ TFCSPredictExtrapWeights::prepareInputs(TFCSSimulationState &simulstate, inputVariables["pdgId"] = 1; // one hot enconding } else if (is_match_pdgid(11) || is_match_pdgid(-11)) { inputVariables["pdgId"] = 0; // one hot enconding - } + }; return inputVariables; } @@ -301,6 +301,7 @@ bool TFCSPredictExtrapWeights::initializeNetwork( ATH_MSG_INFO( "Using FastCaloNNInputFolderName: " << FastCaloNNInputFolderName); + set_pdgid(pid); std::string inputFileName = FastCaloNNInputFolderName + "NN_" + etaBin + ".json"; @@ -373,7 +374,28 @@ void TFCSPredictExtrapWeights::Streamer(TBuffer &R__b) { void TFCSPredictExtrapWeights::unit_test( TFCSSimulationState *simulstate, const TFCSTruthState *truth, const TFCSExtrapolationState *extrapol) { + const std::string this_file = __FILE__; + const std::string parent_dir = this_file.substr(0, this_file.find("/src/")); + const std::string norm_path = parent_dir + "/share/NormPredExtrapSample/"; + std::string net_path = "/cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/" + "FastCaloSim/LWTNNPredExtrapSample/"; + test_path(net_path, norm_path, simulstate, truth, extrapol); + //net_path = "/cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/FastCaloSim/" + // "ONNXPredExtrapSample/"; + //test_path(net_path, norm_path, simulstate, truth, extrapol); +} + +// test_path() +// Function for testing +void TFCSPredictExtrapWeights::test_path( + std::string &net_path, std::string const &norm_path, + TFCSSimulationState *simulstate, const TFCSTruthState *truth, + const TFCSExtrapolationState *extrapol) { ISF_FCS::MLogging logger; + ATH_MSG_NOCLASS(logger, "Testing net path ..." + << net_path.substr(net_path.length() - 20) + << " and norm path ..." + << norm_path.substr(norm_path.length() - 20)); if (!simulstate) { simulstate = new TFCSSimulationState(); #if defined(__FastCaloSimStandAlone__) @@ -430,7 +452,7 @@ void TFCSPredictExtrapWeights::unit_test( << " eta " << eta); // Find eta bin - int Eta = eta * 10; + const int Eta = eta * 10; std::string etaBin = ""; for (int i = 0; i <= 25; ++i) { int etaTmp = i * 5; @@ -442,13 +464,10 @@ void TFCSPredictExtrapWeights::unit_test( ATH_MSG_NOCLASS(logger, "etaBin = " << etaBin); TFCSPredictExtrapWeights NN("NN", "NN"); - NN.setLevel(MSG::VERBOSE); + NN.setLevel(MSG::INFO); const int pid = truth->pdgid(); - NN.initializeNetwork(pid, etaBin, - "/eos/atlas/atlascerngroupdisk/proj-simul/AF3_Run3/Jona/" - "lwtnn_inputs/json/v23/"); - NN.getNormInputs(etaBin, "/eos/atlas/atlascerngroupdisk/proj-simul/AF3_Run3/" - "Jona/lwtnn_inputs/txt/v23/"); + NN.initializeNetwork(pid, etaBin, net_path); + NN.getNormInputs(etaBin, norm_path); // Get extrapWeights and save them as AuxInfo in simulstate @@ -457,15 +476,16 @@ void TFCSPredictExtrapWeights::unit_test( NN.prepareInputs(*simulstate, truth->E() * 0.001); // Get predicted extrapolation weights + ATH_MSG_NOCLASS(logger, "computing with m_nn"); auto outputs = NN.m_nn->compute(inputVariables); - std::vector<int> layers = {0, 1, 2, 3, 12}; + const std::vector<int> layers = {0, 1, 2, 3, 12}; for (int ilayer : layers) { simulstate->setAuxInfo<float>( ilayer, outputs["extrapWeight_" + std::to_string(ilayer)]); } // Simulate - int layer = 0; + const int layer = 0; NN.set_calosample(layer); TFCSLateralShapeParametrizationHitBase::Hit hit; NN.simulate_hit(hit, *simulstate, truth, extrapol); @@ -481,9 +501,9 @@ void TFCSPredictExtrapWeights::unit_test( fNN = TFile::Open("FCSNNtest.root"); TFCSPredictExtrapWeights *NN2 = (TFCSPredictExtrapWeights *)(fNN->Get("NN")); - NN2->setLevel(MSG::DEBUG); + NN2->setLevel(MSG::INFO); NN2->simulate_hit(hit, *simulstate, truth, extrapol); - simulstate->Print(); + //simulstate->Print(); return; } diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSSimpleLWTNNHandler.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSSimpleLWTNNHandler.cxx new file mode 100644 index 0000000000000000000000000000000000000000..d738fcb4cbfe949cc67e39dd67c09901d5bb6a35 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/TFCSSimpleLWTNNHandler.cxx @@ -0,0 +1,101 @@ +#include "ISF_FastCaloSimEvent/TFCSSimpleLWTNNHandler.h" + +// For writing to a tree +#include "TBranch.h" +#include "TTree.h" + +// LWTNN +#include "lwtnn/LightweightNeuralNetwork.hh" +#include "lwtnn/parse_json.hh" + +TFCSSimpleLWTNNHandler::TFCSSimpleLWTNNHandler(const std::string &inputFile) + : VNetworkLWTNN(inputFile) { + ATH_MSG_DEBUG("Setting up from inputFile."); + setupPersistedVariables(); + setupNet(); +}; + +TFCSSimpleLWTNNHandler::TFCSSimpleLWTNNHandler( + const TFCSSimpleLWTNNHandler ©_from) + : VNetworkLWTNN(copy_from) { + // Cannot take copy of lwt::LightweightNeuralNetwork + // (copy constructor disabled) + ATH_MSG_DEBUG("Making new m_lwtnn_neural for copy of network."); + std::stringstream json_stream(m_json); + const lwt::JSONConfig config = lwt::parse_json(json_stream); + m_lwtnn_neural = std::make_unique<lwt::LightweightNeuralNetwork>( + config.inputs, config.layers, config.outputs); + m_outputLayers = copy_from.m_outputLayers; +}; + +void TFCSSimpleLWTNNHandler::setupNet() { + // build the graph + ATH_MSG_DEBUG("Reading the m_json string stream into a neural network"); + std::stringstream json_stream(m_json); + const lwt::JSONConfig config = lwt::parse_json(json_stream); + m_lwtnn_neural = std::make_unique<lwt::LightweightNeuralNetwork>( + config.inputs, config.layers, config.outputs); + // Get the output layers + ATH_MSG_DEBUG("Getting output layers for neural network"); + for (std::string name : config.outputs) { + ATH_MSG_VERBOSE("Found output layer called " << name); + m_outputLayers.push_back(name); + }; + ATH_MSG_DEBUG("Removing prefix from stored layers."); + removePrefixes(m_outputLayers); + ATH_MSG_DEBUG("Finished output nodes."); +} + +std::vector<std::string> TFCSSimpleLWTNNHandler::getOutputLayers() const { + return m_outputLayers; +}; + +// This is implement the specific compute, and ensure the output is returned in +// regular format. For LWTNN, that's easy. +TFCSSimpleLWTNNHandler::NetworkOutputs TFCSSimpleLWTNNHandler::compute( + TFCSSimpleLWTNNHandler::NetworkInputs const &inputs) const { + ATH_MSG_DEBUG("Running computation on LWTNN neural network"); + ATH_MSG_DEBUG(VNetworkBase::representNetworkInputs(inputs, 20)); + // Flatten the map depth + if (inputs.size() != 1) { + ATH_MSG_ERROR("The inputs have multiple elements." + << " An LWTNN neural network can only handle one node."); + }; + std::map<std::string, double> flat_inputs; + for (auto node : inputs) { + flat_inputs = node.second; + } + // Now we have flattened, we can compute. + NetworkOutputs outputs = m_lwtnn_neural->compute(flat_inputs); + removePrefixes(outputs); + ATH_MSG_DEBUG(VNetworkBase::representNetworkOutputs(outputs, 20)); + ATH_MSG_DEBUG("Computation on LWTNN neural network done, returning"); + return outputs; +}; + +// Giving this it's own streamer to call setupNet +void TFCSSimpleLWTNNHandler::Streamer(TBuffer &buf) { + ATH_MSG_DEBUG("In streamer of " << __FILE__); + if (buf.IsReading()) { + ATH_MSG_DEBUG("Reading buffer in TFCSSimpleLWTNNHandler "); + // Get the persisted variables filled in + TFCSSimpleLWTNNHandler::Class()->ReadBuffer(buf, this); + // Setup the net, creating the non persisted variables + // exactly as in the constructor + this->setupNet(); +#ifndef __FastCaloSimStandAlone__ + // When running inside Athena, delete persisted information + // to conserve memory + this->deleteAllButNet(); +#endif + } else { + if (!m_json.empty()) { + ATH_MSG_DEBUG("Writing buffer in TFCSSimpleLWTNNHandler "); + } else { + ATH_MSG_WARNING( + "Writing buffer in TFCSSimpleLWTNNHandler, but m_json is empty."); + } + // Persist variables + TFCSSimpleLWTNNHandler::Class()->WriteBuffer(buf, this); + }; +}; diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/VNetworkBase.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/VNetworkBase.cxx new file mode 100644 index 0000000000000000000000000000000000000000..93e38495addd279411fb5966216d4b51bb05c5e5 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/VNetworkBase.cxx @@ -0,0 +1,159 @@ + +#include "ISF_FastCaloSimEvent/VNetworkBase.h" +#include <iostream> + +// For streamer +#include "TBuffer.h" + +// For reading and writing to root +#include "TFile.h" +#include "TTree.h" + +// Probably called by a streamer. +VNetworkBase::VNetworkBase() : m_inputFile("unknown"){}; + +// record the input file and provided it's not empty call SetUp +VNetworkBase::VNetworkBase(const std::string &inputFile) + : m_inputFile(inputFile) { + ATH_MSG_DEBUG("Constructor called with inputFile"); +}; + +// No setupPersistedVariables or setupNet here! +VNetworkBase::VNetworkBase(const VNetworkBase ©_from) : MLogging() { + m_inputFile = std::string(copy_from.m_inputFile); +}; + +// Nothing is needed from the destructor right now. +// We don't use new anywhere, so the whole thing should clean +// itself up. +VNetworkBase::~VNetworkBase(){}; + +std::string +VNetworkBase::representNetworkInputs(VNetworkBase::NetworkInputs const &inputs, + int maxValues) { + std::string representation = + "NetworkInputs, outer size " + std::to_string(inputs.size()); + int valuesIncluded = 0; + for (const auto &outer : inputs) { + representation += "\n key->" + outer.first + "; "; + for (const auto &inner : outer.second) { + representation += inner.first + "=" + std::to_string(inner.second) + ", "; + ++valuesIncluded; + if (valuesIncluded > maxValues) + break; + }; + if (valuesIncluded > maxValues) + break; + }; + representation += "\n"; + return representation; +}; + +std::string VNetworkBase::representNetworkOutputs( + VNetworkBase::NetworkOutputs const &outputs, int maxValues) { + std::string representation = + "NetworkOutputs, size " + std::to_string(outputs.size()) + "; \n"; + int valuesIncluded = 0; + for (const auto &item : outputs) { + representation += item.first + "=" + std::to_string(item.second) + ", "; + ++valuesIncluded; + if (valuesIncluded > maxValues) + break; + }; + representation += "\n"; + return representation; +}; + +// this is also used for the stream operator +void VNetworkBase::print(std::ostream &strm) const { + if (m_inputFile.empty()) { + ATH_MSG_DEBUG("Making a network without a named inputFile"); + strm << "Unknown network"; + } else { + ATH_MSG_DEBUG("Making a network with input file " << m_inputFile); + strm << m_inputFile; + }; +}; + +void VNetworkBase::writeNetToTTree(TFile &root_file, + std::string const &tree_name) { + ATH_MSG_DEBUG("Making tree name " << tree_name); + root_file.cd(); + const std::string title = "onnxruntime saved network"; + TTree tree(tree_name.c_str(), title.c_str()); + this->writeNetToTTree(tree); + root_file.Write(); +}; + +void VNetworkBase::writeNetToTTree(std::string const &root_name, + std::string const &tree_name) { + ATH_MSG_DEBUG("Making or updating file name " << root_name); + TFile root_file(root_name.c_str(), "UPDATE"); + this->writeNetToTTree(root_file, tree_name); + root_file.Close(); +}; + +bool VNetworkBase::isRootFile(std::string const &filename) const { + const std::string *to_check = &filename; + if (filename.length() == 0) { + to_check = &this->m_inputFile; + ATH_MSG_DEBUG("No file name given, so using m_inputFile, " << m_inputFile); + }; + const std::string ending = ".root"; + const int ending_len = ending.length(); + const int filename_len = to_check->length(); + if (filename_len < ending_len) { + return false; + } + return (0 == + to_check->compare(filename_len - ending_len, ending_len, ending)); +}; + +bool VNetworkBase::isFile() const { return isFile(m_inputFile); }; + +bool VNetworkBase::isFile(std::string const &inputFile) { + if (FILE *file = std::fopen(inputFile.c_str(), "r")) { + std::fclose(file); + return true; + } else { + return false; + }; +}; + +namespace { +int GetPrefixLength(const std::vector<std::string> strings) { + const std::string first = strings[0]; + int length = first.length(); + for (std::string this_string : strings) { + for (int i = 0; i < length; i++) { + if (first[i] != this_string[i]) { + length = i; + break; + } + } + } + return length; +}; +} // namespace + +void VNetworkBase::removePrefixes( + std::vector<std::string> &output_names) const { + const int length = GetPrefixLength(output_names); + for (long unsigned int i = 0; i < output_names.size(); i++) + output_names[i] = output_names[i].substr(length); +}; + +void VNetworkBase::removePrefixes(VNetworkBase::NetworkOutputs &outputs) const { + std::vector<std::string> output_layers; + for (auto const &output : outputs) + output_layers.push_back(output.first); + const int length = GetPrefixLength(output_layers); + for (std::string layer_name : output_layers) { + // remove this output + auto nodeHandle = outputs.extract(layer_name); + // change the key + nodeHandle.key() = layer_name.substr(length); + // replace the output + outputs.insert(std::move(nodeHandle)); + } +}; diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/VNetworkLWTNN.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/VNetworkLWTNN.cxx new file mode 100644 index 0000000000000000000000000000000000000000..d7f906d59d8da2b5b5c5e2de6f1edc65fbecc5d1 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/src/VNetworkLWTNN.cxx @@ -0,0 +1,95 @@ + +#include "ISF_FastCaloSimEvent/VNetworkLWTNN.h" +#include <fstream> +#include <sstream> +#include <stdexcept> + +// For reading and writing to root +#include "TFile.h" +#include "TTree.h" + +VNetworkLWTNN::VNetworkLWTNN(const VNetworkLWTNN ©_from) + : VNetworkBase(copy_from) { + m_json = copy_from.m_json; + if (m_json.length() == 0) { + throw std::invalid_argument( + "Trying to copy a VNetworkLWTNN with length 0 m_json, probably " + "deleteAllButNet was called on the object being coppied from."); + }; + m_printable_name = copy_from.m_printable_name; +}; + +VNetworkLWTNN::~VNetworkLWTNN(){}; + +// This setup is going to do it's best to +// fill in m_json. +void VNetworkLWTNN::setupPersistedVariables() { + if (this->isFile(m_inputFile)) { + ATH_MSG_DEBUG("Making an LWTNN network using a file on disk, " + << m_inputFile); + m_printable_name = m_inputFile; + fillJson(); + } else { + ATH_MSG_DEBUG("Making an LWTNN network using a json in memory, length " + << m_inputFile.length()); + m_printable_name = "JSON from memory"; + m_json = m_inputFile; + }; +}; + +void VNetworkLWTNN::print(std::ostream &strm) const { + strm << m_printable_name; +}; + +void VNetworkLWTNN::writeNetToTTree(TTree &tree) { + writeStringToTTree(tree, m_json); +}; + +void VNetworkLWTNN::fillJson(std::string const &tree_name) { + ATH_MSG_VERBOSE("Trying to fill the m_json variable"); + if (this->isRootFile()) { + ATH_MSG_VERBOSE("Treating input file as a root file"); + TFile tfile(this->m_inputFile.c_str(), "READ"); + TTree *tree = (TTree *)tfile.Get(tree_name.c_str()); + std::string found = this->readStringFromTTree(*tree); + ATH_MSG_DEBUG("Read json from root file, length " << found.length()); + m_json = found; + } else { + ATH_MSG_VERBOSE("Treating input file as a text json file"); + // The input file is read into a stringstream + std::ifstream input(m_inputFile); + std::ostringstream sstr; + sstr << input.rdbuf(); + m_json = sstr.str(); + input.close(); + ATH_MSG_DEBUG("Read json from text file"); + } +} + +std::string VNetworkLWTNN::readStringFromTTree(TTree &tree) { + std::string found = std::string(); + std::string *to_found = &found; + tree.SetBranchAddress("lwtnn_json", &to_found); + tree.GetEntry(0); + return found; +}; + +void VNetworkLWTNN::writeStringToTTree(TTree &tree, std::string json_string) { + tree.Branch("lwtnn_json", &json_string); + tree.Fill(); + tree.Write(); +}; + +void VNetworkLWTNN::deleteAllButNet() { + ATH_MSG_DEBUG("Replacing m_inputFile with unknown"); + m_inputFile.assign("unknown"); + m_inputFile.shrink_to_fit(); + ATH_MSG_DEBUG("Emptying the m_json string"); + m_json.clear(); + m_json.shrink_to_fit(); + ATH_MSG_VERBOSE("m_json now has capacity " + << m_json.capacity() << ". m_inputFile now has capacity " + << m_inputFile.capacity() + << ". m_printable_name now has capacity " + << m_printable_name.capacity()); +}; diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/test/GenericNetwork_test.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/test/GenericNetwork_test.cxx new file mode 100644 index 0000000000000000000000000000000000000000..6c4219b79a4db6473e9420253535f92ac6fce053 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/test/GenericNetwork_test.cxx @@ -0,0 +1,362 @@ +// test interface for a wrapper module that can open various kinds of neural +// network graph file + +// include things to test +#include "ISF_FastCaloSimEvent/TFCSGANLWTNNHandler.h" +#include "ISF_FastCaloSimEvent/TFCSSimpleLWTNNHandler.h" +#include "ISF_FastCaloSimEvent/TFCSONNXHandler.h" + +// include things to allow us to check the streamers +#include "TFile.h" + +// include generic utilities +#include "ISF_FastCaloSimEvent/MLogging.h" +#include <iostream> +#include <fstream> +#include <vector> + +using ISF_FCS::MLogging; + +// Very crude way to make a file for input +// writes a fake lwtnn GAN to the disk at the +// specified file name +void setup_fastCaloGAN(std::string outputFile) { + std::ofstream output; + output.open(outputFile); + + const char *text = "{\n" + " \"defaults\": {},\n" + " \"inputs\": [\n" + " {\n" + " \"name\": \"node_1\",\n" + " \"offset\": 0,\n" + " \"scale\": 1\n" + " },\n" + " {\n" + " \"name\": \"node_2\",\n" + " \"offset\": 0,\n" + " \"scale\": 1\n" + " }\n" + " ],\n" + " \"layers\": [\n" + " {\n" + " \"activation\": \"rectified\",\n" + " \"architecture\": \"dense\",\n" + " \"bias\": [\n" + " 1.0,\n" + " 0.0\n" + " ],\n" + " \"weights\": [\n" + " 1.0,\n" + " 0.5,\n" + " 0.5,\n" + " 0.0\n" + " ]\n" + " }\n" + " ],\n" + " \"outputs\": [\n" + " \"EXTRAPWEIGHT_0\",\n" + " \"EXTRAPWEIGHT_1\"\n" + " ]\n" + "}\n"; + output << text; + + output.close(); +} + +// ditto for Sim +void setup_fastCaloSim(std::string outputFile) { + std::ofstream output; + output.open(outputFile); + + const char *text = "{\n" + " \"input_sequences\": [],\n" + " \"inputs\": [\n" + " {\n" + " \"name\": \"node_0\",\n" + " \"variables\": [\n" + " {\n" + " \"name\": \"0\",\n" + " \"offset\": 0,\n" + " \"scale\": 1\n" + " }\n" + " ]\n" + " },\n" + " {\n" + " \"name\": \"node_1\",\n" + " \"variables\": [\n" + " {\n" + " \"name\": \"0\",\n" + " \"offset\": 0,\n" + " \"scale\": 1\n" + " },\n" + " {\n" + " \"name\": \"1\",\n" + " \"offset\": 0,\n" + " \"scale\": 1\n" + " }\n" + " ]\n" + " }\n" + " ],\n" + " \"layers\": [\n" + " {\n" + " \"activation\": \"rectified\",\n" + " \"architecture\": \"dense\",\n" + " \"bias\": [\n" + " 0.1,\n" + " 0.2,\n" + " 0.3\n" + " ],\n" + " \"weights\": [\n" + " 0.1,\n" + " 0.2,\n" + " 0.3,\n" + " 0.4,\n" + " 0.5,\n" + " 0.6,\n" + " 0.7,\n" + " 0.8,\n" + " 0.9\n" + " ]\n" + " }\n" + " ],\n" + " \"nodes\": [\n" + " {\n" + " \"size\": 1,\n" + " \"sources\": [\n" + " 0\n" + " ],\n" + " \"type\": \"input\"\n" + " },\n" + " {\n" + " \"size\": 2,\n" + " \"sources\": [\n" + " 1\n" + " ],\n" + " \"type\": \"input\"\n" + " },\n" + " {\n" + " \"sources\": [\n" + " 0,\n" + " 1\n" + " ],\n" + " \"type\": \"concatenate\"\n" + " },\n" + " {\n" + " \"layer_index\": 0,\n" + " \"sources\": [\n" + " 2\n" + " ],\n" + " \"type\": \"feed_forward\"\n" + " }\n" + " ],\n" + " \"outputs\": {\n" + " \"output_layer\": {\n" + " \"labels\": [\n" + " \"out_0\",\n" + " \"out_1\",\n" + " \"out_2\"\n" + " ],\n" + " \"node_index\": 3\n" + " }\n" + " }\n" + "}\n"; + output << text; + + output.close(); +} + +void test_fastCaloGAN(std::string inputFile, ISF_FCS::MLogging logger) { + ATH_MSG_NOCLASS(logger, "Testing fastCaloGAN format."); + TFCSSimpleLWTNNHandler my_net(inputFile); + + // Fake inputs + std::map<std::string, double> input_nodes; + for (int node = 1; node < 3; node++) { + input_nodes["node_" + std::to_string(node)] = node; + } + decltype(my_net)::NetworkInputs inputs; + inputs["inputs"] = input_nodes; + ATH_MSG_NOCLASS(logger, VNetworkBase::representNetworkInputs(inputs, 100)); + + // run the net + decltype(my_net)::NetworkOutputs outputs = my_net.compute(inputs); + + ATH_MSG_NOCLASS(logger, VNetworkBase::representNetworkOutputs(outputs, 100)); + + // Save the net to root file + std::string output_root_file = "with_lwtnn_network.root"; + ATH_MSG_NOCLASS(logger, "Writing to a root file; " << output_root_file); + my_net.writeNetToTTree(output_root_file); + + ATH_MSG_NOCLASS(logger, "Reading copy written to root"); + TFCSSimpleLWTNNHandler copy_net(output_root_file); + ATH_MSG_NOCLASS(logger, "Running copy from root file"); + decltype(copy_net)::NetworkOutputs other_out = copy_net.compute(inputs); + ATH_MSG_NOCLASS(logger, + VNetworkBase::representNetworkOutputs(other_out, 100)); + + ATH_MSG_NOCLASS(logger, + "Outputs should before and after writing shoud be identical"); + + // Finally, save the network using a streamer. + ATH_MSG_NOCLASS(logger, "Writing with a streamer to; " << output_root_file); + TFile test_stream_write(output_root_file.c_str(), "RECREATE"); + test_stream_write.WriteObjectAny(&my_net, "TFCSSimpleLWTNNHandler", + "test_net"); + test_stream_write.Close(); + + ATH_MSG_NOCLASS(logger, "Reading streamer copy written to root"); + TFile test_stream_read(output_root_file.c_str(), "READ"); + TFCSSimpleLWTNNHandler *streamed_net = + test_stream_read.Get<TFCSSimpleLWTNNHandler>("test_net"); + + ATH_MSG_NOCLASS(logger, "Running copy streamed from root file"); + TFCSSimpleLWTNNHandler::NetworkOutputs streamed_out = + streamed_net->compute(inputs); + ATH_MSG_NOCLASS(logger, + VNetworkBase::representNetworkOutputs(streamed_out, 100)); + ATH_MSG_NOCLASS(logger, + "Outputs should before and after writing shoud be identical"); +} + +void test_fastCaloSim(std::string inputFile, ISF_FCS::MLogging logger) { + ATH_MSG_NOCLASS(logger, "Testing fastCaloSim format."); + TFCSGANLWTNNHandler my_net(inputFile); + ATH_MSG_NOCLASS(logger, "Made the net."); + + // Fake inputs + decltype(my_net)::NetworkInputs inputs; + for (int node = 0; node < 2; node++) { + std::map<std::string, double> node_input; + for (int i = 0; i <= node; i++) { + node_input[std::to_string(i)] = i; + } + inputs["node_" + std::to_string(node)] = node_input; + } + ATH_MSG_NOCLASS(logger, "Made the inputs."); + ATH_MSG_NOCLASS(logger, VNetworkBase::representNetworkInputs(inputs, 100)); + + // run the net + decltype(my_net)::NetworkOutputs outputs = my_net.compute(inputs); + + ATH_MSG_NOCLASS(logger, VNetworkBase::representNetworkOutputs(outputs, 100)); + // Save the net to root file + std::string output_root_file = "with_lwtnn_graph.root"; + ATH_MSG_NOCLASS(logger, "Writing to a root file; " << output_root_file); + my_net.writeNetToTTree(output_root_file); + + ATH_MSG_NOCLASS(logger, "Reading copy written to root"); + TFCSGANLWTNNHandler copy_net(output_root_file); + ATH_MSG_NOCLASS(logger, "Running copy from root file"); + decltype(copy_net)::NetworkOutputs other_out = copy_net.compute(inputs); + ATH_MSG_NOCLASS(logger, + VNetworkBase::representNetworkOutputs(other_out, 100)); + ATH_MSG_NOCLASS(logger, + "Outputs should before and after writing shoud be identical"); + + // Finally, save the network using a streamer. + ATH_MSG_NOCLASS(logger, "Writing with a streamer to; " << output_root_file); + TFile test_stream_write(output_root_file.c_str(), "RECREATE"); + test_stream_write.WriteObjectAny(&my_net, "TFCSGANLWTNNHandler", "test_net"); + test_stream_write.Close(); + + ATH_MSG_NOCLASS(logger, "Reading streamer copy written to root"); + TFile test_stream_read(output_root_file.c_str(), "READ"); + TFCSGANLWTNNHandler *streamed_net = + test_stream_read.Get<TFCSGANLWTNNHandler>("test_net"); + + ATH_MSG_NOCLASS(logger, "Running copy streamed from root file"); + TFCSGANLWTNNHandler::NetworkOutputs streamed_out = + streamed_net->compute(inputs); + ATH_MSG_NOCLASS(logger, + VNetworkBase::representNetworkOutputs(streamed_out, 100)); + ATH_MSG_NOCLASS(logger, + "Outputs should before and after writing shoud be identical"); +} + +void test_ONNX(ISF_FCS::MLogging logger) { + // Curiously, there is no easy way to generate an ONNX + // model from c++. It is expected you will convert an + // existing model, so creating one here would require + // additional imports. + std::string this_file = __FILE__; + std::string parent_dir = this_file.substr(0, this_file.find("/test/")); + std::string inputFile = parent_dir + "/share/toy_network.onnx"; + // Only proceed if that file can be read. + std::ifstream onnx_file(inputFile); + if (onnx_file.good()) { + // Read form regular onnx file + ATH_MSG_NOCLASS(logger, "Testing ONNX format."); + TFCSONNXHandler my_net(inputFile); + + // Fake inputs + decltype(my_net)::NetworkInputs inputs; + std::map<std::string, double> node_input; + for (int i = 0; i < 2; i++) { + node_input[std::to_string(i)] = i + 1; + } + inputs["inputs"] = node_input; + ATH_MSG_NOCLASS(logger, VNetworkBase::representNetworkInputs(inputs, 100)); + + // run the net + decltype(my_net)::NetworkOutputs outputs = my_net.compute(inputs); + ATH_MSG_NOCLASS(logger, + VNetworkBase::representNetworkOutputs(outputs, 100)); + + // Write a copy to a root file + std::string output_root_file = "with_serialized_network.root"; + ATH_MSG_NOCLASS(logger, "Writing to a root file; " << output_root_file); + my_net.writeNetToTTree(output_root_file); + + ATH_MSG_NOCLASS(logger, "Reading copy written to root"); + TFCSONNXHandler copy_net(output_root_file); + + ATH_MSG_NOCLASS(logger, "Running copy from root file"); + decltype(copy_net)::NetworkOutputs other_out = copy_net.compute(inputs); + ATH_MSG_NOCLASS(logger, + VNetworkBase::representNetworkOutputs(other_out, 100)); + + ATH_MSG_NOCLASS( + logger, "Outputs should before and after writing shoud be identical"); + + // Finally, save the network using a streamer. + ATH_MSG_NOCLASS(logger, "Writing with a streamer to; " << output_root_file); + TFile test_stream_write(output_root_file.c_str(), "RECREATE"); + test_stream_write.WriteObjectAny(&my_net, "TFCSONNXHandler", "test_net"); + test_stream_write.Close(); + + ATH_MSG_NOCLASS(logger, "Reading streamer copy written to root"); + TFile test_stream_read(output_root_file.c_str(), "READ"); + TFCSONNXHandler *streamed_net = + test_stream_read.Get<TFCSONNXHandler>("test_net"); + + ATH_MSG_NOCLASS(logger, "Running copy streamed from root file"); + TFCSONNXHandler::NetworkOutputs streamed_out = + streamed_net->compute(inputs); + ATH_MSG_NOCLASS(logger, + VNetworkBase::representNetworkOutputs(streamed_out)); + ATH_MSG_NOCLASS( + logger, "Outputs should before and after writing shoud be identical"); + + } else { + ATH_MSG_NOCLASS(logger, "Couldn't read file " + << inputFile << "\n Will skip ONNX tests."); + } +} + +int main() { + ISF_FCS::MLogging logger; + std::string gan_data_example("example_data_gan.json"); + setup_fastCaloGAN(gan_data_example); + test_fastCaloGAN(gan_data_example, logger); + + std::string sim_data_example("example_data_sim.json"); + setup_fastCaloSim(sim_data_example); + test_fastCaloSim(sim_data_example, logger); + + // For some reason this freezes... + //test_ONNX(logger); + ATH_MSG_NOCLASS(logger, "Program ends"); + return 0; +} diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/test/TFCSEnergyAndHitGANV2_test.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/test/TFCSEnergyAndHitGANV2_test.cxx new file mode 100644 index 0000000000000000000000000000000000000000..9084e8d72ddc92b2d95c192e7d9ac289a0093111 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/test/TFCSEnergyAndHitGANV2_test.cxx @@ -0,0 +1,28 @@ +// test interface for a wrapper module that can open various kinds of neural +// network graph file + +// include things to test +#include "ISF_FastCaloSimEvent/TFCSEnergyAndHitGANV2.h" +#include "ISF_FastCaloSimEvent/TFCSPredictExtrapWeights.h" + +// include generic utilities +#include "ISF_FastCaloSimEvent/MLogging.h" +#include <fstream> +#include <iostream> +#include <vector> + +using ISF_FCS::MLogging; + +int main() { + ISF_FCS::MLogging logger; + ATH_MSG_NOCLASS(logger, "Running TFCSEnergyAndHitGANV2 on LWTNN"); + std::string path = "/cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/FastCaloSim/LWTNNsample"; + TFCSEnergyAndHitGANV2::test_path(path, nullptr, nullptr, nullptr, "unnamed", 211); + + // This causes timeouts in CI + // ATH_MSG_NOCLASS(logger, "Running TFCSEnergyAndHitGANV2 on ONNX"); + // path = "/cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/FastCaloSim/ONNXsample"; + // TFCSEnergyAndHitGANV2::test_path(path, nullptr, nullptr, nullptr, "unnamed", 211); + ATH_MSG_NOCLASS(logger, "Program ends"); + return 0; +} diff --git a/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/test/TFCSPredictExtrapWeights_test.cxx b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/test/TFCSPredictExtrapWeights_test.cxx new file mode 100644 index 0000000000000000000000000000000000000000..78da7d4e6b1eed563cca6c89e738b502bba92e81 --- /dev/null +++ b/Simulation/ISF/ISF_FastCaloSim/ISF_FastCaloSimEvent/test/TFCSPredictExtrapWeights_test.cxx @@ -0,0 +1,31 @@ +// test interface for a wrapper module that can open various kinds of neural +// network graph file + +// include things to test +#include "ISF_FastCaloSimEvent/TFCSEnergyAndHitGANV2.h" +#include "ISF_FastCaloSimEvent/TFCSPredictExtrapWeights.h" + +// include generic utilities +#include "ISF_FastCaloSimEvent/MLogging.h" +#include <fstream> +#include <iostream> +#include <vector> + +using ISF_FCS::MLogging; + +int main() { + ISF_FCS::MLogging logger; + ATH_MSG_NOCLASS(logger, "Running TFCSPredictExtrapWeights"); + const std::string this_file = __FILE__; + const std::string parent_dir = this_file.substr(0, this_file.find("/test/")); + const std::string norm_path = parent_dir + "/share/NormPredExtrapSample/"; + std::string net_path = "/cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/" + "FastCaloSim/LWTNNPredExtrapSample/"; + TFCSPredictExtrapWeights::test_path(net_path, norm_path); + // This causes timeouts in CI + //net_path = "/cvmfs/atlas-nightlies.cern.ch/repo/data/data-art/FastCaloSim/" + // "ONNXPredExtrapSample/"; + //TFCSPredictExtrapWeights::test_path(net_path, norm_path); + ATH_MSG_NOCLASS(logger, "Program ends"); + return 0; +}