From 2538a42bbd8f003ea1669e5e95485637362c4c83 Mon Sep 17 00:00:00 2001 From: Hugh Do Date: Thu, 27 Jun 2024 23:20:16 +0900 Subject: [PATCH] feat: add .rdb file parser & KEYS command (#5) * feat: .rdb parser * feat: support pattern matching for KEYS * refactor: default dir * fix: lint err --- .DS_Store | Bin 0 -> 8196 bytes app/server.go | 13 +- go.mod | 1 + go.sum | 2 + internal/app/server/config/config.go | 2 +- internal/app/server/server.go | 45 ++- internal/app/server/server_test.go | 2 +- pkg/command/command.go | 3 +- pkg/command/config.go | 4 +- pkg/command/get.go | 2 +- pkg/command/info.go | 2 +- pkg/command/keys.go | 36 +++ pkg/command/set.go | 4 +- pkg/crc64/crc64.go | 64 +++++ pkg/keyval/kv.go | 23 ++ pkg/rdb/rdb.go | 404 +++++++++++++++++++++++++++ pkg/rdb/rdb_test.go | 66 +++++ pkg/resp/pattern.go | 53 ++++ pkg/resp/pattern_test.go | 38 +++ pkg/resp/writer.go | 89 +++--- pkg/resp/writer_test.go | 2 +- testdata/dump.rdb | Bin 0 -> 369 bytes 22 files changed, 805 insertions(+), 50 deletions(-) create mode 100644 .DS_Store create mode 100644 pkg/command/keys.go create mode 100644 pkg/crc64/crc64.go create mode 100644 pkg/rdb/rdb.go create mode 100644 pkg/rdb/rdb_test.go create mode 100644 pkg/resp/pattern.go create mode 100644 pkg/resp/pattern_test.go create mode 100644 testdata/dump.rdb diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..10a3f9ba276271cc8379ec99889962778c4eeee7 GIT binary patch literal 8196 zcmeHMTWl0n7(U;$z?p%zQ)rihEG*U_mx85~B9`dV3$${nVM}jR*V&zcj!b9j&g>Q? z#G;sJqDB*MFA^^gzNitSJZPdmq7Ocp#)ukUjQF5NeN$ibKXYaYz3^Z{3<-0RIsZBT z<;;Kn^Zj%BZy96gD4JUtYha9VdIZ%MQ#Vf&a`DV*O2nxo3bJR+WCjbEk(=}<=4i)^ zI0A75;t0eMh$9e3;ChGv-Pxk?tK9oi8uxJo;t1T35fI;p6g`5bBRVB%`01c7ECGm$ z5+H0eraBK&T*NtF% z6mV5lvwGs>$>ye1OH1pirqs!uEi`UzYdv*J<@JrXwI9kJcg9`s3AqtO9|Kl3YhzE& zw=2glb;?~Sf0lRkb$M5+wm6aL*9Nls(h|AeHxbyL>kn83pJ-OE@o9^_dC$G4UujMtrXjbcTvqT`&69A&d0vlGR1p{*X<} z3AMUe&O?1?N&U(-jq5kJ?AVp=?maMFRb8WN+OoRAlI`0$#~L>Lyycp81cu9g4H)k zXo@U4dpHlzhWNs$+q46Zgks~##D30)#d5n%lUCAes;dt6B@%QJwp+3ZyL7M;!upp4 z>c;d&c@AOl)W&pkbS#{$U|Z-G>S3dlB@^s1_7r=Con>#c^Xy}GfqlU)vTxY8?0fbD z`-T0=eq+D0zo4K3l~{;nScN23Vk2*o$5q#2|(+j3Y4c0FJ|fixM8e z6i(v|p2mxK2`}SKyo+;q5AWj>T)=1e9N*zbT*6QI1Ai-3N{yl`waQZECMBh;l|C!w zycnDbEk$f`PPS8COhkDxK3`rWyZ2~4*F|1jTp}`|cImRZ`qd&m+S=2XK_p>UM%tAK z4naEu>LF;ubDoX2#PWs}E45Y0!-Q(<#OypP=byev;FnJe?KRrkq`rV~FQ>O?>yr9H z!n~YrNN!5%3Z$4dIcqt8rp74ZqjscIoYA@)bvV9DLLJx-KptJWV4*_7Ap}> z|2t66u@Bj&1k`B)>LvCw!Sf$*RHK1lxen{G0ZrJ9Hr$S#Xb<7J7d^Ne`_PX8g6J?t zkcWk%7=w)h0u)1-p2njD)5q})p2c%`9xvc5Ud3y89dF<*0ohq#?41cld^8^zOPPY} zc4;8AN)l@S`ws#1r*3@z$M=5-z8kpv7vOUb AZ~y=R literal 0 HcmV?d00001 diff --git a/app/server.go b/app/server.go index d49ce8f..9b5bb3e 100644 --- a/app/server.go +++ b/app/server.go @@ -22,9 +22,10 @@ var ( ) func run(ctx context.Context, _ io.Writer, _ []string) error { - flag.Parse() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) sigCh := make(chan os.Signal, 1) + logger.Info(ctx, "oO0OoO0OoO0Oo Redis is starting oO0OoO0OoO0Oo") + flag.Parse() signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) defer signal.Stop(sigCh) cfg := config.NewConfig() @@ -38,8 +39,8 @@ func run(ctx context.Context, _ io.Writer, _ []string) error { } server := server.NewServer(cfg) go func() { - if err := server.Listen(ctx); err != nil { - logger.Error(ctx, "failed to listen, err: %v", err) + if err := server.Start(ctx); err != nil { + logger.Error(ctx, "failed to start the server, err: %v", err) } cancel() }() @@ -48,7 +49,7 @@ func run(ctx context.Context, _ io.Writer, _ []string) error { case <-ctx.Done(): logger.Info(ctx, "context done, shutting down") case <-sigCh: - logger.Info(ctx, "received signal, shutting down") + logger.Info(ctx, "Received signal, shutting down") } ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -56,7 +57,7 @@ func run(ctx context.Context, _ io.Writer, _ []string) error { logger.Error(ctx, "failed to close server, err: %v", err) } - logger.Info(ctx, "server shutdown") + logger.Info(ctx, "Redis is now ready to exit, bye bye...") return nil } diff --git a/go.mod b/go.mod index 131bd90..bd80767 100644 --- a/go.mod +++ b/go.mod @@ -18,5 +18,6 @@ require ( github.com/redis/go-redis/v9 v9.5.3 // indirect github.com/stretchr/testify v1.9.0 // indirect github.com/test-go/testify v1.1.4 // indirect + github.com/zhuyie/golzf v0.0.0-20161112031142-8387b0307ade // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 5b0b009..caf1544 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/test-go/testify v1.1.4 h1:Tf9lntrKUMHiXQ07qBScBTSA0dhYQlu83hswqelv1iE= github.com/test-go/testify v1.1.4/go.mod h1:rH7cfJo/47vWGdi4GPj16x3/t1xGOj2YxzmNQzk2ghU= +github.com/zhuyie/golzf v0.0.0-20161112031142-8387b0307ade h1:bafvQukPrIYwYWcft4rl3WpHo3qO0/voaAgnCwgdhi0= +github.com/zhuyie/golzf v0.0.0-20161112031142-8387b0307ade/go.mod h1:juNhYdla04C276MyU4zR0BA7t90ziLKPwkjDgddGYV0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/app/server/config/config.go b/internal/app/server/config/config.go index 5ba29a2..a6b5a0e 100644 --- a/internal/app/server/config/config.go +++ b/internal/app/server/config/config.go @@ -63,7 +63,7 @@ func (c *Config) Get(key string) (string, error) { seachKey := strings.ToUpper(key) value, exists := c.options[seachKey] if !exists { - return "", ErrOptionNotFound + return "", fmt.Errorf("%s for %s", ErrOptionNotFound, key) } return value, nil } diff --git a/internal/app/server/server.go b/internal/app/server/server.go index 65fc11b..4e9f103 100644 --- a/internal/app/server/server.go +++ b/internal/app/server/server.go @@ -4,12 +4,14 @@ import ( "context" "fmt" "net" + "os" "sync" "github.com/codecrafters-io/redis-starter-go/internal/app/server/config" "github.com/codecrafters-io/redis-starter-go/internal/client" "github.com/codecrafters-io/redis-starter-go/pkg/command" "github.com/codecrafters-io/redis-starter-go/pkg/keyval" + "github.com/codecrafters-io/redis-starter-go/pkg/rdb" "github.com/codecrafters-io/redis-starter-go/pkg/resp" "github.com/codecrafters-io/redis-starter-go/pkg/telemetry/logger" ) @@ -23,31 +25,67 @@ type Server struct { mu sync.Mutex cfg *config.Config done chan struct{} + store keyval.KV clients map[*client.Client]struct{} cFactory *command.CommandFactory } func NewServer(cfg *config.Config) *Server { + store := keyval.NewStore() return &Server{ mu: sync.Mutex{}, cfg: cfg, done: make(chan struct{}), - cFactory: command.NewCommandFactory(keyval.NewStore(), cfg), + store: store, + cFactory: command.NewCommandFactory(store, cfg), clients: make(map[*client.Client]struct{}), } } +func (s *Server) Start(ctx context.Context) error { + logger.Info(ctx, "Server initialized") + dir, _ := s.cfg.Get(config.DirKey) + dbFilename, _ := s.cfg.Get(config.DBFilenameKey) + file, openErr := os.Open(dir + "/" + dbFilename) + if openErr != nil && !os.IsNotExist(openErr) { + return fmt.Errorf("failed to open file, err: %v", openErr) + } + var isValidRDBFile bool + var stat os.FileInfo + var err error + if !os.IsNotExist(openErr) { + stat, err = file.Stat() + if err != nil { + return fmt.Errorf("failed to get file stat, err: %v", err) + } + isValidRDBFile = !stat.IsDir() + } + + if isValidRDBFile { + defer file.Close() + rdb := rdb.NewRDBParser(file) + if err := rdb.ParseRDB(ctx); err != nil { + return fmt.Errorf("failed to parse rdb, err: %v", err) + } + s.store.RestoreRDB(rdb.GetData(), rdb.GetExpiry()) + } + if err := s.Listen(ctx); err != nil { + return fmt.Errorf("failed to listen, err: %v", err) + } + return nil +} + func (s *Server) Listen(ctx context.Context) error { listenAddr, err := s.cfg.Get(config.ListenAddrKey) if err != nil { listenAddr = defaultListenAddr } ln, err := net.Listen("tcp", listenAddr) - logger.Info(ctx, "listening, addr: %s", ln.Addr()) if err != nil { return err } s.ln = ln + logger.Info(ctx, "Ready to accept connections tcp on %s", listenAddr) return s.loop(ctx) } @@ -127,6 +165,9 @@ func (s *Server) handleMessage(ctx context.Context, cl *client.Client, r *resp.R func (s *Server) Close(_ context.Context) error { close(s.done) + if s.ln == nil { + return nil + } if err := s.ln.Close(); err != nil { return fmt.Errorf("failed to close listener: %w", err) } diff --git a/internal/app/server/server_test.go b/internal/app/server/server_test.go index 55b8359..a61b201 100644 --- a/internal/app/server/server_test.go +++ b/internal/app/server/server_test.go @@ -38,7 +38,7 @@ func startServer(ctx context.Context) (*Server, int, error) { } srv := NewServer(cfg) go func() { - if err := srv.Listen(ctx); err != nil { + if err := srv.Start(ctx); err != nil { panic(err) } }() diff --git a/pkg/command/command.go b/pkg/command/command.go index 5400951..3906bf3 100644 --- a/pkg/command/command.go +++ b/pkg/command/command.go @@ -31,7 +31,7 @@ func (ec *EchoCommand) Execute(c *client.Client, wr *resp.Writer, args []*resp.R return wr.WriteError(errors.New("wrong number of arguments for 'echo' command")) } - return wr.WriteValue(args[0].Bytes()) + return wr.WriteBytes(args[0].Bytes()) } type PingCommand struct { @@ -55,6 +55,7 @@ func NewCommandFactory(kv keyval.KV, cfg *config.Config) *CommandFactory { "info": &Info{}, "client": &ClientCmd{}, "config": &ConfigCmd{cfg: cfg}, + "keys": &Keys{kv: kv}, }, } } diff --git a/pkg/command/config.go b/pkg/command/config.go index 84caf64..c7e3d9f 100644 --- a/pkg/command/config.go +++ b/pkg/command/config.go @@ -64,11 +64,11 @@ func (c *ConfigCmd) handleGet(cl *client.Client, wr *resp.Writer, args []*resp.R } return wr.WriteMap(resp) } else { - var resp []any + var resp []string for key, value := range valueMap { resp = append(resp, key) resp = append(resp, value) } - return wr.WriteArray(resp) + return wr.WriteStringSlice(resp) } } diff --git a/pkg/command/get.go b/pkg/command/get.go index b39ed80..3272009 100644 --- a/pkg/command/get.go +++ b/pkg/command/get.go @@ -25,5 +25,5 @@ func (g *Get) Execute(c *client.Client, wr *resp.Writer, args []*resp.Resp) erro return wr.WriteNull(resp.BulkString) } - return wr.WriteValue(val) + return wr.WriteBytes(val) } diff --git a/pkg/command/info.go b/pkg/command/info.go index bdbe7a2..8d652a4 100644 --- a/pkg/command/info.go +++ b/pkg/command/info.go @@ -35,7 +35,7 @@ func (h *Info) Execute(c *client.Client, wr *resp.Writer, args []*resp.Resp) err str := strings.Builder{} if argsLen == 1 { str.WriteString(buildSectionString(args[0].String(), sections[args[0].String()])) - return wr.WriteValue(str.String()) + return wr.WriteString(str.String()) } for sectionName, section := range sections { str.WriteString(buildSectionString(sectionName, section)) diff --git a/pkg/command/keys.go b/pkg/command/keys.go new file mode 100644 index 0000000..34dead2 --- /dev/null +++ b/pkg/command/keys.go @@ -0,0 +1,36 @@ +package command + +import ( + "errors" + + "github.com/codecrafters-io/redis-starter-go/internal/client" + "github.com/codecrafters-io/redis-starter-go/pkg/keyval" + "github.com/codecrafters-io/redis-starter-go/pkg/resp" +) + +type Keys struct { + kv keyval.KV +} + +func (k *Keys) Execute(_ *client.Client, wr *resp.Writer, args []*resp.Resp) error { + if len(args) != 1 { + return wr.WriteError(errors.New("wrong number of arguments for 'keys' command")) + } + pattern := args[0].String() + keys := k.kv.Keys() + if pattern == "*" { + return wr.WriteStringSlice(k.kv.Keys()) + } + re, err := resp.TranslatePattern(pattern) + if err != nil { + return wr.WriteError(err) + } + var matchedKeys []string + for _, key := range keys { + if re.MatchString(key) { + matchedKeys = append(matchedKeys, key) + } + } + + return wr.WriteStringSlice(matchedKeys) +} diff --git a/pkg/command/set.go b/pkg/command/set.go index 98a9674..c64391b 100644 --- a/pkg/command/set.go +++ b/pkg/command/set.go @@ -64,7 +64,7 @@ func (s *Set) Execute(c *client.Client, wr *resp.Writer, args []*resp.Resp) erro if opts.NX { if opts.GET && keyExists { - return wr.WriteValue(oldVal) + return wr.WriteBytes(oldVal) } if keyExists { return wr.WriteNull(resp.BulkString) @@ -97,7 +97,7 @@ func (s *Set) Execute(c *client.Client, wr *resp.Writer, args []*resp.Resp) erro if opts.GET { if keyExists { - return wr.WriteValue(oldVal) + return wr.WriteBytes(oldVal) } else { return wr.WriteNull(resp.BulkString) } diff --git a/pkg/crc64/crc64.go b/pkg/crc64/crc64.go new file mode 100644 index 0000000..54fed9c --- /dev/null +++ b/pkg/crc64/crc64.go @@ -0,0 +1,64 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package crc64 implements the Jones coefficients with an init value of 0. +package crc64 + +import "hash" + +// Redis uses the CRC64 variant with "Jones" coefficients and init value of 0. +// +// Specification of this CRC64 variant follows: +// Name: crc-64-jones +// Width: 64 bits +// Poly: 0xad93d23594c935a9 +// Reflected In: True +// Xor_In: 0xffffffffffffffff +// Reflected_Out: True +// Xor_Out: 0x0 + +var table = [256]uint64{0x0000000000000000, 0x7ad870c830358979, 0xf5b0e190606b12f2, 0x8f689158505e9b8b, 0xc038e5739841b68f, 0xbae095bba8743ff6, 0x358804e3f82aa47d, 0x4f50742bc81f2d04, 0xab28ecb46814fe75, 0xd1f09c7c5821770c, 0x5e980d24087fec87, 0x24407dec384a65fe, 0x6b1009c7f05548fa, 0x11c8790fc060c183, 0x9ea0e857903e5a08, 0xe478989fa00bd371, 0x7d08ff3b88be6f81, 0x07d08ff3b88be6f8, 0x88b81eabe8d57d73, 0xf2606e63d8e0f40a, 0xbd301a4810ffd90e, 0xc7e86a8020ca5077, 0x4880fbd87094cbfc, 0x32588b1040a14285, 0xd620138fe0aa91f4, 0xacf86347d09f188d, 0x2390f21f80c18306, 0x594882d7b0f40a7f, 0x1618f6fc78eb277b, 0x6cc0863448deae02, 0xe3a8176c18803589, 0x997067a428b5bcf0, 0xfa11fe77117cdf02, 0x80c98ebf2149567b, 0x0fa11fe77117cdf0, 0x75796f2f41224489, 0x3a291b04893d698d, 0x40f16bccb908e0f4, 0xcf99fa94e9567b7f, 0xb5418a5cd963f206, 0x513912c379682177, 0x2be1620b495da80e, 0xa489f35319033385, 0xde51839b2936bafc, 0x9101f7b0e12997f8, 0xebd98778d11c1e81, 0x64b116208142850a, 0x1e6966e8b1770c73, 0x8719014c99c2b083, 0xfdc17184a9f739fa, 0x72a9e0dcf9a9a271, 0x08719014c99c2b08, 0x4721e43f0183060c, 0x3df994f731b68f75, 0xb29105af61e814fe, 0xc849756751dd9d87, 0x2c31edf8f1d64ef6, 0x56e99d30c1e3c78f, 0xd9810c6891bd5c04, 0xa3597ca0a188d57d, 0xec09088b6997f879, 0x96d1784359a27100, 0x19b9e91b09fcea8b, 0x636199d339c963f2, 0xdf7adabd7a6e2d6f, 0xa5a2aa754a5ba416, 0x2aca3b2d1a053f9d, 0x50124be52a30b6e4, 0x1f423fcee22f9be0, 0x659a4f06d21a1299, 0xeaf2de5e82448912, 0x902aae96b271006b, 0x74523609127ad31a, 0x0e8a46c1224f5a63, 0x81e2d7997211c1e8, 0xfb3aa75142244891, 0xb46ad37a8a3b6595, 0xceb2a3b2ba0eecec, 0x41da32eaea507767, 0x3b024222da65fe1e, 0xa2722586f2d042ee, 0xd8aa554ec2e5cb97, 0x57c2c41692bb501c, 0x2d1ab4dea28ed965, 0x624ac0f56a91f461, 0x1892b03d5aa47d18, 0x97fa21650afae693, 0xed2251ad3acf6fea, 0x095ac9329ac4bc9b, 0x7382b9faaaf135e2, 0xfcea28a2faafae69, 0x8632586aca9a2710, 0xc9622c4102850a14, 0xb3ba5c8932b0836d, 0x3cd2cdd162ee18e6, 0x460abd1952db919f, 0x256b24ca6b12f26d, 0x5fb354025b277b14, 0xd0dbc55a0b79e09f, 0xaa03b5923b4c69e6, 0xe553c1b9f35344e2, 0x9f8bb171c366cd9b, 0x10e3202993385610, 0x6a3b50e1a30ddf69, 0x8e43c87e03060c18, 0xf49bb8b633338561, 0x7bf329ee636d1eea, 0x012b592653589793, 0x4e7b2d0d9b47ba97, 0x34a35dc5ab7233ee, 0xbbcbcc9dfb2ca865, 0xc113bc55cb19211c, 0x5863dbf1e3ac9dec, 0x22bbab39d3991495, 0xadd33a6183c78f1e, 0xd70b4aa9b3f20667, 0x985b3e827bed2b63, 0xe2834e4a4bd8a21a, 0x6debdf121b863991, 0x1733afda2bb3b0e8, 0xf34b37458bb86399, 0x8993478dbb8deae0, 0x06fbd6d5ebd3716b, 0x7c23a61ddbe6f812, 0x3373d23613f9d516, 0x49aba2fe23cc5c6f, 0xc6c333a67392c7e4, 0xbc1b436e43a74e9d, 0x95ac9329ac4bc9b5, 0xef74e3e19c7e40cc, 0x601c72b9cc20db47, 0x1ac40271fc15523e, 0x5594765a340a7f3a, 0x2f4c0692043ff643, 0xa02497ca54616dc8, 0xdafce7026454e4b1, 0x3e847f9dc45f37c0, 0x445c0f55f46abeb9, 0xcb349e0da4342532, 0xb1eceec59401ac4b, 0xfebc9aee5c1e814f, 0x8464ea266c2b0836, 0x0b0c7b7e3c7593bd, 0x71d40bb60c401ac4, 0xe8a46c1224f5a634, 0x927c1cda14c02f4d, 0x1d148d82449eb4c6, 0x67ccfd4a74ab3dbf, 0x289c8961bcb410bb, 0x5244f9a98c8199c2, 0xdd2c68f1dcdf0249, 0xa7f41839ecea8b30, 0x438c80a64ce15841, 0x3954f06e7cd4d138, 0xb63c61362c8a4ab3, 0xcce411fe1cbfc3ca, 0x83b465d5d4a0eece, 0xf96c151de49567b7, 0x76048445b4cbfc3c, 0x0cdcf48d84fe7545, 0x6fbd6d5ebd3716b7, 0x15651d968d029fce, 0x9a0d8ccedd5c0445, 0xe0d5fc06ed698d3c, 0xaf85882d2576a038, 0xd55df8e515432941, 0x5a3569bd451db2ca, 0x20ed197575283bb3, 0xc49581ead523e8c2, 0xbe4df122e51661bb, 0x3125607ab548fa30, 0x4bfd10b2857d7349, 0x04ad64994d625e4d, 0x7e7514517d57d734, 0xf11d85092d094cbf, 0x8bc5f5c11d3cc5c6, 0x12b5926535897936, 0x686de2ad05bcf04f, 0xe70573f555e26bc4, 0x9ddd033d65d7e2bd, 0xd28d7716adc8cfb9, 0xa85507de9dfd46c0, 0x273d9686cda3dd4b, 0x5de5e64efd965432, 0xb99d7ed15d9d8743, 0xc3450e196da80e3a, 0x4c2d9f413df695b1, 0x36f5ef890dc31cc8, 0x79a59ba2c5dc31cc, 0x037deb6af5e9b8b5, 0x8c157a32a5b7233e, 0xf6cd0afa9582aa47, 0x4ad64994d625e4da, 0x300e395ce6106da3, 0xbf66a804b64ef628, 0xc5bed8cc867b7f51, 0x8aeeace74e645255, 0xf036dc2f7e51db2c, 0x7f5e4d772e0f40a7, 0x05863dbf1e3ac9de, 0xe1fea520be311aaf, 0x9b26d5e88e0493d6, 0x144e44b0de5a085d, 0x6e963478ee6f8124, 0x21c640532670ac20, 0x5b1e309b16452559, 0xd476a1c3461bbed2, 0xaeaed10b762e37ab, 0x37deb6af5e9b8b5b, 0x4d06c6676eae0222, 0xc26e573f3ef099a9, 0xb8b627f70ec510d0, 0xf7e653dcc6da3dd4, 0x8d3e2314f6efb4ad, 0x0256b24ca6b12f26, 0x788ec2849684a65f, 0x9cf65a1b368f752e, 0xe62e2ad306bafc57, 0x6946bb8b56e467dc, 0x139ecb4366d1eea5, 0x5ccebf68aecec3a1, 0x2616cfa09efb4ad8, 0xa97e5ef8cea5d153, 0xd3a62e30fe90582a, 0xb0c7b7e3c7593bd8, 0xca1fc72bf76cb2a1, 0x45775673a732292a, 0x3faf26bb9707a053, 0x70ff52905f188d57, 0x0a2722586f2d042e, 0x854fb3003f739fa5, 0xff97c3c80f4616dc, 0x1bef5b57af4dc5ad, 0x61372b9f9f784cd4, 0xee5fbac7cf26d75f, 0x9487ca0fff135e26, 0xdbd7be24370c7322, 0xa10fceec0739fa5b, 0x2e675fb4576761d0, 0x54bf2f7c6752e8a9, 0xcdcf48d84fe75459, 0xb71738107fd2dd20, 0x387fa9482f8c46ab, 0x42a7d9801fb9cfd2, 0x0df7adabd7a6e2d6, 0x772fdd63e7936baf, 0xf8474c3bb7cdf024, 0x829f3cf387f8795d, 0x66e7a46c27f3aa2c, 0x1c3fd4a417c62355, 0x935745fc4798b8de, 0xe98f353477ad31a7, 0xa6df411fbfb21ca3, 0xdc0731d78f8795da, 0x536fa08fdfd90e51, 0x29b7d047efec8728} + +func crc64(crc uint64, b []byte) uint64 { + for _, v := range b { + crc = table[byte(crc)^v] ^ (crc >> 8) + } + return crc +} + +func Digest(b []byte) uint64 { + return crc64(0, b) +} + +type digest struct { + crc uint64 +} + +func New() hash.Hash64 { + return &digest{} +} + +func (h *digest) Write(p []byte) (int, error) { + h.crc = crc64(h.crc, p) + return len(p), nil +} + +// Encode in little endian +func (d *digest) Sum(in []byte) []byte { + s := d.Sum64() + in = append(in, byte(s)) + in = append(in, byte(s>>8)) + in = append(in, byte(s>>16)) + in = append(in, byte(s>>24)) + in = append(in, byte(s>>32)) + in = append(in, byte(s>>40)) + in = append(in, byte(s>>48)) + in = append(in, byte(s>>56)) + return in +} + +func (d *digest) Sum64() uint64 { return d.crc } +func (d *digest) BlockSize() int { return 1 } +func (d *digest) Size() int { return 8 } +func (d *digest) Reset() { d.crc = 0 } diff --git a/pkg/keyval/kv.go b/pkg/keyval/kv.go index 16a929a..6e9ef99 100644 --- a/pkg/keyval/kv.go +++ b/pkg/keyval/kv.go @@ -6,6 +6,7 @@ import ( ) type KV interface { + RestoreRDB(data map[string]string, expiry map[string]uint64) Get(key string) ([]byte, error) Set(key string, value []byte) error Expire(key string, duration time.Duration) @@ -14,6 +15,7 @@ type KV interface { PExpireAt(key string, expiry time.Time) Exists(key string) bool DeleteTTL(key string) + Keys() []string } type kv struct { @@ -58,6 +60,27 @@ func (kv *kv) Expire(key string, duration time.Duration) { kv.expiry[key] = time.Now().Add(duration) } +func (kv *kv) Keys() []string { + kv.mu.RLock() + defer kv.mu.RUnlock() + keys := make([]string, 0, len(kv.store)) + for key := range kv.store { + keys = append(keys, key) + } + return keys +} + +func (kv *kv) RestoreRDB(data map[string]string, expiry map[string]uint64) { + kv.mu.Lock() + defer kv.mu.Unlock() + for key, value := range data { + kv.store[key] = []byte(value) + if exp, ok := expiry[key]; ok { + kv.expiry[key] = time.UnixMilli(int64(exp)) + } + } +} + func (kv *kv) PExpire(key string, duration time.Duration) { kv.Expire(key, duration) } diff --git a/pkg/rdb/rdb.go b/pkg/rdb/rdb.go new file mode 100644 index 0000000..b334992 --- /dev/null +++ b/pkg/rdb/rdb.go @@ -0,0 +1,404 @@ +package rdb + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "strconv" + "time" + + lzf "github.com/zhuyie/golzf" + + "github.com/codecrafters-io/redis-starter-go/pkg/telemetry/logger" +) + +const ( + RDBTypeString = 0 + RDBTypeList = 1 + RDBTypeSet = 2 + RDBTypeSortedSet = 3 + RDBTypeHash = 4 + RDBTypeZipMap = 9 + RDBTypeZipList = 10 + RDBTypeIntSet = 11 + RDBTypeSortedSetInZipList = 12 + RDBTypeHashMapInZipList = 13 + RDBTypeListInQuickList = 14 + + RDBMarkerAuxiliaryField = 0xFA // FA: Auxiliary field + RDBMarkerResizeDB = 0xFB // FB: Resize DB + RDBMarkerExpiryMS = 0xFC // FC: Expire time in milliseconds + RDBMarkerExpiry = 0xFD // FD: Expire time in seconds + RDBMarkerSelectDB = 0xFE // FE: Select database + RDBMarkerEOF = 0xFF // FF: End of the RDB file +) + +var ErrInvalidRDBHeader = errors.New("invalid RDB header") + +type RDBParser struct { + rd io.ReadSeeker + length uint64 + db uint64 + data map[string]string + expiry map[string]uint64 +} + +func NewRDBParser(rd io.ReadSeeker) *RDBParser { + return &RDBParser{ + rd: rd, + data: make(map[string]string), + expiry: make(map[string]uint64), + } +} + +func (p *RDBParser) readByte() (byte, error) { + var b [1]byte + _, err := io.ReadFull(p.rd, b[:]) + return b[0], err +} + +func (p *RDBParser) readBytes(n int) ([]byte, error) { + b := make([]byte, n) + _, err := io.ReadFull(p.rd, b) + return b, err +} + +func (p *RDBParser) readLength() (length uint64, special bool, err error) { + b, err := p.readByte() + if err != nil { + return 0, false, fmt.Errorf("failed to read length: %w", err) + } + + encType := (b & 0xC0) >> 6 + if encType == 0 { + return uint64(b & 0x3F), false, nil + } else if encType == 1 { + nextByte, err := p.readByte() + if err != nil { + return 0, false, fmt.Errorf("failed to read next byte: %w", err) + } + return uint64(b&0x3F)<<8 | uint64(nextByte), false, nil + } else if encType == 2 { + bb, err := p.readBytes(4) + if err != nil { + return 0, false, fmt.Errorf("failed to read 4 bytes: %w", err) + } + return uint64(binary.BigEndian.Uint32(bb)), false, nil + } else if encType == 3 { // 00000011 + // The next object is encoded in a special format. The remaining 6 bits indicate the format. + return uint64(b & 0x3F), true, nil + } + return 0, false, fmt.Errorf("invalid length encoding: %d", encType) +} + +func (p *RDBParser) readString() (string, error) { + length, special, err := p.readLength() + if err != nil { + return "", err + } + if special { + // 0 indicates that an 8 bit integer follows + // 1 indicates that a 16 bit integer follows + // 2 indicates that a 32 bit integer follows + switch length { + case 0: + b, err := p.readByte() + if err != nil { + return "", fmt.Errorf("failed to read 8 bit integer: %w", err) + } + return fmt.Sprintf("%d", b), nil + case 1: + bb, err := p.readBytes(2) + if err != nil { + return "", fmt.Errorf("failed to read 16 bit integer: %w", err) + } + return fmt.Sprintf("%d", binary.LittleEndian.Uint16(bb)), nil + case 2: + bb, err := p.readBytes(4) + if err != nil { + return "", fmt.Errorf("failed to read 32 bit integer: %w", err) + } + return fmt.Sprintf("%d", binary.LittleEndian.Uint32(bb)), nil + case 3: + // If the value of those 6 bits is 3, it indicates that a compressed string follows. + // The compressed string is read as follows: + // The compressed length clen is read from the stream using Length Encoding + // The uncompressed length is read from the stream using Length Encoding + // The next clen bytes are read from the stream + // Finally, these bytes are decompressed using LZF algorithm + compressedLength, _, err := p.readLength() + if err != nil { + return "", fmt.Errorf("failed to read compressed length: %w", err) + } + _, _, err = p.readLength() + if err != nil { + return "", fmt.Errorf("failed to read uncompressed length: %w", err) + } + compressed, err := p.readBytes(int(compressedLength)) + if err != nil { + return "", fmt.Errorf("failed to read compressed data: %w", err) + } + decompressed := make([]byte, int(compressedLength)) + _, err = lzf.Decompress(compressed, decompressed) + if err != nil { + return "", fmt.Errorf("failed to decompress data: %w", err) + } + return string(decompressed), nil + default: + return "", fmt.Errorf("unsupported special string encoding: %d", length) + } + } + b, err := p.readBytes(int(length)) + if err != nil { + return "", fmt.Errorf("failed to read string: %w", err) + } + return string(b), nil +} + +func (p *RDBParser) readHeader() error { + header := make([]byte, 9) + if _, err := io.ReadFull(p.rd, header); err != nil { + return fmt.Errorf("failed to read header: %w", err) + } + if !bytes.Equal(header[:5], []byte("REDIS")) { + return ErrInvalidRDBHeader + } + return nil +} + +func (p *RDBParser) handlerAuxiliaryField(ctx context.Context) error { + key, err := p.readString() + if err != nil { + return err + } + value, err := p.readString() + if err != nil { + return err + } + switch key { + case "redis-ver": + logger.Info(ctx, "Loading RDB produced by version %s", value) + case "ctime": + ctime, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse ctime: %w", err) + } + logger.Info(ctx, "RDB age %d seconds", time.Now().Unix()-ctime) + case "used-mem": + usedMem, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse used-mem: %w", err) + } + logger.Info(ctx, "RDB memory usage when created %.2f MB", float64(usedMem)/1024/1024) + } + + return nil +} + +func (p *RDBParser) handleSelectDB() error { + db, _, err := p.readLength() + if err != nil { + return err + } + p.db = db + return nil +} + +func (p *RDBParser) handleResizeDB() error { + length, _, err := p.readLength() + if err != nil { + return err + } + p.length = length + _, _, err = p.readLength() + if err != nil { + return err + } + return nil +} + +func (p *RDBParser) handleExpiryTime(ms bool) error { + var expiry uint64 + + if ms { + expBytes, err := p.readBytes(8) + if err != nil { + return err + } + expiry = binary.LittleEndian.Uint64(expBytes) + } else { + expBytes, err := p.readBytes(4) + if err != nil { + return err + } + expiry = uint64(binary.LittleEndian.Uint32(expBytes)) + } + objType, err := p.readByte() + if err != nil { + return err + } + + return p.handleObject(objType, expiry) +} + +func (p *RDBParser) handleObject(objType byte, expiry uint64) error { + if objType != RDBTypeString { + return fmt.Errorf("unsupported object type: %s", objTypeName(objType)) + } + key, err := p.readString() + if err != nil { + return err + } + value, err := p.readString() + if err != nil { + return err + } + if expiry > 0 && expiry < uint64(time.Now().UnixMilli()) { + return nil + } + p.data[key] = value + if expiry > uint64(time.Now().Unix()) { + p.expiry[key] = expiry + } + return nil +} + +func objTypeName(objType byte) string { + switch objType { + case RDBTypeString: + return "string" + case RDBTypeList: + return "list" + case RDBTypeSet: + return "set" + case RDBTypeSortedSet: + return "sorted set" + case RDBTypeHash: + return "hash" + case RDBTypeZipMap: + return "zip map" + case RDBTypeZipList: + return "zip list" + case RDBTypeIntSet: + return "int set" + case RDBTypeSortedSetInZipList: + return "sorted set in zip list" + case RDBTypeHashMapInZipList: + return "hash map in zip list" + case RDBTypeListInQuickList: + return "list in quick list" + default: + return "unknown" + } +} + +// func (p *RDBParser) readFileContent() ([]byte, error) { +// currentOffset, err := p.rd.Seek(0, io.SeekCurrent) +// if err != nil { +// return nil, err +// } + +// // Go to the start of the file +// if _, err := p.rd.Seek(0, io.SeekStart); err != nil { +// return nil, err +// } + +// content, err := io.ReadAll(p.rd) +// if err != nil { +// return nil, err +// } + +// // Restore reader to current offset +// if _, err := p.rd.Seek(currentOffset, io.SeekStart); err != nil { +// return nil, err +// } + +// return content, nil +// } + +// func (p *RDBParser) verifyChecksum() error { +// content, err := p.readFileContent() +// if err != nil { +// return err +// } + +// // Exclude the last 8 bytes (checksum) +// if len(content) < 8 { +// return errors.New("RDB file too short") +// } +// data := content[:len(content)-8] +// expectedChecksum := binary.LittleEndian.Uint64(content[len(content)-8:]) + +// crc64Checksum := crc64.New() +// _, err = crc64Checksum.Write(data) +// if err != nil { +// return fmt.Errorf("failed to calculate checksum: %w", err) +// } +// if calculatedChecksum := crc64Checksum.Sum64(); calculatedChecksum != expectedChecksum { +// return fmt.Errorf("checksum mismatch: expected %v, got %v", expectedChecksum, calculatedChecksum) +// } +// return nil +// } + +func (p *RDBParser) ParseRDB(ctx context.Context) error { + if err := p.readHeader(); err != nil { + return err + } +fileLoop: + for { + objType, err := p.readByte() + if err == io.EOF { + logger.Info(ctx, "Parsing RDB file completed") + break + } + if err != nil { + return err + } + switch objType { + case RDBMarkerEOF: + // TODO: enable this after finishing codecrafters course + // err := p.verifyChecksum() + // if err != nil { + // return fmt.Errorf("failed to verify checksum: %w", err) + // } + break fileLoop + case RDBMarkerAuxiliaryField: + if err := p.handlerAuxiliaryField(ctx); err != nil { + return fmt.Errorf("failed to handle auxiliary field: %w", err) + } + case RDBMarkerSelectDB: + if err := p.handleSelectDB(); err != nil { + return fmt.Errorf("failed to handle select db: %w", err) + } + case RDBMarkerResizeDB: + if err := p.handleResizeDB(); err != nil { + return fmt.Errorf("failed to handle resize db: %w", err) + } + case RDBMarkerExpiry: + if err := p.handleExpiryTime(false); err != nil { + return fmt.Errorf("failed to handle expiry time: %w", err) + } + case RDBMarkerExpiryMS: + if err := p.handleExpiryTime(true); err != nil { + return fmt.Errorf("failed to handle expiry time: %w", err) + } + default: + if err := p.handleObject(objType, 0); err != nil { + return fmt.Errorf("failed to handle object: %w", err) + } + } + } + logger.Info(ctx, "Done loading RDB, keys loaded: %d, keys expired: %d", len(p.data), p.length-uint64(len(p.data))) + return nil +} + +func (r *RDBParser) GetData() map[string]string { + return r.data +} + +func (r *RDBParser) GetExpiry() map[string]uint64 { + return r.expiry +} diff --git a/pkg/rdb/rdb_test.go b/pkg/rdb/rdb_test.go new file mode 100644 index 0000000..254f1b1 --- /dev/null +++ b/pkg/rdb/rdb_test.go @@ -0,0 +1,66 @@ +package rdb + +import ( + "context" + "os" + "reflect" + "testing" +) + +func TestRDBParser(t *testing.T) { + tests := []struct { + name string + filepath string + want map[string]string + wantExpiry map[string]uint64 + wantErr error + }{ + { + name: "Normal", + filepath: "../../testdata/dump.rdb", + wantErr: nil, + want: map[string]string{ + "double": "3.1456789", + "str": "thisisastring", + "emptystr": "", + "int64": "-9223372036854775808", + "int": "1000", + "boolean": "true", + "bignumber": "734589327589437589324758943278954389", + "strwithexpire": "not_expired", + "strwithPX": "not_expired_with_PX", + "uint64": "18446744073709551615", + }, + wantExpiry: map[string]uint64{ + "strwithexpire": 11719245191989, + "strwithPX": 101719245280555, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + file, err := os.Open(tt.filepath) + if err != nil { + t.Errorf("failed to open file, err: %v", err) + } + defer file.Close() + + rdb := NewRDBParser(file) + err = rdb.ParseRDB(ctx) + if err != tt.wantErr { + t.Errorf("RDBParser.Parse() error = %v, wantErr %v", err, tt.wantErr) + return + } + gotData := rdb.GetData() + gotExpiry := rdb.GetExpiry() + if !reflect.DeepEqual(gotData, tt.want) { + t.Errorf("RDBParser.Parse() got = %v, want %v", gotData, tt.want) + } + if !reflect.DeepEqual(gotExpiry, tt.wantExpiry) { + t.Errorf("RDBParser.Parse() gotExpiry = %v, wantExpiry %v", gotExpiry, tt.wantExpiry) + } + }) + } +} diff --git a/pkg/resp/pattern.go b/pkg/resp/pattern.go new file mode 100644 index 0000000..3c08f07 --- /dev/null +++ b/pkg/resp/pattern.go @@ -0,0 +1,53 @@ +package resp + +import ( + "fmt" + "regexp" + "strings" +) + +// TranslatePattern translates a glob-style pattern to a regular expression, +// correctly handling escape sequences. +func TranslatePattern(pattern string) (*regexp.Regexp, error) { + var rePattern strings.Builder + + rePattern.WriteString("^") + + inEscape := false + for i := 0; i < len(pattern); i++ { + c := pattern[i] + if inEscape { + // After a backslash, treat the next character as a literal + rePattern.WriteString(regexp.QuoteMeta(string(c))) + inEscape = false + } else { + switch c { + case '\\': + // Start escape sequence + inEscape = true + case '*': + rePattern.WriteString(".*") + case '?': + rePattern.WriteString(".") + case '[': + rePattern.WriteString("[") + case ']': + rePattern.WriteString("]") + case '^': + // Handle the beginning of set negation in character classes + rePattern.WriteString("^") + case '-': + rePattern.WriteString("-") + default: + rePattern.WriteString(regexp.QuoteMeta(string(c))) + } + } + } + + if inEscape { + return nil, fmt.Errorf("invalid escape sequence at end of pattern") + } + + rePattern.WriteString("$") + return regexp.Compile(rePattern.String()) +} diff --git a/pkg/resp/pattern_test.go b/pkg/resp/pattern_test.go new file mode 100644 index 0000000..de823e2 --- /dev/null +++ b/pkg/resp/pattern_test.go @@ -0,0 +1,38 @@ +package resp + +import ( + "testing" +) + +func TestTranslatePattern(t *testing.T) { + tests := []struct { + name string + pattern string + input string + shouldMatch bool + }{ + {"Basic wildcard", "h*llo", "hello", true}, + {"Wildcard middle", "h*llo", "heeeello", true}, + {"Single character", "h?llo", "hallo", true}, + {"Character class", "h[ae]llo", "hello", true}, + {"Negation in character class", "h[^e]llo", "hallo", true}, + {"Range in character class", "h[a-b]llo", "hallo", true}, + {"Escaped wildcard", "h*llo\\*", "hello*", true}, + {"Escaped wildcard no match", "h*llo\\*", "hello", false}, + {"Escaped Special Char", "h\\[e\\]llo", "h[e]llo", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + re, err := TranslatePattern(tt.pattern) + if err != nil { + t.Fatalf("TranslatePattern() error = %v", err) + } + + matched := re.MatchString(tt.input) + if matched != tt.shouldMatch { + t.Errorf("TranslatePattern() = %v, want %v for pattern %s and input %s", matched, tt.shouldMatch, tt.pattern, tt.input) + } + }) + } +} diff --git a/pkg/resp/writer.go b/pkg/resp/writer.go index a162bcb..e96501e 100644 --- a/pkg/resp/writer.go +++ b/pkg/resp/writer.go @@ -77,13 +77,13 @@ func (w *Writer) WriteValue(value any) error { func (w *Writer) writeResp2Value(value any) error { switch v := value.(type) { case nil: - return w.writeString("") + return w.WriteString("") case string: - return w.writeString(v) + return w.WriteString(v) case []byte: - return w.writeBytes(BulkString, v) + return w.writeBytesWithType(BulkString, v) case *string: - return w.writeString(*v) + return w.WriteString(*v) case int: return w.writeResp2Int(int64(v)) case *int: @@ -144,7 +144,7 @@ func (w *Writer) writeResp2Value(value any) error { return w.writeResp2Int(0) case time.Time: w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano) - return w.writeBytes(BulkString, w.numBuf) + return w.writeBytesWithType(BulkString, w.numBuf) case time.Duration: return w.writeResp2Int(v.Nanoseconds()) case encoding.BinaryMarshaler: @@ -152,23 +152,23 @@ func (w *Writer) writeResp2Value(value any) error { if err != nil { return err } - return w.writeBytes(BulkString, b) + return w.writeBytesWithType(BulkString, b) case net.IP: - return w.writeString(v.String()) + return w.WriteString(v.String()) case error: return w.WriteSimpleValue(SimpleError, []byte(v.Error())) case RESPVersion: return w.writeInt(int64(v)) default: if reflect.ValueOf(value).Kind() == reflect.Slice { - return w.WriteArray(value) + return w.WriteSlice(value) } var rawMessage []byte err := json.NewEncoder(w).Encode(value) if err != nil { return fmt.Errorf("failed to encode value: %w", err) } - return w.writeBytes(BulkString, rawMessage) + return w.writeBytesWithType(BulkString, rawMessage) } } @@ -177,11 +177,11 @@ func (w *Writer) writeResp3Value(value any) error { case nil: return w.WriteSimpleValue(Null, []byte{}) case []byte: - return w.writeBytes(BulkString, v) + return w.writeBytesWithType(BulkString, v) case string: - return w.writeString(v) + return w.WriteString(v) case *string: - return w.writeString(*v) + return w.WriteString(*v) case int: return w.writeInt(int64(v)) case *int: @@ -242,7 +242,7 @@ func (w *Writer) writeResp3Value(value any) error { return w.WriteSimpleValue(Boolean, []byte{'f'}) case time.Time: w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano) - return w.writeBytes(BulkString, w.numBuf) + return w.writeBytesWithType(BulkString, w.numBuf) case time.Duration: return w.writeInt(v.Nanoseconds()) case encoding.BinaryMarshaler: @@ -250,55 +250,76 @@ func (w *Writer) writeResp3Value(value any) error { if err != nil { return err } - return w.writeBytes(BulkString, b) + return w.writeBytesWithType(BulkString, b) case net.IP: - return w.writeString(v.String()) + return w.WriteString(v.String()) case error: - return w.writeBytes(BulkError, []byte(v.Error())) + return w.writeBytesWithType(BulkError, []byte(v.Error())) case RESPVersion: return w.writeInt(int64(v)) default: if reflect.ValueOf(value).Kind() == reflect.Slice { - return w.WriteArray(value) + return w.WriteSlice(value) } var rawMessage []byte err := json.NewEncoder(w).Encode(value) if err != nil { return fmt.Errorf("failed to encode value: %w", err) } - return w.writeBytes(BulkString, rawMessage) + return w.writeBytesWithType(BulkString, rawMessage) } } func (w *Writer) writeResp2Int(n int64) error { w.numBuf = append(w.numBuf[:0], []byte(strconv.FormatInt(n, 10))...) - return w.writeBytes(BulkString, w.numBuf) + return w.writeBytesWithType(BulkString, w.numBuf) } func (w *Writer) writeResp2Uint(n uint64) error { w.numBuf = append(w.numBuf[:0], []byte(strconv.FormatUint(n, 10))...) - return w.writeBytes(BulkString, w.numBuf) + return w.writeBytesWithType(BulkString, w.numBuf) } func (w *Writer) writeResp2Float(f float64) error { w.numBuf = append(w.numBuf[:0], []byte(strconv.FormatFloat(f, 'f', -1, 64))...) - return w.writeBytes(BulkString, w.numBuf) + return w.writeBytesWithType(BulkString, w.numBuf) } -func (w *Writer) WriteArray(v any) error { - values, _ := v.([]any) - +func (w *Writer) WriteStringSlice(v []string) error { if err := w.WriteByte(Array); err != nil { return err } + if err := w.writeLen(len(v)); err != nil { + return err + } + for i := range v { + if err := w.WriteString(v[i]); err != nil { + return err + } + } + return nil +} + +func (w *Writer) WriteSlice(v any) error { + reflectType := reflect.TypeOf(v) + if reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() + } + + if reflectType.Kind() != reflect.Slice { + return fmt.Errorf("expected slice, got %s", reflectType.Kind()) + } - sLen := len(values) + if err := w.WriteByte(Array); err != nil { + return err + } + values := reflect.ValueOf(v) + sLen := values.Len() if err := w.writeLen(sLen); err != nil { return err } - - for _, value := range values { - if err := w.WriteValue(value); err != nil { + for i := 0; i < sLen; i++ { + if err := w.WriteValue(values.Index(i).Interface()); err != nil { return err } } @@ -316,7 +337,7 @@ func (w *Writer) WriteMap(m map[string]any) error { } for k, v := range m { - if err := w.writeBytes(BulkString, []byte(k)); err != nil { + if err := w.writeBytesWithType(BulkString, []byte(k)); err != nil { return err } if err := w.WriteValue(v); err != nil { @@ -378,11 +399,15 @@ func (w *Writer) writeFloat(f float64) error { return w.WriteSimpleValue(Double, w.numBuf) } -func (w *Writer) writeString(s string) error { - return w.writeBytes(BulkString, []byte(s)) +func (w *Writer) WriteString(s string) error { + return w.writeBytesWithType(BulkString, []byte(s)) +} + +func (w *Writer) WriteBytes(b []byte) error { + return w.writeBytesWithType(BulkString, b) } -func (w *Writer) writeBytes(valType byte, b []byte) error { +func (w *Writer) writeBytesWithType(valType byte, b []byte) error { if err := w.WriteByte(valType); err != nil { return err } diff --git a/pkg/resp/writer_test.go b/pkg/resp/writer_test.go index 3956226..21d7712 100644 --- a/pkg/resp/writer_test.go +++ b/pkg/resp/writer_test.go @@ -91,7 +91,7 @@ func TestWriteArray(t *testing.T) { t.Run(tt.name, func(t *testing.T) { buf := &bytes.Buffer{} w := NewWriter(buf, RESP3) - err := w.WriteArray(tt.values) + err := w.WriteSlice(tt.values) if err != tt.wantErr { t.Errorf("WriteArray() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/testdata/dump.rdb b/testdata/dump.rdb new file mode 100644 index 0000000000000000000000000000000000000000..34273e70604b2e2e2eddb3510113224476bc433c GIT binary patch literal 369 zcmYk0ze>YE9LIkpwl*zw(FbsFlawTPmrFJW9o&i{6e^U9y|xEUQgRoqba3zmTzvy4 z!8fQK#0L-@baC?qTw+nLq=UY(Eu9JXpNFU$t&Z08rzTvCzM?Z!pH7%A>Km)C!<1V&&m2D{x(esB0t7Czw*= z+dddNd;BzDJxr1jXYra!GY)1b25}aJJl&xP6F)$%t_K(`CvtHiT+Ku<@`DBGq_=l9 zP0`Q~FN7KvF!eZ5-4)kj_hj0+KGhv_)oHai|K;7~S+~`ON><=HKE@u!*rtfu0U?g( Okl(M5r~8NJ*X=(jY