Skip to content

Commit cf638f5

Browse files
authored
Models extractor (#553)
* Models extractor Signed-off-by: irar2 <irar@il.ibm.com> * Update register.go Signed-off-by: Ira Rosen <irar@il.ibm.com> * Updated for the newer GIE Signed-off-by: irar2 <irar@il.ibm.com> * Review comments Signed-off-by: irar2 <irar@il.ibm.com> * Check the scheme Signed-off-by: irar2 <irar@il.ibm.com> --------- Signed-off-by: irar2 <irar@il.ibm.com> Signed-off-by: Ira Rosen <irar@il.ibm.com>
1 parent 415bbcc commit cf638f5

File tree

5 files changed

+336
-0
lines changed

5 files changed

+336
-0
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Package models
2+
package models
3+
4+
import (
5+
"context"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
"k8s.io/apimachinery/pkg/types"
11+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
12+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http"
13+
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
14+
)
15+
16+
func TestDatasource(t *testing.T) {
17+
source := http.NewHTTPDataSource("https", "/models", true, ModelsDataSourceType,
18+
"models-data-source", parseModels, ModelsResponseType)
19+
extractor, err := NewModelExtractor()
20+
assert.Nil(t, err, "failed to create extractor")
21+
22+
err = source.AddExtractor(extractor)
23+
assert.Nil(t, err, "failed to add extractor")
24+
25+
err = source.AddExtractor(extractor)
26+
assert.NotNil(t, err, "expected to fail to add the same extractor twice")
27+
28+
extractors := source.Extractors()
29+
assert.Len(t, extractors, 1)
30+
assert.Equal(t, extractor.TypedName().String(), extractors[0])
31+
32+
err = datalayer.RegisterSource(source)
33+
assert.Nil(t, err, "failed to register")
34+
35+
ctx := context.Background()
36+
factory := datalayer.NewEndpointFactory([]fwkdl.DataSource{source}, 100*time.Hour)
37+
pod := &fwkdl.EndpointMetadata{
38+
NamespacedName: types.NamespacedName{
39+
Name: "pod1",
40+
Namespace: "default",
41+
},
42+
Address: "1.2.3.4:5678",
43+
}
44+
endpoint := factory.NewEndpoint(ctx, pod, nil)
45+
assert.NotNil(t, endpoint, "failed to create endpoint")
46+
47+
err = source.Collect(ctx, endpoint)
48+
assert.NotNil(t, err, "expected to fail to collect metrics")
49+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package models
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"reflect"
7+
"strings"
8+
9+
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
10+
fwkplugin "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
11+
)
12+
13+
const modelsAttributeKey = "/v1/models"
14+
15+
// ModelInfoCollection defines models' data returned from /v1/models API
16+
type ModelInfoCollection []ModelInfo
17+
18+
// ModelInfo defines model's data returned from /v1/models API
19+
type ModelInfo struct {
20+
ID string `json:"id"`
21+
Parent string `json:"parent,omitempty"`
22+
}
23+
24+
// String returns a string representation of the model info
25+
func (m *ModelInfo) String() string {
26+
return fmt.Sprintf("%+v", *m)
27+
}
28+
29+
// Clone returns a full copy of the object
30+
func (m ModelInfoCollection) Clone() fwkdl.Cloneable {
31+
if m == nil {
32+
return nil
33+
}
34+
clone := make([]ModelInfo, len(m))
35+
copy(clone, m)
36+
return (*ModelInfoCollection)(&clone)
37+
}
38+
39+
func (m ModelInfoCollection) String() string {
40+
if m == nil {
41+
return "[]"
42+
}
43+
parts := make([]string, len(m))
44+
for i, p := range m {
45+
parts[i] = p.String()
46+
}
47+
return "[" + strings.Join(parts, ", ") + "]"
48+
}
49+
50+
// ModelResponse is the response from /v1/models API
51+
type ModelResponse struct {
52+
Object string `json:"object"`
53+
Data []ModelInfo `json:"data"`
54+
}
55+
56+
// ModelsResponseType is the type of models response
57+
var (
58+
ModelsResponseType = reflect.TypeOf(ModelResponse{})
59+
)
60+
61+
// ModelExtractor implements the models extraction.
62+
type ModelExtractor struct {
63+
typedName fwkplugin.TypedName
64+
}
65+
66+
// NewModelExtractor returns a new model extractor.
67+
func NewModelExtractor() (*ModelExtractor, error) {
68+
return &ModelExtractor{
69+
typedName: fwkplugin.TypedName{
70+
Type: ModelsExtractorType,
71+
Name: ModelsExtractorType,
72+
},
73+
}, nil
74+
}
75+
76+
// TypedName returns the type and name of the ModelExtractor.
77+
func (me *ModelExtractor) TypedName() fwkplugin.TypedName {
78+
return me.typedName
79+
}
80+
81+
// WithName sets the name of the extractor.
82+
func (me *ModelExtractor) WithName(name string) *ModelExtractor {
83+
me.typedName.Name = name
84+
return me
85+
}
86+
87+
// ExpectedInputType defines the type expected by ModelExtractor.
88+
func (me *ModelExtractor) ExpectedInputType() reflect.Type {
89+
return ModelsResponseType
90+
}
91+
92+
// Extract transforms the data source output into a concrete attribute that
93+
// is stored on the given endpoint.
94+
func (me *ModelExtractor) Extract(_ context.Context, data any, ep fwkdl.Endpoint) error {
95+
models, ok := data.(*ModelResponse)
96+
if !ok {
97+
return fmt.Errorf("unexpected input in Extract: %T", data)
98+
}
99+
100+
ep.GetAttributes().Put(modelsAttributeKey, ModelInfoCollection(models.Data))
101+
return nil
102+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package models
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/google/go-cmp/cmp"
8+
9+
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
10+
)
11+
12+
func TestExtractorExtract(t *testing.T) {
13+
ctx := context.Background()
14+
15+
extractor, err := NewModelExtractor()
16+
if err != nil {
17+
t.Fatalf("failed to create extractor: %v", err)
18+
}
19+
20+
if exType := extractor.TypedName().Type; exType == "" {
21+
t.Error("empty extractor type")
22+
}
23+
24+
if exName := extractor.TypedName().Name; exName == "" {
25+
t.Error("empty extractor name")
26+
}
27+
28+
if inputType := extractor.ExpectedInputType(); inputType != ModelsResponseType {
29+
t.Errorf("incorrect expected input type: %v", inputType)
30+
}
31+
32+
ep := fwkdl.NewEndpoint(nil, nil)
33+
if ep == nil {
34+
t.Fatal("expected non-nil endpoint")
35+
}
36+
37+
model := "food-review"
38+
39+
tests := []struct {
40+
name string
41+
data any
42+
wantErr bool
43+
updated bool // whether metrics are expected to change
44+
}{
45+
{
46+
name: "nil data",
47+
data: nil,
48+
wantErr: true,
49+
updated: false,
50+
},
51+
{
52+
name: "empty ModelsResponse",
53+
data: &ModelResponse{},
54+
wantErr: false,
55+
updated: false,
56+
},
57+
{
58+
name: "valid models response",
59+
data: &ModelResponse{
60+
Object: "list",
61+
Data: []ModelInfo{
62+
{
63+
ID: model,
64+
},
65+
{
66+
ID: "lora1",
67+
Parent: model,
68+
},
69+
},
70+
},
71+
wantErr: false,
72+
updated: true,
73+
},
74+
}
75+
76+
for _, tt := range tests {
77+
t.Run(tt.name, func(t *testing.T) {
78+
defer func() {
79+
if r := recover(); r != nil {
80+
t.Errorf("Extract panicked: %v", r)
81+
}
82+
}()
83+
84+
attr := ep.GetAttributes()
85+
before, ok := attr.Get(modelsAttributeKey)
86+
if ok && before != nil {
87+
t.Error("expected empty attributes")
88+
}
89+
err := extractor.Extract(ctx, tt.data, ep)
90+
after, ok := attr.Get(modelsAttributeKey)
91+
if !ok && tt.updated {
92+
t.Error("expected updated attributes")
93+
}
94+
95+
if tt.wantErr && err == nil {
96+
t.Errorf("expected error but got nil")
97+
}
98+
if !tt.wantErr && err != nil {
99+
t.Errorf("unexpected error: %v", err)
100+
}
101+
102+
if tt.updated {
103+
if diff := cmp.Diff(before, after); diff == "" {
104+
t.Errorf("expected models to be updated, but no change detected")
105+
}
106+
} else {
107+
if diff := cmp.Diff(before, after); diff != "" {
108+
t.Errorf("expected no models update, but got changes:\n%s", diff)
109+
}
110+
}
111+
})
112+
}
113+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package models
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
8+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http"
9+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
10+
)
11+
12+
const (
13+
// ModelsDataSourceType is models data source type
14+
ModelsDataSourceType = "models-data-source"
15+
// ModelsExtractorType is models extractor type
16+
ModelsExtractorType = "model-server-protocol-models"
17+
)
18+
19+
// Configuration parameters for models data source.
20+
type modelsDatasourceParams struct {
21+
// Scheme defines the protocol scheme used in models retrieval (e.g., "http").
22+
Scheme string `json:"scheme"`
23+
// Path defines the URL path used in models retrieval (e.g., "/v1/models").
24+
Path string `json:"path"`
25+
// InsecureSkipVerify defines whether model server certificate should be verified or not.
26+
InsecureSkipVerify bool `json:"insecureSkipVerify"`
27+
}
28+
29+
// ModelDataSourceFactory is a factory function used to instantiate data layer's
30+
// models data source plugins specified in a configuration.
31+
func ModelDataSourceFactory(name string, parameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
32+
cfg := defaultDataSourceConfigParams()
33+
if parameters != nil { // overlay the defaults with configured values
34+
if err := json.Unmarshal(parameters, cfg); err != nil {
35+
return nil, err
36+
}
37+
}
38+
if cfg.Scheme != "http" && cfg.Scheme != "https" {
39+
return nil, fmt.Errorf("unsupported scheme: %s", cfg.Scheme)
40+
}
41+
42+
ds := http.NewHTTPDataSource(cfg.Scheme, cfg.Path, cfg.InsecureSkipVerify, ModelsDataSourceType,
43+
name, parseModels, ModelsResponseType)
44+
return ds, nil
45+
}
46+
47+
// ModelServerExtractorFactory is a factory function used to instantiate data layer's models
48+
// Extractor plugins specified in a configuration.
49+
func ModelServerExtractorFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
50+
extractor, err := NewModelExtractor()
51+
if err != nil {
52+
return nil, err
53+
}
54+
return extractor.WithName(name), nil
55+
}
56+
57+
func defaultDataSourceConfigParams() *modelsDatasourceParams {
58+
return &modelsDatasourceParams{Scheme: "http", Path: "/v1/models", InsecureSkipVerify: true}
59+
}
60+
61+
func parseModels(data io.Reader) (any, error) {
62+
body, err := io.ReadAll(data)
63+
if err != nil {
64+
return nil, fmt.Errorf("failed to read response body: %v", err)
65+
}
66+
var modelsResponse ModelResponse
67+
err = json.Unmarshal(body, &modelsResponse)
68+
return &modelsResponse, err
69+
}

pkg/plugins/register.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package plugins
22

33
import (
4+
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/datalayer/models"
45
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter"
56
prerequest "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/pre-request"
67
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile"
@@ -22,4 +23,6 @@ func RegisterAllPlugins() {
2223
plugin.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory)
2324
plugin.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory)
2425
plugin.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory)
26+
plugin.Register(models.ModelsDataSourceType, models.ModelDataSourceFactory)
27+
plugin.Register(models.ModelsExtractorType, models.ModelServerExtractorFactory)
2528
}

0 commit comments

Comments
 (0)