Skip to content
Snippets Groups Projects
Commit 1ea35ef6 authored by Walter Lampl's avatar Walter Lampl
Browse files

Merge branch 'bootstrap_refactor' into 'main'

CPAlgorithms: refactor BoostrapGeneratorAlg, add unit tests

See merge request atlas/athena!66779
parents da78eafc 17b76d50
No related branches found
No related tags found
No related merge requests found
......@@ -17,33 +17,54 @@
namespace CP
{
/// \brief an algorithm to compute per-event bootstrap replica weights
class BootstrapGeneratorAlg final : public EL::AnaAlgorithm
/// \brief a class to generate random numbers with a unique seed
class BootstrapGenerator
{
/// \brief the standard constructor
public:
BootstrapGeneratorAlg(const std::string &name,
ISvcLocator *pSvcLocator);
BootstrapGenerator() {};
/// \brief implementation of the hash function from https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
public:
StatusCode initialize() override;
std::uint64_t fnv1a_64(const void *buffer, size_t size, std::uint64_t offset_basis);
/// \brief set the seed of the random number generator based on event properties
public:
StatusCode execute() override;
void setSeed(std::uint64_t eventNumber, std::uint32_t runNumber, std::uint32_t mcChannelNumber);
/// \brief generate a unique seed based on event identifiers
public:
std::uint64_t generateSeed(std::uint64_t eventNumber, std::uint32_t runNumber, std::uint32_t mcChannelNumber);
/// \brief implementation of the hash function from https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
/// \brief get the next bootstrap weight
public:
std::uint64_t fnv1a_64(const void *buffer, size_t size, std::uint64_t offset_basis);
std::uint8_t getBootstrap() { return m_rng.Poisson(1); };
/// \brief constants for seed generation
private:
static constexpr std::uint64_t m_offset = 14695981039346656037u;
static constexpr std::uint64_t m_prime = 1099511628211u;
/// \brief the random number generator (Ranlux++)
private:
TRandomRanluxpp m_rng;
};
/// \brief an algorithm to compute per-event bootstrap replica weights
class BootstrapGeneratorAlg final : public EL::AnaAlgorithm
{
/// \brief the standard constructor
public:
BootstrapGeneratorAlg(const std::string &name,
ISvcLocator *pSvcLocator);
public:
StatusCode initialize() override;
public:
StatusCode execute() override;
/// \brief the systematics list we run
private:
SysListHandle m_systematicsList{this};
......@@ -51,7 +72,7 @@ namespace CP
/// \brief the EventInfo container
private:
SysReadHandle<xAOD::EventInfo> m_eventInfoHandle{
this, "eventInfo", "EventInfo", "the EventInfo container"};
this, "eventInfo", "EventInfo", "the EventInfo container"};
/// \brief the number of bootstrap replicas
private:
......@@ -61,9 +82,9 @@ namespace CP
private:
Gaudi::Property<bool> m_data {this, "isData", false, "whether we are running on data"};
/// \brief the random number generator (Ranlux++)
/// \brief the bootstrap generator instance
private:
TRandomRanluxpp m_rng;
BootstrapGenerator m_bootstrap;
/// \brief the vector of bootstrap replica weights
private:
......@@ -72,7 +93,7 @@ namespace CP
/// \brief the output decoration
private:
SysWriteDecorHandle<std::vector<std::uint8_t>> m_decoration{
this, "decorationName", "bootstrapWeights_%SYS%", "decoration name for the vector of bootstrapped weights"};
this, "decorationName", "bootstrapWeights_%SYS%", "decoration name for the vector of bootstrapped weights"};
};
} // namespace CP
......
......@@ -16,6 +16,7 @@
<class name="CP::AsgxAODNTupleMakerAlg" />
<class name="CP::AsgxAODMetNTupleMakerAlg" />
<class name="CP::BootstrapGeneratorAlg" />
<class name="CP::BootstrapGenerator" />
<class name="CP::CopyNominalSelectionAlg" />
<class name="CP::EventFlagSelectionAlg" />
<class name="CP::EventStatusSelectionAlg" />
......
......@@ -32,6 +32,15 @@ atlas_install_python_modules( python/*.py )
atlas_install_joboptions( share/*_jobOptions.py )
atlas_install_scripts( share/*_eljob.py )
find_package( GTest )
find_package( GMock )
atlas_add_test( gt_BootstrapGenerator
SOURCES test/gt_BootstrapGenerator.cxx
INCLUDE_DIRS ${GTEST_INCLUDE_DIRS}
LINK_LIBRARIES ${GTEST_LIBRARIES} AsgTestingLib AsgAnalysisAlgorithmsLib CxxUtils
POST_EXEC_SCRIPT nopost.sh )
if( XAOD_STANDALONE )
atlas_add_test( EventAlgsTestJobData
......
......@@ -8,11 +8,11 @@
CP::BootstrapGeneratorAlg::BootstrapGeneratorAlg(const std::string &name,
ISvcLocator *pSvcLocator)
: EL::AnaAlgorithm(name, pSvcLocator)
: EL::AnaAlgorithm(name, pSvcLocator)
{
}
std::uint64_t CP::BootstrapGeneratorAlg::fnv1a_64(const void *buffer, size_t size, std::uint64_t offset_basis) {
std::uint64_t CP::BootstrapGenerator::fnv1a_64(const void *buffer, size_t size, std::uint64_t offset_basis) {
std::uint64_t h = offset_basis;
const unsigned char *p = static_cast<const unsigned char *>(buffer);
for (size_t i = 0; i < size; i++) {
......@@ -22,7 +22,7 @@ std::uint64_t CP::BootstrapGeneratorAlg::fnv1a_64(const void *buffer, size_t siz
return h;
}
std::uint64_t CP::BootstrapGeneratorAlg::generateSeed(std::uint64_t eventNumber, std::uint32_t runNumber, std::uint32_t mcChannelNumber)
std::uint64_t CP::BootstrapGenerator::generateSeed(std::uint64_t eventNumber, std::uint32_t runNumber, std::uint32_t mcChannelNumber)
{
std::uint64_t hash = fnv1a_64(&runNumber, sizeof(runNumber), m_offset);
hash = fnv1a_64(&eventNumber, sizeof(eventNumber), hash);
......@@ -30,13 +30,19 @@ std::uint64_t CP::BootstrapGeneratorAlg::generateSeed(std::uint64_t eventNumber,
return hash;
}
void CP::BootstrapGenerator::setSeed(std::uint64_t eventNumber, std::uint32_t runNumber, std::uint32_t mcChannelNumber)
{
std::uint64_t seed = generateSeed(eventNumber, runNumber, mcChannelNumber);
m_rng.SetSeed(seed);
}
StatusCode CP::BootstrapGeneratorAlg::initialize()
{
if (m_nReplicas < 0)
{
ANA_MSG_ERROR("The number of bootstrapped weights (toys) cannot be negative!");
return StatusCode::FAILURE;
}
{
ANA_MSG_ERROR("The number of bootstrapped weights (toys) cannot be negative!");
return StatusCode::FAILURE;
}
ANA_CHECK(m_eventInfoHandle.initialize(m_systematicsList));
ANA_CHECK(m_decoration.initialize(m_systematicsList, m_eventInfoHandle));
......@@ -48,24 +54,24 @@ StatusCode CP::BootstrapGeneratorAlg::initialize()
StatusCode CP::BootstrapGeneratorAlg::execute()
{
for (const auto &sys : m_systematicsList.systematicsVector())
{
// retrieve the EventInfo
const xAOD::EventInfo *evtInfo = nullptr;
ANA_CHECK(m_eventInfoHandle.retrieve(evtInfo, sys));
{
// retrieve the EventInfo
const xAOD::EventInfo *evtInfo = nullptr;
ANA_CHECK(m_eventInfoHandle.retrieve(evtInfo, sys));
// generate a unique seed from runNumber, eventNumber and DSID!
m_rng.SetSeed(generateSeed(evtInfo->eventNumber(), evtInfo->runNumber(), m_data ? 0 : evtInfo->mcChannelNumber()));
// generate a unique seed from runNumber, eventNumber and DSID!
m_bootstrap.setSeed(evtInfo->eventNumber(), evtInfo->runNumber(), m_data ? 0 : evtInfo->mcChannelNumber());
m_weights.resize(m_nReplicas);
// and fill it with Poisson(1)
for (int i = 0; i < m_nReplicas; i++)
{
m_weights.at(i) = m_rng.Poisson(1);
}
m_weights.resize(m_nReplicas);
// and fill it with Poisson(1)
for (int i = 0; i < m_nReplicas; i++)
{
m_weights.at(i) = m_bootstrap.getBootstrap();
}
// decorate weights onto EventInfo
m_decoration.set(*evtInfo, m_weights, sys);
}
// decorate weights onto EventInfo
m_decoration.set(*evtInfo, m_weights, sys);
}
return StatusCode::SUCCESS;
}
/*
Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
*/
/// @author Baptiste Ravina
#include "CxxUtils/checker_macros.h"
ATLAS_NO_CHECK_FILE_THREAD_SAFETY;
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <AsgTesting/UnitTest.h>
#include <AsgMessaging/MessageCheck.h>
#include <AsgAnalysisAlgorithms/BootstrapGeneratorAlg.h>
//
// unit test
//
using namespace asg::msgUserCode;
namespace CP
{
// Test fixture for BootstrapGenerator
class BootstrapGeneratorTest : public ::testing::Test
{
protected:
BootstrapGenerator m_bootstrapGenerator;
};
TEST_F (BootstrapGeneratorTest, SeedGeneration)
{
// BootstrapGenerator bootstrapGenerator;
// Test seed generation with specific input values
std::uint64_t eventNumber = 123;
std::uint32_t runNumber = 456;
std::uint32_t mcChannelNumber = 789;
// Generate seed
std::uint64_t seed = m_bootstrapGenerator.generateSeed(eventNumber, runNumber, mcChannelNumber);
// Define the expected result based on the known hash function
std::uint64_t expectedSeed = 8089831591138695645u;
// Check if the generated seed matches the expected result
EXPECT_EQ(seed, expectedSeed);
}
TEST_F (BootstrapGeneratorTest, WeightGenerationMC)
{
// BootstrapGenerator bootstrapGenerator;
// Test weight generation with specific input values
std::uint64_t eventNumber = 306030154;
std::uint32_t runNumber = 310000;
std::uint32_t mcChannelNumber = 410470;
// Generate and set seed
m_bootstrapGenerator.setSeed(eventNumber, runNumber, mcChannelNumber);
// Collect the first ten weights
std::vector<std::uint8_t> weights;
for (int i = 0; i < 10; i++)
{
weights.push_back( m_bootstrapGenerator.getBootstrap() );
}
// Check if the generated weights match the expected result
ASSERT_THAT(weights, ::testing::ElementsAre(2,0,0,1,4,2,0,1,1,2));
}
TEST_F (BootstrapGeneratorTest, WeightGenerationData)
{
// BootstrapGenerator bootstrapGenerator;
// Test weight generation with specific input values
std::uint64_t eventNumber = 3772712513;
std::uint32_t runNumber = 438481;
std::uint32_t mcChannelNumber = 0;
// Generate and set seed
m_bootstrapGenerator.setSeed(eventNumber, runNumber, mcChannelNumber);
// Collect the first ten weights
std::vector<std::uint8_t> weights;
for (int i = 0; i < 10; i++)
{
weights.push_back( m_bootstrapGenerator.getBootstrap() );
}
// Check if the generated weights match the expected result
ASSERT_THAT(weights, ::testing::ElementsAre(1,3,2,1,2,0,0,1,1,1));
}
} // namespace
int main (int argc, char **argv)
{
#ifdef ROOTCORE
StatusCode::enableFailure();
ANA_CHECK (xAOD::Init ());
#endif
::testing::InitGoogleTest (&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -57,3 +57,8 @@ elseif( NOT "${CMAKE_PROJECT_NAME}" STREQUAL "AthDerivation" )
POST_EXEC_SCRIPT nopost.sh
PROPERTIES TIMEOUT 600 )
endif()
# Unit test for python modules
atlas_add_test( pymodules
SCRIPT python -m unittest discover -v -s ${CMAKE_CURRENT_SOURCE_DIR}/test
POST_EXEC_SCRIPT nopost.sh )
\ No newline at end of file
import unittest
from FTagAnalysisAlgorithms.FTagAnalysisConfig import parseTDPdatabase
from PathResolver import PathResolver
class TestParseTDPDatabase(unittest.TestCase):
def test_parseTDPdatabase(self):
# Run 2 tests
reference = {
410470: 'pythia8',
700168: 'sherpa2210',
600020: 'herwigpp713',
502957: 'amcatnlopythia8',
504337: 'herwigpp721',
999999: None,
}
for dsid in reference.keys():
result = parseTDPdatabase('dev/AnalysisTop/TopDataPreparation/XSection-MC16-13TeV_JESinfo.data',
dsid)
self.assertEqual(result, reference[dsid])
# Run 3 tests
reference = {
601229: 'pythia8',
700660: 'sherpa2210',
513105: 'amcatnlopythia8',
999999: None,
}
for dsid in reference.keys():
result = parseTDPdatabase('dev/AnalysisTop/TopDataPreparation/XSection-MC21-13p6TeV.data',
dsid)
self.assertEqual(result, reference[dsid])
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment