From 3bfc273dda2adcc09d92b7f4416683d2ed31f448 Mon Sep 17 00:00:00 2001 From: Matthieu HERNANDEZ Date: Sun, 10 Dec 2023 00:14:58 +0100 Subject: [PATCH] Improve logs --- .../StraightforwardNeuralNetwork.cpp | 38 +++++++----------- .../StraightforwardNeuralNetwork.hpp | 40 ++++++++++++++++++- src/neural_network/Wait.cpp | 27 ++++++++++--- src/neural_network/Wait.hpp | 9 +++-- src/tools/Tools.hpp | 18 +++++++-- 5 files changed, 97 insertions(+), 35 deletions(-) diff --git a/src/neural_network/StraightforwardNeuralNetwork.cpp b/src/neural_network/StraightforwardNeuralNetwork.cpp index 402ec566..7f86db71 100644 --- a/src/neural_network/StraightforwardNeuralNetwork.cpp +++ b/src/neural_network/StraightforwardNeuralNetwork.cpp @@ -95,16 +95,8 @@ void StraightforwardNeuralNetwork::trainSync(Data& data, Wait wait, const int ba this->numberOfTrainingsBetweenTwoEvaluations = data.sets[training].size; this->wantToStopTraining = false; this->isIdle = false; - if (evaluationFrequency > 0) - { - this->evaluate(data); - log("Epoch: ", toConstSizeString(this->epoch, 2), - " - Accuracy: ", toConstSizeString(this->getGlobalClusteringRate(), 6), - " - MAE: ", toConstSizeString(this->getMeanAbsoluteError(), 7), - " - Time: ", toConstSizeString(wait.getDurationSinceLastTime(), 2), "s"); - } - + this->evaluate(data, &wait); for (this->epoch = 1; this->continueTraining(wait); this->epoch++) { data.shuffle(); @@ -123,27 +115,16 @@ void StraightforwardNeuralNetwork::trainSync(Data& data, Wait wait, const int ba data.getTrainingOutputs(this->index, batchSize), data.isFirstTrainingDataOfTemporalSequence(this->index)); else this->outputForTraining(data.getTrainingData(this->index, batchSize), data.isFirstTrainingDataOfTemporalSequence(this->index)); + this->logInProgress(wait, data, training); } if (evaluationFrequency > 0 && this->epoch % evaluationFrequency == 0) - { - this->evaluate(data); - log("Epoch: ", toConstSizeString(this->epoch, 2), - " - Accuracy: ", toConstSizeString(this->getGlobalClusteringRate(), 6), - " - MAE: ", toConstSizeString(this->getMeanAbsoluteError(), 7), - " - Time: ", toConstSizeString(wait.getDurationSinceLastTime(), 2), "s"); - if (this->autoSaveWhenBetter && this->globalClusteringRateIsBetterThanMax) - { - this->saveSync(autoSaveFilePath); - log(" - Saved"); - } - log(); - } + this->evaluate(data, &wait); } this->resetTrainingValues(); log("Stop training"); } -void StraightforwardNeuralNetwork::evaluate(const Data& data) +void StraightforwardNeuralNetwork::evaluate(const Data& data, Wait* wait) { this->startTesting(); for (this->index = 0; this->index < data.sets[testing].size; this->index++) @@ -159,8 +140,19 @@ void StraightforwardNeuralNetwork::evaluate(const Data& data) this->evaluateOnce(data); else this->output(data.getTestingData(this->index), data.isFirstTestingDataOfTemporalSequence(this->index)); + this->logInProgress(*wait, data, testing); } this->stopTesting(); + if (this->autoSaveWhenBetter && this->globalClusteringRateIsBetterThanMax) + { + this->saveSync(autoSaveFilePath); + if(wait != nullptr) + this->logAccuracy(*wait, true); + } + else + { + this->logAccuracy(*wait, false); + } } inline diff --git a/src/neural_network/StraightforwardNeuralNetwork.hpp b/src/neural_network/StraightforwardNeuralNetwork.hpp index 47da6c87..6405917b 100644 --- a/src/neural_network/StraightforwardNeuralNetwork.hpp +++ b/src/neural_network/StraightforwardNeuralNetwork.hpp @@ -27,11 +27,17 @@ namespace snn void trainSync(Data& data, Wait wait, int batchSize, int evaluationFrequency); void saveSync(std::string filePath); + void evaluate(const Data& data, Wait* wait); void evaluateOnce(const Data& data); bool continueTraining(Wait wait) const; void validData(const Data& data, int batchSize) const; + template + void logAccuracy(Wait& wait, const bool hasSaved) const; + template + void logInProgress(Wait& wait, const Data& data, set set) const; + friend class boost::serialization::access; template void serialize(Archive& ar, unsigned version); @@ -56,7 +62,7 @@ namespace snn void waitFor(Wait wait) const; void train(Data& data, Wait wait, int batchSize = 1, int evaluationFrequency = 1); - void evaluate(const Data& data); + void evaluate(const Data& data) { return this->evaluate(data, nullptr); } std::vector computeOutput(const std::vector& inputs, bool temporalReset = false); int computeCluster(const std::vector& inputs, bool temporalReset = false); @@ -82,6 +88,38 @@ namespace snn bool operator!=(const StraightforwardNeuralNetwork& neuralNetwork) const; }; + + template + void StraightforwardNeuralNetwork::logAccuracy(Wait& wait, const bool hasSaved) const + { + if constexpr (T > none && T <= verbose) + { + tools::log("\rEpoch: ", tools::toConstSizeString(this->epoch, 2), + " - Accuracy: ", tools::toConstSizeString<2>(this->getGlobalClusteringRate(), 4), + " - MAE: ", tools::toConstSizeString<4>(this->getMeanAbsoluteError(), 7), + " - Time: ", tools::toConstSizeString<0>(wait.getDurationAndReset(), 2), "s"); + if (hasSaved) + tools::log(" - Saved"); + tools::log(); + } + } + + template + void StraightforwardNeuralNetwork::logInProgress(Wait& wait, const Data& data, set set) const + { + if constexpr (T > none && T <= verbose) + { + if (wait.tick() >= 100) + { + const std::string name = set == training ? "Training " : "Evaluation"; + const int progress = static_cast(this->index / static_cast(data.sets[set].size) * 100); + tools::log("\rEpoch: ", tools::toConstSizeString(this->epoch, 2), + " - ", name, " in progress... ", tools::toConstSizeString(progress, 2), "%", + " - Time: ", tools::toConstSizeString<0>(wait.getDuration(), 2), "s"); + } + } + } + template void StraightforwardNeuralNetwork::serialize(Archive& ar, [[maybe_unused]] const unsigned version) { diff --git a/src/neural_network/Wait.cpp b/src/neural_network/Wait.cpp index a1632260..36d7eaf2 100644 --- a/src/neural_network/Wait.cpp +++ b/src/neural_network/Wait.cpp @@ -58,7 +58,8 @@ Wait& Wait::operator&&(const Wait& wait) void Wait::startClock() { this->start = system_clock::now(); - this->last = this->start; + this->lastTick = this->start; + this->lastReset = this->start; } bool Wait::isOver(int currentEpochs, float CurrentAccuracy, float currentMae) const @@ -86,12 +87,28 @@ bool Wait::isOver(int currentEpochs, float CurrentAccuracy, float currentMae) co return false; } -int Wait::getDurationSinceLastTime() +int Wait::tick() { const auto now = system_clock::now(); - const auto currentDuration = static_cast(duration_cast(now - this->last).count()); - this->last = now; - return currentDuration; + const auto tickDuration = static_cast(duration_cast(now - this->lastTick).count()); + return tickDuration; +} + +float Wait::getDuration() +{ + const auto now = system_clock::now(); + const auto currentDuration = static_cast(duration_cast(now - this->lastReset).count()); + this->lastTick = now; + return currentDuration / 1000.0f; +} + +float Wait::getDurationAndReset() +{ + const auto now = system_clock::now(); + const auto currentDuration = static_cast(duration_cast(now - this->lastReset).count()); + this->lastTick = now; + this->lastReset = now; + return currentDuration / 1000.0f; } Wait snn::operator""_ep(unsigned long long value) diff --git a/src/neural_network/Wait.hpp b/src/neural_network/Wait.hpp index aa3c49a7..2b26d10b 100644 --- a/src/neural_network/Wait.hpp +++ b/src/neural_network/Wait.hpp @@ -17,13 +17,16 @@ namespace snn float mae = -1; int duration = -1; std::chrono::time_point start; - std::chrono::time_point last; + std::chrono::time_point lastReset; + std::chrono::time_point lastTick; waitOperator op = waitOperator::noneOp; Wait& operator||(const Wait& wait); Wait& operator&&(const Wait& wait); void startClock(); - bool isOver(int currentEpochs, float CurrentAccuracy, float currentMae) const; - int getDurationSinceLastTime(); + [[nodiscard]] bool isOver(int currentEpochs, float CurrentAccuracy, float currentMae) const; + [[nodiscard]] int tick(); // Time since last tick in milliseconds + [[nodiscard]] float getDuration(); + [[nodiscard]] float getDurationAndReset(); }; extern Wait operator""_ep(unsigned long long value); diff --git a/src/tools/Tools.hpp b/src/tools/Tools.hpp index f88c6220..4888f4b3 100644 --- a/src/tools/Tools.hpp +++ b/src/tools/Tools.hpp @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -117,6 +118,8 @@ namespace snn::tools (std::cout << ... << messages); if constexpr (endLine) std::cout << std::endl; + else + std::cout << std::flush; } } @@ -129,12 +132,21 @@ namespace snn::tools return str; } - inline + template std::string toConstSizeString(float value, size_t length) { - auto str = std::to_string(value); + std::string str; + if constexpr (T == 0) + str = std::format("{:.0f}", value); + else if constexpr (T == 2) + str = std::format("{:.2f}", value); + else if constexpr (T == 4) + str = std::format("{:.4f}", value); + else + throw std::exception(); + while (str.length() < length) - str += "0"; + str = " " + str; return str; }