Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up legacy nats bus features #91

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,8 @@ import (
"github.com/livekit/psrpc/internal/bus"
)

const (
LegacySubLegacyPub = bus.LegacySubLegacyPub
LegacySubCompatiblePub = bus.LegacySubCompatiblePub
WildcardSubCompatiblePub = bus.RouterSubCompatiblePub
WildcardSubWildcardPub = bus.RouterSubWildcardPub
)

func SetChannelMode(m uint32) {
if m <= WildcardSubWildcardPub {
bus.ChannelMode.Store(m)
}
}
// TODO: clean up
func SetChannelMode(m uint32) {}

type Channel = bus.Channel
type MessageBus bus.MessageBus
Expand Down
12 changes: 1 addition & 11 deletions internal/bus/bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package bus

import (
"context"
"sync/atomic"

"google.golang.org/protobuf/proto"
)
Expand All @@ -25,17 +24,8 @@ const (
DefaultChannelSize = 100
)

const (
LegacySubLegacyPub = iota
LegacySubCompatiblePub
RouterSubCompatiblePub
RouterSubWildcardPub
)

var ChannelMode atomic.Uint32

type Channel struct {
Legacy, Server, Local string
Legacy, Server, Server2, Local string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is Server2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missed cleanup

}

type MessageBus interface {
Expand Down
185 changes: 40 additions & 145 deletions internal/bus/bus_nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,23 @@ package bus
import (
"context"
"fmt"
"io"
"sync"

"github.com/nats-io/nats.go"
"go.uber.org/multierr"
"google.golang.org/protobuf/proto"
)

type natsMessageBus struct {
nc *nats.Conn

mu sync.Mutex
routers map[string]*natsWildcardRouter
routers map[string]*natsRouter
}

func NewNatsMessageBus(nc *nats.Conn) MessageBus {
return &natsMessageBus{
nc: nc,
routers: map[string]*natsWildcardRouter{},
routers: map[string]*natsRouter{},
}
}

Expand All @@ -44,45 +42,22 @@ func (n *natsMessageBus) Publish(_ context.Context, channel Channel, msg proto.M
if err != nil {
return err
}

if ChannelMode.Load() != RouterSubWildcardPub {
err = multierr.Append(err, n.nc.Publish(channel.Legacy, b))
}
if ChannelMode.Load() != LegacySubLegacyPub {
err = multierr.Append(err, n.nc.Publish(channel.Server, b))
}
return err
return n.nc.Publish(channel.Server, b)
}

func (n *natsMessageBus) Subscribe(_ context.Context, channel Channel, size int) (Reader, error) {
if channel.Local == "" {
if ChannelMode.Load() == RouterSubWildcardPub {
return n.subscribe(channel.Server, size, false)
} else {
return n.subscribeCompatible(channel, size, false)
}
return n.subscribe(channel.Server, size, false)
} else {
if ChannelMode.Load() == RouterSubWildcardPub {
return n.subscribeRouter(channel, size, false)
} else {
return n.subscribeCompatibleRouter(channel, size, false)
}
return n.subscribeRouter(channel, size, false)
}
}

func (n *natsMessageBus) SubscribeQueue(_ context.Context, channel Channel, size int) (Reader, error) {
if channel.Local == "" {
if ChannelMode.Load() == RouterSubWildcardPub {
return n.subscribe(channel.Server, size, true)
} else {
return n.subscribeCompatible(channel, size, true)
}
return n.subscribe(channel.Server, size, true)
} else {
if ChannelMode.Load() == RouterSubWildcardPub {
return n.subscribeRouter(channel, size, true)
} else {
return n.subscribeCompatibleRouter(channel, size, true)
}
return n.subscribeRouter(channel, size, true)
}
}

Expand All @@ -105,103 +80,59 @@ func (n *natsMessageBus) subscribe(channel string, size int, queue bool) (*natsS
}, nil
}

