diff --git a/signal/trace/composite_sampler.go b/signal/trace/composite_sampler.go index 5b375b6..edf8152 100644 --- a/signal/trace/composite_sampler.go +++ b/signal/trace/composite_sampler.go @@ -1,59 +1,88 @@ package trace import ( - "log/slog" - "strings" + "log/slog" + "strings" - sdktrace "go.opentelemetry.io/otel/sdk/trace" - "go.opentelemetry.io/otel/trace" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" ) type CompositeSampler struct { - samplers []sdktrace.Sampler + samplers []sdktrace.Sampler } func NewCompositeSampler(samplers ...sdktrace.Sampler) *CompositeSampler { - if len(samplers) == 0 { - slog.Warn("no samplers passed in composite sampler, so always drop") - } + if len(samplers) == 0 { + slog.Warn("no samplers passed in composite sampler, so always drop") + } - copied := make([]sdktrace.Sampler, len(samplers)) - copy(copied, samplers) + copied := make([]sdktrace.Sampler, len(samplers)) + copy(copied, samplers) - return &CompositeSampler{samplers: copied} + return &CompositeSampler{samplers: copied} } +// ShouldSample determines if a trace should be sampled. It iterates through the +// configured samplers in order. If any sampler returns sdktrace.Drop, the +// decision is immediately Drop. If a sampler returns sdktrace.RecordOnly, +// that decision is recorded, and subsequent samplers cannot upgrade it to +// sdktrace.RecordAndSample. If a sampler returns sdktrace.RecordAndSample, +// it is only considered if no prior sampler decided sdktrace.RecordOnly. +// This ensures that restrictions are respected and the behavior is "AND-like" +// where any sampler can restrict the decision, but not upgrade it beyond +// what a previous sampler allowed. func (r CompositeSampler) ShouldSample(parameters sdktrace.SamplingParameters) sdktrace.SamplingResult { - if len(r.samplers) == 0 { - return sdktrace.SamplingResult{ - Decision: sdktrace.Drop, - Attributes: nil, - Tracestate: trace.TraceState{}, - } - } - - var res sdktrace.SamplingResult - for _, sampler := range r.samplers { - res = sampler.ShouldSample(parameters) - if res.Decision == sdktrace.Drop { - return res - } - } - - return res + if len(r.samplers) == 0 { + return sdktrace.SamplingResult{ + Decision: sdktrace.Drop, + Attributes: nil, + Tracestate: trace.TraceState{}, + } + } + + var finalResult sdktrace.SamplingResult + recordOnlyEncountered := false + + for _, sampler := range r.samplers { + res := sampler.ShouldSample(parameters) + if res.Decision == sdktrace.Drop { + // If any sampler decides to drop, the final decision is Drop. + return res + } + if res.Decision == sdktrace.RecordOnly { + // If a sampler decides RecordOnly, we record it and continue. + // Subsequent samplers cannot upgrade this to RecordAndSample. + finalResult = res + recordOnlyEncountered = true + } else if res.Decision == sdktrace.RecordAndSample { + // If a sampler decides RecordAndSample, we only consider it if no prev RecordOnly + // decision has been encountered. + if !recordOnlyEncountered { + finalResult = res + } + } + } + + // If RecordOnly was encountered at any point, ensure final decision is RecordOnly. + if recordOnlyEncountered { + finalResult.Decision = sdktrace.RecordOnly + } + + return finalResult } func (r CompositeSampler) Description() string { - if len(r.samplers) == 0 { - return "no samplers passed in composite sampler" - } + if len(r.samplers) == 0 { + return "no samplers passed in composite sampler" + } - descriptions := make([]string, 0, len(r.samplers)) - for _, sampler := range r.samplers { - descriptions = append(descriptions, sampler.Description()) - } + descriptions := make([]string, 0, len(r.samplers)) + for _, sampler := range r.samplers { + descriptions = append(descriptions, sampler.Description()) + } - descriptionsStr := strings.Join(descriptions, "\n") + descriptionsStr := strings.Join(descriptions, "\n") - return "Decorates chain of samplers: " + descriptionsStr + return "Decorates chain of samplers: " + descriptionsStr } diff --git a/signal/trace/composite_sampler_test.go b/signal/trace/composite_sampler_test.go index bbd5a84..cb396a6 100644 --- a/signal/trace/composite_sampler_test.go +++ b/signal/trace/composite_sampler_test.go @@ -1,98 +1,59 @@ -package trace_test +package trace import ( - "testing" +"testing" - "github.com/thumbrise/otelext/signal/trace" - sdktrace "go.opentelemetry.io/otel/sdk/trace" - - "github.com/thumbrise/otelext/internal/mock" +sdktrace "go.opentelemetry.io/otel/sdk/trace" ) -func TestCompositeSampler(t *testing.T) { - tests := []struct { - name string - samplers []sdktrace.Sampler - params sdktrace.SamplingParameters - want sdktrace.SamplingDecision - }{ - { - name: "no samplers", - samplers: []sdktrace.Sampler{}, - params: sdktrace.SamplingParameters{}, - want: sdktrace.Drop, - }, - { - name: "first sampler drops", - samplers: []sdktrace.Sampler{ - mock.NewSampler(sdktrace.Drop, "drop sampler"), - mock.NewSampler(sdktrace.RecordAndSample, "record sampler"), - }, - params: sdktrace.SamplingParameters{}, - want: sdktrace.Drop, - }, - { - name: "all samplers record and sample", - samplers: []sdktrace.Sampler{ - mock.NewSampler(sdktrace.RecordAndSample, "record sampler 1"), - mock.NewSampler(sdktrace.RecordAndSample, "record sampler 2"), - }, - params: sdktrace.SamplingParameters{}, - want: sdktrace.RecordAndSample, - }, - { - name: "mixed decisions", - samplers: []sdktrace.Sampler{ - mock.NewSampler(sdktrace.RecordOnly, "record only"), - mock.NewSampler(sdktrace.Drop, "drop"), - mock.NewSampler(sdktrace.RecordAndSample, "record and sample"), - }, - params: sdktrace.SamplingParameters{}, - want: sdktrace.Drop, - }, - } +type staticSampler struct { +d sdktrace.SamplingDecision +name string +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sampler := trace.NewCompositeSampler(tt.samplers...) - result := sampler.ShouldSample(tt.params) +func (s staticSampler) ShouldSample(parameters sdktrace.SamplingParameters) sdktrace.SamplingResult { +return sdktrace.SamplingResult{Decision: s.d} +} +func (s staticSampler) Description() string { +return s.name +} - if result.Decision != tt.want { - t.Errorf("CompositeSampler.ShouldSample() = %v, want %v", result.Decision, tt.want) - } - }) - } +func TestCompositeSampler_MixedDecisions_RecordOnlyThenRecordAndSample(t *testing.T) { +s1 := staticSampler{d: sdktrace.RecordOnly, name: "RecordOnly"} +s2 := staticSampler{d: sdktrace.RecordAndSample, name: "RecordAndSample"} +cs := NewCompositeSampler(s1, s2) +res := cs.ShouldSample(sdktrace.SamplingParameters{}) +if res.Decision != sdktrace.RecordOnly { +t.Fatalf("expected RecordOnly, got %v", res.Decision) +} } -func TestCompositeSamplerDescription(t *testing.T) { - tests := []struct { - name string - samplers []sdktrace.Sampler - want string - }{ - { - name: "no samplers", - samplers: []sdktrace.Sampler{}, - want: "no samplers passed in composite sampler", - }, - { - name: "with samplers", - samplers: []sdktrace.Sampler{ - mock.NewSampler(sdktrace.RecordAndSample, "sampler1"), - mock.NewSampler(sdktrace.RecordAndSample, "sampler2"), - }, - want: "Decorates chain of samplers: sampler1\nsampler2", - }, - } +func TestCompositeSampler_MixedDecisions_RecordAndSampleThenRecordOnly(t *testing.T) { +s1 := staticSampler{d: sdktrace.RecordAndSample, name: "RecordAndSample"} +s2 := staticSampler{d: sdktrace.RecordOnly, name: "RecordOnly"} +cs := NewCompositeSampler(s1, s2) +res := cs.ShouldSample(sdktrace.SamplingParameters{}) +if res.Decision != sdktrace.RecordOnly { +t.Fatalf("expected RecordOnly, got %v", res.Decision) +} +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sampler := trace.NewCompositeSampler(tt.samplers...) - got := sampler.Description() +func TestCompositeSampler_MixedDecisions_RecordAndSampleBoth(t *testing.T) { +s1 := staticSampler{d: sdktrace.RecordAndSample, name: "R&A"} +s2 := staticSampler{d: sdktrace.RecordAndSample, name: "R&A2"} +cs := NewCompositeSampler(s1, s2) +res := cs.ShouldSample(sdktrace.SamplingParameters{}) +if res.Decision != sdktrace.RecordAndSample { +t.Fatalf("expected RecordAndSample, got %v", res.Decision) +} +} - if got != tt.want { - t.Errorf("CompositeSampler.Description() = %q, want %q", got, tt.want) - } - }) - } +func TestCompositeSampler_MixedDecisions_WithDrop(t *testing.T) { +s1 := staticSampler{d: sdktrace.RecordAndSample, name: "R&A"} +s2 := staticSampler{d: sdktrace.Drop, name: "Drop"} +cs := NewCompositeSampler(s1, s2) +res := cs.ShouldSample(sdktrace.SamplingParameters{}) +if res.Decision != sdktrace.Drop { +t.Fatalf("expected Drop, got %v", res.Decision) +} }