Skip to content

Commit

Permalink
Improve logs
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthieuHernandez committed Dec 9, 2023
1 parent df196af commit dd46b51
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 35 deletions.
38 changes: 15 additions & 23 deletions src/neural_network/StraightforwardNeuralNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,8 @@ void StraightforwardNeuralNetwork::trainSync(Data& data, Wait wait, const int ba
this->numberOfTrainingsBetweenTwoEvaluations = data.sets[training].size;
this->wantToStopTraining = false;
this->isIdle = false;

if (evaluationFrequency > 0)
{
this->evaluate(data);
log<minimal>("Epoch: ", toConstSizeString(this->epoch, 2),
" - Accuracy: ", toConstSizeString(this->getGlobalClusteringRate(), 6),
" - MAE: ", toConstSizeString(this->getMeanAbsoluteError(), 7),
" - Time: ", toConstSizeString(wait.getDurationSinceLastTime(), 2), "s");
}

this->evaluate(data, &wait);
for (this->epoch = 1; this->continueTraining(wait); this->epoch++)
{
data.shuffle();
Expand All @@ -123,27 +115,16 @@ void StraightforwardNeuralNetwork::trainSync(Data& data, Wait wait, const int ba
data.getTrainingOutputs(this->index, batchSize), data.isFirstTrainingDataOfTemporalSequence(this->index));
else
this->outputForTraining(data.getTrainingData(this->index, batchSize), data.isFirstTrainingDataOfTemporalSequence(this->index));
this->logInProgress<minimal>(wait, data, training);
}
if (evaluationFrequency > 0 && this->epoch % evaluationFrequency == 0)
{
this->evaluate(data);
log<minimal, false>("Epoch: ", toConstSizeString(this->epoch, 2),
" - Accuracy: ", toConstSizeString(this->getGlobalClusteringRate(), 6),
" - MAE: ", toConstSizeString(this->getMeanAbsoluteError(), 7),
" - Time: ", toConstSizeString(wait.getDurationSinceLastTime(), 2), "s");
if (this->autoSaveWhenBetter && this->globalClusteringRateIsBetterThanMax)
{
this->saveSync(autoSaveFilePath);
log<minimal, false>(" - Saved");
}
log<minimal>();
}
this->evaluate(data, &wait);
}
this->resetTrainingValues();
log<minimal>("Stop training");
}

void StraightforwardNeuralNetwork::evaluate(const Data& data)
void StraightforwardNeuralNetwork::evaluate(const Data& data, Wait* wait)
{
this->startTesting();
for (this->index = 0; this->index < data.sets[testing].size; this->index++)
Expand All @@ -159,8 +140,19 @@ void StraightforwardNeuralNetwork::evaluate(const Data& data)
this->evaluateOnce(data);
else
this->output(data.getTestingData(this->index), data.isFirstTestingDataOfTemporalSequence(this->index));
this->logInProgress<minimal>(*wait, data, testing);
}
this->stopTesting();
if (this->autoSaveWhenBetter && this->globalClusteringRateIsBetterThanMax)
{
this->saveSync(autoSaveFilePath);
if(wait != nullptr)
this->logAccuracy<minimal>(*wait, true);
}
else
{
this->logAccuracy<minimal>(*wait, false);
}
}

inline
Expand Down
40 changes: 39 additions & 1 deletion src/neural_network/StraightforwardNeuralNetwork.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,17 @@ namespace snn

void trainSync(Data& data, Wait wait, int batchSize, int evaluationFrequency);
void saveSync(std::string filePath);
void evaluate(const Data& data, Wait* wait);
void evaluateOnce(const Data& data);

bool continueTraining(Wait wait) const;
void validData(const Data& data, int batchSize) const;

template <logLevel T>
void logAccuracy(Wait& wait, const bool hasSaved) const;
template <logLevel T>
void logInProgress(Wait& wait, const Data& data, set set) const;

friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, unsigned version);
Expand All @@ -56,7 +62,7 @@ namespace snn
void waitFor(Wait wait) const;
void train(Data& data, Wait wait, int batchSize = 1, int evaluationFrequency = 1);

void evaluate(const Data& data);
void evaluate(const Data& data) { return this->evaluate(data, nullptr); }

std::vector<float> computeOutput(const std::vector<float>& inputs, bool temporalReset = false);
int computeCluster(const std::vector<float>& inputs, bool temporalReset = false);
Expand All @@ -82,6 +88,38 @@ namespace snn
bool operator!=(const StraightforwardNeuralNetwork& neuralNetwork) const;
};


