-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.go
48 lines (40 loc) · 1.46 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
package main
import (
"github.com/gokadin/ai-backpropagation-continued/core"
"github.com/gokadin/ai-backpropagation-continued/data"
"github.com/gokadin/ai-backpropagation-continued/layer"
"github.com/gokadin/ai-backpropagation-continued/runners"
"math/rand"
"runtime"
"time"
)
func main() {
runtime.GOMAXPROCS(runtime.NumCPU())
rand.Seed(time.Now().UTC().UnixNano())
network := buildNetwork()
trainingSet := data.NewDataset()
trainingSet.FromRandom(10, 8)
expectedSet := data.NewDataset()
expectedSet.FromRandom(10, 1).OneHotEncode()
//trainingSet := data.NewDataset()
//trainingSet.FromCsv("data/mnist_train_half.csv", 1, -1, -1).Normalize(0, 255)
//expectedSet := data.NewDataset()
//expectedSet.FromCsv("data/mnist_train_half.csv", -1, 0, -1).OneHotEncode()
runner := runners.NewNetworkRunner()
//runner.SetErrorFunction(runners.ErrorFunctionCrossEntropy)
runner.SetBatchSize(4)
runner.SetLearningRate(0.01)
runner.SetMaxError(0.001)
runner.SetValidOutputRange(0.05)
//runner.Train(network, [][]float64{{1, 1}}, [][]float64{{0.5}})
runner.Train(network, [][]float64{{1, 0}, {1, 1}, {0, 1}, {0, 0}}, [][]float64{{1}, {0}, {1}, {0}})
//runner.Train(network, trainingSet.Data(), expectedSet.Data())
//runner.Test(network, trainingSet.Data(), expectedSet.Data())
}
func buildNetwork() *core.Network {
network := core.NewNetwork()
network.AddInputLayer(2).
AddHiddenLayer(2, layer.FunctionSigmoid).
AddOutputLayer(1, layer.FunctionIdentity)
return network
}