Skip to content

Commit

Permalink
feat: replication: propagate commands to replicas (#14)
Browse files Browse the repository at this point in the history
* feat: replication: propagate commands to replicas

* fix: lint err
  • Loading branch information
mhughdo authored Sep 23, 2024
1 parent 2590f2b commit 5c4fff9
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 4 deletions.
7 changes: 7 additions & 0 deletions internal/app/server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ const (
ReplicaOfKey = "REPLICAOF"
)

var (
WriteableCommands = map[string]struct{}{
"set": {},
"del": {},
}
)

var supportedOptions = map[string]struct{}{
ListenAddrKey: {},
DirKey: {},
Expand Down
32 changes: 32 additions & 0 deletions internal/app/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,19 @@ func (s *Server) handleMasterResponse(ctx context.Context, r *resp.Resp, reader
s.receiveRDB(ctx, reader)
}
case resp.Array:
cmdName := strings.ToLower(r.Data.([]*resp.Resp)[0].String())
args := r.Data.([]*resp.Resp)[1:]
cmd, err := s.cFactory.GetCommand(cmdName)
if err != nil {
logger.Error(ctx, "Unknown command from master: %s", cmdName)
return
}
logger.Info(ctx, "Received command from master, cmd: %s, args: %v", cmdName, args)
tmpWriter := resp.NewWriter(&bytes.Buffer{}, resp.RESP3)
err = cmd.Execute(s.masterClient, tmpWriter, args)
if err != nil {
logger.Error(ctx, "Failed to execute command from master: %v", err)
}
case resp.BulkString:
case resp.SimpleError, resp.BulkError:
errMsg := r.String()
Expand Down Expand Up @@ -478,13 +491,32 @@ func (s *Server) handleMessage(ctx context.Context, cl *client.Client, r *resp.R
cl.Writer.Reset()
return s.writeError(cl, err)
}
if _, ok := config.WriteableCommands[cmdName]; ok {
s.propagateCommand(ctx, r)
}
if cmdName == "replconf" && len(args) > 0 && args[0].String() == "listening-port" {
s.addReplica(cl)
}

return nil
}

func (s *Server) propagateCommand(ctx context.Context, r *resp.Resp) {
s.mu.Lock()
defer s.mu.Unlock()

// Send the command to each replica
for replica := range s.replicas {
_, err := replica.Conn().Write(r.RAW())
if err != nil {
logger.Error(ctx, "Failed to send command to replica %s: %v", replica.ID, err)
// Remove disconnected replica
delete(s.replicas, replica)
replica.Close(ctx)
}
}
}

func (s *Server) addReplica(c *client.Client) {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down
2 changes: 1 addition & 1 deletion internal/app/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestGetAndSetCommand(t *testing.T) {
{
name: "Set and Get with EXAT",
setup: func() error {
return rdb.SetArgs(ctx, "foo-exat", "bar", redis.SetArgs{ExpireAt: time.Now().Add(100 * time.Millisecond)}).Err()
return rdb.SetArgs(ctx, "foo-exat", "bar", redis.SetArgs{ExpireAt: time.Now().Add(500 * time.Millisecond)}).Err()
},
action: func() (interface{}, error) {
return rdb.Get(ctx, "foo-exat").Result()
Expand Down
6 changes: 3 additions & 3 deletions pkg/command/replconf.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ func (rc *ReplConf) Execute(c *client.Client, wr *resp.Writer, args []*resp.Resp
port := args[1].String()
c.ListeningPort = port
case "capa":
if len(args) < 2 {
return wr.WriteError(errors.New("wrong number of arguments for 'replconf capa' command"))
}
// if len(args) < 2 {
// return wr.WriteError(errors.New("wrong number of arguments for 'replconf capa' command"))
// }
// We don't need to handle/save the capa arguments
default:
return wr.WriteError(fmt.Errorf("unknown replconf subcommand: %s", subCommand))
Expand Down
36 changes: 36 additions & 0 deletions pkg/resp/resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,42 @@ func parseLen(line []byte) (int, error) {
return n, nil
}

func (r *Resp) RAW() []byte {
switch r.Type {
case SimpleString, SimpleError, Integer, Null, Boolean, Double, BigNumber:
return []byte(fmt.Sprintf("%c%s\r\n", r.Type, r))
case BulkString, BulkError, VerbatimString:
if r.Data == nil {
return []byte("$-1\r\n")
}
data := r.Data.([]byte)
return []byte(fmt.Sprintf("%c%d\r\n%s\r\n", r.Type, len(data), data))
case Array, Pushes, Set:
if r.Data == nil {
return []byte("*-1\r\n")
}
elements := r.Data.([]*Resp)
b := []byte(fmt.Sprintf("%c%d\r\n", r.Type, len(elements)))
for _, elem := range elements {
b = append(b, elem.RAW()...)
}
return b
case Map:
if r.Data == nil {
return []byte("%0\r\n")
}
m := r.Data.(map[string]*Resp)
b := []byte(fmt.Sprintf("%c%d\r\n", r.Type, len(m)))
for k, v := range m {
b = append(b, []byte(fmt.Sprintf("+%s\r\n", k))...)
b = append(b, v.RAW()...)
}
return b
default:
return []byte{}
}
}

func (r *Resp) ToResponse() []byte {
switch r.Type {
case SimpleString, SimpleError, Integer, Null, Boolean, Double, BigNumber:
Expand Down

0 comments on commit 5c4fff9

Please sign in to comment.