Skip to content

Commit

Permalink
unblock subscriptions on close (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulwe authored Apr 26, 2024
1 parent 811331b commit 8ba067a
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 46 deletions.
52 changes: 33 additions & 19 deletions internal/bus/bus_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ func (l *localMessageBus) Publish(_ context.Context, channel Channel, msg proto.
return nil
}

func (l *localMessageBus) Subscribe(_ context.Context, channel Channel, size int) (Reader, error) {
return l.subscribe(l.subs, channel.Legacy, size, false)
func (l *localMessageBus) Subscribe(ctx context.Context, channel Channel, size int) (Reader, error) {
return l.subscribe(ctx, l.subs, channel.Legacy, size, false)
}

func (l *localMessageBus) SubscribeQueue(_ context.Context, channel Channel, size int) (Reader, error) {
return l.subscribe(l.queues, channel.Legacy, size, true)
func (l *localMessageBus) SubscribeQueue(ctx context.Context, channel Channel, size int) (Reader, error) {
return l.subscribe(ctx, l.queues, channel.Legacy, size, true)
}

func (l *localMessageBus) subscribe(subLists map[string]*localSubList, channel string, size int, queue bool) (Reader, error) {
func (l *localMessageBus) subscribe(ctx context.Context, subLists map[string]*localSubList, channel string, size int, queue bool) (Reader, error) {
l.Lock()
defer l.Unlock()

Expand All @@ -74,7 +74,6 @@ func (l *localMessageBus) subscribe(subLists map[string]*localSubList, channel s
l.Lock()
subList.Lock()

close(subList.subs[index])
subList.subs[index] = nil
subList.subCount--
if subList.subCount == 0 {
Expand All @@ -87,20 +86,25 @@ func (l *localMessageBus) subscribe(subLists map[string]*localSubList, channel s
subLists[channel] = subList
}

return subList.create(size), nil
return subList.create(ctx, size), nil
}

type localSubList struct {
sync.RWMutex // locking while holding localMessageBus lock is allowed
subs []chan []byte
subs []*localSubscription
subCount int
queue bool
next int
onUnsubscribe func(int)
}

func (l *localSubList) create(size int) *localSubscription {
msgChan := make(chan []byte, size)
func (l *localSubList) create(ctx context.Context, size int) *localSubscription {
ctx, cancel := context.WithCancel(ctx)
sub := &localSubscription{
ctx: ctx,
cancel: cancel,
msgChan: make(chan []byte, size),
}

l.Lock()
defer l.Unlock()
Expand All @@ -112,22 +116,21 @@ func (l *localSubList) create(size int) *localSubscription {
if s == nil {
added = true
index = i
l.subs[i] = msgChan
l.subs[i] = sub
break
}
}

if !added {
index = len(l.subs)
l.subs = append(l.subs, msgChan)
l.subs = append(l.subs, sub)
}

return &localSubscription{
msgChan: msgChan,
onClose: func() {
l.onUnsubscribe(index)
},
sub.onClose = func() {
l.onUnsubscribe(index)
}

return sub
}

func (l *localSubList) dispatch(b []byte) {
Expand All @@ -143,7 +146,7 @@ func (l *localSubList) dispatch(b []byte) {
s := l.subs[l.next]
l.next++
if s != nil {
s <- b
s.write(b)
return
}
}
Expand All @@ -154,17 +157,26 @@ func (l *localSubList) dispatch(b []byte) {
// send to all
for _, s := range l.subs {
if s != nil {
s <- b
s.write(b)
}
}
}
}

type localSubscription struct {
ctx context.Context
cancel context.CancelFunc
msgChan chan []byte
onClose func()
}

func (l *localSubscription) write(b []byte) {
select {
case l.msgChan <- b:
case <-l.ctx.Done():
}
}

func (l *localSubscription) read() ([]byte, bool) {
msg, ok := <-l.msgChan
if !ok {
Expand All @@ -174,6 +186,8 @@ func (l *localSubscription) read() ([]byte, bool) {
}

func (l *localSubscription) Close() error {
l.cancel()
l.onClose()
close(l.msgChan)
return nil
}
55 changes: 38 additions & 17 deletions internal/bus/bus_nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,39 +46,41 @@ func (n *natsMessageBus) Publish(_ context.Context, channel Channel, msg proto.M
return n.nc.Publish(channel.Server, b)
}

func (n *natsMessageBus) Subscribe(_ context.Context, channel Channel, size int) (Reader, error) {
func (n *natsMessageBus) Subscribe(ctx context.Context, channel Channel, size int) (Reader, error) {
if channel.Local == "" {
return n.subscribe(channel.Server, size, false)
return n.subscribe(ctx, channel.Server, size, false)
} else {
return n.subscribeRouter(channel, size, false)
return n.subscribeRouter(ctx, channel, size, false)
}
}

func (n *natsMessageBus) SubscribeQueue(_ context.Context, channel Channel, size int) (Reader, error) {
func (n *natsMessageBus) SubscribeQueue(ctx context.Context, channel Channel, size int) (Reader, error) {
if channel.Local == "" {
return n.subscribe(channel.Server, size, true)
return n.subscribe(ctx, channel.Server, size, true)
} else {
return n.subscribeRouter(channel, size, true)
return n.subscribeRouter(ctx, channel, size, true)
}
}

func (n *natsMessageBus) subscribe(channel string, size int, queue bool) (*natsSubscription, error) {
msgChan := make(chan *nats.Msg, size)
var sub *nats.Subscription
func (n *natsMessageBus) subscribe(ctx context.Context, channel string, size int, queue bool) (*natsSubscription, error) {
ctx, cancel := context.WithCancel(ctx)
sub := &natsSubscription{
ctx: ctx,
cancel: cancel,
msgChan: make(chan *nats.Msg, size),
}

var err error
if queue {
sub, err = n.nc.ChanQueueSubscribe(channel, "bus", msgChan)
sub.sub, err = n.nc.QueueSubscribe(channel, "bus", sub.write)
} else {
sub, err = n.nc.ChanSubscribe(channel, msgChan)
sub.sub, err = n.nc.Subscribe(channel, sub.write)
}
if err != nil {
return nil, err
}

return &natsSubscription{
sub: sub,
msgChan: msgChan,
}, nil
return sub, nil
}

func (n *natsMessageBus) unsubscribeRouter(r *natsRouter, channel string, s *natsRouterSubscription) {
Expand All @@ -89,8 +91,11 @@ func (n *natsMessageBus) unsubscribeRouter(r *natsRouter, channel string, s *nat
}
}

func (n *natsMessageBus) subscribeRouter(channel Channel, size int, queue bool) (*natsRouterSubscription, error) {
func (n *natsMessageBus) subscribeRouter(ctx context.Context, channel Channel, size int, queue bool) (*natsRouterSubscription, error) {
ctx, cancel := context.WithCancel(ctx)
sub := &natsRouterSubscription{
ctx: ctx,
cancel: cancel,
msgChan: make(chan *nats.Msg, size),
channel: channel.Local,
}
Expand Down Expand Up @@ -135,10 +140,19 @@ func (n *natsMessageBus) subscribeRouter(channel Channel, size int, queue bool)
}

type natsSubscription struct {
ctx context.Context
cancel context.CancelFunc
sub *nats.Subscription
msgChan chan *nats.Msg
}

func (n *natsSubscription) write(msg *nats.Msg) {
select {
case n.msgChan <- msg:
case <-n.ctx.Done():
}
}

func (n *natsSubscription) read() ([]byte, bool) {
msg, ok := <-n.msgChan
if !ok {
Expand All @@ -148,6 +162,7 @@ func (n *natsSubscription) read() ([]byte, bool) {
}

func (n *natsSubscription) Close() error {
n.cancel()
err := n.sub.Unsubscribe()
close(n.msgChan)
return err
Expand Down Expand Up @@ -205,13 +220,18 @@ func (n *natsRouter) write(m *nats.Msg) {
}

type natsRouterSubscription struct {
ctx context.Context
cancel context.CancelFunc
msgChan chan *nats.Msg
router *natsRouter
channel string
}

func (n *natsRouterSubscription) write(m *nats.Msg) {
n.msgChan <- m
select {
case n.msgChan <- m:
case <-n.ctx.Done():
}
}

func (n *natsRouterSubscription) read() ([]byte, bool) {
Expand All @@ -223,6 +243,7 @@ func (n *natsRouterSubscription) read() ([]byte, bool) {
}

func (n *natsRouterSubscription) Close() error {
n.cancel()
n.router.bus.unsubscribeRouter(n.router, n.channel, n)
close(n.msgChan)
return nil
Expand Down
30 changes: 20 additions & 10 deletions internal/bus/bus_redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ func (r *redisMessageBus) SubscribeQueue(ctx context.Context, channel Channel, s
}

func (r *redisMessageBus) subscribe(ctx context.Context, channel string, size int, subLists map[string]*redisSubList, queue bool) (Reader, error) {
ctx, cancel := context.WithCancel(ctx)
sub := &redisSubscription{
bus: r,
ctx: ctx,
cancel: cancel,
channel: channel,
msgChan: make(chan *redis.Message, size),
queue: queue,
Expand All @@ -120,12 +122,12 @@ func (r *redisMessageBus) subscribe(ctx context.Context, channel string, size in
subLists[channel] = subList
r.reconcileSubscriptions(channel)
}
subList.subs = append(subList.subs, sub.msgChan)
subList.subs = append(subList.subs, sub)

return sub, nil
}

func (r *redisMessageBus) unsubscribe(channel string, queue bool, msgChan chan *redis.Message) {
func (r *redisMessageBus) unsubscribe(channel string, queue bool, sub *redisSubscription) {
r.mu.Lock()
defer r.mu.Unlock()

Expand All @@ -140,13 +142,12 @@ func (r *redisMessageBus) unsubscribe(channel string, queue bool, msgChan chan *
if !ok {
return
}
i := slices.Index(subList.subs, msgChan)
i := slices.Index(subList.subs, sub)
if i == -1 {
return
}

subList.subs = slices.Delete(subList.subs, i, i+1)
close(msgChan)

if len(subList.subs) == 0 {
delete(subLists, channel)
Expand Down Expand Up @@ -320,32 +321,40 @@ func (r *redisReconcileSubscriptionsOp) run() error {
}

type redisSubList struct {
subs []chan *redis.Message
subs []*redisSubscription
next int
}

func (r *redisSubList) dispatchQueue(msg *redis.Message) {
if r.next >= len(r.subs) {
r.next = 0
}
r.subs[r.next] <- msg
r.subs[r.next].write(msg)
r.next++
}

func (r *redisSubList) dispatch(msg *redis.Message) {
for _, ch := range r.subs {
ch <- msg
for _, sub := range r.subs {
sub.write(msg)
}
}

type redisSubscription struct {
bus *redisMessageBus
ctx context.Context
cancel context.CancelFunc
channel string
msgChan chan *redis.Message
queue bool
}

func (r *redisSubscription) write(msg *redis.Message) {
select {
case r.msgChan <- msg:
case <-r.ctx.Done():
}
}

func (r *redisSubscription) read() ([]byte, bool) {
for {
var msg *redis.Message
Expand All @@ -356,7 +365,6 @@ func (r *redisSubscription) read() ([]byte, bool) {
return nil, false
}
case <-r.ctx.Done():
r.Close()
return nil, false
}

Expand All @@ -374,6 +382,8 @@ func (r *redisSubscription) read() ([]byte, bool) {
}

func (r *redisSubscription) Close() error {
r.bus.unsubscribe(r.channel, r.queue, r.msgChan)
r.cancel()
r.bus.unsubscribe(r.channel, r.queue, r)
close(r.msgChan)
return nil
}

0 comments on commit 8ba067a

Please sign in to comment.