-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.go
105 lines (81 loc) · 3.33 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package main
import (
"fmt"
"math/rand"
)
func main() {
// fmt.Printf("Loading labels...\n")
// labels, err := read_labels("../data/train-labels-idx1-ubyte")
// if err != nil {
// fmt.Printf("Error reading the labels: %v\n", err)
// os.Exit(-1)
// }
// fmt.Printf(
// "\tMagic number: %d (%#x)\n\tData type: %s\n\tDimensionality: %d\n\tN labels: %d (%#x)\n\tFirst 5 labels: %v\n",
// labels.magic_num, labels.magic_num, labels.data_type, labels.dimensionality, labels.n, labels.n, labels.labels[:10],
// )
// fmt.Printf("\nLoading images...\n")
// imgs, err := read_imgs("../data/train-images-idx3-ubyte")
// if err != nil {
// fmt.Printf("Error reading the images: %v\n", err)
// os.Exit(-1)
// }
// fmt.Printf(
// "\tMagic number: %d (%#x)\n\tData type: %s\n\tDimensionality: %d\n\tN images: %d (%#x)\n\tRows / image: %d (%#x)\n\tCols / image: %d (%#x)\n\tFirst image: %v\n",
// imgs.magic_num, imgs.magic_num, imgs.data_type, imgs.dimensionality, imgs.n, imgs.n, imgs.img_rows, imgs.img_rows, imgs.img_cols, imgs.img_cols, imgs.images[0],
// )
// fmt.Printf("\nStoring the 20 first images as PNG files...\n")
// for i := 0; i < 20; i++ {
// if err := dump_image(imgs, i, fmt.Sprintf("../data/imgs/img-%d.png", i)); err != nil {
// fmt.Printf("Error generating the image: %v\n", err)
// }
// }
dsize := 80.0
train_passes := 1000000
fmt.Printf("\nGenerating some XOR data...\n")
xorData, xorLabels := genXor(int(dsize), 0.1)
// for i, p := range xorData {
// fmt.Printf("\tData point %2d (Out -> %1.0f) = %6.3f\n", i, xorOutput(p), mat.Formatted(p))
// }
xorDataTrain, xorLabelsTrain := xorData[:int(dsize*0.9)], xorLabels[:int(dsize*0.9)]
xorDataTest, xorLabelsTest := xorData[int(dsize*0.9):], xorLabels[int(dsize*0.9):]
var outputPredTest []float64
m, _ := NewMlp([]int{2, 2, 1}, Sigmoid, 1)
// m.SetWeights([][]float64{{6, 0, -2, 2, -2, 0}, {-4, 2, 2}})
fmt.Printf("%s", m)
// acts, net_acts := m.ComputeActivation([]float64{1, 0})
// fmt.Printf("\nActivations for [1; 0]:\n")
// for i, act := range acts {
// fmt.Printf("\tActivation %2d -> %6.3f\n", i, mat.Formatted(act, mat.FormatMATLAB()))
// }
// fmt.Printf("\nNet Activations for [1; 0]:\n")
// for i, net_act := range net_acts {
// fmt.Printf("\tNet Activation %2d -> %6.3f\n", i, mat.Formatted(net_act, mat.FormatMATLAB()))
// }
for i := 0; i < train_passes; i++ {
rSample := rand.Intn(int(dsize * 0.9))
m.Adapt(xorDataTrain[rSample], []float64{xorLabelsTrain[rSample]}, 0.05)
// output, _, _ := m.ComputeActivation(xorDataTrain[rSample].RawMatrix().Data)
// if i < int(0.1*float64(train_passes)) || i > int(0.9*float64(train_passes)) {
// fmt.Printf("Training output for %6.3f: %6.3f [%d]\n",
// mat.Formatted(xorDataTrain[rSample], mat.FormatMATLAB()), output[0], int(xorLabelsTrain[rSample]))
// }
}
for i, dp := range xorDataTest {
output, _, _ := m.ComputeActivation(dp)
if output[0] > 0.5 {
outputPredTest = append(outputPredTest, 1)
} else {
outputPredTest = append(outputPredTest, 0)
}
fmt.Printf("Testing output for %6.3f: [%6.3f; %6.3f] [%d] [%d]\n",
dp[0], dp[1], output[0], int(outputPredTest[i]), int(xorLabelsTest[i]))
}
errs := 0.0
for i, p_res := range outputPredTest {
if p_res != xorLabelsTest[i] {
errs++
}
}
fmt.Printf("Testing error rate: %2.5f\n", errs/float64(len(outputPredTest)))
}