Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:MafiaHub/Framework into improvem…
Browse files Browse the repository at this point in the history
…ent/job-system-revamp
  • Loading branch information
Segfaultd committed Dec 27, 2024
2 parents 5c4a0ec + c649d22 commit fa2c15d
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 44 deletions.
98 changes: 65 additions & 33 deletions code/framework/src/utils/states/machine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,78 +7,110 @@
*/

#include "machine.h"

#include <logging/logger.h>

namespace Framework::Utils::States {
Machine::Machine(): _currentState(nullptr), _nextState(nullptr), _currentContext(Context::Enter) {}
Machine::Machine(): _currentState(nullptr), _nextState(nullptr), _currentContext(Context::Enter), _isUpdating(false) {}

Machine::~Machine() {
std::lock_guard<std::mutex> lock(_mutex);
_states.clear();
}

bool Machine::RequestNextState(int32_t stateId) {
// Has the state been registered?
const auto it = _states.find(stateId);
std::lock_guard<std::mutex> lock(_mutex);

auto it = _states.find(stateId);
if (it == _states.end()) {
return false;
}

// The state has already been requested
if ((*it).second == _nextState) {
if (it->second.get() == _nextState) {
return false;
}

// Already transitionning to a new state, so we cannot request a new one.
if (_nextState != nullptr) {
return false;
}

// Mark it for processing and force the actual state to exit
_nextState = (*it).second;
_nextState = it->second.get();
_currentContext = Context::Exit;

Framework::Logging::GetInstance()->Get(FRAMEWORK_INNER_UTILS)->debug("[StateMachine] Requesting new state {}", _nextState->GetName());
return true;
}

bool Machine::Update() {
if (_currentState != nullptr) {
// Otherwise, we just process the current state
if (_currentContext == Context::Enter) {
// If init succeed, next context is obviously the update, otherwise it means that something failed and
// exit is required
_currentContext = _currentState->OnEnter(this) ? Context::Update : Context::Exit;
bool Machine::Update() {
std::unique_lock<std::mutex> lock(_mutex);

if (_isUpdating) {
return false;
}

_isUpdating = true;

IState *currentState = _currentState;
IState *nextState = _nextState;
Context currentContext = _currentContext;

if (!currentState && !nextState) {
_isUpdating = false;
return false;
}

if (!currentState && nextState) {
_currentState = nextState;
_currentContext = Context::Enter;
_nextState = nullptr;
_isUpdating = false;
return true;
}

lock.unlock();

bool result = false;
try {
switch (currentContext) {
case Context::Enter: {
result = currentState->OnEnter(this);
lock.lock();
_currentContext = result ? Context::Update : Context::Exit;
break;
}
else if (_currentContext == Context::Update) {
// If the state answer true to update call, it means that it willed only a single tick, otherwise we
// keep ticking
if (_currentState->OnUpdate(this)) {
case Context::Update: {
result = currentState->OnUpdate(this);
lock.lock();
if (result)
_currentContext = Context::Exit;
}
break;
}
else if (_currentContext == Context::Exit) {
_currentState->OnExit(this);
case Context::Exit: {
result = currentState->OnExit(this);
lock.lock();
_currentContext = Context::Next;
break;
}
else if (_currentContext == Context::Next) {
_currentState = _nextState;
case Context::Next: {
lock.lock();
_currentState = nextState;
_currentContext = Context::Enter;
_nextState = nullptr;
break;
}
else {
default:
lock.lock();
_isUpdating = false;
return false;
}
}
else if (_nextState != nullptr) {
_currentState = _nextState;
_currentContext = Context::Enter;
_nextState = nullptr;
}
else {
return false;
catch (const std::exception &e) {
Framework::Logging::GetInstance()->Get(FRAMEWORK_INNER_UTILS)->error("[StateMachine] Error in state {}: {}", currentState ? currentState->GetName() : "null", e.what());
lock.lock();
_isUpdating = false;
throw;
}

_isUpdating = false;
return true;
}
} // namespace Framework::Utils::States
38 changes: 28 additions & 10 deletions code/framework/src/utils/states/machine.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,61 @@

#include <map>
#include <memory>
#include <mutex>

namespace Framework::Utils::States {
class StateTransitionError : public std::runtime_error {
using std::runtime_error::runtime_error;
};

enum class Context {
Enter,
Update,
Exit,
Next // This one does not keep calling, it requests going to the next one
Next
};

class Machine {
private:
std::map<int32_t, std::shared_ptr<IState>> _states;

std::shared_ptr<IState> _currentState;
std::shared_ptr<IState> _nextState;

mutable std::mutex _mutex;
std::map<int32_t, std::unique_ptr<IState>> _states;
IState* _currentState;
IState* _nextState;
Context _currentContext;
bool _isUpdating;

public:
Machine();
~Machine();

// Prevent copying
Machine(const Machine&) = delete;
Machine& operator=(const Machine&) = delete;

bool RequestNextState(int32_t);

template <typename T>
void RegisterState() {
auto ptr = std::make_shared<T>();
_states.insert(std::make_pair(ptr->GetId(), ptr));
std::lock_guard<std::mutex> lock(_mutex);
auto ptr = std::make_unique<T>();
int32_t id = ptr->GetId();

if (_states.find(id) != _states.end()) {
throw std::runtime_error("State ID already registered");
}

_states.emplace(id, std::move(ptr));
}

bool Update();

std::shared_ptr<IState> GetCurrentState() const {
const IState* GetCurrentState() const {
std::lock_guard<std::mutex> lock(_mutex);
return _currentState;
}

std::shared_ptr<IState> GetNextState() const {
const IState* GetNextState() const {
std::lock_guard<std::mutex> lock(_mutex);
return _nextState;
}
};
Expand Down
4 changes: 3 additions & 1 deletion code/tests/framework_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
* See LICENSE file in the source repository for information regarding licensing.
*/

#define UNIT_MAX_MODULES 2
#define UNIT_MAX_MODULES 3
#include "logging/logger.h"
#include "unit.h"

/* TEST CATEGORIES */
#include "modules/interpolator_ut.h"
#include "modules/scripting_module_ut.h"
#include "modules/state_machine_ut.h"

int main() {
UNIT_CREATE("FrameworkTests");
Expand All @@ -21,6 +22,7 @@ int main() {

UNIT_MODULE(interpolator);
UNIT_MODULE(scripting_module);
UNIT_MODULE(state_machine);

return UNIT_RUN();
}
Loading

0 comments on commit fa2c15d

Please sign in to comment.