Skip to content

Commit 6da5045

Browse files
committed
Merge remote-tracking branch 'upstream/main' into rhoai-3.5
2 parents b48664f + fb9b52a commit 6da5045

4 files changed

Lines changed: 208 additions & 0 deletions

File tree

deploy/payload-processing/values.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ upstreamIpp:
2121
fieldName: model
2222
headerName: X-Gateway-Model-Name
2323
- type: model-provider-resolver
24+
- type: stream-usage-enforcer
2425
- type: api-translation
2526
- type: apikey-injection
2627
profiles:
@@ -30,6 +31,7 @@ upstreamIpp:
3031
- pluginRef: maas-headers-guard
3132
- pluginRef: model-extractor
3233
- pluginRef: model-provider-resolver
34+
- pluginRef: stream-usage-enforcer
3335
- pluginRef: api-translation
3436
- pluginRef: apikey-injection
3537
response:

pkg/plugins/plugins.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
maas_headers_guard "github.com/opendatahub-io/ai-gateway-payload-processing/pkg/plugins/maas-headers-guard"
2323
provider_resolver "github.com/opendatahub-io/ai-gateway-payload-processing/pkg/plugins/model-provider-resolver"
2424
"github.com/opendatahub-io/ai-gateway-payload-processing/pkg/plugins/nemo"
25+
stream_usage_enforcer "github.com/opendatahub-io/ai-gateway-payload-processing/pkg/plugins/stream-usage-enforcer"
2526
"github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/plugin"
2627
)
2728

