forked from ChizhovVadim/CounterGo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement Stochastic gradient descent for HCE tuning
- Loading branch information
1 parent
f4ddf9f
commit 1914935
Showing
14 changed files
with
1,195 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,6 @@ counter/counter | |
|
||
# Output of the go coverage tool, specifically when used with LiteIDE | ||
*.out | ||
|
||
# | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
package main | ||
|
||
import ( | ||
"bufio" | ||
"fmt" | ||
"os" | ||
"strconv" | ||
"strings" | ||
|
||
"github.com/ChizhovVadim/CounterGo/internal/domain" | ||
|
||
"github.com/ChizhovVadim/CounterGo/pkg/common" | ||
) | ||
|
||
type Sample struct { | ||
Target float32 | ||
domain.TuneEntry | ||
} | ||
|
||
func LoadDataset(filepath string, e ITunableEvaluator, | ||
parser func(string, ITunableEvaluator) (Sample, error)) ([]Sample, error) { | ||
file, err := os.Open(filepath) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer file.Close() | ||
|
||
var result []Sample | ||
|
||
var scanner = bufio.NewScanner(file) | ||
for scanner.Scan() { | ||
var s = scanner.Text() | ||
var sample, err = parser(s, e) | ||
if err != nil { | ||
return nil, fmt.Errorf("parse fail %v %w", s, err) | ||
} | ||
result = append(result, sample) | ||
} | ||
|
||
err = scanner.Err() | ||
if err != nil { | ||
return nil, err | ||
} | ||
return result, nil | ||
} | ||
|
||
func parseValidationSample(s string, e ITunableEvaluator) (Sample, error) { | ||
var sample Sample | ||
|
||
var index = strings.Index(s, "\"") | ||
if index < 0 { | ||
return Sample{}, fmt.Errorf("bad separator") | ||
} | ||
|
||
var fen = s[:index] | ||
var strScore = s[index+1:] | ||
|
||
var pos, err = common.NewPositionFromFEN(fen) | ||
if err != nil { | ||
return Sample{}, err | ||
} | ||
sample.TuneEntry = e.ComputeFeatures(&pos) | ||
|
||
var prob float32 | ||
if strings.HasPrefix(strScore, "1/2-1/2") { | ||
prob = 0.5 | ||
} else if strings.HasPrefix(strScore, "1-0") { | ||
prob = 1.0 | ||
} else if strings.HasPrefix(strScore, "0-1") { | ||
prob = 0.0 | ||
} else { | ||
return Sample{}, fmt.Errorf("bad game result") | ||
} | ||
sample.Target = prob | ||
|
||
return sample, nil | ||
} | ||
|
||
func parseTrainingSample(line string, e ITunableEvaluator) (Sample, error) { | ||
var sample Sample | ||
|
||
var fileds = strings.SplitN(line, ";", 3) | ||
if len(fileds) < 3 { | ||
return Sample{}, fmt.Errorf("Bad line") | ||
} | ||
|
||
var fen = fileds[0] | ||
var pos, err = common.NewPositionFromFEN(fen) | ||
if err != nil { | ||
return Sample{}, err | ||
} | ||
sample.TuneEntry = e.ComputeFeatures(&pos) | ||
|
||
var sScore = fileds[1] | ||
score, err := strconv.Atoi(sScore) | ||
if err != nil { | ||
return Sample{}, err | ||
} | ||
|
||
var sResult = fileds[2] | ||
gameResult, err := strconv.ParseFloat(sResult, 64) | ||
if err != nil { | ||
return Sample{}, err | ||
} | ||
|
||
const W = 0.75 | ||
var prob = W*Sigmoid(float64(score)) + (1-W)*gameResult | ||
sample.Target = float32(prob) | ||
|
||
return sample, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package main | ||
|
||
import ( | ||
"math" | ||
) | ||
|
||
type Gradient struct { | ||
Value float64 | ||
M1 float64 | ||
M2 float64 | ||
} | ||
|
||
const ( | ||
Beta1 float64 = 0.9 | ||
Beta2 float64 = 0.999 | ||
) | ||
|
||
// Implementing Gradient | ||
|
||
func (g *Gradient) Update(delta float64) { | ||
g.Value += delta | ||
} | ||
|
||
func (g *Gradient) Calculate() float64 { | ||
|
||
if g.Value == 0 { | ||
// nothing to calculate | ||
return 0 | ||
} | ||
|
||
g.M1 = g.M1*Beta1 + g.Value*(1-Beta1) | ||
g.M2 = g.M2*Beta2 + (g.Value*g.Value)*(1-Beta2) | ||
|
||
return LearningRate * g.M1 / (math.Sqrt(g.M2) + 1e-8) | ||
} | ||
|
||
func (g *Gradient) Reset() { | ||
g.Value = 0.0 | ||
} | ||
|
||
func (g *Gradient) Apply(elem *float64) { | ||
*elem -= g.Calculate() | ||
g.Reset() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package main | ||
|
||
import ( | ||
"fmt" | ||
"log" | ||
"math" | ||
"math/rand" | ||
"runtime" | ||
|
||
"github.com/ChizhovVadim/CounterGo/internal/domain" | ||
eval "github.com/ChizhovVadim/CounterGo/pkg/eval/linear" | ||
|
||
"github.com/ChizhovVadim/CounterGo/pkg/common" | ||
) | ||
|
||
type ITunableEvaluator interface { | ||
StartingWeights() []float64 | ||
ComputeFeatures(pos *common.Position) domain.TuneEntry | ||
} | ||
|
||
func main() { | ||
log.SetFlags(log.LstdFlags | log.Lshortfile) | ||
|
||
var e = eval.NewEvaluationService() | ||
var trainingPath = "/Users/vadimchizhov/chess/fengen.txt" | ||
var validationPath = "/Users/vadimchizhov/chess/tuner/quiet-labeled.epd" | ||
var threads = 4 | ||
var epochs = 100 | ||
|
||
var err = run(e, trainingPath, validationPath, threads, epochs) | ||
if err != nil { | ||
log.Println(err) | ||
} | ||
} | ||
|
||
func run(evaluator ITunableEvaluator, trainingPath, validationPath string, threads, epochs int) error { | ||
td, err := LoadDataset(trainingPath, evaluator, parseTrainingSample) | ||
if err != nil { | ||
return err | ||
} | ||
log.Println("Loaded dataset", len(td)) | ||
runtime.GC() | ||
|
||
vd, err := LoadDataset(validationPath, evaluator, parseValidationSample) | ||
if err != nil { | ||
return err | ||
} | ||
log.Println("Loaded validation", len(vd)) | ||
|
||
var weights = evaluator.StartingWeights() | ||
log.Println("Num of weights", len(weights)) | ||
|
||
var trainer = &Trainer{ | ||
threads: threads, | ||
weigths: weights, | ||
gradients: make([]Gradient, len(weights)), | ||
training: td, | ||
validation: vd, | ||
rnd: rand.New(rand.NewSource(0)), | ||
} | ||
|
||
err = trainer.Train(epochs) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
var wInt = make([]int, len(trainer.weigths)) | ||
for i := range wInt { | ||
wInt[i] = int(math.Round(100 * trainer.weigths[i])) | ||
} | ||
fmt.Printf("var w = %#v\n", wInt) | ||
|
||
return nil | ||
} |
Oops, something went wrong.