Skip to content

Commit

Permalink
Fix some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthieuHernandez committed Nov 3, 2024
1 parent 5081c5e commit 52b128b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
7 changes: 4 additions & 3 deletions tests/dataset_tests/Fashion-MNIST/FashionMnistTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ TEST_F(FashionMnistTest, convolutionNeuralNetwork)
{
StraightforwardNeuralNetwork neuralNetwork({
Input(1, 28, 28),
Convolution(5, 5, activation::ReLU),
Convolution(8, 3, activation::ReLU),
Convolution(24, 5, activation::ReLU),
FullyConnected(10)
},
StochasticGradientDescent(0.002f, 0.95f));
StochasticGradientDescent(0.0002f, 0.80f));
neuralNetwork.train(*data, 1_ep || 20_s);
auto accuracy = neuralNetwork.getGlobalClusteringRate();
ASSERT_ACCURACY(accuracy, 0.75f);
ASSERT_ACCURACY(accuracy, 0.70f);
}

TEST_F(FashionMnistTest, DISABLED_trainBestNeuralNetwork)
Expand Down
7 changes: 4 additions & 3 deletions tests/unit_tests/AdditionTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ TEST(Addition, WithMPL)
auto data = createDataForAdditionTests();
StraightforwardNeuralNetwork neuralNetwork({
Input(2),
FullyConnected(8, activation::sigmoid),
FullyConnected(16, activation::sigmoid),
FullyConnected(1, activation::identity)
});
neuralNetwork.train(*data, 1.0_acc || 2_s, 3, 4);
},
StochasticGradientDescent(0.01f));
neuralNetwork.train(*data, 1.0_acc || 1_s, 3, 4);
testNeuralNetworkForAddition(neuralNetwork);
}

Expand Down
10 changes: 7 additions & 3 deletions tests/unit_tests/IdentityTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ TEST(Identity, WorksWithLotsOfNumbers)
Data data(problem::regression, inputData, expectedOutputs);
data.setPrecision(precision);

StraightforwardNeuralNetwork neuralNetwork({Input(1), FullyConnected(8), FullyConnected(1, snn::activation::identity)},
StochasticGradientDescent(0.0002f, 0.99f));
StraightforwardNeuralNetwork neuralNetwork({
Input(1),
FullyConnected(30),
FullyConnected(1, activation::identity)
},
StochasticGradientDescent(0.001f, 0.80f));

neuralNetwork.train(data, 1.00_acc || 2_s);

Expand All @@ -70,5 +74,5 @@ TEST(Identity, WorksWithLotsOfNumbers)
&& neuralNetwork.isValid() == 0)
ASSERT_SUCCESS();
else
ASSERT_FAIL("MAE > 1: " + to_string(mae));
ASSERT_FAIL("MAE > 0.4: " + to_string(mae));
}

0 comments on commit 52b128b

Please sign in to comment.