diff --git a/Control/RootUtils/CMakeLists.txt b/Control/RootUtils/CMakeLists.txt index c301bcad6f9c47fcecf0e32dd7c1d44ef4312c3d..52a1d7ba36d3328df5c7d9c60a514343c029f4ec 100644 --- a/Control/RootUtils/CMakeLists.txt +++ b/Control/RootUtils/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (C) 2002-2021 CERN for the benefit of the ATLAS collaboration +# Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration # Declare the package name: atlas_subdir( RootUtils ) @@ -61,5 +61,11 @@ atlas_add_test( WithRootErrorHandler_test INCLUDE_DIRS ${ROOT_INCLUDE_DIRS} LINK_LIBRARIES ${ROOT_LIBRARIES} RootUtils ) +atlas_add_test( TRandomTLS_test + SOURCES test/TRandomTLS_test.cxx + INCLUDE_DIRS ${ROOT_INCLUDE_DIRS} + LINK_LIBRARIES ${ROOT_LIBRARIES} CxxUtils RootUtils + POST_EXEC_SCRIPT nopost.sh ) + # Install files from the package: atlas_install_python_modules( python/*.py POST_BUILD_CMD ${ATLAS_FLAKE8} ) diff --git a/Control/RootUtils/RootUtils/TRandomTLS.h b/Control/RootUtils/RootUtils/TRandomTLS.h new file mode 100644 index 0000000000000000000000000000000000000000..96d5befd8a0272b88c46f9d3e6022c8c65aea575 --- /dev/null +++ b/Control/RootUtils/RootUtils/TRandomTLS.h @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration. + */ +/** + * @file RootUtils/TRandomTLS.h + * @author Frank Winklmeier + * @date Sept, 2022 + * @brief Thread-local TRandom generator. + */ + +#ifndef ROOTUTILS_TRANDOMTLS_H +#define ROOTUTILS_TRANDOMTLS_H + +#include <atomic> + +#include "Rtypes.h" +#include "boost/thread/tss.hpp" + +namespace RootUtils { + + /** + * Thread-local TRandom generator. + * + * This class provides a thread-local TRandom instance of type T. + * For each new instance/thread the seed is incremented to ensure + * independent random numbers. + */ + template <class T> + class TRandomTLS { + public: + /** + * Constructor + * + * @param seed The seed of the first TRandom instance. Additional instances + * are created with the seed incremented by one. If the seed is 0, + * it will not be incremented as ROOT then uses a time-based seed. + */ + TRandomTLS(UInt_t seed = 4357) : m_seed(seed) {} + + /// Destructor + ~TRandomTLS() = default; + + /// Get thread-specific TRandom + T* get() const; + T* operator->() const { return get(); } + T& operator*() const { return *get(); } + + private: + /// Thread-local TRandom + mutable boost::thread_specific_ptr<T> m_rand_tls; + + /// TRandom seed (incremented for each new instance/thread) + mutable std::atomic<UInt_t> m_seed; + }; + + + // + // Inline methods + // + + template <class T> + inline T* TRandomTLS<T>::get() const + { + T* random = m_rand_tls.get(); + if (!random) { + random = new T(m_seed > 0 ? m_seed++ : 0); + m_rand_tls.reset(random); + } + return random; + } + +} // namespace RootUtils + +#endif diff --git a/Control/RootUtils/test/TRandomTLS_test.cxx b/Control/RootUtils/test/TRandomTLS_test.cxx new file mode 100644 index 0000000000000000000000000000000000000000..b4ee35695d437c9a16b83da8802f5adbc0605be3 --- /dev/null +++ b/Control/RootUtils/test/TRandomTLS_test.cxx @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration. + */ + +#undef NDEBUG + +#include "CxxUtils/checker_macros.h" +#include "RootUtils/TRandomTLS.h" +#include "TRandom3.h" + +#include <cassert> +#include <iostream> +#include <mutex> +#include <set> +#include <thread> + +/// Helper class +struct Test { + void rnd() const + { + const int v = m_rnd->Integer(1000); + std::scoped_lock lock(m_mutex); + m_values.insert(v); + } + RootUtils::TRandomTLS<TRandom3> m_rnd; + mutable std::mutex m_mutex; + mutable std::set<int> m_values ATLAS_THREAD_SAFE; +}; + + +void test_compilation() +{ + RootUtils::TRandomTLS<TRandom3> rnd(42); + rnd->Rndm(); + (*rnd).Rndm(); + assert(rnd.get()); +} + + +void test_unique() +{ + Test test; + std::vector<std::thread> threads; + constexpr int Nthreads = 4; + + // Launch threads and wait + for (size_t i = 0; i < Nthreads; i++) { + threads.emplace_back(&Test::rnd, &test); + } + for (auto& th : threads) th.join(); + + // Each thread should have created a unique random number + assert(test.m_values.size() == Nthreads); + + for (int v : test.m_values) { + std::cout << v << std::endl; + } +} + + +int main() +{ + test_compilation(); + test_unique(); + + return 0; +}