diff --git a/tests/dataset_tests/Fashion-MNIST/FashionMnistTest.cpp b/tests/dataset_tests/Fashion-MNIST/FashionMnistTest.cpp index a24ca820..99969651 100644 --- a/tests/dataset_tests/Fashion-MNIST/FashionMnistTest.cpp +++ b/tests/dataset_tests/Fashion-MNIST/FashionMnistTest.cpp @@ -55,13 +55,14 @@ TEST_F(FashionMnistTest, convolutionNeuralNetwork) { StraightforwardNeuralNetwork neuralNetwork({ Input(1, 28, 28), - Convolution(5, 5, activation::ReLU), + Convolution(8, 3, activation::ReLU), + Convolution(24, 5, activation::ReLU), FullyConnected(10) }, - StochasticGradientDescent(0.002f, 0.95f)); + StochasticGradientDescent(0.0002f, 0.80f)); neuralNetwork.train(*data, 1_ep || 20_s); auto accuracy = neuralNetwork.getGlobalClusteringRate(); - ASSERT_ACCURACY(accuracy, 0.75f); + ASSERT_ACCURACY(accuracy, 0.70f); } TEST_F(FashionMnistTest, DISABLED_trainBestNeuralNetwork) diff --git a/tests/unit_tests/AdditionTests.cpp b/tests/unit_tests/AdditionTests.cpp index 4013aa82..44d5d579 100644 --- a/tests/unit_tests/AdditionTests.cpp +++ b/tests/unit_tests/AdditionTests.cpp @@ -15,10 +15,11 @@ TEST(Addition, WithMPL) auto data = createDataForAdditionTests(); StraightforwardNeuralNetwork neuralNetwork({ Input(2), - FullyConnected(8, activation::sigmoid), + FullyConnected(16, activation::sigmoid), FullyConnected(1, activation::identity) - }); - neuralNetwork.train(*data, 1.0_acc || 2_s, 3, 4); + }, + StochasticGradientDescent(0.01f)); + neuralNetwork.train(*data, 1.0_acc || 1_s, 3, 4); testNeuralNetworkForAddition(neuralNetwork); } diff --git a/tests/unit_tests/IdentityTests.cpp b/tests/unit_tests/IdentityTests.cpp index 171a33a9..542d0d3e 100644 --- a/tests/unit_tests/IdentityTests.cpp +++ b/tests/unit_tests/IdentityTests.cpp @@ -57,8 +57,12 @@ TEST(Identity, WorksWithLotsOfNumbers) Data data(problem::regression, inputData, expectedOutputs); data.setPrecision(precision); - StraightforwardNeuralNetwork neuralNetwork({Input(1), FullyConnected(8), FullyConnected(1, snn::activation::identity)}, - StochasticGradientDescent(0.0002f, 0.99f)); + StraightforwardNeuralNetwork neuralNetwork({ + Input(1), + FullyConnected(30), + FullyConnected(1, activation::identity) + }, + StochasticGradientDescent(0.001f, 0.80f)); neuralNetwork.train(data, 1.00_acc || 2_s); @@ -70,5 +74,5 @@ TEST(Identity, WorksWithLotsOfNumbers) && neuralNetwork.isValid() == 0) ASSERT_SUCCESS(); else - ASSERT_FAIL("MAE > 1: " + to_string(mae)); + ASSERT_FAIL("MAE > 0.4: " + to_string(mae)); } \ No newline at end of file