Skip to content

Commit

Permalink
allow multiple subs for the same nats topic (#93)
Browse files Browse the repository at this point in the history
* allow multiple subs for the same nats topic

* test
  • Loading branch information
paulwe authored Apr 3, 2024
1 parent cec3a0e commit 811331b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
29 changes: 21 additions & 8 deletions internal/bus/bus_nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package bus
import (
"context"
"fmt"
"slices"
"sync"

"github.com/nats-io/nats.go"
Expand Down Expand Up @@ -80,10 +81,10 @@ func (n *natsMessageBus) subscribe(channel string, size int, queue bool) (*natsS
}, nil
}

func (n *natsMessageBus) unsubscribeRouter(r *natsRouter, channel string) {
func (n *natsMessageBus) unsubscribeRouter(r *natsRouter, channel string, s *natsRouterSubscription) {
n.mu.Lock()
defer n.mu.Unlock()
if r.close(channel) {
if r.close(channel, s) {
delete(n.routers, r.channel)
}
}
Expand All @@ -98,7 +99,7 @@ func (n *natsMessageBus) subscribeRouter(channel Channel, size int, queue bool)
r, ok := n.routers[channel.Server]
if !ok {
r = &natsRouter{
routes: map[string]*natsRouterSubscription{},
routes: map[string][]*natsRouterSubscription{},
bus: n,
channel: channel.Server,
queue: queue,
Expand Down Expand Up @@ -155,7 +156,7 @@ func (n *natsSubscription) Close() error {
type natsRouter struct {
sub *nats.Subscription
mu sync.Mutex
routes map[string]*natsRouterSubscription
routes map[string][]*natsRouterSubscription
bus *natsMessageBus
channel string
queue bool
Expand All @@ -164,12 +165,24 @@ type natsRouter struct {
func (n *natsRouter) open(channel string, s *natsRouterSubscription) {
n.mu.Lock()
defer n.mu.Unlock()
n.routes[channel] = s
n.routes[channel] = append(n.routes[channel], s)
}

func (n *natsRouter) close(channel string) bool {
func (n *natsRouter) close(channel string, s *natsRouterSubscription) bool {
n.mu.Lock()
defer n.mu.Unlock()

subs := n.routes[channel]
i := slices.Index(n.routes[channel], s)
if i == -1 {
return false
}

if len(subs) > 1 {
n.routes[channel] = slices.Delete(subs, i, i+1)
return false
}

delete(n.routes, channel)
if len(n.routes) == 0 {
n.sub.Unsubscribe()
Expand All @@ -186,7 +199,7 @@ func (n *natsRouter) write(m *nats.Msg) {

n.mu.Lock()
defer n.mu.Unlock()
if s, ok := n.routes[channel]; ok {
for _, s := range n.routes[channel] {
s.write(m)
}
}
Expand All @@ -210,7 +223,7 @@ func (n *natsRouterSubscription) read() ([]byte, bool) {
}

func (n *natsRouterSubscription) Close() error {
n.router.bus.unsubscribeRouter(n.router, n.channel)
n.router.bus.unsubscribeRouter(n.router, n.channel, n)
close(n.msgChan)
return nil
}
1 change: 1 addition & 0 deletions internal/test/psrpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ func testStream(t *testing.T, bus func() psrpc.MessageBus) {

err = server.RegisterStreamHandler[*internal.Response, *internal.Response](serverA, rpc, nil, handlePing, nil)
require.NoError(t, err)
time.Sleep(time.Second)

ctx := context.Background()
stream, err := client.OpenStream[*internal.Response, *internal.Response](
Expand Down

0 comments on commit 811331b

Please sign in to comment.