template <logLevel T>
void StraightforwardNeuralNetwork::logAccuracy(Wait& wait, const bool hasSaved) const
{
if constexpr (T > none && T <= verbose)
{
tools::log<T, false>("\rEpoch: ", tools::toConstSizeString(this->epoch, 2),
" - Accuracy: ", tools::toConstSizeString<2>(this->getGlobalClusteringRate(), 4),
" - MAE: ", tools::toConstSizeString<4>(this->getMeanAbsoluteError(), 7),
" - Time: ", tools::toConstSizeString<0>(wait.getDurationAndReset(), 2), "s");
if (hasSaved)
tools::log<T, false>(" - Saved");
tools::log<T>();
}
}

template <logLevel T>
void StraightforwardNeuralNetwork::logInProgress(Wait& wait, const Data& data, set set) const
{
if constexpr (T > none && T <= verbose)
{
if (wait.tick() >= 100)
{
const std::string name = set == training ? "Training " : "Evaluation";
const int progress = static_cast<int>(this->index / static_cast<float>(data.sets[set].size) * 100);
tools::log<T, false>("\rEpoch: ", tools::toConstSizeString(this->epoch, 2),
" - ", name, " in progress... ", tools::toConstSizeString(progress, 2), "%",
" - Time: ", tools::toConstSizeString<0>(wait.getDuration(), 2), "s");
}
}
}

template <class Archive>
void StraightforwardNeuralNetwork::serialize(Archive& ar, [[maybe_unused]] const unsigned version)
{
Expand Down
27 changes: 22 additions & 5 deletions src/neural_network/Wait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ Wait& Wait::operator&&(const Wait& wait)
void Wait::startClock()
{
this->start = system_clock::now();
this->last = this->start;
this->lastTick = this->start;
this->lastReset = this->start;
}

bool Wait::isOver(int currentEpochs, float CurrentAccuracy, float currentMae) const
Expand Down Expand Up @@ -86,12 +87,28 @@ bool Wait::isOver(int currentEpochs, float CurrentAccuracy, float currentMae) co
return false;
}

int Wait::getDurationSinceLastTime()
int Wait::tick()
{
const auto now = system_clock::now();
const auto currentDuration = static_cast<int>(duration_cast<seconds>(now - this->last).count());
this->last = now;
return currentDuration;
const auto tickDuration = static_cast<int>(duration_cast<milliseconds>(now - this->lastTick).count());
return tickDuration;
}

float Wait::getDuration()
{
const auto now = system_clock::now();
const auto currentDuration = static_cast<int>(duration_cast<milliseconds>(now - this->lastReset).count());
this->lastTick = now;
return currentDuration / 1000.0f;
}

float Wait::getDurationAndReset()
{
const auto now = system_clock::now();
const auto currentDuration = static_cast<int>(duration_cast<milliseconds>(now - this->lastReset).count());
this->lastTick = now;
this->lastReset = now;
return currentDuration / 1000.0f;
}

Wait snn::operator""_ep(unsigned long long value)
Expand Down
9 changes: 6 additions & 3 deletions src/neural_network/Wait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ namespace snn
float mae = -1;
int duration = -1;
std::chrono::time_point<std::chrono::system_clock> start;
std::chrono::time_point<std::chrono::system_clock> last;
std::chrono::time_point<std::chrono::system_clock> lastReset;
std::chrono::time_point<std::chrono::system_clock> lastTick;
waitOperator op = waitOperator::noneOp;
Wait& operator||(const Wait& wait);
Wait& operator&&(const Wait& wait);
void startClock();
bool isOver(int currentEpochs, float CurrentAccuracy, float currentMae) const;
int getDurationSinceLastTime();
[[nodiscard]] bool isOver(int currentEpochs, float CurrentAccuracy, float currentMae) const;
[[nodiscard]] int tick(); // Time since last tick in milliseconds
[[nodiscard]] float getDuration();
[[nodiscard]] float getDurationAndReset();
};

extern Wait operator""_ep(unsigned long long value);
Expand Down
18 changes: 15 additions & 3 deletions src/tools/Tools.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include <charconv>
#include <chrono>
#include <iostream>
#include <random>
Expand Down Expand Up @@ -117,6 +118,8 @@ namespace snn::tools
(std::cout << ... << messages);
if constexpr (endLine)
std::cout << std::endl;
else
std::cout << std::flush;
}
}

Expand All @@ -129,12 +132,21 @@ namespace snn::tools
return str;
}

inline
template<int T>
std::string toConstSizeString(float value, size_t length)
{
auto str = std::to_string(value);
std::string str;
if constexpr (T == 0)
str = std::format("{:.0f}", value);
else if constexpr (T == 2)
str = std::format("{:.2f}", value);
else if constexpr (T == 4)
str = std::format("{:.4f}", value);
else
throw std::exception();

while (str.length() < length)
str += "0";
str = " " + str;
return str;
}

Expand Down

0 comments on commit dd46b51

Please sign in to comment.