@@ -32,4 +33,5 @@ func RegisterPlugins() {
3233
plugin.Register(apikey_injection.APIKeyInjectionPluginType, apikey_injection.APIKeyInjectionFactory)
3334
plugin.Register(nemo.NemoRequestGuardPluginType, nemo.NemoRequestGuardFactory)
3435
plugin.Register(nemo.NemoResponseGuardPluginType, nemo.NemoResponseGuardFactory)
36+
plugin.Register(stream_usage_enforcer.PluginType, stream_usage_enforcer.Factory)
3537
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
Copyright 2026 The opendatahub.io Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package stream_usage_enforcer
18+
19+
import (
20+
"context"
21+
"encoding/json"
22+
23+
"sigs.k8s.io/controller-runtime/pkg/log"
24+
25+
logutil "github.com/llm-d/llm-d-inference-payload-processor/pkg/common/observability/logging"
26+
"github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/plugin"
27+
"github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/requesthandling"
28+
)
29+
30+
const PluginType = "stream-usage-enforcer"
31+
32+
var _ requesthandling.RequestProcessor = &Plugin{}
33+
34+
type Plugin struct {
35+
typedName plugin.TypedName
36+
}
37+
38+
func Factory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
39+
return (&Plugin{
40+
typedName: plugin.TypedName{Type: PluginType, Name: PluginType},
41+
}).WithName(name), nil
42+
}
43+
44+
func (p *Plugin) WithName(name string) *Plugin {
45+
p.typedName.Name = name
46+
return p
47+
}
48+
49+
func (p *Plugin) TypedName() plugin.TypedName { return p.typedName }
50+
51+
func (p *Plugin) ProcessRequest(ctx context.Context, _ *plugin.CycleState, request *requesthandling.InferenceRequest) error {
52+
logger := log.FromContext(ctx).V(logutil.VERBOSE)
53+
54+
stream, ok := request.Body["stream"].(bool)
55+
if !ok || !stream {
56+
return nil
57+
}
58+
59+
streamOptions, _ := request.Body["stream_options"].(map[string]any)
60+
if streamOptions == nil {
61+
streamOptions = map[string]any{}
62+
}
63+
64+
if existing, ok := streamOptions["include_usage"].(bool); ok && existing {
65+
logger.Info("stream_options.include_usage already set")
66+
return nil
67+
}
68+
69+
streamOptions["include_usage"] = true
70+
request.SetBodyField("stream_options", streamOptions)
71+
72+
logger.Info("enforced stream_options.include_usage=true")
73+
return nil
74+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
Copyright 2026 The opendatahub.io Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package stream_usage_enforcer
18+
19+
import (
20+
"context"
21+
"testing"
22+
23+
"github.com/stretchr/testify/assert"
24+
"github.com/stretchr/testify/require"
25+
26+
"github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/plugin"
27+
"github.com/llm-d/llm-d-inference-payload-processor/pkg/framework/interface/requesthandling"
28+
)
29+
30+
func TestFactory(t *testing.T) {
31+
p, err := Factory("test", nil, nil)
32+
require.NoError(t, err)
33+
assert.Equal(t, "test", p.TypedName().Name)
34+
assert.Equal(t, PluginType, p.TypedName().Type)
35+
}
36+
37+
func TestProcessRequest_NonStreamingRequest(t *testing.T) {
38+
instance, _ := Factory("test", nil, nil)
39+
req := requesthandling.NewInferenceRequest()
40+
req.Body["model"] = "gpt-oss-20b"
41+
req.Body["messages"] = []any{map[string]any{"role": "user", "content": "Hi"}}
42+
43+
err := instance.(*Plugin).ProcessRequest(context.Background(), plugin.NewCycleState(), req)
44+
require.NoError(t, err)
45+
assert.Nil(t, req.Body["stream_options"], "should not add stream_options to non-streaming request")
46+
assert.False(t, req.BodyMutated())
47+
}
48+
49+
func TestProcessRequest_StreamFalse(t *testing.T) {
50+
instance, _ := Factory("test", nil, nil)
51+
req := requesthandling.NewInferenceRequest()
52+
req.Body["stream"] = false
53+
54+
err := instance.(*Plugin).ProcessRequest(context.Background(), plugin.NewCycleState(), req)
55+
require.NoError(t, err)
56+
assert.Nil(t, req.Body["stream_options"])
57+
assert.False(t, req.BodyMutated())
58+
}
59+
60+
func TestProcessRequest_StreamTrueNoStreamOptions(t *testing.T) {
61+
instance, _ := Factory("test", nil, nil)
62+
req := requesthandling.NewInferenceRequest()
63+
req.Body["model"] = "gpt-oss-20b"
64+
req.Body["stream"] = true
65+
req.Body["messages"] = []any{map[string]any{"role": "user", "content": "Hello"}}
66+
67+
err := instance.(*Plugin).ProcessRequest(context.Background(), plugin.NewCycleState(), req)
68+
require.NoError(t, err)
69+
70+
opts, ok := req.Body["stream_options"].(map[string]any)
71+
require.True(t, ok, "stream_options should be set")
72+
assert.Equal(t, true, opts["include_usage"])
73+
assert.True(t, req.BodyMutated())
74+
}
75+
76+
func TestProcessRequest_StreamTrueIncludeUsageFalse(t *testing.T) {
77+
instance, _ := Factory("test", nil, nil)
78+
req := requesthandling.NewInferenceRequest()
79+
req.Body["stream"] = true
80+
req.Body["stream_options"] = map[string]any{"include_usage": false}
81+
82+
err := instance.(*Plugin).ProcessRequest(context.Background(), plugin.NewCycleState(), req)
83+
require.NoError(t, err)
84+
85+
opts := req.Body["stream_options"].(map[string]any)
86+
assert.Equal(t, true, opts["include_usage"], "should override false to true")
87+
assert.True(t, req.BodyMutated())
88+
}
89+
90+
func TestProcessRequest_StreamTrueIncludeUsageAlreadyTrue(t *testing.T) {
91+
instance, _ := Factory("test", nil, nil)
92+
req := requesthandling.NewInferenceRequest()
93+
req.Body["stream"] = true
94+
req.Body["stream_options"] = map[string]any{"include_usage": true}
95+
96+
err := instance.(*Plugin).ProcessRequest(context.Background(), plugin.NewCycleState(), req)
97+
require.NoError(t, err)
98+
99+
opts := req.Body["stream_options"].(map[string]any)
100+
assert.Equal(t, true, opts["include_usage"])
101+
assert.False(t, req.BodyMutated(), "should not mutate when already correct")
102+
}
103+
104+
func TestProcessRequest_PreservesExistingStreamOptions(t *testing.T) {
105+
instance, _ := Factory("test", nil, nil)
106+
req := requesthandling.NewInferenceRequest()
107+
req.Body["stream"] = true
108+
req.Body["stream_options"] = map[string]any{
109+
"continuous_usage_stats": true,
110+
}
111+
112+
err := instance.(*Plugin).ProcessRequest(context.Background(), plugin.NewCycleState(), req)
113+
require.NoError(t, err)
114+
115+
opts := req.Body["stream_options"].(map[string]any)
116+
assert.Equal(t, true, opts["include_usage"], "include_usage should be added")
117+
assert.Equal(t, true, opts["continuous_usage_stats"], "existing fields should be preserved")
118+
assert.True(t, req.BodyMutated())
119+
}
120+
121+
func TestProcessRequest_StreamFieldAbsent(t *testing.T) {
122+
instance, _ := Factory("test", nil, nil)
123+
req := requesthandling.NewInferenceRequest()
124+
req.Body["model"] = "gpt-oss-20b"
125+
126+
err := instance.(*Plugin).ProcessRequest(context.Background(), plugin.NewCycleState(), req)
127+
require.NoError(t, err)
128+
assert.Nil(t, req.Body["stream_options"])
129+
assert.False(t, req.BodyMutated())
130+
}

0 commit comments

Comments
 (0)