diff --git a/code/framework/src/utils/states/machine.cpp b/code/framework/src/utils/states/machine.cpp index 5e048282d..40a2dfedf 100644 --- a/code/framework/src/utils/states/machine.cpp +++ b/code/framework/src/utils/states/machine.cpp @@ -7,78 +7,110 @@ */ #include "machine.h" - #include namespace Framework::Utils::States { - Machine::Machine(): _currentState(nullptr), _nextState(nullptr), _currentContext(Context::Enter) {} + Machine::Machine(): _currentState(nullptr), _nextState(nullptr), _currentContext(Context::Enter), _isUpdating(false) {} Machine::~Machine() { + std::lock_guard lock(_mutex); _states.clear(); } bool Machine::RequestNextState(int32_t stateId) { - // Has the state been registered? - const auto it = _states.find(stateId); + std::lock_guard lock(_mutex); + + auto it = _states.find(stateId); if (it == _states.end()) { return false; } - // The state has already been requested - if ((*it).second == _nextState) { + if (it->second.get() == _nextState) { return false; } - // Already transitionning to a new state, so we cannot request a new one. if (_nextState != nullptr) { return false; } - // Mark it for processing and force the actual state to exit - _nextState = (*it).second; + _nextState = it->second.get(); _currentContext = Context::Exit; Framework::Logging::GetInstance()->Get(FRAMEWORK_INNER_UTILS)->debug("[StateMachine] Requesting new state {}", _nextState->GetName()); return true; } - bool Machine::Update() { - if (_currentState != nullptr) { - // Otherwise, we just process the current state - if (_currentContext == Context::Enter) { - // If init succeed, next context is obviously the update, otherwise it means that something failed and - // exit is required - _currentContext = _currentState->OnEnter(this) ? Context::Update : Context::Exit; + bool Machine::Update() { + std::unique_lock lock(_mutex); + + if (_isUpdating) { + return false; + } + + _isUpdating = true; + + IState *currentState = _currentState; + IState *nextState = _nextState; + Context currentContext = _currentContext; + + if (!currentState && !nextState) { + _isUpdating = false; + return false; + } + + if (!currentState && nextState) { + _currentState = nextState; + _currentContext = Context::Enter; + _nextState = nullptr; + _isUpdating = false; + return true; + } + + lock.unlock(); + + bool result = false; + try { + switch (currentContext) { + case Context::Enter: { + result = currentState->OnEnter(this); + lock.lock(); + _currentContext = result ? Context::Update : Context::Exit; + break; } - else if (_currentContext == Context::Update) { - // If the state answer true to update call, it means that it willed only a single tick, otherwise we - // keep ticking - if (_currentState->OnUpdate(this)) { + case Context::Update: { + result = currentState->OnUpdate(this); + lock.lock(); + if (result) _currentContext = Context::Exit; - } + break; } - else if (_currentContext == Context::Exit) { - _currentState->OnExit(this); + case Context::Exit: { + result = currentState->OnExit(this); + lock.lock(); _currentContext = Context::Next; + break; } - else if (_currentContext == Context::Next) { - _currentState = _nextState; + case Context::Next: { + lock.lock(); + _currentState = nextState; _currentContext = Context::Enter; _nextState = nullptr; + break; } - else { + default: + lock.lock(); + _isUpdating = false; return false; } } - else if (_nextState != nullptr) { - _currentState = _nextState; - _currentContext = Context::Enter; - _nextState = nullptr; - } - else { - return false; + catch (const std::exception &e) { + Framework::Logging::GetInstance()->Get(FRAMEWORK_INNER_UTILS)->error("[StateMachine] Error in state {}: {}", currentState ? currentState->GetName() : "null", e.what()); + lock.lock(); + _isUpdating = false; + throw; } + _isUpdating = false; return true; } } // namespace Framework::Utils::States diff --git a/code/framework/src/utils/states/machine.h b/code/framework/src/utils/states/machine.h index 64dc26c0e..456832dca 100644 --- a/code/framework/src/utils/states/machine.h +++ b/code/framework/src/utils/states/machine.h @@ -12,43 +12,61 @@ #include #include +#include namespace Framework::Utils::States { + class StateTransitionError : public std::runtime_error { + using std::runtime_error::runtime_error; + }; + enum class Context { Enter, Update, Exit, - Next // This one does not keep calling, it requests going to the next one + Next }; class Machine { private: - std::map> _states; - - std::shared_ptr _currentState; - std::shared_ptr _nextState; - + mutable std::mutex _mutex; + std::map> _states; + IState* _currentState; + IState* _nextState; Context _currentContext; + bool _isUpdating; public: Machine(); ~Machine(); + // Prevent copying + Machine(const Machine&) = delete; + Machine& operator=(const Machine&) = delete; + bool RequestNextState(int32_t); template void RegisterState() { - auto ptr = std::make_shared(); - _states.insert(std::make_pair(ptr->GetId(), ptr)); + std::lock_guard lock(_mutex); + auto ptr = std::make_unique(); + int32_t id = ptr->GetId(); + + if (_states.find(id) != _states.end()) { + throw std::runtime_error("State ID already registered"); + } + + _states.emplace(id, std::move(ptr)); } bool Update(); - std::shared_ptr GetCurrentState() const { + const IState* GetCurrentState() const { + std::lock_guard lock(_mutex); return _currentState; } - std::shared_ptr GetNextState() const { + const IState* GetNextState() const { + std::lock_guard lock(_mutex); return _nextState; } }; diff --git a/code/tests/framework_ut.cpp b/code/tests/framework_ut.cpp index 89530bb4d..8c6d34186 100644 --- a/code/tests/framework_ut.cpp +++ b/code/tests/framework_ut.cpp @@ -6,13 +6,14 @@ * See LICENSE file in the source repository for information regarding licensing. */ -#define UNIT_MAX_MODULES 2 +#define UNIT_MAX_MODULES 3 #include "logging/logger.h" #include "unit.h" /* TEST CATEGORIES */ #include "modules/interpolator_ut.h" #include "modules/scripting_module_ut.h" +#include "modules/state_machine_ut.h" int main() { UNIT_CREATE("FrameworkTests"); @@ -21,6 +22,7 @@ int main() { UNIT_MODULE(interpolator); UNIT_MODULE(scripting_module); + UNIT_MODULE(state_machine); return UNIT_RUN(); } diff --git a/code/tests/modules/state_machine_ut.h b/code/tests/modules/state_machine_ut.h new file mode 100644 index 000000000..9165092fe --- /dev/null +++ b/code/tests/modules/state_machine_ut.h @@ -0,0 +1,235 @@ +/* + * MafiaHub OSS license + * Copyright (c) 2021-2023, MafiaHub. All rights reserved. + * + * This file comes from MafiaHub, hosted at https://github.com/MafiaHub/Framework. + * See LICENSE file in the source repository for information regarding licensing. + */ + +#pragma once + +#include "utils/states/machine.h" +#include +#include + +// Test states +class InitialState: public Framework::Utils::States::IState { +public: + const char *GetName() const override { + return "Initial"; + } + int32_t GetId() const override { return 1; } + bool OnEnter(Framework::Utils::States::Machine* machine) override { return true; } + bool OnUpdate(Framework::Utils::States::Machine* machine) override { return false; } + bool OnExit(Framework::Utils::States::Machine* machine) override { return true; } +}; + +class ProcessingState: public Framework::Utils::States::IState { +private: + static std::atomic _counter; +public: + static void ResetCounter() { _counter = 0; } + static int GetCounter() { return _counter.load(); } + + const char *GetName() const override { + return "Processing"; + } + int32_t GetId() const override { return 2; } + bool OnEnter(Framework::Utils::States::Machine* machine) override { + _counter++; + return true; + } + bool OnUpdate(Framework::Utils::States::Machine* machine) override { return false; } + bool OnExit(Framework::Utils::States::Machine* machine) override { return true; } +}; + +class FailingState: public Framework::Utils::States::IState { +private: + static std::atomic _failures; +public: + static void ResetFailures() { _failures = 0; } + static int GetFailures() { return _failures.load(); } + + const char *GetName() const override { + return "Failing"; + } + int32_t GetId() const override { return 3; } + bool OnEnter(Framework::Utils::States::Machine* machine) override { return true; } + bool OnUpdate(Framework::Utils::States::Machine* machine) override { + _failures++; + return true; + } + bool OnExit(Framework::Utils::States::Machine* machine) override { return true; } +}; + +// Initialize static members +std::atomic ProcessingState::_counter(0); +std::atomic FailingState::_failures(0); + +MODULE(state_machine, { + using namespace Framework::Utils::States; + + IT("can register and transition between states", { + auto machine = std::make_unique(); + machine->RegisterState(); + + EQUALS(machine->GetCurrentState(), nullptr); + + // Request initial state + EQUALS(machine->RequestNextState(1), true); + machine->Update(); // Enter InitialState + + auto state = machine->GetCurrentState(); + EQUALS(state != nullptr, true); + EQUALS(state->GetId(), 1); + EQUALS(state->GetName(), "Initial"); + }); + + IT("handles invalid state transitions gracefully", { + auto machine = std::make_unique(); + machine->RegisterState(); + + // Try to request non-existent state + EQUALS(machine->RequestNextState(999), false); + + // Request valid state + EQUALS(machine->RequestNextState(1), true); + + // Try to request another state while transition is pending + EQUALS(machine->RequestNextState(1), false); + }); + + IT("executes state lifecycle correctly", { + auto machine = std::make_unique(); + ProcessingState::ResetCounter(); + + machine->RegisterState(); + machine->RegisterState(); + + // Start Initial state + EQUALS(machine->RequestNextState(1), true); + machine->Update(); // Enter + EQUALS(machine->GetCurrentState()->GetId(), 1); + + // Request transition while in Initial state + EQUALS(machine->RequestNextState(2), true); + machine->Update(); // Should move to Exit state since we requested next + machine->Update(); // Should move to Next state + machine->Update(); // Should Enter ProcessingState + + EQUALS(ProcessingState::GetCounter(), 1); + EQUALS(machine->GetCurrentState()->GetId(), 2); + }); + + IT("handles failing states properly", { + auto machine = std::make_unique(); + FailingState::ResetFailures(); + + machine->RegisterState(); + machine->RegisterState(); + + // Initial state + EQUALS(machine->RequestNextState(1), true); + machine->Update(); // Enter + machine->Update(); // Update + + // Transition to failing state + EQUALS(machine->RequestNextState(3), true); + machine->Update(); // Exit Initial + machine->Update(); // Enter Failing + machine->Update(); // Update Failing (should increment counter and request exit) + machine->Update(); // Exit Failing + + EQUALS(FailingState::GetFailures(), 1); + }); + + IT("can handle rapid state transitions", { + auto machine = std::make_unique(); + ProcessingState::ResetCounter(); + + machine->RegisterState(); + machine->RegisterState(); + + // Perform rapid transitions + for(int i = 0; i < 1000; i++) { + machine->RequestNextState(1); + machine->Update(); // Enter + machine->Update(); // Update + + machine->RequestNextState(2); + machine->Update(); // Exit Initial + machine->Update(); // Enter Processing + machine->Update(); // Update Processing + } + + EQUALS(ProcessingState::GetCounter(), 1000); + }); + + IT("maintains thread safety under concurrent access", { + auto machine = std::make_unique(); + ProcessingState::ResetCounter(); + std::atomic running = true; + + machine->RegisterState(); + machine->RegisterState(); + + // Thread constantly updating the state machine + std::thread updater([&]() { + while(running) { + machine->Update(); + std::this_thread::yield(); + } + }); + + // Thread requesting state transitions + std::thread requester([&]() { + for(int i = 0; i < 100; i++) { + machine->RequestNextState(1); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + machine->RequestNextState(2); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + }); + + requester.join(); + running = false; + updater.join(); + + EQUALS(ProcessingState::GetCounter() > 0, true); + }); + + IT("handles recursive state updates safely", { + auto machine = std::make_unique(); + ProcessingState::ResetCounter(); + + // Create a state that tries to perform state machine operations during callbacks + class RecursiveState: public Framework::Utils::States::IState { + public: + const char *GetName() const override { + return "Recursive"; + } + int32_t GetId() const override { + return 4; + } + bool OnEnter(Machine *machine) override { + // Try recursive update + machine->Update(); + return true; + } + bool OnUpdate(Machine *machine) override { + return true; + } + bool OnExit(Machine *machine) override { + return true; + } + }; + + machine->RegisterState(); + machine->RegisterState(); + + // This should not deadlock + EQUALS(machine->RequestNextState(4), true); + machine->Update(); + machine->Update(); + }); +});