From c2e1e88e850cffcb9c34cd7c6296326a5cc8a15a Mon Sep 17 00:00:00 2001 From: Salva Corts Date: Wed, 5 Feb 2025 14:06:06 +0100 Subject: [PATCH] feat(policies): Add PoliciesStreamMapping to loghttp limits interface (#16105) --- clients/pkg/promtail/targets/lokipush/pushtarget.go | 2 +- pkg/distributor/distributor.go | 7 +++++++ pkg/distributor/http.go | 2 +- pkg/distributor/http_test.go | 1 + pkg/loghttp/push/otlp.go | 2 +- pkg/loghttp/push/push.go | 9 +++++---- pkg/loghttp/push/push_test.go | 9 +++++---- 7 files changed, 21 insertions(+), 11 deletions(-) diff --git a/clients/pkg/promtail/targets/lokipush/pushtarget.go b/clients/pkg/promtail/targets/lokipush/pushtarget.go index f6e33eb8f72d9..e1ebafc1bab2e 100644 --- a/clients/pkg/promtail/targets/lokipush/pushtarget.go +++ b/clients/pkg/promtail/targets/lokipush/pushtarget.go @@ -111,7 +111,7 @@ func (t *PushTarget) run() error { func (t *PushTarget) handleLoki(w http.ResponseWriter, r *http.Request) { logger := util_log.WithContext(r.Context(), util_log.Logger) userID, _ := tenant.TenantID(r.Context()) - req, err := push.ParseRequest(logger, userID, r, nil, push.EmptyLimits{}, push.ParseLokiRequest, nil, false) + req, err := push.ParseRequest(logger, userID, r, nil, push.EmptyLimits{}, push.ParseLokiRequest, nil, nil, false) if err != nil { level.Warn(t.logger).Log("msg", "failed to parse incoming push request", "err", err.Error()) http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/distributor/distributor.go b/pkg/distributor/distributor.go index 8c44a66832c44..40bc4f3b5b036 100644 --- a/pkg/distributor/distributor.go +++ b/pkg/distributor/distributor.go @@ -180,6 +180,7 @@ type Distributor struct { streamShardCount prometheus.Counter tenantPushSanitizedStructuredMetadata *prometheus.CounterVec + policyResolver push.PolicyResolver usageTracker push.UsageTracker ingesterTasks chan pushIngesterTask ingesterTaskWg sync.WaitGroup @@ -223,6 +224,11 @@ func New( return client.New(internalCfg, addr) } + policyResolver := push.PolicyResolver(func(userID string, lbs labels.Labels) string { + mappings := overrides.PoliciesStreamMapping(userID) + return mappings.PolicyFor(lbs) + }) + validator, err := NewValidator(overrides, usageTracker) if err != nil { return nil, err @@ -280,6 +286,7 @@ func New( healthyInstancesCount: atomic.NewUint32(0), rateLimitStrat: rateLimitStrat, tee: tee, + policyResolver: policyResolver, usageTracker: usageTracker, ingesterTasks: make(chan pushIngesterTask), ingesterAppends: promauto.With(registerer).NewCounterVec(prometheus.CounterOpts{ diff --git a/pkg/distributor/http.go b/pkg/distributor/http.go index 1b0cee2a9c62a..c6c87dbc74454 100644 --- a/pkg/distributor/http.go +++ b/pkg/distributor/http.go @@ -41,7 +41,7 @@ func (d *Distributor) pushHandler(w http.ResponseWriter, r *http.Request, pushRe } logPushRequestStreams := d.tenantConfigs.LogPushRequestStreams(tenantID) - req, err := push.ParseRequest(logger, tenantID, r, d.tenantsRetention, d.validator.Limits, pushRequestParser, d.usageTracker, logPushRequestStreams) + req, err := push.ParseRequest(logger, tenantID, r, d.tenantsRetention, d.validator.Limits, pushRequestParser, d.usageTracker, d.policyResolver, logPushRequestStreams) if err != nil { if !errors.Is(err, push.ErrAllLogsFiltered) { if d.tenantConfigs.LogPushRequest(tenantID) { diff --git a/pkg/distributor/http_test.go b/pkg/distributor/http_test.go index 7e1ee788994c4..a73a73fa5e2ab 100644 --- a/pkg/distributor/http_test.go +++ b/pkg/distributor/http_test.go @@ -128,6 +128,7 @@ func (p *fakeParser) parseRequest( _ push.TenantsRetention, _ push.Limits, _ push.UsageTracker, + _ push.PolicyResolver, _ bool, _ log.Logger, ) (*logproto.PushRequest, *push.Stats, error) { diff --git a/pkg/loghttp/push/otlp.go b/pkg/loghttp/push/otlp.go index dbb4ec8349e63..55e5b59174868 100644 --- a/pkg/loghttp/push/otlp.go +++ b/pkg/loghttp/push/otlp.go @@ -43,7 +43,7 @@ func newPushStats() *Stats { } } -func ParseOTLPRequest(userID string, r *http.Request, tenantsRetention TenantsRetention, limits Limits, tracker UsageTracker, logPushRequestStreams bool, logger log.Logger) (*logproto.PushRequest, *Stats, error) { +func ParseOTLPRequest(userID string, r *http.Request, tenantsRetention TenantsRetention, limits Limits, tracker UsageTracker, _ PolicyResolver, logPushRequestStreams bool, logger log.Logger) (*logproto.PushRequest, *Stats, error) { stats := newPushStats() otlpLogs, err := extractLogs(r, stats) if err != nil { diff --git a/pkg/loghttp/push/push.go b/pkg/loghttp/push/push.go index 37938fe2a8e89..bc0d7aa8f4112 100644 --- a/pkg/loghttp/push/push.go +++ b/pkg/loghttp/push/push.go @@ -90,9 +90,10 @@ func (EmptyLimits) DiscoverServiceName(string) []string { } type ( - RequestParser func(userID string, r *http.Request, tenantsRetention TenantsRetention, limits Limits, tracker UsageTracker, logPushRequestStreams bool, logger log.Logger) (*logproto.PushRequest, *Stats, error) + RequestParser func(userID string, r *http.Request, tenantsRetention TenantsRetention, limits Limits, tracker UsageTracker, policyResolver PolicyResolver, logPushRequestStreams bool, logger log.Logger) (*logproto.PushRequest, *Stats, error) RequestParserWrapper func(inner RequestParser) RequestParser ErrorWriter func(w http.ResponseWriter, error string, code int, logger log.Logger) + PolicyResolver func(userID string, lbs labels.Labels) string ) type Stats struct { @@ -113,8 +114,8 @@ type Stats struct { IsAggregatedMetric bool } -func ParseRequest(logger log.Logger, userID string, r *http.Request, tenantsRetention TenantsRetention, limits Limits, pushRequestParser RequestParser, tracker UsageTracker, logPushRequestStreams bool) (*logproto.PushRequest, error) { - req, pushStats, err := pushRequestParser(userID, r, tenantsRetention, limits, tracker, logPushRequestStreams, logger) +func ParseRequest(logger log.Logger, userID string, r *http.Request, tenantsRetention TenantsRetention, limits Limits, pushRequestParser RequestParser, tracker UsageTracker, policyResolver PolicyResolver, logPushRequestStreams bool) (*logproto.PushRequest, error) { + req, pushStats, err := pushRequestParser(userID, r, tenantsRetention, limits, tracker, policyResolver, logPushRequestStreams, logger) if err != nil && !errors.Is(err, ErrAllLogsFiltered) { return nil, err } @@ -171,7 +172,7 @@ func ParseRequest(logger log.Logger, userID string, r *http.Request, tenantsRete return req, err } -func ParseLokiRequest(userID string, r *http.Request, tenantsRetention TenantsRetention, limits Limits, tracker UsageTracker, logPushRequestStreams bool, logger log.Logger) (*logproto.PushRequest, *Stats, error) { +func ParseLokiRequest(userID string, r *http.Request, tenantsRetention TenantsRetention, limits Limits, tracker UsageTracker, _ PolicyResolver, logPushRequestStreams bool, logger log.Logger) (*logproto.PushRequest, *Stats, error) { // Body var body io.Reader // bodySize should always reflect the compressed size of the request body diff --git a/pkg/loghttp/push/push_test.go b/pkg/loghttp/push/push_test.go index 54618eb3480cc..2c6c1cefe31b6 100644 --- a/pkg/loghttp/push/push_test.go +++ b/pkg/loghttp/push/push_test.go @@ -270,6 +270,7 @@ func TestParseRequest(t *testing.T) { &fakeLimits{enabled: test.enableServiceDiscovery}, ParseLokiRequest, tracker, + nil, false, ) @@ -364,7 +365,7 @@ func Test_ServiceDetection(t *testing.T) { request := createRequest("/loki/api/v1/push", strings.NewReader(body)) limits := &fakeLimits{enabled: true, labels: []string{"foo"}} - data, err := ParseRequest(util_log.Logger, "fake", request, nil, limits, ParseLokiRequest, tracker, false) + data, err := ParseRequest(util_log.Logger, "fake", request, nil, limits, ParseLokiRequest, tracker, nil, false) require.NoError(t, err) require.Equal(t, labels.FromStrings("foo", "bar", LabelServiceName, "bar").String(), data.Streams[0].Labels) @@ -375,7 +376,7 @@ func Test_ServiceDetection(t *testing.T) { request := createRequest("/otlp/v1/push", bytes.NewReader(body)) limits := &fakeLimits{enabled: true} - data, err := ParseRequest(util_log.Logger, "fake", request, limits, limits, ParseOTLPRequest, tracker, false) + data, err := ParseRequest(util_log.Logger, "fake", request, limits, limits, ParseOTLPRequest, tracker, nil, false) require.NoError(t, err) require.Equal(t, labels.FromStrings("k8s_job_name", "bar", LabelServiceName, "bar").String(), data.Streams[0].Labels) }) @@ -389,7 +390,7 @@ func Test_ServiceDetection(t *testing.T) { labels: []string{"special"}, indexAttributes: []string{"special"}, } - data, err := ParseRequest(util_log.Logger, "fake", request, limits, limits, ParseOTLPRequest, tracker, false) + data, err := ParseRequest(util_log.Logger, "fake", request, limits, limits, ParseOTLPRequest, tracker, nil, false) require.NoError(t, err) require.Equal(t, labels.FromStrings("special", "sauce", LabelServiceName, "sauce").String(), data.Streams[0].Labels) }) @@ -403,7 +404,7 @@ func Test_ServiceDetection(t *testing.T) { labels: []string{"special"}, indexAttributes: []string{}, } - data, err := ParseRequest(util_log.Logger, "fake", request, limits, limits, ParseOTLPRequest, tracker, false) + data, err := ParseRequest(util_log.Logger, "fake", request, limits, limits, ParseOTLPRequest, tracker, nil, false) require.NoError(t, err) require.Equal(t, labels.FromStrings(LabelServiceName, ServiceUnknown).String(), data.Streams[0].Labels) })