Skip to content

Commit

Permalink
clean up legacy nats bus features (#91)
Browse files Browse the repository at this point in the history
* clean up legacy nats bus features

* cleanup

* cleanup
  • Loading branch information
paulwe authored Mar 26, 2024
1 parent ba2e5b9 commit 0d14376
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 346 deletions.
14 changes: 2 additions & 12 deletions bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,8 @@ import (
"github.com/livekit/psrpc/internal/bus"
)

const (
LegacySubLegacyPub = bus.LegacySubLegacyPub
LegacySubCompatiblePub = bus.LegacySubCompatiblePub
WildcardSubCompatiblePub = bus.RouterSubCompatiblePub
WildcardSubWildcardPub = bus.RouterSubWildcardPub
)

func SetChannelMode(m uint32) {
if m <= WildcardSubWildcardPub {
bus.ChannelMode.Store(m)
}
}
// TODO: clean up
func SetChannelMode(m uint32) {}

type Channel = bus.Channel
type MessageBus bus.MessageBus
Expand Down
10 changes: 0 additions & 10 deletions internal/bus/bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package bus

import (
"context"
"sync/atomic"

"google.golang.org/protobuf/proto"
)
Expand All @@ -25,15 +24,6 @@ const (
DefaultChannelSize = 100
)

const (
LegacySubLegacyPub = iota
LegacySubCompatiblePub
RouterSubCompatiblePub
RouterSubWildcardPub
)

var ChannelMode atomic.Uint32

type Channel struct {
Legacy, Server, Local string
}
Expand Down
185 changes: 40 additions & 145 deletions internal/bus/bus_nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,23 @@ package bus
import (
"context"
"fmt"
"io"
"sync"

"github.com/nats-io/nats.go"
"go.uber.org/multierr"
"google.golang.org/protobuf/proto"
)

type natsMessageBus struct {
nc *nats.Conn

mu sync.Mutex
routers map[string]*natsWildcardRouter
routers map[string]*natsRouter
}

func NewNatsMessageBus(nc *nats.Conn) MessageBus {
return &natsMessageBus{
nc: nc,
routers: map[string]*natsWildcardRouter{},
routers: map[string]*natsRouter{},
}
}

Expand All @@ -44,45 +42,22 @@ func (n *natsMessageBus) Publish(_ context.Context, channel Channel, msg proto.M
if err != nil {
return err
}

if ChannelMode.Load() != RouterSubWildcardPub {
err = multierr.Append(err, n.nc.Publish(channel.Legacy, b))
}
if ChannelMode.Load() != LegacySubLegacyPub {
err = multierr.Append(err, n.nc.Publish(channel.Server, b))
}
return err
return n.nc.Publish(channel.Server, b)
}

func (n *natsMessageBus) Subscribe(_ context.Context, channel Channel, size int) (Reader, error) {
if channel.Local == "" {
if ChannelMode.Load() == RouterSubWildcardPub {
return n.subscribe(channel.Server, size, false)
} else {
return n.subscribeCompatible(channel, size, false)
}
return n.subscribe(channel.Server, size, false)
} else {
if ChannelMode.Load() == RouterSubWildcardPub {
return n.subscribeRouter(channel, size, false)
} else {
return n.subscribeCompatibleRouter(channel, size, false)
}
return n.subscribeRouter(channel, size, false)
}
}

func (n *natsMessageBus) SubscribeQueue(_ context.Context, channel Channel, size int) (Reader, error) {
if channel.Local == "" {
if ChannelMode.Load() == RouterSubWildcardPub {
return n.subscribe(channel.Server, size, true)
} else {
return n.subscribeCompatible(channel, size, true)
}
return n.subscribe(channel.Server, size, true)
} else {
if ChannelMode.Load() == RouterSubWildcardPub {
return n.subscribeRouter(channel, size, true)
} else {
return n.subscribeCompatibleRouter(channel, size, true)
}
return n.subscribeRouter(channel, size, true)
}
}

Expand All @@ -105,103 +80,59 @@ func (n *natsMessageBus) subscribe(channel string, size int, queue bool) (*natsS
}, nil
}

