From 4a993f33369fb7e5bdb9658934745113e1d89b12 Mon Sep 17 00:00:00 2001 From: George Robinson Date: Thu, 10 Apr 2025 21:21:11 +0100 Subject: [PATCH] feat: add tests for enforcing limits in distributors --- pkg/distributor/distributor.go | 20 ++- pkg/distributor/distributor_test.go | 235 ++++++++++++++++++++++++++ pkg/distributor/ingest_limits.go | 21 +++ pkg/distributor/ingest_limits_test.go | 3 + 4 files changed, 271 insertions(+), 8 deletions(-) diff --git a/pkg/distributor/distributor.go b/pkg/distributor/distributor.go index 7b658142af0d0..4309fe42dfab3 100644 --- a/pkg/distributor/distributor.go +++ b/pkg/distributor/distributor.go @@ -722,16 +722,20 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe } if d.cfg.IngestLimitsEnabled { - exceedsLimits, _, err := d.ingestLimits.exceedsLimits(ctx, tenantID, streams) + streamsAfterLimits, reasonsForHashes, err := d.ingestLimits.enforceLimits(ctx, tenantID, streams) if err != nil { level.Error(d.logger).Log("msg", "failed to check if request exceeds limits, request has been accepted", "err", err) - } else if exceedsLimits { - if d.cfg.IngestLimitsDryRunEnabled { - level.Debug(d.logger).Log("msg", "request exceeded limits", "tenant", tenantID) - } else { - // TODO(grobinson): This will be removed, as we only want to fail the request - // when specific limits are exceeded. - return nil, httpgrpc.Error(http.StatusBadRequest, "request exceeded limits") + } else if len(streamsAfterLimits) == 0 { + // All streams have been dropped. + level.Debug(d.logger).Log("msg", "request exceeded limits, all streams will be dropped", "tenant", tenantID) + if !d.cfg.IngestLimitsDryRunEnabled { + return nil, httpgrpc.Error(http.StatusTooManyRequests, "request exceeded limits: "+firstReasonForHashes(reasonsForHashes)) + } + } else if len(streamsAfterLimits) < len(streams) { + // Some streams have been dropped. + level.Debug(d.logger).Log("msg", "request exceeded limits, some streams will be dropped", "tenant", tenantID) + if !d.cfg.IngestLimitsDryRunEnabled { + streams = streamsAfterLimits } } } diff --git a/pkg/distributor/distributor_test.go b/pkg/distributor/distributor_test.go index 028287c1aa8dd..e52c67adaae6c 100644 --- a/pkg/distributor/distributor_test.go +++ b/pkg/distributor/distributor_test.go @@ -2,6 +2,7 @@ package distributor import ( "context" + "errors" "fmt" "math" "math/rand" @@ -2381,3 +2382,237 @@ func TestRequestScopedStreamResolver(t *testing.T) { policy = newResolver.PolicyFor(labels.FromStrings("env", "dev")) require.Equal(t, "policy1", policy) } + +func TestDistributor_PushIngestLimits(t *testing.T) { + tests := []struct { + name string + ingestLimitsEnabled bool + ingestLimitsDryRunEnabled bool + tenant string + streams logproto.PushRequest + expectedLimitsCalls uint64 + expectedLimitsRequest *logproto.ExceedsLimitsRequest + limitsResponse *logproto.ExceedsLimitsResponse + limitsResponseErr error + expectedErr string + }{{ + name: "limits are not checked when disabled", + ingestLimitsEnabled: false, + tenant: "test", + streams: logproto.PushRequest{ + Streams: []logproto.Stream{{ + Labels: "{foo=\"bar\"}", + }}, + }, + expectedLimitsCalls: 0, + }, { + name: "limits are checked", + ingestLimitsEnabled: true, + tenant: "test", + streams: logproto.PushRequest{ + Streams: []logproto.Stream{{ + Labels: "{foo=\"bar\"}", + Entries: []logproto.Entry{{ + Timestamp: time.Now(), + Line: "baz", + }}, + }}, + }, + expectedLimitsCalls: 1, + expectedLimitsRequest: &logproto.ExceedsLimitsRequest{ + Tenant: "test", + Streams: []*logproto.StreamMetadata{{ + StreamHash: 0x90eb45def17f924, + EntriesSize: 0x3, + StructuredMetadataSize: 0x0, + }}, + }, + limitsResponse: &logproto.ExceedsLimitsResponse{ + Tenant: "test", + Results: []*logproto.ExceedsLimitsResult{}, + }, + }, { + name: "max stream limit is exceeded", + ingestLimitsEnabled: true, + tenant: "test", + streams: logproto.PushRequest{ + Streams: []logproto.Stream{{ + Labels: "{foo=\"bar\"}", + Entries: []logproto.Entry{{ + Timestamp: time.Now(), + Line: "baz", + }}, + }}, + }, + expectedLimitsCalls: 1, + expectedLimitsRequest: &logproto.ExceedsLimitsRequest{ + Tenant: "test", + Streams: []*logproto.StreamMetadata{{ + StreamHash: 0x90eb45def17f924, + EntriesSize: 0x3, + StructuredMetadataSize: 0x0, + }}, + }, + limitsResponse: &logproto.ExceedsLimitsResponse{ + Tenant: "test", + Results: []*logproto.ExceedsLimitsResult{{ + StreamHash: 0x90eb45def17f924, + Reason: limits_frontend.ReasonExceedsMaxStreams, + }}, + }, + expectedErr: "rpc error: code = Code(429) desc = request exceeded limits: max streams exceeded", + }, { + name: "rate limit is exceeded", + ingestLimitsEnabled: true, + tenant: "test", + streams: logproto.PushRequest{ + Streams: []logproto.Stream{{ + Labels: "{foo=\"bar\"}", + Entries: []logproto.Entry{{ + Timestamp: time.Now(), + Line: "baz", + }}, + }}, + }, + expectedLimitsCalls: 1, + expectedLimitsRequest: &logproto.ExceedsLimitsRequest{ + Tenant: "test", + Streams: []*logproto.StreamMetadata{{ + StreamHash: 0x90eb45def17f924, + EntriesSize: 0x3, + StructuredMetadataSize: 0x0, + }}, + }, + limitsResponse: &logproto.ExceedsLimitsResponse{ + Tenant: "test", + Results: []*logproto.ExceedsLimitsResult{{ + StreamHash: 0x90eb45def17f924, + Reason: limits_frontend.ReasonExceedsRateLimit, + }}, + }, + expectedErr: "rpc error: code = Code(429) desc = request exceeded limits: rate limit exceeded", + }, { + name: "one of two streams exceed max stream limit, request is accepted", + ingestLimitsEnabled: true, + tenant: "test", + streams: logproto.PushRequest{ + Streams: []logproto.Stream{{ + Labels: "{foo=\"bar\"}", + Entries: []logproto.Entry{{ + Timestamp: time.Now(), + Line: "baz", + }}, + }, { + Labels: "{bar=\"baz\"}", + Entries: []logproto.Entry{{ + Timestamp: time.Now(), + Line: "qux", + }}, + }}, + }, + expectedLimitsCalls: 1, + expectedLimitsRequest: &logproto.ExceedsLimitsRequest{ + Tenant: "test", + Streams: []*logproto.StreamMetadata{{ + StreamHash: 0x90eb45def17f924, + EntriesSize: 0x3, + StructuredMetadataSize: 0x0, + }, { + StreamHash: 0x11561609feba8cf6, + EntriesSize: 0x3, + StructuredMetadataSize: 0x0, + }}, + }, + limitsResponse: &logproto.ExceedsLimitsResponse{ + Tenant: "test", + Results: []*logproto.ExceedsLimitsResult{{ + StreamHash: 1, + Reason: limits_frontend.ReasonExceedsMaxStreams, + }}, + }, + }, { + name: "dry-run does not enforce limits", + ingestLimitsEnabled: true, + ingestLimitsDryRunEnabled: true, + tenant: "test", + streams: logproto.PushRequest{ + Streams: []logproto.Stream{{ + Labels: "{foo=\"bar\"}", + Entries: []logproto.Entry{{ + Timestamp: time.Now(), + Line: "baz", + }}, + }}, + }, + expectedLimitsCalls: 1, + expectedLimitsRequest: &logproto.ExceedsLimitsRequest{ + Tenant: "test", + Streams: []*logproto.StreamMetadata{{ + StreamHash: 0x90eb45def17f924, + EntriesSize: 0x3, + StructuredMetadataSize: 0x0, + }}, + }, + limitsResponse: &logproto.ExceedsLimitsResponse{ + Tenant: "test", + Results: []*logproto.ExceedsLimitsResult{{ + StreamHash: 1, + Reason: limits_frontend.ReasonExceedsMaxStreams, + }}, + }, + }, { + name: "error checking limits", + ingestLimitsEnabled: true, + tenant: "test", + streams: logproto.PushRequest{ + Streams: []logproto.Stream{{ + Labels: "{foo=\"bar\"}", + Entries: []logproto.Entry{{ + Timestamp: time.Now(), + Line: "baz", + }}, + }}, + }, + expectedLimitsCalls: 1, + expectedLimitsRequest: &logproto.ExceedsLimitsRequest{ + Tenant: "test", + Streams: []*logproto.StreamMetadata{{ + StreamHash: 0x90eb45def17f924, + EntriesSize: 0x3, + StructuredMetadataSize: 0x0, + }}, + }, + limitsResponseErr: errors.New("failed to check limits"), + }} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + limits := &validation.Limits{} + flagext.DefaultValues(limits) + distributors, _ := prepare(t, 1, 3, limits, nil) + d := distributors[0] + d.cfg.IngestLimitsEnabled = test.ingestLimitsEnabled + d.cfg.IngestLimitsDryRunEnabled = test.ingestLimitsDryRunEnabled + + mockClient := mockIngestLimitsFrontendClient{ + t: t, + expectedRequest: test.expectedLimitsRequest, + response: test.limitsResponse, + responseErr: test.limitsResponseErr, + } + l := newIngestLimits(&mockClient, prometheus.NewRegistry()) + d.ingestLimits = l + + ctx = user.InjectOrgID(context.Background(), test.tenant) + resp, err := d.Push(ctx, &test.streams) + if test.expectedErr != "" { + require.EqualError(t, err, test.expectedErr) + require.Nil(t, resp) + } else { + require.Nil(t, err) + require.Equal(t, success, resp) + } + require.Equal(t, test.expectedLimitsCalls, mockClient.calls.Load()) + }) + } +} diff --git a/pkg/distributor/ingest_limits.go b/pkg/distributor/ingest_limits.go index 7fd28ba222905..8b55118df8568 100644 --- a/pkg/distributor/ingest_limits.go +++ b/pkg/distributor/ingest_limits.go @@ -11,6 +11,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + limits_frontend "github.com/grafana/loki/v3/pkg/limits/frontend" limits_frontend_client "github.com/grafana/loki/v3/pkg/limits/frontend/client" "github.com/grafana/loki/v3/pkg/logproto" ) @@ -155,3 +156,23 @@ func newExceedsLimitsRequest(tenant string, streams []KeyedStream) (*logproto.Ex Streams: streamMetadata, }, nil } + +func firstReasonForHashes(reasonsForHashes map[uint64][]string) string { + for _, reasons := range reasonsForHashes { + return humanizeReasonForHash(reasons[0]) + } + return "unknown reason" +} + +// TODO(grobinson): Move this to the same place where the consts +// are defined. +func humanizeReasonForHash(s string) string { + switch s { + case limits_frontend.ReasonExceedsMaxStreams: + return "max streams exceeded" + case limits_frontend.ReasonExceedsRateLimit: + return "rate limit exceeded" + default: + return s + } +} diff --git a/pkg/distributor/ingest_limits_test.go b/pkg/distributor/ingest_limits_test.go index f8029ebde393f..00bc138a454e5 100644 --- a/pkg/distributor/ingest_limits_test.go +++ b/pkg/distributor/ingest_limits_test.go @@ -9,6 +9,7 @@ import ( "github.com/coder/quartz" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" + "go.uber.org/atomic" "github.com/grafana/loki/v3/pkg/logproto" ) @@ -16,6 +17,7 @@ import ( // mockIngestLimitsFrontendClient mocks the RPC calls for tests. type mockIngestLimitsFrontendClient struct { t *testing.T + calls atomic.Uint64 expectedRequest *logproto.ExceedsLimitsRequest response *logproto.ExceedsLimitsResponse responseErr error @@ -23,6 +25,7 @@ type mockIngestLimitsFrontendClient struct { // Implements the ingestLimitsFrontendClient interface. func (c *mockIngestLimitsFrontendClient) exceedsLimits(_ context.Context, r *logproto.ExceedsLimitsRequest) (*logproto.ExceedsLimitsResponse, error) { + c.calls.Add(1) require.Equal(c.t, c.expectedRequest, r) if c.responseErr != nil { return nil, c.responseErr