Skip to content

Commit

Permalink
fix: xread: maintain the order of streams
Browse files Browse the repository at this point in the history
  • Loading branch information
mhughdo committed Aug 28, 2024
1 parent 5c907e8 commit 92b40b2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
32 changes: 16 additions & 16 deletions pkg/command/xread.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import (
)

type XRead struct {
kv keyval.KV
kv keyval.KV
streamOrder []string
}

type XReadOptions struct {
Expand Down Expand Up @@ -77,10 +78,12 @@ func (x *XRead) parseArgs(args []*resp.Resp) (*XReadOptions, error) {
if streamCount <= 0 || (len(args)-i)%2 != 0 {
return nil, errors.New("ERR Unbalanced 'xread' list of streams: for each stream key an ID or '$' must be specified.")
}
x.streamOrder = make([]string, streamCount)
for j := 0; j < streamCount; j++ {
key := args[i+j].String()
id := args[i+streamCount+j].String()
opts.Streams[key] = id
x.streamOrder[j] = key
}
return opts, nil
default:
Expand Down Expand Up @@ -110,7 +113,7 @@ func (x *XRead) readStreams(opts *XReadOptions) (map[string][]keyval.StreamEntry
// Ignore this case as $ means only new entries
} else if lastID == "+" {
entries = append(entries, stream.Range(stream.LastID(), "+", 1)...)
} else {
} else if opts.Block <= 0 {
entries = append(entries, stream.Range(lastID, "+", opts.Count)...)
}
if len(entries) > 0 {
Expand Down Expand Up @@ -142,21 +145,14 @@ func (x *XRead) blockingRead(opts *XReadOptions) (map[string][]keyval.StreamEntr
}
}()

for streamName, lastID := range opts.Streams {
for streamName := range opts.Streams {
stream, err := x.kv.GetStream(streamName, true)
if err != nil {
return nil, err
}

ch := stream.Subscribe()
subscriptions[streamName] = ch

// Check for existing entries
entries := stream.Range(lastID, "+", opts.Count)
if len(entries) > 0 {
result[streamName] = entries
return result, nil
}
}

timer := time.NewTimer(opts.Block)
Expand Down Expand Up @@ -191,19 +187,23 @@ func (x *XRead) writeResult(cl *client.Client, wr *resp.Writer, result map[strin

func (x *XRead) writeResultRESP2(wr *resp.Writer, result map[string][]keyval.StreamEntry) error {
var response []any
for streamName, entries := range result {
streamResponse := []any{streamName, x.formatEntries(entries)}
response = append(response, streamResponse)
for _, streamName := range x.streamOrder {
if entries, ok := result[streamName]; ok {
streamResponse := []any{streamName, x.formatEntries(entries)}
response = append(response, streamResponse)
}
}
return wr.WriteSlice(response)
}

func (x *XRead) writeResultRESP3(wr *resp.Writer, result map[string][]keyval.StreamEntry) error {
response := make(map[string]any)
for streamName, entries := range result {
response[streamName] = x.formatEntries(entries)
for _, streamName := range x.streamOrder {
if entries, ok := result[streamName]; ok {
response[streamName] = x.formatEntries(entries)
}
}
return wr.WriteMap(response)
return wr.WriteMapOrdered(response, x.streamOrder)
}

func (x *XRead) formatEntries(entries []keyval.StreamEntry) []any {
Expand Down
25 changes: 25 additions & 0 deletions pkg/resp/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,31 @@ func (w *Writer) WriteSlice(v any) error {
return nil
}

func (w *Writer) WriteMapOrdered(m map[string]any, orderedKeys []string) error {
if err := w.WriteByte(Map); err != nil {
return err
}

if err := w.writeLen(len(m)); err != nil {
return err
}

for _, k := range orderedKeys {
v, ok := m[k]
if !ok {
continue // Skip keys that are not in the map
}
if err := w.writeBytesWithType(BulkString, []byte(k)); err != nil {
return err
}
if err := w.WriteValue(v); err != nil {
return err
}
}

return nil
}

func (w *Writer) WriteMap(m map[string]any) error {
if err := w.WriteByte(Map); err != nil {
return err
Expand Down

0 comments on commit 92b40b2

Please sign in to comment.