Skip to content

Commit

Permalink
feat: replication: implement WAIT command (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhughdo authored Sep 23, 2024
1 parent 0014bcf commit 1de5b44
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 32 deletions.
90 changes: 58 additions & 32 deletions internal/app/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const (

type Server struct {
ln net.Listener
mu sync.Mutex
mu sync.RWMutex
cfg *config.Config
done chan struct{}
store keyval.KV
Expand Down Expand Up @@ -74,7 +74,7 @@ func NewServer(cfg *config.Config) *Server {
replicationID = utils.GenerateRandomAlphanumeric(40)
}
s := &Server{
mu: sync.Mutex{},
mu: sync.RWMutex{},
cfg: cfg,
done: make(chan struct{}),
store: store,
Expand Down Expand Up @@ -506,6 +506,9 @@ func (s *Server) handleMessage(ctx context.Context, cl *client.Client, r *resp.R
if err != nil {
return s.writeError(cl, err)
}
if _, isWait := cmd.(*command.Wait); isWait {
s.sendReplConfGetAckToAllReplicas(ctx)
}
isBlocking := cmd.IsBlocking(args)
if isBlocking {
go s.handleBlockingCommand(ctx, cl, cmd, writer, args)
Expand All @@ -517,7 +520,7 @@ func (s *Server) handleMessage(ctx context.Context, cl *client.Client, r *resp.R
return s.writeError(cl, err)
}
if _, ok := config.WriteableCommands[cmdName]; ok {
s.propagateCommand(ctx, r)
s.propagateCommand(ctx, cl, r)
}
if cmdName == "replconf" && len(args) > 0 && args[0].String() == "listening-port" {
s.addReplica(cl)
Expand All @@ -526,13 +529,16 @@ func (s *Server) handleMessage(ctx context.Context, cl *client.Client, r *resp.R
return nil
}

func (s *Server) propagateCommand(ctx context.Context, r *resp.Resp) {
func (s *Server) propagateCommand(ctx context.Context, cl *client.Client, r *resp.Resp) {
s.mu.Lock()
defer s.mu.Unlock()
rawResp := r.RAW()
s.offset += uint64(len(rawResp))
cl.SetLastWriteOffset(s.offset)

// Send the command to each replica
for replica := range s.replicas {
_, err := replica.Conn().Write(r.RAW())
_, err := replica.Conn().Write(rawResp)
if err != nil {
logger.Error(ctx, "Failed to send command to replica %s: %v", replica.ID, err)
// Remove disconnected replica
Expand Down Expand Up @@ -575,40 +581,48 @@ func (s *Server) sendReplConfGetAck(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
s.mu.Lock()
replicas := make([]*client.Client, 0, len(s.replicas))
for replica := range s.replicas {
replicas = append(replicas, replica)
}
s.mu.Unlock()
s.sendReplConfGetAckToAllReplicas(ctx)
}
}
}

for _, replica := range replicas {
go func(replica *client.Client) {
getAckCmd := resp.CreateCommand("REPLCONF", "GETACK", "*")
_, err := replica.Writer.Write(getAckCmd)
if err != nil {
logger.Error(ctx, "Failed to write REPLCONF GETACK to replica %s: %v", replica.ID, err)
s.removeReplica(ctx, replica)
return
}
err = replica.Writer.Flush()
if err != nil {
if errors.Is(err, net.ErrClosed) {
logger.Info(ctx, "Failed to flush REPLCONF GETACK to replica %s, replica disconnected", replica.ID)
} else {
logger.Error(ctx, "Failed to flush REPLCONF GETACK to replica %s: %v", replica.ID, err)
}
s.removeReplica(ctx, replica)
return
}
}(replica)
func (s *Server) sendReplConfGetAckToAllReplicas(ctx context.Context) {
s.mu.Lock()
replicas := make([]*client.Client, 0, len(s.replicas))
for replica := range s.replicas {
replicas = append(replicas, replica)
}
getAckCmd := resp.CreateCommand("REPLCONF", "GETACK", "*")
s.offset += uint64(len(getAckCmd))
s.mu.Unlock()

for _, replica := range replicas {
go func(replica *client.Client) {
_, err := replica.Writer.Write(getAckCmd)
if err != nil {
logger.Error(ctx, "Failed to write REPLCONF GETACK to replica %s: %v", replica.ID, err)
s.removeReplica(ctx, replica)
return
}
}
err = replica.Writer.Flush()
if err != nil {
if errors.Is(err, net.ErrClosed) {
logger.Info(ctx, "Failed to flush REPLCONF GETACK to replica %s, replica disconnected", replica.ID)
} else {
logger.Error(ctx, "Failed to flush REPLCONF GETACK to replica %s: %v", replica.ID, err)
}
s.removeReplica(ctx, replica)
return
}
}(replica)
}
}

func (s *Server) sendPing(replica *client.Client) error {
pingCmd := resp.CreatePingCommand()
s.mu.Lock()
s.offset += uint64(len(pingCmd))
s.mu.Unlock()

_, err := replica.Conn().Write(pingCmd)
if err != nil {
Expand Down Expand Up @@ -699,3 +713,15 @@ func (s *Server) GetReplicationID() string {
defer s.mu.Unlock()
return s.replicationID
}

func (s *Server) GetReplicaAcknowledgedCount(offset uint64) int {
s.mu.RLock()
defer s.mu.RUnlock()
count := 0
for replica := range s.replicas {
if replica.Offset() >= offset {
count++
}
}
return count
}
14 changes: 14 additions & 0 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type Client struct {
messageChan chan<- Message
Writer *resp.Writer
closed bool
lastWriteOffset uint64
}

func NewClient(conn net.Conn, messageChan chan<- Message) *Client {
Expand All @@ -51,9 +52,22 @@ func NewClient(conn net.Conn, messageChan chan<- Message) *Client {
messageChan: messageChan,
bw: bw,
Writer: resp.NewWriter(bw, resp.DefaultVersion),
mu: sync.RWMutex{},
}
}

func (c *Client) GetLastWriteOffset() uint64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.lastWriteOffset
}

func (c *Client) SetLastWriteOffset(offset uint64) {
c.mu.Lock()
defer c.mu.Unlock()
c.lastWriteOffset = offset
}

func (c *Client) UpdateOffset(offset uint64) {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down
4 changes: 4 additions & 0 deletions pkg/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ var (
type ServerInfoProvider interface {
GetReplicaInfo() []map[string]string
GetReplicationID() string
GetReplicaAcknowledgedCount(offset uint64) int
}

type Command interface {
Expand Down Expand Up @@ -79,6 +80,9 @@ func NewCommandFactory(kv keyval.KV, cfg *config.Config, serverInfo ServerInfoPr
serverInfo: serverInfo,
},
"save": &Save{kv: kv},
"wait": &Wait{
serverInfo: serverInfo,
},
},
}
}
Expand Down
64 changes: 64 additions & 0 deletions pkg/command/wait.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package command

import (
"errors"
"strconv"
"time"

"github.com/codecrafters-io/redis-starter-go/internal/client"
"github.com/codecrafters-io/redis-starter-go/pkg/resp"
)

type Wait struct {
serverInfo ServerInfoProvider
}

func (w *Wait) Execute(c *client.Client, wr *resp.Writer, args []*resp.Resp) error {
if len(args) != 2 {
return wr.WriteError(errors.New("ERR wrong number of arguments for 'wait' command"))
}

numReplicas, err := strconv.Atoi(args[0].String())
if err != nil || numReplicas < 0 {
return wr.WriteError(errors.New("ERR invalid number of replicas"))
}

timeoutMillis, err := strconv.Atoi(args[1].String())
if err != nil || timeoutMillis < 0 {
return wr.WriteError(errors.New("ERR invalid timeout"))
}

clientOffset := c.GetLastWriteOffset()
// If there's no pending writes, return immediately
if clientOffset == 0 {
return wr.WriteSimpleValue(resp.Integer, []byte("0"))
}

// Start waiting for replicas to acknowledge up to clientOffset
startTime := time.Now()
timeout := time.Duration(timeoutMillis) * time.Millisecond
for {
ackedReplicas := w.serverInfo.GetReplicaAcknowledgedCount(clientOffset)
if ackedReplicas >= numReplicas {
return wr.WriteSimpleValue(resp.Integer, []byte(strconv.Itoa(ackedReplicas)))
}

if timeoutMillis == 0 {
// Block forever, so just sleep briefly
time.Sleep(10 * time.Millisecond)
continue
}

elapsed := time.Since(startTime)
if elapsed >= timeout {
return wr.WriteSimpleValue(resp.Integer, []byte(strconv.Itoa(ackedReplicas)))
}

// Sleep briefly before checking again
time.Sleep(10 * time.Millisecond)
}
}

func (w *Wait) IsBlocking(_ []*resp.Resp) bool {
return true
}

0 comments on commit 1de5b44

Please sign in to comment.