diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go index 8c5ad2d..9d8eb61 100644 --- a/internal/agent/agent_test.go +++ b/internal/agent/agent_test.go @@ -11,6 +11,7 @@ import ( api "github.com/justagabriel/proglog/api/v1" "github.com/justagabriel/proglog/internal" "github.com/justagabriel/proglog/internal/config" + "github.com/justagabriel/proglog/internal/loadbalance" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -87,6 +88,8 @@ func TestAgent(t *testing.T) { createResp, err := leaderClient.Create(context.Background(), &createReq) require.NoError(t, err) + time.Sleep(3 * time.Second) + getReq := api.GetRecordRequest{ Offset: createResp.Offset, } @@ -119,7 +122,7 @@ func client(t *testing.T, agent *Agent, tlsConfig *tls.Config) api.LogClient { rpcAddr, err := agent.Config.RPCAddr() require.NoError(t, err) - conn, err := grpc.Dial(rpcAddr, opts...) + conn, err := grpc.Dial(fmt.Sprintf("%s:///%s", loadbalance.Name, rpcAddr), opts...) require.NoError(t, err) return api.NewLogClient(conn) diff --git a/internal/loadbalance/picker.go b/internal/loadbalance/picker.go new file mode 100644 index 0000000..4baa3f1 --- /dev/null +++ b/internal/loadbalance/picker.go @@ -0,0 +1,65 @@ +package loadbalance + +import ( + "strings" + "sync" + "sync/atomic" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" +) + +var _ base.PickerBuilder = (*Picker)(nil) + +type Picker struct { + mu sync.Mutex + leader balancer.SubConn + followers []balancer.SubConn + current uint64 +} + +func (p *Picker) Build(buildInfo base.PickerBuildInfo) balancer.Picker { + p.mu.Lock() + defer p.mu.Unlock() + var followers []balancer.SubConn + for sc, scInfo := range buildInfo.ReadySCs { + isLeader := scInfo.Address.Attributes.Value("is_leader").(bool) + if isLeader { + p.leader = sc + continue + } + followers = append(followers, sc) + } + p.followers = followers + return p +} + +var _ balancer.Picker = (*Picker)(nil) + +func (p *Picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + p.mu.Lock() + defer p.mu.Unlock() + var result balancer.PickResult + if strings.Contains(info.FullMethodName, "Create") || len(p.followers) == 0 { + result.SubConn = p.leader + } else if strings.Contains(info.FullMethodName, "Get") { + result.SubConn = p.nextFollower() + } + if result.SubConn == nil { + return result, balancer.ErrNoSubConnAvailable + } + return result, nil +} + +func (p *Picker) nextFollower() balancer.SubConn { + cur := atomic.AddUint64(&p.current, uint64(1)) + len := uint64(len(p.followers)) + idx := int(cur % len) + return p.followers[idx] +} + +func init() { + balancer.Register( + base.NewBalancerBuilder(Name, &Picker{}, base.Config{}), + ) +} diff --git a/internal/loadbalance/picker_test.go b/internal/loadbalance/picker_test.go new file mode 100644 index 0000000..9a8b59f --- /dev/null +++ b/internal/loadbalance/picker_test.go @@ -0,0 +1,73 @@ +package loadbalance_test + +import ( + "testing" + + "github.com/justagabriel/proglog/internal/loadbalance" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/resolver" +) + +func TestPickerNoSubConnAvailable(t *testing.T) { + picker := &loadbalance.Picker{} + for _, method := range []string{ + "/log.vX.Log/Create", + "/log.vC.Log/Get", + } { + info := balancer.PickInfo{ + FullMethodName: method, + } + result, err := picker.Pick(info) + require.Equal(t, balancer.ErrNoSubConnAvailable, err) + require.Nil(t, result.SubConn) + } +} + +type subConn struct { + addrs []resolver.Address +} + +func (sc *subConn) UpdateAddresses(addrs []resolver.Address) { + sc.addrs = addrs +} + +func (s *subConn) Connect() {} +func (s *subConn) Shutdown() {} +func (s *subConn) GetOrBuildProducer(p balancer.ProducerBuilder) (balancer.Producer, func()) { + return nil, func() {} +} + +func setupTest() (*loadbalance.Picker, []*subConn) { + var subConns []*subConn + buildInfo := base.PickerBuildInfo{ + ReadySCs: make(map[balancer.SubConn]base.SubConnInfo), + } + for i := 0; i < 3; i++ { + sc := &subConn{} + addr := resolver.Address{ + Attributes: attributes.New("is_leader", i == 0), + } + // 9th sub conn is the leader + sc.UpdateAddresses([]resolver.Address{addr}) + buildInfo.ReadySCs[sc] = base.SubConnInfo{Address: addr} + subConns = append(subConns, sc) + } + picker := &loadbalance.Picker{} + picker.Build(buildInfo) + return picker, subConns +} + +func TestPickerCreatesToLeader(t *testing.T) { + picker, subConns := setupTest() + info := balancer.PickInfo{ + FullMethodName: "/log.vX.Log/Create", + } + for i := 0; i < 5; i++ { + gotPick, err := picker.Pick(info) + require.NoError(t, err) + require.Equal(t, subConns[0], gotPick.SubConn) + } +}