Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest-limits): Request stream usage for owned partitions only #16136

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/distributor/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ func (d *Distributor) sendStreamToKafka(ctx context.Context, stream KeyedStream,
}

// Add metadata record
metadataRecord := kafka.EncodeStreamMetadata(partitionID, d.cfg.KafkaConfig.Topic, tenant, stream.HashNoShard)
metadataRecord := kafka.EncodeStreamMetadata(partitionID, d.cfg.KafkaConfig.Topic, tenant, stream.HashNoShard, stream.RingToken)
records = append(records, metadataRecord)

d.kafkaRecordsPerRequest.Observe(float64(len(records)))
Expand Down
8 changes: 6 additions & 2 deletions pkg/kafka/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,23 @@ func sovPush(x uint64) (n int) {

// EncodeStreamMetadata encodes the stream metadata into a Kafka record
// using the tenantID as the key and partition as the target partition
func EncodeStreamMetadata(partition int32, topic string, tenantID string, streamHash uint64) *kgo.Record {
func EncodeStreamMetadata(partition int32, topic string, tenantID string, streamHash uint64, ringToken uint32) *kgo.Record {
// Validate stream hash
if streamHash == 0 {
return nil
}

if ringToken == 0 {
return nil
}

// Get metadata from pool
metadata := metadataPool.Get().(*logproto.StreamMetadata)
defer metadataPool.Put(metadata)

// Set stream hash
metadata.StreamHash = streamHash

metadata.RingToken = ringToken
// Encode the metadata into a byte slice
value, err := metadata.Marshal()
if err != nil {
Expand Down
14 changes: 13 additions & 1 deletion pkg/kafka/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ func TestEncodeDecodeStreamMetadata(t *testing.T) {
tests := []struct {
name string
hash uint64
ringToken uint32
partition int32
topic string
tenantID string
Expand All @@ -163,6 +164,7 @@ func TestEncodeDecodeStreamMetadata(t *testing.T) {
{
name: "Valid metadata",
hash: 12345,
ringToken: 1,
partition: 1,
topic: "logs",
tenantID: "tenant-1",
Expand All @@ -171,17 +173,27 @@ func TestEncodeDecodeStreamMetadata(t *testing.T) {
{
name: "Zero hash - should error",
hash: 0,
ringToken: 1,
partition: 3,
topic: "traces",
tenantID: "tenant-3",
expectErr: true,
},
{
name: "Zero ring token - should error",
hash: 12345,
ringToken: 0,
partition: 1,
topic: "logs",
tenantID: "tenant-1",
expectErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Encode metadata
record := EncodeStreamMetadata(tt.partition, tt.topic, tt.tenantID, tt.hash)
record := EncodeStreamMetadata(tt.partition, tt.topic, tt.tenantID, tt.hash, tt.ringToken)
if tt.expectErr {
require.Nil(t, record)
return
Expand Down
155 changes: 109 additions & 46 deletions pkg/limits/frontend/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package frontend

import (
"context"
"fmt"
"sort"
"strings"

"github.com/go-kit/log"
"github.com/grafana/dskit/ring"
Expand Down Expand Up @@ -42,10 +45,9 @@ var (
)

type metrics struct {
tenantExceedsLimits *prometheus.CounterVec
tenantActiveStreams *prometheus.GaugeVec
tenantDuplicateStreamsFound *prometheus.CounterVec
tenantRejectedStreams *prometheus.CounterVec
tenantExceedsLimits *prometheus.CounterVec
tenantActiveStreams *prometheus.GaugeVec
tenantRejectedStreams *prometheus.CounterVec
}

func newMetrics(reg prometheus.Registerer) *metrics {
Expand All @@ -60,11 +62,6 @@ func newMetrics(reg prometheus.Registerer) *metrics {
Name: "ingest_limits_frontend_streams_active",
Help: "The current number of active streams (seen within the window) per tenant.",
}, []string{"tenant"}),
tenantDuplicateStreamsFound: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{
Namespace: constants.Loki,
Name: "ingest_limits_frontend_streams_duplicate_total",
Help: "The total number of duplicate streams found per tenant.",
}, []string{"tenant"}),
tenantRejectedStreams: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{
Namespace: constants.Loki,
Name: "ingest_limits_frontend_streams_rejected_total",
Expand All @@ -73,7 +70,7 @@ func newMetrics(reg prometheus.Registerer) *metrics {
}
}

type ringFunc func(context.Context, logproto.IngestLimitsClient) (*logproto.GetStreamUsageResponse, error)
type ringGetUsageFunc func(context.Context, logproto.IngestLimitsClient, []int32) (*logproto.GetStreamUsageResponse, error)

// RingIngestLimitsService is an IngestLimitsService that uses the ring to read the responses
// from all limits backends.
Expand All @@ -99,41 +96,115 @@ func NewRingIngestLimitsService(ring ring.ReadRing, pool *ring_client.Pool, limi
}
}

func (s *RingIngestLimitsService) forAllBackends(ctx context.Context, f ringFunc) ([]Response, error) {
func (s *RingIngestLimitsService) forAllBackends(ctx context.Context, f ringGetUsageFunc) ([]GetStreamUsageResponse, error) {
replicaSet, err := s.ring.GetAllHealthy(LimitsRead)
if err != nil {
return nil, err
}
return s.forGivenReplicaSet(ctx, replicaSet, f)
}

func (s *RingIngestLimitsService) forGivenReplicaSet(ctx context.Context, replicaSet ring.ReplicationSet, f ringFunc) ([]Response, error) {
func (s *RingIngestLimitsService) forGivenReplicaSet(ctx context.Context, replicaSet ring.ReplicationSet, f ringGetUsageFunc) ([]GetStreamUsageResponse, error) {
partitions, err := s.perReplicaSetPartitions(ctx, replicaSet)
if err != nil {
return nil, err
}

g, ctx := errgroup.WithContext(ctx)
responses := make([]Response, len(replicaSet.Instances))
responses := make([]GetStreamUsageResponse, len(replicaSet.Instances))

for i, instance := range replicaSet.Instances {
g.Go(func() error {
client, err := s.pool.GetClientFor(instance.Addr)
if err != nil {
return err
}
resp, err := f(ctx, client.(logproto.IngestLimitsClient))

var partitionStr strings.Builder
for _, partition := range partitions[instance.Addr] {
partitionStr.WriteString(fmt.Sprintf("%d,", partition))
}

resp, err := f(ctx, client.(logproto.IngestLimitsClient), partitions[instance.Addr])
if err != nil {
return err
}
responses[i] = Response{Addr: instance.Addr, Response: resp}
responses[i] = GetStreamUsageResponse{Addr: instance.Addr, Response: resp}
return nil
})
}

if err := g.Wait(); err != nil {
return nil, err
}

return responses, nil
}

func (s *RingIngestLimitsService) perReplicaSetPartitions(ctx context.Context, replicaSet ring.ReplicationSet) (map[string][]int32, error) {
g, ctx := errgroup.WithContext(ctx)
responses := make([]GetAssignedPartitionsResponse, len(replicaSet.Instances))
for i, instance := range replicaSet.Instances {
g.Go(func() error {
client, err := s.pool.GetClientFor(instance.Addr)
if err != nil {
return err
}
resp, err := client.(logproto.IngestLimitsClient).GetAssignedPartitions(ctx, &logproto.GetAssignedPartitionsRequest{})
if err != nil {
return err
}
responses[i] = GetAssignedPartitionsResponse{Addr: instance.Addr, Response: resp}
return nil
})
}

if err := g.Wait(); err != nil {
return nil, err
}

partitions := make(map[string][]int32)
// Track highest value seen for each partition
highestValues := make(map[int32]int64)
// Track which addr has the highest value for each partition
highestAddr := make(map[int32]string)

// First pass - find highest values for each partition
for _, resp := range responses {
for partition, value := range resp.Response.AssignedPartitions {
if currentHighest, exists := highestValues[partition]; !exists || value > currentHighest {
highestValues[partition] = value
highestAddr[partition] = resp.Addr
}
}
}

// Second pass - assign partitions to addrs that have the highest values
for partition, addr := range highestAddr {
partitions[addr] = append(partitions[addr], partition)
}

// Sort partition IDs for each address for consistent ordering
for addr := range partitions {
sort.Slice(partitions[addr], func(i, j int) bool {
return partitions[addr][i] < partitions[addr][j]
})
}

return partitions, nil
}

func (s *RingIngestLimitsService) ExceedsLimits(ctx context.Context, req *logproto.ExceedsLimitsRequest) (*logproto.ExceedsLimitsResponse, error) {
resps, err := s.forAllBackends(ctx, func(_ context.Context, client logproto.IngestLimitsClient) (*logproto.GetStreamUsageResponse, error) {
reqStreams := make([]uint64, 0, len(req.Streams))
for _, stream := range req.Streams {
reqStreams = append(reqStreams, stream.StreamHash)
}

resps, err := s.forAllBackends(ctx, func(_ context.Context, client logproto.IngestLimitsClient, partitions []int32) (*logproto.GetStreamUsageResponse, error) {
return client.GetStreamUsage(ctx, &logproto.GetStreamUsageRequest{
Tenant: req.Tenant,
Tenant: req.Tenant,
Partitions: partitions,
StreamHashes: reqStreams,
})
})
if err != nil {
Expand All @@ -142,28 +213,9 @@ func (s *RingIngestLimitsService) ExceedsLimits(ctx context.Context, req *logpro

maxGlobalStreams := s.limits.MaxGlobalStreamsPerUser(req.Tenant)

var (
activeStreamsTotal uint64
uniqueStreamHashes = make(map[uint64]bool)
)
var activeStreamsTotal uint64
for _, resp := range resps {
var duplicates uint64
// Record the unique stream hashes
// and count duplicates active streams
// to be subtracted from the total
for _, stream := range resp.Response.RecordedStreams {
if uniqueStreamHashes[stream.StreamHash] {
duplicates++
continue
}
uniqueStreamHashes[stream.StreamHash] = true
}

activeStreamsTotal += resp.Response.ActiveStreams - duplicates

if duplicates > 0 {
s.metrics.tenantDuplicateStreamsFound.WithLabelValues(req.Tenant).Inc()
}
activeStreamsTotal += resp.Response.ActiveStreams
}

s.metrics.tenantActiveStreams.WithLabelValues(req.Tenant).Set(float64(activeStreamsTotal))
Expand All @@ -174,13 +226,19 @@ func (s *RingIngestLimitsService) ExceedsLimits(ctx context.Context, req *logpro
}, nil
}

var rejectedStreams []*logproto.RejectedStream
for _, stream := range req.Streams {
if !uniqueStreamHashes[stream.StreamHash] {
rejectedStreams = append(rejectedStreams, &logproto.RejectedStream{
StreamHash: stream.StreamHash,
Reason: RejectedStreamReasonExceedsGlobalLimit,
})
var (
rejectedStreams []*logproto.RejectedStream
uniqueStreamHashes = make(map[uint64]bool)
)
for _, resp := range resps {
for _, unknownStream := range resp.Response.UnknownStreams {
if !uniqueStreamHashes[unknownStream] {
uniqueStreamHashes[unknownStream] = true
rejectedStreams = append(rejectedStreams, &logproto.RejectedStream{
StreamHash: unknownStream,
Reason: RejectedStreamReasonExceedsGlobalLimit,
})
}
}
}

Expand All @@ -195,7 +253,12 @@ func (s *RingIngestLimitsService) ExceedsLimits(ctx context.Context, req *logpro
}, nil
}

type Response struct {
type GetStreamUsageResponse struct {
Addr string
Response *logproto.GetStreamUsageResponse
}

type GetAssignedPartitionsResponse struct {
Addr string
Response *logproto.GetAssignedPartitionsResponse
}
Loading