From 5d88a02c4491dbdef6015b1bf723a07994f3a484 Mon Sep 17 00:00:00 2001 From: guangzhixu Date: Thu, 4 Jan 2024 17:56:53 +0800 Subject: [PATCH] wrap wal module --- .gitignore | 3 +- Makefile | 2 +- benchmark/main.go | 25 +- rotom.go => db.go | 383 +++++++---------- rotom_test.go => db_test.go | 171 ++++---- example/main.go | 2 +- go.mod | 4 + go.sum | 8 + option.go => options.go | 14 +- wal/bench_test.go | 115 +++++ wal/wal.go | 821 ++++++++++++++++++++++++++++++++++++ wal/wal_test.go | 817 +++++++++++++++++++++++++++++++++++ 12 files changed, 2021 insertions(+), 344 deletions(-) rename rotom.go => db.go (65%) rename rotom_test.go => db_test.go (86%) rename option.go => options.go (74%) create mode 100644 wal/bench_test.go create mode 100644 wal/wal.go create mode 100644 wal/wal_test.go diff --git a/.gitignore b/.gitignore index cbcf8a9..5322719 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,2 @@ -rotom.db -*.db +tmp coverage.* \ No newline at end of file diff --git a/Makefile b/Makefile index 7a76d72..e08acdd 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ test-cover: go tool cover -html=coverage.txt -o coverage.html rm coverage.txt - rm *.db + rm -r tmp* pprof: go tool pprof -http=:18081 "http://localhost:6060/debug/pprof/profile?seconds=60" diff --git a/benchmark/main.go b/benchmark/main.go index 195242a..5e3246d 100644 --- a/benchmark/main.go +++ b/benchmark/main.go @@ -14,13 +14,10 @@ const ( KB = 1 << (10 * (iota + 1)) MB GB - TB ) func convertSize(size int64) string { switch { - case size >= TB: - return fmt.Sprintf("%.1fTB", float64(size)/TB) case size >= GB: return fmt.Sprintf("%.1fGB", float64(size)/GB) case size >= MB: @@ -41,12 +38,12 @@ func fileSize(filename string) string { return convertSize(size) } -func createDB() *rotom.Engine { - cfg := rotom.DefaultConfig - cfg.Logger = nil - cfg.Path = fmt.Sprintf("%d.db", time.Now().UnixNano()) +func createDB() *rotom.DB { + options := rotom.DefaultOptions + options.Logger = nil + options.DirPath = fmt.Sprintf("%d", time.Now().UnixNano()) - db, err := rotom.Open(cfg) + db, err := rotom.Open(options) if err != nil { panic(err) } @@ -81,7 +78,7 @@ func benchSet() { fmt.Printf("99th: %.0f ns\n", td.Quantile(0.99)) // wait for sync time.Sleep(time.Second) - fmt.Printf("db file size: %v\n", fileSize(db.Path)) + fmt.Printf("db file size: %v\n", fileSize(db.DirPath)) fmt.Println() } @@ -123,7 +120,7 @@ func benchSet8parallel() { fmt.Printf("99th: %.0f ns\n", td.Quantile(0.99)) // wait for sync time.Sleep(time.Second) - fmt.Printf("db file size: %v\n", fileSize(db.Path)) + fmt.Printf("db file size: %v\n", fileSize(db.DirPath)) fmt.Println() } @@ -154,7 +151,7 @@ func benchSetEx() { fmt.Printf("99th: %.0f ns\n", td.Quantile(0.99)) // wait for sync time.Sleep(time.Second) - fmt.Printf("db file size: %v\n", fileSize(db.Path)) + fmt.Printf("db file size: %v\n", fileSize(db.DirPath)) fmt.Println() } @@ -217,7 +214,7 @@ func benchHSet() { fmt.Printf("99th: %.0f ns\n", td.Quantile(0.99)) // wait for sync time.Sleep(time.Second) - fmt.Printf("db file size: %v\n", fileSize(db.Path)) + fmt.Printf("db file size: %v\n", fileSize(db.DirPath)) fmt.Println() } @@ -248,7 +245,7 @@ func benchLRPush() { fmt.Printf("99th: %.0f ns\n", td.Quantile(0.99)) // wait for sync time.Sleep(time.Second) - fmt.Printf("db file size: %v\n", fileSize(db.Path)) + fmt.Printf("db file size: %v\n", fileSize(db.DirPath)) fmt.Println() } @@ -310,7 +307,7 @@ func benchBitSet() { fmt.Printf("99th: %.0f ns\n", td.Quantile(0.99)) // wait for sync time.Sleep(time.Second) - fmt.Printf("db file size: %v\n", fileSize(db.Path)) + fmt.Printf("db file size: %v\n", fileSize(db.DirPath)) fmt.Println() } diff --git a/rotom.go b/db.go similarity index 65% rename from rotom.go rename to db.go index a0dcd06..d185c9e 100644 --- a/rotom.go +++ b/db.go @@ -2,25 +2,21 @@ package rotom import ( - "bytes" "context" "fmt" - "hash/crc32" - "os" "sync" + "sync/atomic" "time" cmap "github.com/orcaman/concurrent-map/v2" cache "github.com/xgzlucario/GigaCache" "github.com/xgzlucario/rotom/codeman" "github.com/xgzlucario/rotom/structx" + "github.com/xgzlucario/rotom/wal" ) const ( noTTL = 0 - - KB = 1024 - MB = 1024 * KB ) // Operations. @@ -58,36 +54,31 @@ const ( listDirectionRight ) -const ( - timeCarry = 1e9 -) - // Cmd type Cmd struct { op Operation - hook func(*Engine, *codeman.Parser) error + hook func(*DB, *codeman.Parser) error } // cmdTable defines the rNum and callback function required for the operation. var cmdTable = []Cmd{ - {OpSetTx, func(e *Engine, decoder *codeman.Parser) error { + {OpSetTx, func(e *DB, decoder *codeman.Parser) error { // type, key, ts, val tp := decoder.ParseVarint().Int64() key := decoder.Parse().Str() - ts := decoder.ParseVarint().Int64() * timeCarry + ts := decoder.ParseVarint().Int64() val := decoder.Parse() if decoder.Error != nil { return decoder.Error } - // check ttl - if ts < cache.GetNanoSec() && ts != noTTL { - return nil - } switch tp { case TypeString: - e.SetTx(key, val, ts) + // check ttl + if ts > cache.GetNanoSec() || ts == noTTL { + e.SetTx(key, val, ts) + } case TypeList: ls := structx.NewList[string]() @@ -131,7 +122,7 @@ var cmdTable = []Cmd{ return nil }}, - {OpRemove, func(e *Engine, decoder *codeman.Parser) error { + {OpRemove, func(e *DB, decoder *codeman.Parser) error { // keys keys := decoder.Parse().StrSlice() @@ -143,7 +134,7 @@ var cmdTable = []Cmd{ return nil }}, - {OpHSet, func(e *Engine, decoder *codeman.Parser) error { + {OpHSet, func(e *DB, decoder *codeman.Parser) error { // key, field, val key := decoder.Parse().Str() field := decoder.Parse().Str() @@ -156,7 +147,7 @@ var cmdTable = []Cmd{ return e.HSet(key, field, val) }}, - {OpHRemove, func(e *Engine, decoder *codeman.Parser) error { + {OpHRemove, func(e *DB, decoder *codeman.Parser) error { // key, fields key := decoder.Parse().Str() fields := decoder.Parse().StrSlice() @@ -169,7 +160,7 @@ var cmdTable = []Cmd{ return err }}, - {OpSAdd, func(e *Engine, decoder *codeman.Parser) error { + {OpSAdd, func(e *DB, decoder *codeman.Parser) error { // key, items key := decoder.Parse().Str() items := decoder.Parse().StrSlice() @@ -182,7 +173,7 @@ var cmdTable = []Cmd{ return err }}, - {OpSRemove, func(e *Engine, decoder *codeman.Parser) error { + {OpSRemove, func(e *DB, decoder *codeman.Parser) error { // key, items key := decoder.Parse().Str() items := decoder.Parse().StrSlice() @@ -194,7 +185,7 @@ var cmdTable = []Cmd{ return e.SRemove(key, items...) }}, - {OpSMerge, func(e *Engine, decoder *codeman.Parser) error { + {OpSMerge, func(e *DB, decoder *codeman.Parser) error { // op, key, items op := decoder.ParseVarint().Byte() key := decoder.Parse().Str() @@ -215,7 +206,7 @@ var cmdTable = []Cmd{ return ErrInvalidBitmapOperation }}, - {OpLPush, func(e *Engine, decoder *codeman.Parser) error { + {OpLPush, func(e *DB, decoder *codeman.Parser) error { // direct, key, items direct := decoder.ParseVarint().Byte() key := decoder.Parse().Str() @@ -234,7 +225,7 @@ var cmdTable = []Cmd{ return ErrInvalidListDirect }}, - {OpLPop, func(e *Engine, decoder *codeman.Parser) (err error) { + {OpLPop, func(e *DB, decoder *codeman.Parser) (err error) { // direct, key direct := decoder.ParseVarint().Byte() key := decoder.Parse().Str() @@ -254,7 +245,7 @@ var cmdTable = []Cmd{ return }}, - {OpBitSet, func(e *Engine, decoder *codeman.Parser) error { + {OpBitSet, func(e *DB, decoder *codeman.Parser) error { // key, offset, val key := decoder.Parse().Str() val := decoder.ParseVarint().Bool() @@ -268,7 +259,7 @@ var cmdTable = []Cmd{ return err }}, - {OpBitFlip, func(e *Engine, decoder *codeman.Parser) error { + {OpBitFlip, func(e *DB, decoder *codeman.Parser) error { // key, offset key := decoder.Parse().Str() offset := decoder.ParseVarint().Uint32() @@ -280,7 +271,7 @@ var cmdTable = []Cmd{ return e.BitFlip(key, offset) }}, - {OpBitMerge, func(e *Engine, decoder *codeman.Parser) error { + {OpBitMerge, func(e *DB, decoder *codeman.Parser) error { // op, key, items op := decoder.ParseVarint().Byte() key := decoder.Parse().Str() @@ -301,7 +292,7 @@ var cmdTable = []Cmd{ return ErrInvalidBitmapOperation }}, - {OpZAdd, func(e *Engine, decoder *codeman.Parser) error { + {OpZAdd, func(e *DB, decoder *codeman.Parser) error { // key, field, score key := decoder.Parse().Str() field := decoder.Parse().Str() @@ -314,7 +305,7 @@ var cmdTable = []Cmd{ return e.ZAdd(key, field, score) }}, - {OpZIncr, func(e *Engine, decoder *codeman.Parser) error { + {OpZIncr, func(e *DB, decoder *codeman.Parser) error { // key, field, score key := decoder.Parse().Str() field := decoder.Parse().Str() @@ -328,7 +319,7 @@ var cmdTable = []Cmd{ return err }}, - {OpZRemove, func(e *Engine, decoder *codeman.Parser) error { + {OpZRemove, func(e *DB, decoder *codeman.Parser) error { // key, field key := decoder.Parse().Str() field := decoder.Parse().Str() @@ -337,8 +328,7 @@ var cmdTable = []Cmd{ return decoder.Error } - _, err := e.ZRemove(key, field) - return err + return e.ZRemove(key, field) }}, } @@ -364,10 +354,10 @@ type ( BitMap = *structx.Bitmap ) -// Engine represents a rotom engine for storage. -type Engine struct { +// DB represents a rotom database engine. +type DB struct { sync.Mutex - Config + *Options // context. ctx context.Context @@ -377,92 +367,81 @@ type Engine struct { // if db loading encode not allowed. loading bool - buf *bytes.Buffer - rwbuf *bytes.Buffer + // write ahead log. + internalKey atomic.Uint64 + wal *wal.Log // data for bytes. m *cache.GigaCache // data for built-in structure. cm cmap.ConcurrentMap[string, any] - - // crc. - crc *crc32.Table } -// Open opens a database specified by config. -// The file will be created automatically if not exist. -func Open(conf Config) (*Engine, error) { +// Open opens a database by options. +func Open(options Options) (*DB, error) { + // create wal. + wal, err := wal.Open(options.DirPath, wal.DefaultOptions) + if err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(context.Background()) - e := &Engine{ - Config: conf, + db := &DB{ + Options: &options, ctx: ctx, cancel: cancel, loading: true, - buf: bytes.NewBuffer(make([]byte, 0, 1024)), - rwbuf: bytes.NewBuffer(make([]byte, 0, 1024)), tickers: [2]*Ticker{}, + wal: wal, m: cache.New(cache.DefaultOption), cm: cmap.New[any](), - crc: crc32.MakeTable(crc32.Castagnoli), } - // load db from disk. - if err := e.load(); err != nil { - e.logError("db load error: %v", err) + // load db from wal. + if err := db.loadFromWal(); err != nil { return nil, err } - e.loading = false + db.loading = false - if e.SyncPolicy == EverySecond { - // sync buffer to disk. - e.tickers[0] = NewTicker(ctx, time.Second, func() { - e.Lock() - _, err := e.writeTo(e.buf, e.Path) - e.Unlock() - if err != nil { - e.logError("writeTo buffer error: %v", err) - } + if db.SyncPolicy == EverySecond { + // sync to disk. + db.tickers[0] = NewTicker(ctx, time.Second, func() { + db.wal.Sync() }) // shrink db. - e.tickers[1] = NewTicker(ctx, e.ShrinkInterval, func() { - e.Lock() - e.shrink() - e.Unlock() + db.tickers[1] = NewTicker(ctx, db.ShrinkInterval, func() { + db.shrink() }) } - e.logInfo("rotom is ready to go") + db.logInfo("rotom is ready to go") - return e, nil + return db, nil } -// Close closes the db engine. -func (e *Engine) Close() error { +// Close closes the db. +func (db *DB) Close() error { select { - case <-e.ctx.Done(): + case <-db.ctx.Done(): return ErrDatabaseClosed default: - e.Lock() - _, err := e.writeTo(e.buf, e.Path) - e.Unlock() - e.cancel() - return err + db.wal.Sync() + db.cancel() + return db.wal.Close() } } // encode -func (e *Engine) encode(cd *codeman.Codec) { - if e.SyncPolicy == Never { +func (db *DB) encode(cd *codeman.Codec) { + if db.SyncPolicy == Never { return } - if e.loading { + if db.loading { return } - e.Lock() - e.buf.Write(cd.Content()) - e.Unlock() + db.wal.Write(db.internalKey.Add(1), cd.Content()) cd.Recycle() } @@ -472,7 +451,7 @@ func NewCodec(op Operation) *codeman.Codec { } // Get -func (e *Engine) Get(key string) ([]byte, int64, error) { +func (e *DB) Get(key string) ([]byte, int64, error) { // check if e.cm.Has(key) { return nil, 0, ErrTypeAssert @@ -486,26 +465,26 @@ func (e *Engine) Get(key string) ([]byte, int64, error) { } // Set store key-value pair. -func (e *Engine) Set(key string, val []byte) { +func (e *DB) Set(key string, val []byte) { e.SetTx(key, val, noTTL) } // SetEx store key-value pair with ttl. -func (e *Engine) SetEx(key string, val []byte, ttl time.Duration) { +func (e *DB) SetEx(key string, val []byte, ttl time.Duration) { e.SetTx(key, val, cache.GetNanoSec()+int64(ttl)) } // SetTx store key-value pair with deadline. -func (e *Engine) SetTx(key string, val []byte, ts int64) { +func (e *DB) SetTx(key string, val []byte, ts int64) { if ts < 0 { return } - e.encode(NewCodec(OpSetTx).Int(TypeString).Str(key).Int(ts / timeCarry).Bytes(val)) + e.encode(NewCodec(OpSetTx).Int(TypeString).Str(key).Int(ts).Bytes(val)) e.m.SetTx(key, val, ts) } // Remove -func (e *Engine) Remove(keys ...string) (n int) { +func (e *DB) Remove(keys ...string) (n int) { e.encode(NewCodec(OpRemove).StrSlice(keys)) for _, key := range keys { if e.m.Delete(key) { @@ -516,26 +495,26 @@ func (e *Engine) Remove(keys ...string) (n int) { } // Len -func (e *Engine) Len() uint64 { +func (e *DB) Len() uint64 { return e.m.Stat().Len + uint64(e.cm.Count()) } // GC triggers the garbage collection to evict expired kv datas. -func (e *Engine) GC() { +func (e *DB) GC() { e.Lock() e.m.Migrate() e.Unlock() } // Scan -func (e *Engine) Scan(f func(string, []byte, int64) bool) { +func (e *DB) Scan(f func(string, []byte, int64) bool) { e.m.Scan(func(key, value []byte, ttl int64) bool { return f(string(key), value, ttl) }) } // HGet -func (e *Engine) HGet(key, field string) ([]byte, error) { +func (e *DB) HGet(key, field string) ([]byte, error) { m, err := e.fetchMap(key) if err != nil { return nil, err @@ -548,7 +527,7 @@ func (e *Engine) HGet(key, field string) ([]byte, error) { } // HLen -func (e *Engine) HLen(key string) (int, error) { +func (e *DB) HLen(key string) (int, error) { m, err := e.fetchMap(key) if err != nil { return 0, err @@ -557,7 +536,7 @@ func (e *Engine) HLen(key string) (int, error) { } // HSet -func (e *Engine) HSet(key, field string, val []byte) error { +func (e *DB) HSet(key, field string, val []byte) error { m, err := e.fetchMap(key, true) if err != nil { return err @@ -569,7 +548,7 @@ func (e *Engine) HSet(key, field string, val []byte) error { } // HRemove -func (e *Engine) HRemove(key string, fields ...string) (n int, err error) { +func (e *DB) HRemove(key string, fields ...string) (n int, err error) { m, err := e.fetchMap(key) if err != nil { return 0, err @@ -584,7 +563,7 @@ func (e *Engine) HRemove(key string, fields ...string) (n int, err error) { } // HKeys -func (e *Engine) HKeys(key string) ([]string, error) { +func (e *DB) HKeys(key string) ([]string, error) { m, err := e.fetchMap(key) if err != nil { return nil, err @@ -593,7 +572,7 @@ func (e *Engine) HKeys(key string) ([]string, error) { } // SAdd -func (e *Engine) SAdd(key string, items ...string) (int, error) { +func (e *DB) SAdd(key string, items ...string) (int, error) { s, err := e.fetchSet(key, true) if err != nil { return 0, err @@ -603,7 +582,7 @@ func (e *Engine) SAdd(key string, items ...string) (int, error) { } // SRemove -func (e *Engine) SRemove(key string, items ...string) error { +func (e *DB) SRemove(key string, items ...string) error { s, err := e.fetchSet(key) if err != nil { return err @@ -614,7 +593,7 @@ func (e *Engine) SRemove(key string, items ...string) error { } // SHas returns whether the given items are all in the set. -func (e *Engine) SHas(key string, items ...string) (bool, error) { +func (e *DB) SHas(key string, items ...string) (bool, error) { s, err := e.fetchSet(key) if err != nil { return false, err @@ -623,7 +602,7 @@ func (e *Engine) SHas(key string, items ...string) (bool, error) { } // SCard -func (e *Engine) SCard(key string) (int, error) { +func (e *DB) SCard(key string) (int, error) { s, err := e.fetchSet(key) if err != nil { return 0, err @@ -632,7 +611,7 @@ func (e *Engine) SCard(key string) (int, error) { } // SMembers -func (e *Engine) SMembers(key string) ([]string, error) { +func (e *DB) SMembers(key string) ([]string, error) { s, err := e.fetchSet(key) if err != nil { return nil, err @@ -641,7 +620,7 @@ func (e *Engine) SMembers(key string) ([]string, error) { } // SUnion -func (e *Engine) SUnion(dst string, src ...string) error { +func (e *DB) SUnion(dst string, src ...string) error { srcSet, err := e.fetchSet(src[0]) if err != nil { return err @@ -662,7 +641,7 @@ func (e *Engine) SUnion(dst string, src ...string) error { } // SInter -func (e *Engine) SInter(dst string, src ...string) error { +func (e *DB) SInter(dst string, src ...string) error { srcSet, err := e.fetchSet(src[0]) if err != nil { return err @@ -683,7 +662,7 @@ func (e *Engine) SInter(dst string, src ...string) error { } // SDiff -func (e *Engine) SDiff(dst string, src ...string) error { +func (e *DB) SDiff(dst string, src ...string) error { srcSet, err := e.fetchSet(src[0]) if err != nil { return err @@ -704,7 +683,7 @@ func (e *Engine) SDiff(dst string, src ...string) error { } // LLPush -func (e *Engine) LLPush(key string, items ...string) error { +func (e *DB) LLPush(key string, items ...string) error { ls, err := e.fetchList(key, true) if err != nil { return err @@ -716,7 +695,7 @@ func (e *Engine) LLPush(key string, items ...string) error { } // LRPush -func (e *Engine) LRPush(key string, items ...string) error { +func (e *DB) LRPush(key string, items ...string) error { ls, err := e.fetchList(key, true) if err != nil { return err @@ -728,7 +707,7 @@ func (e *Engine) LRPush(key string, items ...string) error { } // LIndex -func (e *Engine) LIndex(key string, i int) (string, error) { +func (e *DB) LIndex(key string, i int) (string, error) { ls, err := e.fetchList(key) if err != nil { return "", err @@ -741,7 +720,7 @@ func (e *Engine) LIndex(key string, i int) (string, error) { } // LLPop -func (e *Engine) LLPop(key string) (string, error) { +func (e *DB) LLPop(key string) (string, error) { ls, err := e.fetchList(key) if err != nil { return "", err @@ -756,7 +735,7 @@ func (e *Engine) LLPop(key string) (string, error) { } // LRPop -func (e *Engine) LRPop(key string) (string, error) { +func (e *DB) LRPop(key string) (string, error) { ls, err := e.fetchList(key) if err != nil { return "", err @@ -771,7 +750,7 @@ func (e *Engine) LRPop(key string) (string, error) { } // LLen -func (e *Engine) LLen(key string) (int, error) { +func (e *DB) LLen(key string) (int, error) { ls, err := e.fetchList(key) if err != nil { return 0, err @@ -780,7 +759,7 @@ func (e *Engine) LLen(key string) (int, error) { } // BitTest -func (e *Engine) BitTest(key string, offset uint32) (bool, error) { +func (e *DB) BitTest(key string, offset uint32) (bool, error) { bm, err := e.fetchBitMap(key) if err != nil { return false, err @@ -789,7 +768,7 @@ func (e *Engine) BitTest(key string, offset uint32) (bool, error) { } // BitSet -func (e *Engine) BitSet(key string, val bool, offsets ...uint32) (int, error) { +func (e *DB) BitSet(key string, val bool, offsets ...uint32) (int, error) { bm, err := e.fetchBitMap(key, true) if err != nil { return 0, err @@ -807,7 +786,7 @@ func (e *Engine) BitSet(key string, val bool, offsets ...uint32) (int, error) { } // BitFlip -func (e *Engine) BitFlip(key string, offset uint32) error { +func (e *DB) BitFlip(key string, offset uint32) error { bm, err := e.fetchBitMap(key) if err != nil { return err @@ -819,7 +798,7 @@ func (e *Engine) BitFlip(key string, offset uint32) error { } // BitAnd -func (e *Engine) BitAnd(dst string, src ...string) error { +func (e *DB) BitAnd(dst string, src ...string) error { bm, err := e.fetchBitMap(src[0]) if err != nil { return err @@ -840,7 +819,7 @@ func (e *Engine) BitAnd(dst string, src ...string) error { } // BitOr -func (e *Engine) BitOr(dst string, src ...string) error { +func (e *DB) BitOr(dst string, src ...string) error { bm, err := e.fetchBitMap(src[0]) if err != nil { return err @@ -861,7 +840,7 @@ func (e *Engine) BitOr(dst string, src ...string) error { } // BitXor -func (e *Engine) BitXor(dst string, src ...string) error { +func (e *DB) BitXor(dst string, src ...string) error { bm, err := e.fetchBitMap(src[0]) if err != nil { return err @@ -882,7 +861,7 @@ func (e *Engine) BitXor(dst string, src ...string) error { } // BitArray -func (e *Engine) BitArray(key string) ([]uint32, error) { +func (e *DB) BitArray(key string) ([]uint32, error) { bm, err := e.fetchBitMap(key) if err != nil { return nil, err @@ -891,7 +870,7 @@ func (e *Engine) BitArray(key string) ([]uint32, error) { } // BitCount -func (e *Engine) BitCount(key string) (uint64, error) { +func (e *DB) BitCount(key string) (uint64, error) { bm, err := e.fetchBitMap(key) if err != nil { return 0, err @@ -900,7 +879,7 @@ func (e *Engine) BitCount(key string) (uint64, error) { } // ZGet -func (e *Engine) ZGet(zset, key string) (float64, error) { +func (e *DB) ZGet(zset, key string) (float64, error) { zs, err := e.fetchZSet(zset) if err != nil { return 0, err @@ -913,7 +892,7 @@ func (e *Engine) ZGet(zset, key string) (float64, error) { } // ZCard -func (e *Engine) ZCard(zset string) (int, error) { +func (e *DB) ZCard(zset string) (int, error) { zs, err := e.fetchZSet(zset) if err != nil { return 0, err @@ -922,7 +901,7 @@ func (e *Engine) ZCard(zset string) (int, error) { } // ZIter -func (e *Engine) ZIter(zset string, f func(string, float64) bool) error { +func (e *DB) ZIter(zset string, f func(string, float64) bool) error { zs, err := e.fetchZSet(zset) if err != nil { return err @@ -934,7 +913,7 @@ func (e *Engine) ZIter(zset string, f func(string, float64) bool) error { } // ZAdd -func (e *Engine) ZAdd(zset, key string, score float64) error { +func (e *DB) ZAdd(zset, key string, score float64) error { zs, err := e.fetchZSet(zset, true) if err != nil { return err @@ -946,7 +925,7 @@ func (e *Engine) ZAdd(zset, key string, score float64) error { } // ZIncr -func (e *Engine) ZIncr(zset, key string, incr float64) (float64, error) { +func (e *DB) ZIncr(zset, key string, incr float64) (float64, error) { zs, err := e.fetchZSet(zset, true) if err != nil { return 0, err @@ -957,109 +936,80 @@ func (e *Engine) ZIncr(zset, key string, incr float64) (float64, error) { } // ZRemove -func (e *Engine) ZRemove(zset string, key string) (float64, error) { +func (e *DB) ZRemove(zset string, key string) error { zs, err := e.fetchZSet(zset) if err != nil { - return 0, err + return err } e.encode(NewCodec(OpZRemove).Str(zset).Str(key)) - res, _ := zs.Delete(key) + zs.Delete(key) - return res, nil + return nil } -// writeTo writes the buffer into the file at the specified path. -func (s *Engine) writeTo(buf *bytes.Buffer, path string) (int, error) { - if buf.Len() == 0 { - return 0, nil - } - - fs, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) +// loadFromWal load data to mem from wal. +func (db *DB) loadFromWal() error { + // load wal file. + firstIndex, err := db.wal.FirstIndex() if err != nil { - return 0, err - } - defer fs.Close() - - // encode block. - data := codeman.Compress(buf.Bytes(), nil) - crc := crc32.Checksum(data, s.crc) - coder := codeman.NewCodec().Bytes(data).Uint(crc) - - n, err := fs.Write(coder.Content()) - if err != nil { - return 0, err + return err } - - // reset - buf.Reset() - coder.Recycle() - - return n, nil -} - -// load reads the persisted data from the shard file and loads it into memory. -func (e *Engine) load() error { - line, err := os.ReadFile(e.Path) + lastIndex, err := db.wal.LastIndex() if err != nil { - if os.IsNotExist(err) { - return nil - } return err } + if lastIndex == 0 { + return nil + } - e.logInfo("loading db file size %s", formatSize(len(line))) - - blkParser := codeman.NewParser(line) - for !blkParser.Done() { - // parse data block. - dataBlock := blkParser.Parse() - crc := uint32(blkParser.ParseVarint()) + db.logInfo("loading db from wal index[%d, %d]", firstIndex, lastIndex) - if blkParser.Error != nil { - return blkParser.Error - } - if crc != crc32.Checksum(dataBlock, e.crc) { - return ErrCheckSum - } - - data, err := codeman.Decompress(dataBlock, nil) + // read. + for i := firstIndex; i <= lastIndex; i++ { + data, err := db.wal.Read(i) if err != nil { return err } + db.internalKey.Store(i) - recParser := codeman.NewParser(data) - for !recParser.Done() { + parser := codeman.NewParser(data) + for !parser.Done() { // parse records. - op := Operation(recParser.ParseVarint()) + op := Operation(parser.ParseVarint()) - if recParser.Error != nil { - return recParser.Error + if parser.Error != nil { + return parser.Error } - if err := cmdTable[op].hook(e, recParser); err != nil { + if err := cmdTable[op].hook(db, parser); err != nil { return err } } } - e.logInfo("db load complete") + db.logInfo("db load complete") return nil } // rewrite write data to the file. -func (e *Engine) shrink() { - // Marshal String - e.m.Scan(func(key, value []byte, ttl int64) bool { - cd := NewCodec(OpSetTx).Int(TypeString).Bytes(key).Int(ttl / timeCarry).Bytes(value) - e.rwbuf.Write(cd.Content()) +func (db *DB) shrink() { + db.Lock() + defer db.Unlock() + + lastIndexBefore := db.internalKey.Load() + + // marshal bytes. + db.m.Scan(func(key, value []byte, ts int64) bool { + cd := NewCodec(OpSetTx).Int(TypeString).Bytes(key).Int(ts).Bytes(value) + db.wal.Write(db.internalKey.Add(1), cd.Content()) cd.Recycle() return false }) - // Marshal built-in type + // marshal built-in types. var _type Type - for t := range e.cm.IterBuffered() { + for t := range db.cm.IterBuffered() { switch t.Val.(type) { case Map: _type = TypeMap @@ -1072,66 +1022,66 @@ func (e *Engine) shrink() { case ZSet: _type = TypeZSet } - // SetTx if cd, err := NewCodec(OpSetTx).Int(_type).Str(t.Key).Int(0).Any(t.Val); err == nil { - e.rwbuf.Write(cd.Content()) + db.wal.Write(db.internalKey.Add(1), cd.Content()) cd.Recycle() } } - // Flush buffer to file - tmpPath := fmt.Sprintf("%v.tmp", time.Now()) - e.writeTo(e.rwbuf, tmpPath) - e.writeTo(e.buf, tmpPath) - - os.Rename(tmpPath, e.Path) + // sync + if err := db.wal.Sync(); err != nil { + panic(err) + } + if err := db.wal.TruncateFront(lastIndexBefore); err != nil { + panic(err) + } - e.logInfo("rotom rewrite done") + db.logInfo("rotom shrink done") } // Shrink forced to shrink db file. // Warning: will panic if SyncPolicy is never. -func (e *Engine) Shrink() error { +func (e *DB) Shrink() error { return e.tickers[1].Do() } // fetchMap -func (e *Engine) fetchMap(key string, setnx ...bool) (m Map, err error) { +func (e *DB) fetchMap(key string, setnx ...bool) (m Map, err error) { return fetch(e, key, func() Map { return structx.NewSyncMap() }, setnx...) } // fetchSet -func (e *Engine) fetchSet(key string, setnx ...bool) (s Set, err error) { +func (e *DB) fetchSet(key string, setnx ...bool) (s Set, err error) { return fetch(e, key, func() Set { return structx.NewSet[string]() }, setnx...) } // fetchList -func (e *Engine) fetchList(key string, setnx ...bool) (m List, err error) { +func (e *DB) fetchList(key string, setnx ...bool) (m List, err error) { return fetch(e, key, func() List { return structx.NewList[string]() }, setnx...) } // fetchBitMap -func (e *Engine) fetchBitMap(key string, setnx ...bool) (bm BitMap, err error) { +func (e *DB) fetchBitMap(key string, setnx ...bool) (bm BitMap, err error) { return fetch(e, key, func() BitMap { return structx.NewBitmap() }, setnx...) } // fetchZSet -func (e *Engine) fetchZSet(key string, setnx ...bool) (z ZSet, err error) { +func (e *DB) fetchZSet(key string, setnx ...bool) (z ZSet, err error) { return fetch(e, key, func() ZSet { return structx.NewZSet[string, float64]() }, setnx...) } // fetch -func fetch[T any](e *Engine, key string, new func() T, setnx ...bool) (v T, err error) { +func fetch[T any](e *DB, key string, new func() T, setnx ...bool) (v T, err error) { // check from bytes cache. if _, _, ok := e.m.Get(key); ok { return v, ErrWrongType @@ -1153,28 +1103,9 @@ func fetch[T any](e *Engine, key string, new func() T, setnx ...bool) (v T, err return v, nil } -// formatSize -func formatSize[T codeman.Integer](size T) string { - switch { - case size < KB: - return fmt.Sprintf("%dB", size) - case size < MB: - return fmt.Sprintf("%.1fKB", float64(size)/KB) - default: - return fmt.Sprintf("%.1fMB", float64(size)/MB) - } -} - // logInfo -func (e *Engine) logInfo(msg string, r ...any) { +func (e *DB) logInfo(msg string, r ...any) { if e.Logger != nil { e.Logger.Info(fmt.Sprintf(msg, r...)) } } - -// logError -func (e *Engine) logError(msg string, r ...any) { - if e.Logger != nil { - e.Logger.Error(fmt.Sprintf(msg, r...)) - } -} diff --git a/rotom_test.go b/db_test.go similarity index 86% rename from rotom_test.go rename to db_test.go index d47fc09..8af5536 100644 --- a/rotom_test.go +++ b/db_test.go @@ -3,7 +3,6 @@ package rotom import ( "fmt" "math/rand" - "os" "strconv" "strings" "testing" @@ -20,14 +19,10 @@ var ( nilStrings []string ) -func unix(nanosec int64) int64 { - return (nanosec / timeCarry) * timeCarry -} - -func createDB() (*Engine, error) { - cfg := DefaultConfig - cfg.Path = fmt.Sprintf("%s-%d.db", gofakeit.UUID(), time.Now().UnixNano()) - return Open(cfg) +func createDB() (*DB, error) { + options := DefaultOptions + options.DirPath = fmt.Sprintf("tmp-%s", gofakeit.UUID()) + return Open(options) } func TestDB(t *testing.T) { @@ -109,48 +104,8 @@ func TestDB(t *testing.T) { assert.Equal(db.Close(), ErrDatabaseClosed) // Load Success - _, err = Open(db.Config) + _, err = Open(*db.Options) assert.Nil(err) - - // Load Error - os.WriteFile(db.Config.Path, []byte("fake"), 0644) - _, err = Open(db.Config) - assert.NotNil(err, ErrCheckSum) - - // error data type - db.encode(NewCodec(OpSetTx).Int(100).Str("key").Str("val")) - db.Close() - _, err = Open(db.Config) - assert.NotNil(err, ErrCheckSum) -} - -func TestSetTTL(t *testing.T) { - println("===== TestSetTTL =====") - assert := assert.New(t) - - cfg := DefaultConfig - cfg.Path = gofakeit.UUID() + ".db" - - db, err := Open(cfg) - assert.Nil(err) - - // SetTTL - for i := 0; i < 2000; i++ { - k := fmt.Sprintf("%08x", i) - db.SetEx(k, []byte(k), time.Second) - } - - for i := 0; i < 2000; i++ { - k := fmt.Sprintf("%08x", i) - val, ts, err := db.Get(k) - now := unix(cache.GetNanoSec()) - nowAddSed := unix(cache.GetNanoSec() + int64(time.Second)) - - assert.Nil(err) - assert.Equal(val, []byte(k)) - assert.True(ts >= now) - assert.True(ts <= nowAddSed) - } } func TestHmap(t *testing.T) { @@ -223,11 +178,15 @@ func TestHmap(t *testing.T) { assert.Equal(err, ErrFieldNotFound) } - db.Shrink() + // reload db.Close() + db, err = Open(*db.Options) + assert.Nil(err) - // Load - _, err = Open(db.Config) + // shrink and reload + db.Shrink() + db.Close() + _, err = Open(*db.Options) assert.Nil(err) } @@ -320,11 +279,15 @@ func TestList(t *testing.T) { db.LRPush("list", gofakeit.Animal()) } - db.Shrink() + // reload db.Close() + db, err = Open(*db.Options) + assert.Nil(err) - // Load - _, err = Open(db.Config) + // shrink and reload + db.Shrink() + db.Close() + _, err = Open(*db.Options) assert.Nil(err) } @@ -435,11 +398,15 @@ func TestSet(t *testing.T) { err = db.SInter("", "set", "map") assert.ErrorContains(err, ErrWrongType.Error()) - db.Shrink() + // reload db.Close() + db, err = Open(*db.Options) + assert.Nil(err) - // Load - _, err = Open(db.Config) + // shrink and reload + db.Shrink() + db.Close() + _, err = Open(*db.Options) assert.Nil(err) } @@ -543,11 +510,15 @@ func TestBitmap(t *testing.T) { assert.ElementsMatch(m1, m2) } - db.Shrink() + // reload db.Close() + db, err = Open(*db.Options) + assert.Nil(err) - // Load - _, err = Open(db.Config) + // shrink and reload + db.Shrink() + db.Close() + _, err = Open(*db.Options) assert.Nil(err) } @@ -558,45 +529,58 @@ func TestZSet(t *testing.T) { db, err := createDB() assert.Nil(err) - genKey := func(i int) string { - return fmt.Sprintf("key-%08x", i) - } + genKey := func(i int) string { return fmt.Sprintf("key-%06d", i) } // ZAdd for i := 0; i < 1000; i++ { err := db.ZAdd("zset", genKey(i), float64(i)) assert.Nil(err) + + // card n, err := db.ZCard("zset") assert.Nil(err) assert.Equal(n, i+1) } - // ZGet - for i := 0; i < 1000; i++ { - score, err := db.ZGet("zset", genKey(i)) + check := func() { + // exist + for i := 0; i < 1000; i++ { + // card + n, err := db.ZCard("zset") + assert.Nil(err) + assert.Equal(n, 1000) + + // zget + score, err := db.ZGet("zset", genKey(i)) + assert.Nil(err) + assert.Equal(score, float64(i)) + } + + // not exist + for i := 1000; i < 2000; i++ { + score, err := db.ZGet("zset", genKey(i)) + assert.Equal(err, ErrKeyNotFound) + assert.Equal(score, float64(0)) + } + + // iter + count := 0 + err = db.ZIter("zset", func(key string, score float64) bool { + count++ + return count >= 1000 + }) assert.Nil(err) - assert.Equal(score, float64(i)) - } - for i := 1000; i < 2000; i++ { - score, err := db.ZGet("zset", genKey(i)) - assert.Equal(err, ErrKeyNotFound) - assert.Equal(score, float64(0)) + assert.Equal(count, 1000) } + check() + // Reload - db.Shrink() db.Close() - db, err = Open(db.Config) + db, err = Open(*db.Options) assert.Nil(err) - // ZIter - count := 0 - err = db.ZIter("zset", func(key string, score float64) bool { - count++ - return count >= 1000 - }) - assert.Nil(err) - assert.Equal(count, 1000) + check() // ZIncr for i := 0; i < 1000; i++ { @@ -612,21 +596,24 @@ func TestZSet(t *testing.T) { // ZRemove for i := 0; i < 800; i++ { - res, err := db.ZRemove("zset", genKey(i)) + err := db.ZRemove("zset", genKey(i)) assert.Nil(err) - assert.Equal(res, float64(i+3)) } for i := 5000; i < 6000; i++ { - res, err := db.ZRemove("zset", genKey(i)) + err := db.ZRemove("zset", genKey(i)) assert.Nil(err) - assert.Equal(res, float64(0)) } - // Reload + // reload + db.Close() + db, err = Open(*db.Options) + assert.Nil(err) + + // shrink and reload db.Shrink() db.Close() - db, err = Open(db.Config) + db, err = Open(*db.Options) assert.Nil(err) // Test error @@ -647,7 +634,7 @@ func TestZSet(t *testing.T) { _, err = db.ZIncr("set", "key", 1) assert.ErrorContains(err, ErrWrongType.Error()) - _, err = db.ZRemove("set", "key") + err = db.ZRemove("set", "key") assert.ErrorContains(err, ErrWrongType.Error()) _, err = db.ZCard("set") @@ -668,7 +655,7 @@ func TestInvalidCodec(t *testing.T) { assert.Nil(err) db.encode(NewCodec(op).Int(100)) db.Close() - _, err = Open(db.Config) + _, err = Open(*db.Options) assert.NotNil(err) } diff --git a/example/main.go b/example/main.go index 509ff19..b589772 100644 --- a/example/main.go +++ b/example/main.go @@ -13,7 +13,7 @@ import ( func main() { go http.ListenAndServe("localhost:6060", nil) - db, err := rotom.Open(rotom.DefaultConfig) + db, err := rotom.Open(rotom.DefaultOptions) if err != nil { panic(err) } diff --git a/go.mod b/go.mod index 6500b2c..efea0cd 100644 --- a/go.mod +++ b/go.mod @@ -11,8 +11,10 @@ require ( github.com/influxdata/tdigest v0.0.1 github.com/klauspost/compress v1.17.4 github.com/orcaman/concurrent-map/v2 v2.0.1 + github.com/rosedblabs/wal v1.3.6 github.com/sakeven/RbTree v0.0.0-20220710124251-94e35f9fed6c github.com/stretchr/testify v1.8.4 + github.com/tidwall/tinylru v1.2.1 github.com/xgzlucario/GigaCache v0.0.0-20231226161408-655de575becb github.com/zyedidia/generic v1.2.2-0.20230819223407-98022f93823f golang.org/x/exp v0.0.0-20231226003508-02704c960a9b @@ -24,12 +26,14 @@ require ( github.com/chenzhuoyu/iasm v0.9.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dolthub/maphash v0.1.0 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.2 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/kr/pretty v0.3.0 // indirect github.com/mschoch/smat v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/arch v0.6.0 // indirect golang.org/x/sys v0.15.0 // indirect diff --git a/go.sum b/go.sum index 8ad4490..c998c5b 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,8 @@ github.com/dolthub/swiss v0.2.1/go.mod h1:8AhKZZ1HK7g18j7v7k6c5cYIGEZJcPn0ARsai8 github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/golang-lru/v2 v2.0.2 h1:Dwmkdr5Nc/oBiXgJS3CDHNhJtIHkuZ3DZF5twqnfBdU= +github.com/hashicorp/golang-lru/v2 v2.0.2/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/influxdata/tdigest v0.0.1 h1:XpFptwYmnEKUqmkcDjrzffswZ3nvNeevbUSLPP/ZzIY= github.com/influxdata/tdigest v0.0.1/go.mod h1:Z0kXnxzbTC2qrx4NaIzYkE1k66+6oEDQTvL95hQFh5Y= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= @@ -53,6 +55,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rosedblabs/wal v1.3.6 h1:oxZYTPX/u4JuGDW98wQ1YamWqerlrlSUFKhgP6Gd/Ao= +github.com/rosedblabs/wal v1.3.6/go.mod h1:wdq54KJUyVTOv1uddMc6Cdh2d/YCIo8yjcwJAb1RCEM= github.com/sakeven/RbTree v0.0.0-20220710124251-94e35f9fed6c h1:u+rpw//zQsDgIhULb0RoiGuYMJwwii4pTWszctE+oKw= github.com/sakeven/RbTree v0.0.0-20220710124251-94e35f9fed6c/go.mod h1:zwEumjdcK6Q/ky/gFPqMviw1p7ZUb+B3pU4ybgOHvns= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -64,8 +68,12 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tidwall/tinylru v1.2.1 h1:VgBr72c2IEr+V+pCdkPZUwiQ0KJknnWIYbhxAVkYfQk= +github.com/tidwall/tinylru v1.2.1/go.mod h1:9bQnEduwB6inr2Y7AkBP7JPgCkyrhTV/ZpX0oOOpBI4= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/xgzlucario/GigaCache v0.0.0-20231226161408-655de575becb h1:f28RNmxeBCBhVIUHx0AySMtJ6XgA/Drxq8gP1RDF8vk= github.com/xgzlucario/GigaCache v0.0.0-20231226161408-655de575becb/go.mod h1:kgFPedK/WBjnfOdhxNSLe2tlYuiN2oG4qHCk5L8WPWA= github.com/zeebo/assert v1.3.1 h1:vukIABvugfNMZMQO1ABsyQDJDTVQbn+LWSMy1ol1h6A= diff --git a/option.go b/options.go similarity index 74% rename from option.go rename to options.go index 3a65fc6..6088607 100644 --- a/option.go +++ b/options.go @@ -15,9 +15,8 @@ const ( ) var ( - // Default config for db - DefaultConfig = Config{ - Path: "rotom.db", + DefaultOptions = Options{ + DirPath: "rotom", ShardCount: 1024, SyncPolicy: EverySecond, ShrinkInterval: time.Minute, @@ -25,19 +24,18 @@ var ( Logger: slog.Default(), } - // No persistent config - NoPersistentConfig = Config{ + NoPersistentOptions = Options{ ShardCount: 1024, SyncPolicy: Never, Logger: slog.Default(), } ) -// Config represents the configuration for a Store. -type Config struct { +// Options represents the configuration for a Store. +type Options struct { ShardCount int - Path string // Path of db file. + DirPath string // Dir path of db file. SyncPolicy SyncPolicy // Data sync policy. ShrinkInterval time.Duration // Shrink db file interval. diff --git a/wal/bench_test.go b/wal/bench_test.go new file mode 100644 index 0000000..e3bf731 --- /dev/null +++ b/wal/bench_test.go @@ -0,0 +1,115 @@ +package wal + +import ( + "fmt" + "math" + "os" + "testing" + + "github.com/rosedblabs/wal" +) + +func BenchmarkWrite(b *testing.B) { + b.Run("wal-sync", func(b *testing.B) { + os.RemoveAll("tmp") + opt := DefaultOptions + wal, _ := Open("tmp", opt) + + for i := 0; i < b.N; i++ { + data := fmt.Sprintf("%12d", i) + wal.Write(uint64(i), []byte(data)) + } + wal.Close() + }) + + b.Run("wal-nosync", func(b *testing.B) { + os.RemoveAll("tmp") + opt := DefaultOptions + opt.NoSync = true + wal, _ := Open("tmp", opt) + + for i := 0; i < b.N; i++ { + data := fmt.Sprintf("%12d", i) + wal.Write(uint64(i), []byte(data)) + } + wal.Close() + }) + + b.Run("rosedb-sync", func(b *testing.B) { + os.RemoveAll("tmp") + opt := wal.DefaultOptions + opt.DirPath = "tmp" + opt.Sync = true + wal, _ := wal.Open(opt) + + for i := 0; i < b.N; i++ { + data := fmt.Sprintf("%12d", i) + wal.Write([]byte(data)) + } + + wal.Close() + }) + + b.Run("rosedb/wal-nosync", func(b *testing.B) { + os.RemoveAll("tmp") + opt := wal.DefaultOptions + opt.DirPath = "tmp" + wal, _ := wal.Open(opt) + + for i := 0; i < b.N; i++ { + data := fmt.Sprintf("%12d", i) + wal.Write([]byte(data)) + } + wal.Close() + }) +} + +func BenchmarkRead(b *testing.B) { + const NUM = 100 * 10000 + + b.Run("wal", func(b *testing.B) { + // init log data. + os.RemoveAll("tmp") + opt := DefaultOptions + opt.NoSync = true + w, _ := Open("tmp", opt) + for i := 0; i < NUM; i++ { + data := fmt.Sprintf("%12d", i) + w.Write(uint64(i), []byte(data)) + } + w.Close() + + // reopen. + w, _ = Open("tmp", opt) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + w.Read(uint64(i % NUM)) + } + }) + + b.Run("rosedb/wal", func(b *testing.B) { + // init log data. + os.RemoveAll("tmp") + opt := wal.DefaultOptions + opt.DirPath = "tmp" + opt.BytesPerSync = math.MaxUint32 + posSlice := make([]*wal.ChunkPosition, 0, NUM) + w, _ := wal.Open(opt) + + for i := 0; i < NUM; i++ { + data := fmt.Sprintf("%12d", i) + pos, _ := w.Write([]byte(data)) + posSlice = append(posSlice, pos) + } + w.Close() + + // reopen. + w, _ = wal.Open(opt) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + w.Read(posSlice[i%NUM]) + } + }) +} diff --git a/wal/wal.go b/wal/wal.go new file mode 100644 index 0000000..56d0f1d --- /dev/null +++ b/wal/wal.go @@ -0,0 +1,821 @@ +package wal + +import ( + "encoding/binary" + "errors" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + + "github.com/tidwall/tinylru" +) + +var ( + // ErrCorrupt is returns when the log is corrupt. + ErrCorrupt = errors.New("log corrupt") + + // ErrClosed is returned when an operation cannot be completed because + // the log is closed. + ErrClosed = errors.New("log closed") + + // ErrNotFound is returned when an entry is not found. + ErrNotFound = errors.New("not found") + + // ErrOutOfOrder is returned from Write() when the index is not equal to + // LastIndex()+1. It's required that log monotonically grows by one and has + // no gaps. Thus, the series 10,11,12,13,14 is valid, but 10,11,13,14 is + // not because there's a gap between 11 and 13. Also, 10,12,11,13 is not + // valid because 12 and 11 are out of order. + ErrOutOfOrder = errors.New("out of order") + + // ErrOutOfRange is returned from TruncateFront() and TruncateBack() when + // the index not in the range of the log's first and last index. Or, this + // may be returned when the caller is attempting to remove *all* entries; + // The log requires that at least one entry exists following a truncate. + ErrOutOfRange = errors.New("out of range") +) + +// Options for Log +type Options struct { + // NoSync disables fsync after writes. This is less durable and puts the + // log at risk of data loss when there's a server crash. + NoSync bool + // SegmentSize of each segment. This is just a target value, actual size + // may differ. Default is 20 MB. + SegmentSize int + // SegmentCacheSize is the maximum number of segments that will be held in + // memory for caching. Increasing this value may enhance performance for + // concurrent read operations. Default is 1 + SegmentCacheSize int + // NoCopy allows for the Read() operation to return the raw underlying data + // slice. This is an optimization to help minimize allocations. When this + // option is set, do not modify the returned data because it may affect + // other Read calls. Default false + NoCopy bool + // Perms represents the datafiles modes and permission bits + DirPerms os.FileMode + FilePerms os.FileMode +} + +// DefaultOptions for Open(). +var DefaultOptions = &Options{ + NoSync: true, // Fsync after every write + SegmentSize: 20971520, // 20 MB log segment files. + SegmentCacheSize: 2, // Number of cached in-memory segments + NoCopy: false, // Make a new copy of data for every Read call. + DirPerms: 0750, // Permissions for the created directories + FilePerms: 0640, // Permissions for the created data files +} + +// Log represents a write ahead log +type Log struct { + mu sync.RWMutex + path string // absolute path to log directory + opts Options // log options + closed bool // log is closed + corrupt bool // log may be corrupt + segments []*segment // all known log segments + firstIndex uint64 // index of the first entry in log + lastIndex uint64 // index of the last entry in log + sfile *os.File // tail segment file handle + wbatch Batch // reusable write batch + scache tinylru.LRU // segment entries cache +} + +// segment represents a single segment file. +type segment struct { + path string // path of segment file + index uint64 // first index of segment + ebuf []byte // cached entries buffer + epos []bpos // cached entries positions in buffer +} + +type bpos struct { + pos int // byte position + end int // one byte past pos +} + +// Open a new write ahead log +func Open(path string, opts *Options) (*Log, error) { + if opts == nil { + opts = DefaultOptions + } + if opts.SegmentCacheSize <= 0 { + opts.SegmentCacheSize = DefaultOptions.SegmentCacheSize + } + if opts.SegmentSize <= 0 { + opts.SegmentSize = DefaultOptions.SegmentSize + } + if opts.DirPerms == 0 { + opts.DirPerms = DefaultOptions.DirPerms + } + if opts.FilePerms == 0 { + opts.FilePerms = DefaultOptions.FilePerms + } + + var err error + path, err = abs(path) + if err != nil { + return nil, err + } + l := &Log{path: path, opts: *opts} + l.scache.Resize(l.opts.SegmentCacheSize) + if err := os.MkdirAll(path, l.opts.DirPerms); err != nil { + return nil, err + } + if err := l.load(); err != nil { + return nil, err + } + return l, nil +} + +func abs(path string) (string, error) { + return filepath.Abs(path) +} + +func (l *Log) pushCache(segIdx int) { + _, _, _, v, evicted := l.scache.SetEvicted(segIdx, l.segments[segIdx]) + if evicted { + s := v.(*segment) + s.ebuf = nil + s.epos = nil + } +} + +// load all the segments. This operation also cleans up any START/END segments. +func (l *Log) load() error { + fis, err := os.ReadDir(l.path) + if err != nil { + return err + } + startIdx := -1 + endIdx := -1 + for _, fi := range fis { + name := fi.Name() + if fi.IsDir() || len(name) < 20 { + continue + } + index, err := strconv.ParseUint(name[:20], 10, 64) + if err != nil || index == 0 { + continue + } + isStart := len(name) == 26 && strings.HasSuffix(name, ".START") + isEnd := len(name) == 24 && strings.HasSuffix(name, ".END") + if len(name) == 20 || isStart || isEnd { + if isStart { + startIdx = len(l.segments) + } else if isEnd && endIdx == -1 { + endIdx = len(l.segments) + } + l.segments = append(l.segments, &segment{ + index: index, + path: filepath.Join(l.path, name), + }) + } + } + if len(l.segments) == 0 { + // Create a new log + l.segments = append(l.segments, &segment{ + index: 1, + path: filepath.Join(l.path, segmentName(1)), + }) + l.firstIndex = 1 + l.lastIndex = 0 + l.sfile, err = os.OpenFile(l.segments[0].path, os.O_CREATE|os.O_RDWR|os.O_TRUNC, l.opts.FilePerms) + return err + } + // Open existing log. Clean up log if START of END segments exists. + if startIdx != -1 { + if endIdx != -1 { + // There should not be a START and END at the same time + return ErrCorrupt + } + // Delete all files leading up to START + for i := 0; i < startIdx; i++ { + if err := os.Remove(l.segments[i].path); err != nil { + return err + } + } + l.segments = append([]*segment{}, l.segments[startIdx:]...) + // Rename the START segment + orgPath := l.segments[0].path + finalPath := orgPath[:len(orgPath)-len(".START")] + err := os.Rename(orgPath, finalPath) + if err != nil { + return err + } + l.segments[0].path = finalPath + } + if endIdx != -1 { + // Delete all files following END + for i := len(l.segments) - 1; i > endIdx; i-- { + if err := os.Remove(l.segments[i].path); err != nil { + return err + } + } + l.segments = append([]*segment{}, l.segments[:endIdx+1]...) + if len(l.segments) > 1 && l.segments[len(l.segments)-2].index == + l.segments[len(l.segments)-1].index { + // remove the segment prior to the END segment because it shares + // the same starting index. + l.segments[len(l.segments)-2] = l.segments[len(l.segments)-1] + l.segments = l.segments[:len(l.segments)-1] + } + // Rename the END segment + orgPath := l.segments[len(l.segments)-1].path + finalPath := orgPath[:len(orgPath)-len(".END")] + err := os.Rename(orgPath, finalPath) + if err != nil { + return err + } + l.segments[len(l.segments)-1].path = finalPath + } + l.firstIndex = l.segments[0].index + // Open the last segment for appending + lseg := l.segments[len(l.segments)-1] + l.sfile, err = os.OpenFile(lseg.path, os.O_WRONLY, l.opts.FilePerms) + if err != nil { + return err + } + if _, err := l.sfile.Seek(0, 2); err != nil { + return err + } + // Load the last segment entries + if err := l.loadSegmentEntries(lseg); err != nil { + return err + } + l.lastIndex = lseg.index + uint64(len(lseg.epos)) - 1 + return nil +} + +// segmentName returns a 20-byte textual representation of an index +// for lexical ordering. This is used for the file names of log segments. +func segmentName(index uint64) string { + return fmt.Sprintf("%020d", index) +} + +// Close the log. +func (l *Log) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + if l.closed { + if l.corrupt { + return ErrCorrupt + } + return ErrClosed + } + if err := l.sfile.Sync(); err != nil { + return err + } + if err := l.sfile.Close(); err != nil { + return err + } + l.closed = true + if l.corrupt { + return ErrCorrupt + } + return nil +} + +// Write an entry to the log. +func (l *Log) Write(index uint64, data []byte) error { + l.mu.Lock() + defer l.mu.Unlock() + if l.corrupt { + return ErrCorrupt + } else if l.closed { + return ErrClosed + } + l.wbatch.Clear() + l.wbatch.Write(index, data) + return l.writeBatch(&l.wbatch) +} + +func (l *Log) appendEntry(dst []byte, index uint64, data []byte) (out []byte, epos bpos) { + return appendBinaryEntry(dst, data) +} + +// Cycle the old segment for a new segment. +func (l *Log) cycle() error { + if err := l.sfile.Sync(); err != nil { + return err + } + if err := l.sfile.Close(); err != nil { + return err + } + // cache the previous segment + l.pushCache(len(l.segments) - 1) + s := &segment{ + index: l.lastIndex + 1, + path: filepath.Join(l.path, segmentName(l.lastIndex+1)), + } + var err error + l.sfile, err = os.OpenFile(s.path, os.O_CREATE|os.O_RDWR|os.O_TRUNC, l.opts.FilePerms) + if err != nil { + return err + } + l.segments = append(l.segments, s) + return nil +} + +func appendBinaryEntry(dst []byte, data []byte) (out []byte, epos bpos) { + // data_size + data + pos := len(dst) + dst = appendUvarint(dst, uint64(len(data))) + dst = append(dst, data...) + return dst, bpos{pos, len(dst)} +} + +func appendUvarint(dst []byte, x uint64) []byte { + var buf [10]byte + n := binary.PutUvarint(buf[:], x) + dst = append(dst, buf[:n]...) + return dst +} + +// Batch of entries. Used to write multiple entries at once using WriteBatch(). +type Batch struct { + entries []batchEntry + datas []byte +} + +type batchEntry struct { + index uint64 + size int +} + +// Write an entry to the batch +func (b *Batch) Write(index uint64, data []byte) { + b.entries = append(b.entries, batchEntry{index, len(data)}) + b.datas = append(b.datas, data...) +} + +// Clear the batch for reuse. +func (b *Batch) Clear() { + b.entries = b.entries[:0] + b.datas = b.datas[:0] +} + +// WriteBatch writes the entries in the batch to the log in the order that they +// were added to the batch. The batch is cleared upon a successful return. +func (l *Log) WriteBatch(b *Batch) error { + l.mu.Lock() + defer l.mu.Unlock() + if l.corrupt { + return ErrCorrupt + } else if l.closed { + return ErrClosed + } + if len(b.entries) == 0 { + return nil + } + return l.writeBatch(b) +} + +func (l *Log) writeBatch(b *Batch) error { + // check that all indexes in batch are sane + for i := 0; i < len(b.entries); i++ { + if b.entries[i].index != l.lastIndex+uint64(i+1) { + return ErrOutOfOrder + } + } + // load the tail segment + s := l.segments[len(l.segments)-1] + if len(s.ebuf) > l.opts.SegmentSize { + // tail segment has reached capacity. Close it and create a new one. + if err := l.cycle(); err != nil { + return err + } + s = l.segments[len(l.segments)-1] + } + + mark := len(s.ebuf) + datas := b.datas + for i := 0; i < len(b.entries); i++ { + data := datas[:b.entries[i].size] + var epos bpos + s.ebuf, epos = l.appendEntry(s.ebuf, b.entries[i].index, data) + s.epos = append(s.epos, epos) + if len(s.ebuf) >= l.opts.SegmentSize { + // segment has reached capacity, cycle now + if _, err := l.sfile.Write(s.ebuf[mark:]); err != nil { + return err + } + l.lastIndex = b.entries[i].index + if err := l.cycle(); err != nil { + return err + } + s = l.segments[len(l.segments)-1] + mark = 0 + } + datas = datas[b.entries[i].size:] + } + if len(s.ebuf)-mark > 0 { + if _, err := l.sfile.Write(s.ebuf[mark:]); err != nil { + return err + } + l.lastIndex = b.entries[len(b.entries)-1].index + } + if !l.opts.NoSync { + if err := l.sfile.Sync(); err != nil { + return err + } + } + b.Clear() + return nil +} + +// FirstIndex returns the index of the first entry in the log. Returns zero +// when log has no entries. +func (l *Log) FirstIndex() (index uint64, err error) { + l.mu.RLock() + defer l.mu.RUnlock() + if l.corrupt { + return 0, ErrCorrupt + } else if l.closed { + return 0, ErrClosed + } + // We check the lastIndex for zero because the firstIndex is always one or + // more, even when there's no entries + if l.lastIndex == 0 { + return 0, nil + } + return l.firstIndex, nil +} + +// LastIndex returns the index of the last entry in the log. Returns zero when +// log has no entries. +func (l *Log) LastIndex() (index uint64, err error) { + l.mu.RLock() + defer l.mu.RUnlock() + if l.corrupt { + return 0, ErrCorrupt + } else if l.closed { + return 0, ErrClosed + } + return l.lastIndex, nil +} + +// findSegment performs a bsearch on the segments +func (l *Log) findSegment(index uint64) int { + i, j := 0, len(l.segments) + for i < j { + h := i + (j-i)/2 + if index >= l.segments[h].index { + i = h + 1 + } else { + j = h + } + } + return i - 1 +} + +func (l *Log) loadSegmentEntries(s *segment) error { + data, err := os.ReadFile(s.path) + if err != nil { + return err + } + ebuf := data + var epos []bpos + var pos int + for exidx := s.index; len(data) > 0; exidx++ { + var n int + n, err = loadNextBinaryEntry(data) + if err != nil { + return err + } + data = data[n:] + epos = append(epos, bpos{pos, pos + n}) + pos += n + } + s.ebuf = ebuf + s.epos = epos + return nil +} + +func loadNextBinaryEntry(data []byte) (n int, err error) { + // data_size + data + size, n := binary.Uvarint(data) + if n <= 0 { + return 0, ErrCorrupt + } + if uint64(len(data)-n) < size { + return 0, ErrCorrupt + } + return n + int(size), nil +} + +// loadSegment loads the segment entries into memory, pushes it to the front +// of the lru cache, and returns it. +func (l *Log) loadSegment(index uint64) (*segment, error) { + // check the last segment first. + lseg := l.segments[len(l.segments)-1] + if index >= lseg.index { + return lseg, nil + } + // check the most recent cached segment + var rseg *segment + l.scache.Range(func(_, v interface{}) bool { + s := v.(*segment) + if index >= s.index && index < s.index+uint64(len(s.epos)) { + rseg = s + } + return false + }) + if rseg != nil { + return rseg, nil + } + // find in the segment array + idx := l.findSegment(index) + s := l.segments[idx] + if len(s.epos) == 0 { + // load the entries from cache + if err := l.loadSegmentEntries(s); err != nil { + return nil, err + } + } + // push the segment to the front of the cache + l.pushCache(idx) + return s, nil +} + +// Read an entry from the log. Returns a byte slice containing the data entry. +func (l *Log) Read(index uint64) (data []byte, err error) { + l.mu.RLock() + defer l.mu.RUnlock() + if l.corrupt { + return nil, ErrCorrupt + } else if l.closed { + return nil, ErrClosed + } + if index == 0 || index < l.firstIndex || index > l.lastIndex { + return nil, ErrNotFound + } + s, err := l.loadSegment(index) + if err != nil { + return nil, err + } + epos := s.epos[index-s.index] + edata := s.ebuf[epos.pos:epos.end] + // binary read + size, n := binary.Uvarint(edata) + if n <= 0 { + return nil, ErrCorrupt + } + if uint64(len(edata)-n) < size { + return nil, ErrCorrupt + } + if l.opts.NoCopy { + data = edata[n : uint64(n)+size] + } else { + data = make([]byte, size) + copy(data, edata[n:]) + } + return data, nil +} + +// ClearCache clears the segment cache +func (l *Log) ClearCache() error { + l.mu.Lock() + defer l.mu.Unlock() + if l.corrupt { + return ErrCorrupt + } else if l.closed { + return ErrClosed + } + l.clearCache() + return nil +} +func (l *Log) clearCache() { + l.scache.Range(func(_, v interface{}) bool { + s := v.(*segment) + s.ebuf = nil + s.epos = nil + return true + }) + l.scache = tinylru.LRU{} + l.scache.Resize(l.opts.SegmentCacheSize) +} + +// TruncateFront truncates the front of the log by removing all entries that +// are before the provided `index`. In other words the entry at +// `index` becomes the first entry in the log. +func (l *Log) TruncateFront(index uint64) error { + l.mu.Lock() + defer l.mu.Unlock() + if l.corrupt { + return ErrCorrupt + } else if l.closed { + return ErrClosed + } + return l.truncateFront(index) +} +func (l *Log) truncateFront(index uint64) (err error) { + if index == 0 || l.lastIndex == 0 || + index < l.firstIndex || index > l.lastIndex { + return ErrOutOfRange + } + if index == l.firstIndex { + // nothing to truncate + return nil + } + segIdx := l.findSegment(index) + var s *segment + s, err = l.loadSegment(index) + if err != nil { + return err + } + epos := s.epos[index-s.index:] + ebuf := s.ebuf[epos[0].pos:] + // Create a temp file contains the truncated segment. + tempName := filepath.Join(l.path, "TEMP") + err = func() error { + f, err := os.OpenFile(tempName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, l.opts.FilePerms) + if err != nil { + return err + } + defer f.Close() + if _, err := f.Write(ebuf); err != nil { + return err + } + if err := f.Sync(); err != nil { + return err + } + return f.Close() + }() + // Rename the TEMP file to it's START file name. + startName := filepath.Join(l.path, segmentName(index)+".START") + if err = os.Rename(tempName, startName); err != nil { + return err + } + // The log was truncated but still needs some file cleanup. Any errors + // following this message will not cause an on-disk data ocorruption, but + // may cause an inconsistency with the current program, so we'll return + // ErrCorrupt so the the user can attempt a recover by calling Close() + // followed by Open(). + defer func() { + if v := recover(); v != nil { + err = ErrCorrupt + l.corrupt = true + } + }() + if segIdx == len(l.segments)-1 { + // Close the tail segment file + if err = l.sfile.Close(); err != nil { + return err + } + } + // Delete truncated segment files + for i := 0; i <= segIdx; i++ { + if err = os.Remove(l.segments[i].path); err != nil { + return err + } + } + // Rename the START file to the final truncated segment name. + newName := filepath.Join(l.path, segmentName(index)) + if err = os.Rename(startName, newName); err != nil { + return err + } + s.path = newName + s.index = index + if segIdx == len(l.segments)-1 { + // Reopen the tail segment file + if l.sfile, err = os.OpenFile(newName, os.O_WRONLY, l.opts.FilePerms); err != nil { + return err + } + var n int64 + if n, err = l.sfile.Seek(0, 2); err != nil { + return err + } + if n != int64(len(ebuf)) { + err = errors.New("invalid seek") + return err + } + // Load the last segment entries + if err = l.loadSegmentEntries(s); err != nil { + return err + } + } + l.segments = append([]*segment{}, l.segments[segIdx:]...) + l.firstIndex = index + l.clearCache() + return nil +} + +// TruncateBack truncates the back of the log by removing all entries that +// are after the provided `index`. In other words the entry at `index` +// becomes the last entry in the log. +func (l *Log) TruncateBack(index uint64) error { + l.mu.Lock() + defer l.mu.Unlock() + if l.corrupt { + return ErrCorrupt + } else if l.closed { + return ErrClosed + } + return l.truncateBack(index) +} + +func (l *Log) truncateBack(index uint64) (err error) { + if index == 0 || l.lastIndex == 0 || + index < l.firstIndex || index > l.lastIndex { + return ErrOutOfRange + } + if index == l.lastIndex { + // nothing to truncate + return nil + } + segIdx := l.findSegment(index) + var s *segment + s, err = l.loadSegment(index) + if err != nil { + return err + } + epos := s.epos[:index-s.index+1] + ebuf := s.ebuf[:epos[len(epos)-1].end] + // Create a temp file contains the truncated segment. + tempName := filepath.Join(l.path, "TEMP") + err = func() error { + f, err := os.OpenFile(tempName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, l.opts.FilePerms) + if err != nil { + return err + } + defer f.Close() + if _, err := f.Write(ebuf); err != nil { + return err + } + if err := f.Sync(); err != nil { + return err + } + return f.Close() + }() + // Rename the TEMP file to it's END file name. + endName := filepath.Join(l.path, segmentName(s.index)+".END") + if err = os.Rename(tempName, endName); err != nil { + return err + } + // The log was truncated but still needs some file cleanup. Any errors + // following this message will not cause an on-disk data ocorruption, but + // may cause an inconsistency with the current program, so we'll return + // ErrCorrupt so the the user can attempt a recover by calling Close() + // followed by Open(). + defer func() { + if v := recover(); v != nil { + err = ErrCorrupt + l.corrupt = true + } + }() + + // Close the tail segment file + if err = l.sfile.Close(); err != nil { + return err + } + // Delete truncated segment files + for i := segIdx; i < len(l.segments); i++ { + if err = os.Remove(l.segments[i].path); err != nil { + return err + } + } + // Rename the END file to the final truncated segment name. + newName := filepath.Join(l.path, segmentName(s.index)) + if err = os.Rename(endName, newName); err != nil { + return err + } + // Reopen the tail segment file + if l.sfile, err = os.OpenFile(newName, os.O_WRONLY, l.opts.FilePerms); err != nil { + return err + } + var n int64 + n, err = l.sfile.Seek(0, 2) + if err != nil { + return err + } + if n != int64(len(ebuf)) { + err = errors.New("invalid seek") + return err + } + s.path = newName + l.segments = append([]*segment{}, l.segments[:segIdx+1]...) + l.lastIndex = index + l.clearCache() + if err = l.loadSegmentEntries(s); err != nil { + return err + } + return nil +} + +// Sync performs an fsync on the log. This is not necessary when the +// NoSync option is set to false. +func (l *Log) Sync() error { + l.mu.Lock() + defer l.mu.Unlock() + if l.corrupt { + return ErrCorrupt + } else if l.closed { + return ErrClosed + } + return l.sfile.Sync() +} diff --git a/wal/wal_test.go b/wal/wal_test.go new file mode 100644 index 0000000..fb3bbe9 --- /dev/null +++ b/wal/wal_test.go @@ -0,0 +1,817 @@ +package wal + +import ( + "fmt" + "math/rand" + "os" + "strings" + "sync/atomic" + "testing" +) + +func dataStr(index uint64) string { + if index%2 == 0 { + return fmt.Sprintf("data-\"%d\"", index) + } + return fmt.Sprintf("data-'%d'", index) +} + +func testLog(t *testing.T, opts *Options, N int) { + logPath := "testlog/" + strings.Join(strings.Split(t.Name(), "/")[1:], "/") + l, err := Open(logPath, opts) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + // FirstIndex - should be zero + n, err := l.FirstIndex() + if err != nil { + t.Fatal(err) + } + if n != 0 { + t.Fatalf("expected %d, got %d", 0, n) + } + + // LastIndex - should be zero + n, err = l.LastIndex() + if err != nil { + t.Fatal(err) + } + if n != 0 { + t.Fatalf("expected %d, got %d", 0, n) + } + + for i := 1; i <= N; i++ { + // Write - try to append previous index, should fail + err = l.Write(uint64(i-1), nil) + if err != ErrOutOfOrder { + t.Fatalf("expected %v, got %v", ErrOutOfOrder, err) + } + // Write - append next item + err = l.Write(uint64(i), []byte(dataStr(uint64(i)))) + if err != nil { + t.Fatalf("expected %v, got %v", nil, err) + } + // Write - get next item + data, err := l.Read(uint64(i)) + if err != nil { + t.Fatalf("expected %v, got %v", nil, err) + } + if string(data) != dataStr(uint64(i)) { + t.Fatalf("expected %s, got %s", dataStr(uint64(i)), data) + } + } + + // Read -- should fail, not found + _, err = l.Read(0) + if err != ErrNotFound { + t.Fatalf("expected %v, got %v", ErrNotFound, err) + } + // Read -- read back all entries + for i := 1; i <= N; i++ { + data, err := l.Read(uint64(i)) + if err != nil { + t.Fatalf("error while getting %d", i) + } + if string(data) != dataStr(uint64(i)) { + t.Fatalf("expected %s, got %s", dataStr(uint64(i)), data) + } + } + // Read -- read back first half entries + for i := 1; i <= N/2; i++ { + data, err := l.Read(uint64(i)) + if err != nil { + t.Fatalf("error while getting %d", i) + } + if string(data) != dataStr(uint64(i)) { + t.Fatalf("expected %s, got %s", dataStr(uint64(i)), data) + } + } + // Read -- read second third entries + for i := N / 3; i <= N/3+N/3; i++ { + data, err := l.Read(uint64(i)) + if err != nil { + t.Fatalf("error while getting %d", i) + } + if string(data) != dataStr(uint64(i)) { + t.Fatalf("expected %s, got %s", dataStr(uint64(i)), data) + } + } + + // Read -- random access + for _, v := range rand.Perm(N) { + index := uint64(v + 1) + data, err := l.Read(index) + if err != nil { + t.Fatal(err) + } + if dataStr(index) != string(data) { + t.Fatalf("expected %v, got %v", dataStr(index), string(data)) + } + } + + // FirstIndex/LastIndex -- check valid first and last indexes + n, err = l.FirstIndex() + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("expected %d, got %d", 1, n) + } + n, err = l.LastIndex() + if err != nil { + t.Fatal(err) + } + if n != uint64(N) { + t.Fatalf("expected %d, got %d", N, n) + } + + // Close -- close the log + if err := l.Close(); err != nil { + t.Fatal(err) + } + + // Write - try while closed + err = l.Write(1, nil) + if err != ErrClosed { + t.Fatalf("expected %v, got %v", ErrClosed, err) + } + // WriteBatch - try while closed + err = l.WriteBatch(nil) + if err != ErrClosed { + t.Fatalf("expected %v, got %v", ErrClosed, err) + } + // FirstIndex - try while closed + _, err = l.FirstIndex() + if err != ErrClosed { + t.Fatalf("expected %v, got %v", ErrClosed, err) + } + // LastIndex - try while closed + _, err = l.LastIndex() + if err != ErrClosed { + t.Fatalf("expected %v, got %v", ErrClosed, err) + } + // Get - try while closed + _, err = l.Read(0) + if err != ErrClosed { + t.Fatalf("expected %v, got %v", ErrClosed, err) + } + // TruncateFront - try while closed + err = l.TruncateFront(0) + if err != ErrClosed { + t.Fatalf("expected %v, got %v", ErrClosed, err) + } + // TruncateBack - try while closed + err = l.TruncateBack(0) + if err != ErrClosed { + t.Fatalf("expected %v, got %v", ErrClosed, err) + } + + // Open -- reopen log + l, err = Open(logPath, opts) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + // Read -- read back all entries + for i := 1; i <= N; i++ { + data, err := l.Read(uint64(i)) + if err != nil { + t.Fatalf("error while getting %d, err=%s", i, err) + } + if string(data) != dataStr(uint64(i)) { + t.Fatalf("expected %s, got %s", dataStr(uint64(i)), data) + } + } + // FirstIndex/LastIndex -- check valid first and last indexes + n, err = l.FirstIndex() + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("expected %d, got %d", 1, n) + } + n, err = l.LastIndex() + if err != nil { + t.Fatal(err) + } + if n != uint64(N) { + t.Fatalf("expected %d, got %d", N, n) + } + // Write -- add 50 more items + for i := N + 1; i <= N+50; i++ { + index := uint64(i) + if err := l.Write(index, []byte(dataStr(index))); err != nil { + t.Fatal(err) + } + data, err := l.Read(index) + if err != nil { + t.Fatal(err) + } + if string(data) != dataStr(index) { + t.Fatalf("expected %v, got %v", dataStr(index), string(data)) + } + } + N += 50 + // FirstIndex/LastIndex -- check valid first and last indexes + n, err = l.FirstIndex() + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("expected %d, got %d", 1, n) + } + n, err = l.LastIndex() + if err != nil { + t.Fatal(err) + } + if n != uint64(N) { + t.Fatalf("expected %d, got %d", N, n) + } + // Batch -- test batch writes + b := new(Batch) + b.Write(1, nil) + b.Write(2, nil) + b.Write(3, nil) + // WriteBatch -- should fail out of order + err = l.WriteBatch(b) + if err != ErrOutOfOrder { + t.Fatalf("expected %v, got %v", ErrOutOfOrder, nil) + } + // Clear -- clear the batch + b.Clear() + // WriteBatch -- should succeed + err = l.WriteBatch(b) + if err != nil { + t.Fatal(err) + } + // Write 100 entries in batches of 10 + for i := 0; i < 10; i++ { + for i := N + 1; i <= N+10; i++ { + index := uint64(i) + b.Write(index, []byte(dataStr(index))) + } + err = l.WriteBatch(b) + if err != nil { + t.Fatal(err) + } + N += 10 + } + // Read -- read back all entries + for i := 1; i <= N; i++ { + data, err := l.Read(uint64(i)) + if err != nil { + t.Fatalf("error while getting %d", i) + } + if string(data) != dataStr(uint64(i)) { + t.Fatalf("expected %s, got %s", dataStr(uint64(i)), data) + } + } + + // Write -- one entry, so the buffer might be activated + err = l.Write(uint64(N+1), []byte(dataStr(uint64(N+1)))) + if err != nil { + t.Fatal(err) + } + N++ + // Read -- one random read, so there is an opened reader + data, err := l.Read(uint64(N / 2)) + if err != nil { + t.Fatal(err) + } + if string(data) != dataStr(uint64(N/2)) { + t.Fatalf("expected %v, got %v", dataStr(uint64(N/2)), string(data)) + } + + // TruncateFront -- should fail, out of range + for _, i := range []int{0, N + 1} { + index := uint64(i) + if err = l.TruncateFront(index); err != ErrOutOfRange { + t.Fatalf("expected %v, got %v", ErrOutOfRange, err) + } + testFirstLast(t, l, uint64(1), uint64(N), nil) + } + + // TruncateBack -- should fail, out of range + err = l.TruncateFront(0) + if err != ErrOutOfRange { + t.Fatalf("expected %v, got %v", ErrOutOfRange, err) + } + testFirstLast(t, l, uint64(1), uint64(N), nil) + + // TruncateFront -- Remove no entries + if err = l.TruncateFront(1); err != nil { + t.Fatal(err) + } + testFirstLast(t, l, uint64(1), uint64(N), nil) + + // TruncateFront -- Remove first 80 entries + if err = l.TruncateFront(81); err != nil { + t.Fatal(err) + } + testFirstLast(t, l, uint64(81), uint64(N), nil) + + // Write -- one entry, so the buffer might be activated + err = l.Write(uint64(N+1), []byte(dataStr(uint64(N+1)))) + if err != nil { + t.Fatal(err) + } + N++ + testFirstLast(t, l, uint64(81), uint64(N), nil) + + // Read -- one random read, so there is an opened reader + data, err = l.Read(uint64(N / 2)) + if err != nil { + t.Fatal(err) + } + if string(data) != dataStr(uint64(N/2)) { + t.Fatalf("expected %v, got %v", dataStr(uint64(N/2)), string(data)) + } + + // TruncateBack -- should fail, out of range + for _, i := range []int{0, 80} { + index := uint64(i) + if err = l.TruncateBack(index); err != ErrOutOfRange { + t.Fatalf("expected %v, got %v", ErrOutOfRange, err) + } + testFirstLast(t, l, uint64(81), uint64(N), nil) + } + + // TruncateBack -- Remove no entries + if err = l.TruncateBack(uint64(N)); err != nil { + t.Fatal(err) + } + testFirstLast(t, l, uint64(81), uint64(N), nil) + // TruncateBack -- Remove last 80 entries + if err = l.TruncateBack(uint64(N - 80)); err != nil { + t.Fatal(err) + } + N -= 80 + testFirstLast(t, l, uint64(81), uint64(N), nil) + + // Read -- read back all entries + for i := 81; i <= N; i++ { + data, err := l.Read(uint64(i)) + if err != nil { + t.Fatalf("error while getting %d", i) + } + if string(data) != dataStr(uint64(i)) { + t.Fatalf("expected %s, got %s", dataStr(uint64(i)), data) + } + } + + // Close -- close log after truncating + if err = l.Close(); err != nil { + t.Fatal(err) + } + + // Open -- open log after truncating + l, err = Open(logPath, opts) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + testFirstLast(t, l, uint64(81), uint64(N), nil) + + // Read -- read back all entries + for i := 81; i <= N; i++ { + data, err := l.Read(uint64(i)) + if err != nil { + t.Fatalf("error while getting %d", i) + } + if string(data) != dataStr(uint64(i)) { + t.Fatalf("expected %s, got %s", dataStr(uint64(i)), data) + } + } + + // TruncateFront -- truncate all entries but one + if err = l.TruncateFront(uint64(N)); err != nil { + t.Fatal(err) + } + testFirstLast(t, l, uint64(N), uint64(N), nil) + + // Write -- write on entry + err = l.Write(uint64(N+1), []byte(dataStr(uint64(N+1)))) + if err != nil { + t.Fatal(err) + } + N++ + testFirstLast(t, l, uint64(N-1), uint64(N), nil) + + // TruncateBack -- truncate all entries but one + if err = l.TruncateBack(uint64(N - 1)); err != nil { + t.Fatal(err) + } + N-- + testFirstLast(t, l, uint64(N), uint64(N), nil) + + if err = l.Write(uint64(N+1), []byte(dataStr(uint64(N+1)))); err != nil { + t.Fatal(err) + } + N++ + + l.Sync() + testFirstLast(t, l, uint64(N-1), uint64(N), nil) +} + +func testFirstLast(t *testing.T, l *Log, expectFirst, expectLast uint64, data func(index uint64) []byte) { + t.Helper() + fi, err := l.FirstIndex() + if err != nil { + t.Fatal(err) + } + li, err := l.LastIndex() + if err != nil { + t.Fatal(err) + } + if fi != expectFirst || li != expectLast { + t.Fatalf("expected %v/%v, got %v/%v", expectFirst, expectLast, fi, li) + } + for i := fi; i <= li; i++ { + dt1, err := l.Read(i) + if err != nil { + t.Fatal(err) + } + if data != nil { + dt2 := data(i) + if string(dt1) != string(dt2) { + t.Fatalf("mismatch '%s' != '%s'", dt2, dt1) + } + } + } + +} + +func TestLog(t *testing.T) { + os.RemoveAll("testlog") + defer os.RemoveAll("testlog") + + t.Run("nil-opts", func(t *testing.T) { + testLog(t, nil, 100) + }) + t.Run("no-sync", func(t *testing.T) { + testLog(t, makeOpts(512, true), 100) + }) + t.Run("sync", func(t *testing.T) { + testLog(t, makeOpts(512, false), 100) + }) +} + +func TestOutliers(t *testing.T) { + // Create some scenarios where the log has been corrupted, operations + // fail, or various weirdnesses. + t.Run("fail-in-memory", func(t *testing.T) { + if l, err := Open(":memory:", nil); err == nil { + l.Close() + t.Fatal("expected error") + } + }) + t.Run("fail-not-a-directory", func(t *testing.T) { + defer os.RemoveAll("testlog/file") + if err := os.MkdirAll("testlog", 0777); err != nil { + t.Fatal(err) + } else if f, err := os.Create("testlog/file"); err != nil { + t.Fatal(err) + } else if err := f.Close(); err != nil { + t.Fatal(err) + } else if l, err := Open("testlog/file", nil); err == nil { + l.Close() + t.Fatal("expected error") + } + }) + t.Run("load-with-junk-files", func(t *testing.T) { + // junk should be ignored + defer os.RemoveAll("testlog/junk") + if err := os.MkdirAll("testlog/junk/other1", 0777); err != nil { + t.Fatal(err) + } + f, err := os.Create("testlog/junk/other2") + if err != nil { + t.Fatal(err) + } + f.Close() + f, err = os.Create("testlog/junk/" + strings.Repeat("A", 20)) + if err != nil { + t.Fatal(err) + } + f.Close() + l, err := Open("testlog/junk", nil) + if err != nil { + t.Fatal(err) + } + l.Close() + }) + + t.Run("start-marker-file", func(t *testing.T) { + lpath := "testlog/start-marker" + opts := makeOpts(512, true) + l := must(Open(lpath, opts)).(*Log) + defer l.Close() + for i := uint64(1); i <= 100; i++ { + must(nil, l.Write(i, []byte(dataStr(i)))) + } + path := l.segments[l.findSegment(35)].path + firstIndex := l.segments[l.findSegment(35)].index + must(nil, l.Close()) + data := must(os.ReadFile(path)).([]byte) + must(nil, os.WriteFile(path+".START", data, 0666)) + l = must(Open(lpath, opts)).(*Log) + defer l.Close() + testFirstLast(t, l, firstIndex, 100, nil) + }) +} + +func makeOpts(segSize int, noSync bool) *Options { + opts := *DefaultOptions + opts.SegmentSize = segSize + opts.NoSync = noSync + return &opts +} + +// https://github.com/tidwall/wal/issues/1 +func TestIssue1(t *testing.T) { + in := []byte{0, 0, 0, 0, 0, 0, 0, 1, 37, 108, 131, 178, 151, 17, 77, 32, + 27, 48, 23, 159, 63, 14, 240, 202, 206, 151, 131, 98, 45, 165, 151, 67, + 38, 180, 54, 23, 138, 238, 246, 16, 0, 0, 0, 0} + opts := *DefaultOptions + os.RemoveAll("testlog") + l, err := Open("testlog", &opts) + if err != nil { + t.Fatal(err) + } + defer l.Close() + if err := l.Write(1, in); err != nil { + t.Fatal(err) + } + out, err := l.Read(1) + if err != nil { + t.Fatal(err) + } + if string(in) != string(out) { + t.Fatal("data mismatch") + } +} + +func TestSimpleTruncateFront(t *testing.T) { + os.RemoveAll("testlog") + + opts := &Options{ + NoSync: true, + SegmentSize: 100, + } + + l, err := Open("testlog", opts) + if err != nil { + t.Fatal(err) + } + defer func() { + l.Close() + }() + + makeData := func(index uint64) []byte { + return []byte(fmt.Sprintf("data-%d", index)) + } + + valid := func(t *testing.T, first, last uint64) { + t.Helper() + index, err := l.FirstIndex() + if err != nil { + t.Fatal(err) + } + if index != first { + t.Fatalf("expected %v, got %v", first, index) + } + index, err = l.LastIndex() + if err != nil { + t.Fatal(err) + } + if index != last { + t.Fatalf("expected %v, got %v", last, index) + } + for i := first; i <= last; i++ { + data, err := l.Read(i) + if err != nil { + t.Fatal(err) + } + if string(data) != string(makeData(i)) { + t.Fatalf("expcted '%s', got '%s'", makeData(i), data) + } + } + } + validReopen := func(t *testing.T, first, last uint64) { + t.Helper() + valid(t, first, last) + if err := l.Close(); err != nil { + t.Fatal(err) + } + l, err = Open("testlog", opts) + if err != nil { + t.Fatal(err) + } + valid(t, first, last) + } + for i := 1; i <= 100; i++ { + err := l.Write(uint64(i), makeData(uint64(i))) + if err != nil { + t.Fatal(err) + } + } + validReopen(t, 1, 100) + + if err := l.TruncateFront(1); err != nil { + t.Fatal(err) + } + validReopen(t, 1, 100) + + if err := l.TruncateFront(2); err != nil { + t.Fatal(err) + } + validReopen(t, 2, 100) + + if err := l.TruncateFront(4); err != nil { + t.Fatal(err) + } + validReopen(t, 4, 100) + + if err := l.TruncateFront(5); err != nil { + t.Fatal(err) + } + validReopen(t, 5, 100) + + if err := l.TruncateFront(99); err != nil { + t.Fatal(err) + } + validReopen(t, 99, 100) + + if err := l.TruncateFront(100); err != nil { + t.Fatal(err) + } + validReopen(t, 100, 100) + +} + +func TestSimpleTruncateBack(t *testing.T) { + os.RemoveAll("testlog") + + opts := &Options{ + NoSync: true, + SegmentSize: 100, + } + + l, err := Open("testlog", opts) + if err != nil { + t.Fatal(err) + } + defer func() { + l.Close() + }() + + makeData := func(index uint64) []byte { + return []byte(fmt.Sprintf("data-%d", index)) + } + + valid := func(t *testing.T, first, last uint64) { + t.Helper() + index, err := l.FirstIndex() + if err != nil { + t.Fatal(err) + } + if index != first { + t.Fatalf("expected %v, got %v", first, index) + } + index, err = l.LastIndex() + if err != nil { + t.Fatal(err) + } + if index != last { + t.Fatalf("expected %v, got %v", last, index) + } + for i := first; i <= last; i++ { + data, err := l.Read(i) + if err != nil { + t.Fatal(err) + } + if string(data) != string(makeData(i)) { + t.Fatalf("expcted '%s', got '%s'", makeData(i), data) + } + } + } + validReopen := func(t *testing.T, first, last uint64) { + t.Helper() + valid(t, first, last) + if err := l.Close(); err != nil { + t.Fatal(err) + } + l, err = Open("testlog", opts) + if err != nil { + t.Fatal(err) + } + valid(t, first, last) + } + for i := 1; i <= 100; i++ { + err := l.Write(uint64(i), makeData(uint64(i))) + if err != nil { + t.Fatal(err) + } + } + validReopen(t, 1, 100) + + ///////////////////////////////////////////////////////////// + if err := l.TruncateBack(100); err != nil { + t.Fatal(err) + } + validReopen(t, 1, 100) + if err := l.Write(101, makeData(101)); err != nil { + t.Fatal(err) + } + validReopen(t, 1, 101) + + ///////////////////////////////////////////////////////////// + if err := l.TruncateBack(99); err != nil { + t.Fatal(err) + } + validReopen(t, 1, 99) + if err := l.Write(100, makeData(100)); err != nil { + t.Fatal(err) + } + validReopen(t, 1, 100) + + if err := l.TruncateBack(94); err != nil { + t.Fatal(err) + } + validReopen(t, 1, 94) + + if err := l.TruncateBack(93); err != nil { + t.Fatal(err) + } + validReopen(t, 1, 93) + + if err := l.TruncateBack(92); err != nil { + t.Fatal(err) + } + validReopen(t, 1, 92) + +} + +func TestConcurrency(t *testing.T) { + os.RemoveAll("testlog") + + l, err := Open("testlog", &Options{ + NoSync: true, + NoCopy: true, + }) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + // Write 1000 entries + for i := 1; i <= 1000; i++ { + err := l.Write(uint64(i), []byte(dataStr(uint64(i)))) + if err != nil { + t.Fatal(err) + } + } + + // Perform 100,000 reads (over 100 threads) + finished := int32(0) + maxIndex := int32(1000) + numReads := int32(0) + for i := 0; i < 100; i++ { + go func() { + defer atomic.AddInt32(&finished, 1) + + for i := 0; i < 1_000; i++ { + index := rand.Int31n(atomic.LoadInt32(&maxIndex)) + 1 + if _, err := l.Read(uint64(index)); err != nil { + panic(err) + } + atomic.AddInt32(&numReads, 1) + } + }() + } + + // continue writing + for index := maxIndex + 1; atomic.LoadInt32(&finished) < 100; index++ { + err := l.Write(uint64(index), []byte(dataStr(uint64(index)))) + if err != nil { + t.Fatal(err) + } + atomic.StoreInt32(&maxIndex, index) + } + + // confirm total reads + if exp := int32(100_000); numReads != exp { + t.Fatalf("expected %d reads, but god %d", exp, numReads) + } +} + +func must(v interface{}, err error) interface{} { + if err != nil { + panic(err) + } + return v +}