func (n *natsMessageBus) subscribeCompatible(channel Channel, size int, queue bool) (*natsCompatibleSubscription, error) {
sub, err := n.subscribe(channel.Server, size, queue)
if err != nil {
return nil, err
}
legacySub, err := n.subscribe(channel.Legacy, size, queue)
if err != nil {
sub.Close()
return nil, err
func (n *natsMessageBus) unsubscribeRouter(r *natsRouter, channel string) {
n.mu.Lock()
defer n.mu.Unlock()
if r.close(channel) {
delete(n.routers, r.channel)
}

return &natsCompatibleSubscription{
sub: sub,
legacySub: legacySub,
msgChan: sub.msgChan,
legacyMsgChan: legacySub.msgChan,
}, nil
}

func (n *natsMessageBus) subscribeWildcardRouter(channel string, sub *natsWildcardSubscription, queue bool) error {
func (n *natsMessageBus) subscribeRouter(channel Channel, size int, queue bool) (*natsRouterSubscription, error) {
sub := &natsRouterSubscription{
msgChan: make(chan *nats.Msg, size),
channel: channel.Local,
}

n.mu.Lock()
r, ok := n.routers[channel]
r, ok := n.routers[channel.Server]
if !ok {
r = &natsWildcardRouter{
routes: map[string]*natsWildcardSubscription{},
r = &natsRouter{
routes: map[string]*natsRouterSubscription{},
bus: n,
channel: channel,
channel: channel.Server,
queue: queue,
}
n.routers[channel] = r
n.routers[channel.Server] = r
} else if r.queue != queue {
n.mu.Unlock()
return fmt.Errorf("subscription type mismatch for channel %q %q", channel, sub.channel)
return nil, fmt.Errorf("subscription type mismatch for channel %q %q", channel, sub.channel)
}

r.open(sub.channel, sub)
sub.router = r
n.mu.Unlock()

if ok {
return nil
return sub, nil
}

var err error
if queue {
r.sub, err = n.nc.QueueSubscribe(channel, "bus", r.write)
r.sub, err = n.nc.QueueSubscribe(channel.Server, "bus", r.write)
} else {
r.sub, err = n.nc.Subscribe(channel, r.write)
r.sub, err = n.nc.Subscribe(channel.Server, r.write)
}
if err != nil {
n.mu.Lock()
delete(n.routers, channel)
delete(n.routers, channel.Server)
n.mu.Unlock()
}
return err
}

func (n *natsMessageBus) unsubscribeWildcardRouter(r *natsWildcardRouter, channel string) {
n.mu.Lock()
defer n.mu.Unlock()
if r.close(channel) {
delete(n.routers, r.channel)
}
}

func (n *natsMessageBus) subscribeRouter(channel Channel, size int, queue bool) (*natsWildcardSubscription, error) {
sub := &natsWildcardSubscription{
msgChan: make(chan *nats.Msg, size),
channel: channel.Local,
}

if err := n.subscribeWildcardRouter(channel.Server, sub, queue); err != nil {
return nil, err
}

return sub, nil
}

func (n *natsMessageBus) subscribeCompatibleRouter(channel Channel, size int, queue bool) (*natsCompatibleSubscription, error) {
sub, err := n.subscribeRouter(channel, size, queue)
if err != nil {
return nil, err
}
legacySub, err := n.subscribe(channel.Legacy, size, queue)
if err != nil {
sub.Close()
return nil, err
}

return &natsCompatibleSubscription{
sub: sub,
legacySub: legacySub,
msgChan: sub.msgChan,
legacyMsgChan: legacySub.msgChan,
}, nil
}

type natsSubscription struct {
sub *nats.Subscription
msgChan chan *nats.Msg
Expand All @@ -221,22 +152,22 @@ func (n *natsSubscription) Close() error {
return err
}

type natsWildcardRouter struct {
type natsRouter struct {
sub *nats.Subscription
mu sync.Mutex
routes map[string]*natsWildcardSubscription
routes map[string]*natsRouterSubscription
bus *natsMessageBus
channel string
queue bool
}

func (n *natsWildcardRouter) open(channel string, s *natsWildcardSubscription) {
func (n *natsRouter) open(channel string, s *natsRouterSubscription) {
n.mu.Lock()
defer n.mu.Unlock()
n.routes[channel] = s
}

func (n *natsWildcardRouter) close(channel string) bool {
func (n *natsRouter) close(channel string) bool {
n.mu.Lock()
defer n.mu.Unlock()
delete(n.routes, channel)
Expand All @@ -247,7 +178,7 @@ func (n *natsWildcardRouter) close(channel string) bool {
return false
}

func (n *natsWildcardRouter) write(m *nats.Msg) {
func (n *natsRouter) write(m *nats.Msg) {
channel, err := deserializeChannel(m.Data)
if err != nil {
return
Expand All @@ -260,62 +191,26 @@ func (n *natsWildcardRouter) write(m *nats.Msg) {
}
}

type natsWildcardSubscription struct {
type natsRouterSubscription struct {
msgChan chan *nats.Msg
router *natsWildcardRouter
router *natsRouter
channel string
}

func (n *natsWildcardSubscription) write(m *nats.Msg) {
select {
case n.msgChan <- m:
default:
}
func (n *natsRouterSubscription) write(m *nats.Msg) {
n.msgChan <- m
}

func (n *natsWildcardSubscription) read() ([]byte, bool) {
func (n *natsRouterSubscription) read() ([]byte, bool) {
msg, ok := <-n.msgChan
if !ok {
return nil, false
}
return msg.Data, true
}

func (n *natsWildcardSubscription) Close() error {
n.router.bus.unsubscribeWildcardRouter(n.router, n.channel)
func (n *natsRouterSubscription) Close() error {
n.router.bus.unsubscribeRouter(n.router, n.channel)
close(n.msgChan)
return nil
}

type natsCompatibleSubscription struct {
sub, legacySub io.Closer
msgChan, legacyMsgChan chan *nats.Msg
}

func (n *natsCompatibleSubscription) read() ([]byte, bool) {
for {
select {
case msg, ok := <-n.msgChan:
if !ok {
return nil, false
}
switch ChannelMode.Load() {
case RouterSubCompatiblePub, RouterSubWildcardPub:
return msg.Data, true
}

case msg, ok := <-n.legacyMsgChan:
if !ok {
return nil, false
}
switch ChannelMode.Load() {
case LegacySubLegacyPub, LegacySubCompatiblePub:
return msg.Data, true
}
}
}
}

func (n *natsCompatibleSubscription) Close() error {
return multierr.Combine(n.sub.Close(), n.legacySub.Close())
}
Loading
Loading