diff --git a/Core/include/Acts/Utilities/FiniteStateMachine.hpp b/Core/include/Acts/Utilities/FiniteStateMachine.hpp index c8fdcd66a74d94b72582c2a01f0a55835774869f..3ebaa2da7bbad6a177bcfb7d97b8a15dd52293dc 100644 --- a/Core/include/Acts/Utilities/FiniteStateMachine.hpp +++ b/Core/include/Acts/Utilities/FiniteStateMachine.hpp @@ -8,6 +8,8 @@ #pragma once +#include "Acts/Utilities/TypeTraits.hpp" + #include <optional> #include <string_view> #include <variant> @@ -37,21 +39,37 @@ class FiniteStateMachine { const StateVariant& getState() const noexcept { return m_state; } StateVariant& getState() noexcept { return m_state; } + private: + template <typename T, typename S, typename... Args> + using on_exit_t = decltype( + std::declval<T>().on_exit(std::declval<S&>(), std::declval<Args>()...)); - template <typename... Args> - void setState(StateVariant state, Args&&... args) { + template <typename T, typename S, typename... Args> + using on_enter_t = decltype( + std::declval<T>().on_enter(std::declval<S&>(), std::declval<Args>()...)); + + public: + template <typename State, typename... Args> + void setState(State state, Args&&... args) { Derived& child = static_cast<Derived&>(*this); // call on exit function - std::visit([&](auto& s) { child.on_exit(s, std::forward<Args>(args)...); }, - m_state); + std::visit( + [&](auto& s) { + using state_type = decltype(s); + if constexpr (concept ::exists<on_exit_t, Derived, state_type, + Args...>) { + child.on_exit(s, std::forward<Args>(args)...); + } + }, + m_state); - // no change state m_state = std::move(state); - // call on enter function - std::visit([&](auto& s) { child.on_enter(s, std::forward<Args>(args)...); }, - m_state); + // call on enter function, the type is known from the template argument. + if constexpr (concept ::exists<on_enter_t, Derived, State, Args...>) { + child.on_enter(std::get<State>(m_state), std::forward<Args>(args)...); + } } template <typename S> @@ -64,21 +82,57 @@ class FiniteStateMachine { bool terminated() const noexcept { return is(Terminated{}); } + private: + template <typename T, typename S, typename E, typename... Args> + using on_event_t = decltype(std::declval<T>().on_event( + std::declval<S&>(), std::declval<E&>(), std::declval<Args>()...)); + + template <typename T, typename... Args> + using on_process_t = + decltype(std::declval<T>().on_process(std::declval<Args>()...)); + + protected: template <typename Event, typename... Args> event_return process_event(Event&& event, Args&&... args) { Derived& child = static_cast<Derived&>(*this); - child.log(event); + + if constexpr (concept ::exists<on_process_t, Derived, Event>) { + child.on_process(event); + } + auto new_state = std::visit( [&](auto& s) -> std::optional<StateVariant> { - auto s2 = child.on_event(s, std::forward<Event>(event), - std::forward<Args>(args)...); - - if (s2) { - std::visit([&](auto& s2_) { child.log(s, event, s2_); }, *s2); + using state_type = decltype(s); + + if constexpr (concept ::exists<on_event_t, Derived, state_type, Event, + Args...>) { + auto s2 = child.on_event(s, std::forward<Event>(event), + std::forward<Args>(args)...); + + if (s2) { + std::visit( + [&](auto& s2_) { + if constexpr (concept ::exists<on_process_t, Derived, + state_type, Event, + decltype(s2_)>) { + child.on_process(s, event, s2_); + } + }, + *s2); + } else { + if constexpr (concept ::exists<on_process_t, Derived, state_type, + Event>) { + child.on_process(s, event); + } + } + return std::move(s2); } else { - child.log(s, event); + if constexpr (concept ::exists<on_process_t, Derived, state_type, + Event, Terminated>) { + child.on_process(s, event, Terminated{}); + } + return Terminated{}; } - return std::move(s2); }, m_state); return std::move(new_state); @@ -88,7 +142,9 @@ class FiniteStateMachine { void dispatch(Event&& event, Args&&... args) { auto new_state = process_event(std::forward<Event>(event), args...); if (new_state) { - setState(std::move(*new_state), std::forward<Args>(args)...); + std::visit( + [&](auto& s) { setState(std::move(s), std::forward<Args>(args)...); }, + *new_state); } } diff --git a/Tests/Core/Utilities/FiniteStateMachineTests.cpp b/Tests/Core/Utilities/FiniteStateMachineTests.cpp index 4b9dd5eb0092476884bbf7a1b2d3736e8c4ddc1c..a745f9168233692743731390de59841b815cfed9 100644 --- a/Tests/Core/Utilities/FiniteStateMachineTests.cpp +++ b/Tests/Core/Utilities/FiniteStateMachineTests.cpp @@ -69,20 +69,6 @@ struct fsm : FiniteStateMachine<fsm, states::Disconnected, states::Connecting, event_return on_event(const states::Connected&, const events::Disconnect&) { return states::Disconnected{}; } - - template <typename State, typename Event> - event_return on_event(const State&, const Event&) const { - return Terminated{}; - } - - template <typename State, typename... Args> - void on_enter(const State&, Args&&...) {} - - template <typename State, typename... Args> - void on_exit(const State&, Args&&...) {} - - template <typename... Args> - void log(Args&&...) {} }; BOOST_AUTO_TEST_SUITE(Utilities) @@ -107,7 +93,7 @@ BOOST_AUTO_TEST_CASE(Transitions) { BOOST_CHECK(sm.is(states::Disconnected{})); } -BOOST_AUTO_TEST_CASE(Terminted) { +BOOST_AUTO_TEST_CASE(Terminated) { fsm sm{}; BOOST_CHECK(sm.is(states::Disconnected{})); @@ -139,14 +125,6 @@ struct fsm2 void on_enter(const Terminated&, Args&&...) { throw std::runtime_error("FSM terminated!"); } - - template <typename State, typename... Args> - void on_enter(const State&, Args&&...) {} - - template <typename State, typename... Args> - void on_exit(const State&, Args&&...) {} - template <typename... Args> - void log(Args&&...) {} }; BOOST_AUTO_TEST_CASE(Arguments) { @@ -187,9 +165,11 @@ struct E3 {}; struct fsm3 : FiniteStateMachine<fsm3, S1, S2, S3> { bool on_exit_called = false; bool on_enter_called = false; + bool on_process_called = false; void reset() { on_exit_called = false; on_enter_called = false; + on_process_called = false; } // S1 + E1 = S2 @@ -210,13 +190,6 @@ struct fsm3 : FiniteStateMachine<fsm3, S1, S2, S3> { // external transition event_return on_event(const S2&, const E3&) { return S3{}; } - // catchers - - template <typename State, typename Event, typename... Args> - event_return on_event(const State&, const Event&, Args&&...) const { - return Terminated{}; - } - template <typename State, typename... Args> void on_enter(const State&, Args&&...) { on_enter_called = true; @@ -226,8 +199,11 @@ struct fsm3 : FiniteStateMachine<fsm3, S1, S2, S3> { void on_exit(const State&, Args&&...) { on_exit_called = true; } + template <typename... Args> - void log(Args&&...) {} + void on_process(Args&&...) { + on_process_called = true; + } }; BOOST_AUTO_TEST_CASE(InternalTransitions) { @@ -238,6 +214,7 @@ BOOST_AUTO_TEST_CASE(InternalTransitions) { BOOST_CHECK(sm.is(S2{})); BOOST_CHECK(sm.on_exit_called); BOOST_CHECK(sm.on_enter_called); + BOOST_CHECK(sm.on_process_called); sm.reset(); @@ -247,6 +224,7 @@ BOOST_AUTO_TEST_CASE(InternalTransitions) { // on_enter / exit should have been called BOOST_CHECK(sm.on_exit_called); BOOST_CHECK(sm.on_enter_called); + BOOST_CHECK(sm.on_process_called); sm.reset(); sm.dispatch(E2{}); @@ -255,6 +233,7 @@ BOOST_AUTO_TEST_CASE(InternalTransitions) { // on_enter / exit should NOT have been called BOOST_CHECK(!sm.on_exit_called); BOOST_CHECK(!sm.on_enter_called); + BOOST_CHECK(sm.on_process_called); sm.reset(); sm.dispatch(E3{}); @@ -262,7 +241,19 @@ BOOST_AUTO_TEST_CASE(InternalTransitions) { // on_enter / exit should have been called BOOST_CHECK(sm.on_exit_called); BOOST_CHECK(sm.on_enter_called); + BOOST_CHECK(sm.on_process_called); + + sm.setState(S1{}); sm.reset(); + BOOST_CHECK(sm.is(S1{})); + // dispatch invalid event + sm.dispatch(E3{}); + // should be terminated now + BOOST_CHECK(sm.terminated()); + // hooks should have fired + BOOST_CHECK(sm.on_exit_called); + BOOST_CHECK(sm.on_enter_called); + BOOST_CHECK(sm.on_process_called); } BOOST_AUTO_TEST_SUITE_END()