Skip to content

Commit

Permalink
implement Stochastic gradient descent for HCE tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
ChizhovVadim committed Aug 7, 2022
1 parent f4ddf9f commit 1914935
Show file tree
Hide file tree
Showing 14 changed files with 1,195 additions and 17 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ counter/counter

# Output of the go coverage tool, specifically when used with LiteIDE
*.out

#
.DS_Store
17 changes: 0 additions & 17 deletions cmd/fengen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/ChizhovVadim/CounterGo/internal/quiet"

"github.com/ChizhovVadim/CounterGo/pkg/common"
"github.com/ChizhovVadim/CounterGo/pkg/engine"
eval "github.com/ChizhovVadim/CounterGo/pkg/eval/counter"
)

Expand All @@ -24,22 +23,6 @@ func quietServiceBuilder() IQuietService {
return quiet.NewQuietService(eval.NewEvaluationService(), 30)
}

type IEngine interface {
Prepare()
Clear()
Search(ctx context.Context, searchParams common.SearchParams) common.SearchInfo
}

func engineBuilder() IEngine {
var eng = engine.NewEngine(func() engine.Evaluator {
return eval.NewEvaluationService()
})
eng.Hash = 32
eng.Threads = 1
eng.Prepare()
return eng
}

func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
var err = run()
Expand Down
111 changes: 111 additions & 0 deletions cmd/trainhce/dataset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package main

import (
"bufio"
"fmt"
"os"
"strconv"
"strings"

"github.com/ChizhovVadim/CounterGo/internal/domain"

"github.com/ChizhovVadim/CounterGo/pkg/common"
)

type Sample struct {
Target float32
domain.TuneEntry
}

func LoadDataset(filepath string, e ITunableEvaluator,
parser func(string, ITunableEvaluator) (Sample, error)) ([]Sample, error) {
file, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer file.Close()

var result []Sample

var scanner = bufio.NewScanner(file)
for scanner.Scan() {
var s = scanner.Text()
var sample, err = parser(s, e)
if err != nil {
return nil, fmt.Errorf("parse fail %v %w", s, err)
}
result = append(result, sample)
}

err = scanner.Err()
if err != nil {
return nil, err
}
return result, nil
}

func parseValidationSample(s string, e ITunableEvaluator) (Sample, error) {
var sample Sample

var index = strings.Index(s, "\"")
if index < 0 {
return Sample{}, fmt.Errorf("bad separator")
}

var fen = s[:index]
var strScore = s[index+1:]

var pos, err = common.NewPositionFromFEN(fen)
if err != nil {
return Sample{}, err
}
sample.TuneEntry = e.ComputeFeatures(&pos)

var prob float32
if strings.HasPrefix(strScore, "1/2-1/2") {
prob = 0.5
} else if strings.HasPrefix(strScore, "1-0") {
prob = 1.0
} else if strings.HasPrefix(strScore, "0-1") {
prob = 0.0
} else {
return Sample{}, fmt.Errorf("bad game result")
}
sample.Target = prob

return sample, nil
}

func parseTrainingSample(line string, e ITunableEvaluator) (Sample, error) {
var sample Sample

var fileds = strings.SplitN(line, ";", 3)
if len(fileds) < 3 {
return Sample{}, fmt.Errorf("Bad line")
}

var fen = fileds[0]
var pos, err = common.NewPositionFromFEN(fen)
if err != nil {
return Sample{}, err
}
sample.TuneEntry = e.ComputeFeatures(&pos)

var sScore = fileds[1]
score, err := strconv.Atoi(sScore)
if err != nil {
return Sample{}, err
}

var sResult = fileds[2]
gameResult, err := strconv.ParseFloat(sResult, 64)
if err != nil {
return Sample{}, err
}

const W = 0.75
var prob = W*Sigmoid(float64(score)) + (1-W)*gameResult
sample.Target = float32(prob)

return sample, nil
}
44 changes: 44 additions & 0 deletions cmd/trainhce/gradient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package main

import (
"math"
)

type Gradient struct {
Value float64
M1 float64
M2 float64
}

const (
Beta1 float64 = 0.9
Beta2 float64 = 0.999
)

// Implementing Gradient

func (g *Gradient) Update(delta float64) {
g.Value += delta
}

func (g *Gradient) Calculate() float64 {

if g.Value == 0 {
// nothing to calculate
return 0
}

g.M1 = g.M1*Beta1 + g.Value*(1-Beta1)
g.M2 = g.M2*Beta2 + (g.Value*g.Value)*(1-Beta2)

return LearningRate * g.M1 / (math.Sqrt(g.M2) + 1e-8)
}

func (g *Gradient) Reset() {
g.Value = 0.0
}

func (g *Gradient) Apply(elem *float64) {
*elem -= g.Calculate()
g.Reset()
}
74 changes: 74 additions & 0 deletions cmd/trainhce/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package main

import (
"fmt"
"log"
"math"
"math/rand"
"runtime"

"github.com/ChizhovVadim/CounterGo/internal/domain"
eval "github.com/ChizhovVadim/CounterGo/pkg/eval/linear"

"github.com/ChizhovVadim/CounterGo/pkg/common"
)

type ITunableEvaluator interface {
StartingWeights() []float64
ComputeFeatures(pos *common.Position) domain.TuneEntry
}

func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile)

var e = eval.NewEvaluationService()
var trainingPath = "/Users/vadimchizhov/chess/fengen.txt"
var validationPath = "/Users/vadimchizhov/chess/tuner/quiet-labeled.epd"
var threads = 4
var epochs = 100

var err = run(e, trainingPath, validationPath, threads, epochs)
if err != nil {
log.Println(err)
}
}

func run(evaluator ITunableEvaluator, trainingPath, validationPath string, threads, epochs int) error {
td, err := LoadDataset(trainingPath, evaluator, parseTrainingSample)
if err != nil {
return err
}
log.Println("Loaded dataset", len(td))
runtime.GC()

vd, err := LoadDataset(validationPath, evaluator, parseValidationSample)
if err != nil {
return err
}
log.Println("Loaded validation", len(vd))

var weights = evaluator.StartingWeights()
log.Println("Num of weights", len(weights))

var trainer = &Trainer{
threads: threads,
weigths: weights,
gradients: make([]Gradient, len(weights)),
training: td,
validation: vd,
rnd: rand.New(rand.NewSource(0)),
}

err = trainer.Train(epochs)
if err != nil {
return err
}

var wInt = make([]int, len(trainer.weigths))
for i := range wInt {
wInt[i] = int(math.Round(100 * trainer.weigths[i]))
}
fmt.Printf("var w = %#v\n", wInt)

return nil
}
Loading

0 comments on commit 1914935

Please sign in to comment.