From e16e7f9453ca4580d0fadb598d03fe884561fb22 Mon Sep 17 00:00:00 2001 From: Christian Haudum Date: Wed, 2 Oct 2024 21:39:04 +0200 Subject: [PATCH] chore: Remove ring client pool from JumpHashClientPool and replace it with a simpler implementation Signed-off-by: Christian Haudum --- pkg/bloomgateway/client.go | 17 ++--- pkg/bloomgateway/client_pool.go | 104 +++++++++++++++++---------- pkg/bloomgateway/client_pool_test.go | 3 +- 3 files changed, 74 insertions(+), 50 deletions(-) diff --git a/pkg/bloomgateway/client.go b/pkg/bloomgateway/client.go index 2529a678e7794..a873d04960b47 100644 --- a/pkg/bloomgateway/client.go +++ b/pkg/bloomgateway/client.go @@ -161,7 +161,7 @@ func NewClient( } } - poolFactory := func(addr string) (ringclient.PoolClient, error) { + clientFactory := func(addr string) (ringclient.PoolClient, error) { pool, err := NewBloomGatewayGRPCPool(addr, dialOpts) if err != nil { return nil, errors.Wrap(err, "new bloom gateway grpc pool") @@ -185,17 +185,10 @@ func NewClient( // Make an attempt to do one DNS lookup so we can start with addresses dnsProvider.RunOnce() - clientPool := ringclient.NewPool( - "bloom-gateway", - ringclient.PoolConfig(cfg.PoolConfig), - func() ([]string, error) { return dnsProvider.Addresses(), nil }, - ringclient.PoolAddrFunc(poolFactory), - metrics.clients, - logger, - ) - - pool := NewJumpHashClientPool(clientPool, dnsProvider, cfg.PoolConfig.CheckInterval, logger) - pool.Start() + pool, err := NewJumpHashClientPool(clientFactory, dnsProvider, cfg.PoolConfig.CheckInterval, logger) + if err != nil { + return nil, err + } return &GatewayClient{ cfg: cfg, diff --git a/pkg/bloomgateway/client_pool.go b/pkg/bloomgateway/client_pool.go index 989ced34c6730..e5163cb514c8d 100644 --- a/pkg/bloomgateway/client_pool.go +++ b/pkg/bloomgateway/client_pool.go @@ -4,6 +4,7 @@ import ( "context" "flag" "sort" + "sync" "time" "github.com/go-kit/log" @@ -34,19 +35,32 @@ func (cfg *PoolConfig) Validate() error { return nil } +// compiler check +var _ clientPool = &JumpHashClientPool{} + +type ClientFactory func(addr string) (client.PoolClient, error) + +func (f ClientFactory) New(addr string) (client.PoolClient, error) { + return f(addr) +} + type JumpHashClientPool struct { - *client.Pool + services.Service *jumphash.Selector + sync.RWMutex + + provider AddressProvider + logger log.Logger - done chan struct{} - logger log.Logger + clients map[string]client.PoolClient + clientFactory ClientFactory } type AddressProvider interface { Addresses() []string } -func NewJumpHashClientPool(pool *client.Pool, dnsProvider AddressProvider, updateInterval time.Duration, logger log.Logger) *JumpHashClientPool { +func NewJumpHashClientPool(clientFactory ClientFactory, dnsProvider AddressProvider, updateInterval time.Duration, logger log.Logger) (*JumpHashClientPool, error) { selector := jumphash.DefaultSelector() err := selector.SetServers(dnsProvider.Addresses()...) if err != nil { @@ -54,14 +68,18 @@ func NewJumpHashClientPool(pool *client.Pool, dnsProvider AddressProvider, updat } p := &JumpHashClientPool{ - Pool: pool, - Selector: selector, - done: make(chan struct{}), - logger: logger, + Selector: selector, + clientFactory: clientFactory, + provider: dnsProvider, + logger: logger, } - go p.updateLoop(dnsProvider, updateInterval) - return p + p.Service = services.NewTimerService(updateInterval, nil, p.updateLoop, nil) + return p, services.StartAndAwaitRunning(context.Background(), p.Service) +} + +func (p *JumpHashClientPool) Stop() { + _ = services.StopAndAwaitTerminated(context.Background(), p.Service) } func (p *JumpHashClientPool) AddrForFingerprint(fp uint64) (string, error) { @@ -80,35 +98,47 @@ func (p *JumpHashClientPool) Addr(key string) (string, error) { return addr.String(), nil } -func (p *JumpHashClientPool) Start() { - ctx := context.Background() - _ = services.StartAndAwaitRunning(ctx, p.Pool) +func (p *JumpHashClientPool) updateLoop(_ context.Context) error { + servers := p.provider.Addresses() + // ServerList deterministically maps keys to _index_ of the server list. + // Since DNS returns records in different order each time, we sort to + // guarantee best possible match between nodes. + sort.Strings(servers) + err := p.SetServers(servers...) + if err != nil { + level.Warn(p.logger).Log("msg", "error updating servers", "err", err) + } + return nil } -func (p *JumpHashClientPool) Stop() { - ctx := context.Background() - _ = services.StopAndAwaitTerminated(ctx, p.Pool) - close(p.done) -} +// GetClientFor implements clientPool. +func (p *JumpHashClientPool) GetClientFor(addr string) (client.PoolClient, error) { + client, ok := p.fromCache(addr) + if ok { + return client, nil + } + + // No client in cache so create one + p.Lock() + defer p.Unlock() -func (p *JumpHashClientPool) updateLoop(provider AddressProvider, updateInterval time.Duration) { - ticker := time.NewTicker(updateInterval) - defer ticker.Stop() - - for { - select { - case <-p.done: - return - case <-ticker.C: - servers := provider.Addresses() - // ServerList deterministically maps keys to _index_ of the server list. - // Since DNS returns records in different order each time, we sort to - // guarantee best possible match between nodes. - sort.Strings(servers) - err := p.SetServers(servers...) - if err != nil { - level.Warn(p.logger).Log("msg", "error updating servers", "err", err) - } - } + // Check if a client has been created just after checking the cache and before acquiring the lock. + client, ok = p.clients[addr] + if ok { + return client, nil } + + client, err := p.clientFactory.New(addr) + if err != nil { + return nil, err + } + p.clients[addr] = client + return client, nil +} + +func (p *JumpHashClientPool) fromCache(addr string) (client.PoolClient, bool) { + p.RLock() + defer p.RUnlock() + client, ok := p.clients[addr] + return client, ok } diff --git a/pkg/bloomgateway/client_pool_test.go b/pkg/bloomgateway/client_pool_test.go index 5e3792861f4c3..a592bf2417866 100644 --- a/pkg/bloomgateway/client_pool_test.go +++ b/pkg/bloomgateway/client_pool_test.go @@ -31,7 +31,8 @@ func TestJumpHashClientPool_UpdateLoop(t *testing.T) { provider := &provider{} provider.UpdateAddresses([]string{"localhost:9095"}) - pool := NewJumpHashClientPool(nil, provider, interval, log.NewNopLogger()) + pool, err := NewJumpHashClientPool(nil, provider, interval, log.NewNopLogger()) + require.NoError(t, err) require.Len(t, pool.Addrs(), 1) require.Equal(t, "127.0.0.1:9095", pool.Addrs()[0].String())