func (n *natsMessageBus) subscribeCompatible(channel Channel, size int, queue bool) (*natsCompatibleSubscription, error) {
sub, err := n.subscribe(channel.Server, size, queue)
if err != nil {
return nil, err
}
legacySub, err := n.subscribe(channel.Legacy, size, queue)
if err != nil {
sub.Close()
return nil, err
func (n *natsMessageBus) unsubscribeRouter(r *natsRouter, channel string) {
n.mu.Lock()
defer n.mu.Unlock()
if r.close(channel) {
delete(n.routers, r.channel)
}

return &natsCompatibleSubscription{
sub: sub,
legacySub: legacySub,
msgChan: sub.msgChan,
legacyMsgChan: legacySub.msgChan,
}, nil
}

func (n *natsMessageBus) subscribeWildcardRouter(channel string, sub *natsWildcardSubscription, queue bool) error {
func (n *natsMessageBus) subscribeRouter(channel Channel, size int, queue bool) (*natsRouterSubscription, error) {
sub := &natsRouterSubscription{
msgChan: make(chan *nats.Msg, size),
channel: channel.Local,
}

n.mu.Lock()
r, ok := n.routers[channel]
r, ok := n.routers[channel.Server]
if !ok {
r = &natsWildcardRouter{
routes: map[string]*natsWildcardSubscription{},
r = &natsRouter{
routes: map[string]*natsRouterSubscription{},
bus: n,
channel: channel,
channel: channel.Server,
queue: queue,
}
n.routers[channel] = r
n.routers[channel.Server] = r
} else if r.queue != queue {
n.mu.Unlock()
return fmt.Errorf("subscription type mismatch for channel %q %q", channel, sub.channel)
return nil, fmt.Errorf("subscription type mismatch for channel %q %q", channel, sub.channel)
}

r.open(sub.channel, sub)
sub.router = r
n.mu.Unlock()

if ok {
return nil
return sub, nil
}

var err error
if queue {
r.sub, err = n.nc.QueueSubscribe(channel, "bus", r.write)
r.sub, err = n.nc.QueueSubscribe(channel.Server, "bus", r.write)
} else {
r.sub, err = n.nc.Subscribe(channel, r.write)
r.sub, err = n.nc.Subscribe(channel.Server, r.write)
}
if err != nil {
n.mu.Lock()
delete(n.routers, channel)
delete(n.routers, channel.Server)
n.mu.Unlock()
}
return err
}

func (n *natsMessageBus) unsubscribeWildcardRouter(r *natsWildcardRouter, channel string) {
n.mu.Lock()
defer n.mu.Unlock()
if r.close(channel) {
delete(n.routers, r.channel)
}
}

func (n *natsMessageBus) subscribeRouter(channel Channel, size int, queue bool) (*natsWildcardSubscription, error) {
sub := &natsWildcardSubscription{
msgChan: make(chan *nats.Msg, size),
channel: channel.Local,
}

if err := n.subscribeWildcardRouter(channel.Server, sub, queue); err != nil {
return nil, err
}

return sub, nil
}

func (n *natsMessageBus) subscribeCompatibleRouter(channel Channel, size int, queue bool) (*natsCompatibleSubscription, error) {
sub, err := n.subscribeRouter(channel, size, queue)
if err != nil {
return nil, err
}
legacySub, err := n.subscribe(channel.Legacy, size, queue)
if err != nil {
sub.Close()
return nil, err
}

return &natsCompatibleSubscription{
sub: sub,
legacySub: legacySub,
msgChan: sub.msgChan,
legacyMsgChan: legacySub.msgChan,
}, nil
}

type natsSubscription struct {
sub *nats.Subscription
msgChan chan *nats.Msg
Expand All @@ -221,22 +152,22 @@ func (n *natsSubscription) Close() error {
return err
}

type natsWildcardRouter struct {
type natsRouter struct {
sub *nats.Subscription
mu sync.Mutex
routes map[string]*natsWildcardSubscription
routes map[string]*natsRouterSubscription
bus *natsMessageBus
channel string
queue bool
}

func (n *natsWildcardRouter) open(channel string, s *natsWildcardSubscription) {
func (n *natsRouter) open(channel string, s *natsRouterSubscription) {
n.mu.Lock()
defer n.mu.Unlock()
n.routes[channel] = s
}

func (n *natsWildcardRouter) close(channel string) bool {
func (n *natsRouter) close(channel string) bool {
n.mu.Lock()
defer n.mu.Unlock()
delete(n.routes, channel)
Expand All @@ -247,7 +178,7 @@ func (n *natsWildcardRouter) close(channel string) bool {
return false
}

func (n *natsWildcardRouter) write(m *nats.Msg) {
func (n *natsRouter) write(m *nats.Msg) {
channel, err := deserializeChannel(m.Data)
if err != nil {
return
Expand All @@ -260,62 +191,26 @@ func (n *natsWildcardRouter) write(m *nats.Msg) {
}
}

type natsWildcardSubscription struct {
type natsRouterSubscription struct {
msgChan chan *nats.Msg
router *natsWildcardRouter
router *natsRouter
channel string
}

func (n *natsWildcardSubscription) write(m *nats.Msg) {
select {
case n.msgChan <- m:
default:
}
func (n *natsRouterSubscription) write(m *nats.Msg) {
n.msgChan <- m
}

func (n *natsWildcardSubscription) read() ([]byte, bool) {
func (n *natsRouterSubscription) read() ([]byte, bool) {
msg, ok := <-n.msgChan
if !ok {
return nil, false
}
return msg.Data, true
}

func (n *natsWildcardSubscription) Close() error {
n.router.bus.unsubscribeWildcardRouter(n.router, n.channel)
func (n *natsRouterSubscription) Close() error {
n.router.bus.unsubscribeRouter(n.router, n.channel)
close(n.msgChan)
return nil
}

type natsCompatibleSubscription struct {
sub, legacySub io.Closer
msgChan, legacyMsgChan chan *nats.Msg
}

func (n *natsCompatibleSubscription) read() ([]byte, bool) {
for {
select {
case msg, ok := <-n.msgChan:
if !ok {
return nil, false
}
switch ChannelMode.Load() {
case RouterSubCompatiblePub, RouterSubWildcardPub:
return msg.Data, true
}

case msg, ok := <-n.legacyMsgChan:
if !ok {
return nil, false
}
switch ChannelMode.Load() {
case LegacySubLegacyPub, LegacySubCompatiblePub:
return msg.Data, true
}
}
}
}

func (n *natsCompatibleSubscription) Close() error {
return multierr.Combine(n.sub.Close(), n.legacySub.Close())
}
Loading

0 comments on commit 0d14376

Please sign in to comment.