diff --git a/pkg/plugins/datalayer/models/datasource_test.go b/pkg/plugins/datalayer/models/datasource_test.go new file mode 100644 index 000000000..1c398ad02 --- /dev/null +++ b/pkg/plugins/datalayer/models/datasource_test.go @@ -0,0 +1,49 @@ +// Package models +package models + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" +) + +func TestDatasource(t *testing.T) { + source := http.NewHTTPDataSource("https", "/models", true, ModelsDataSourceType, + "models-data-source", parseModels, ModelsResponseType) + extractor, err := NewModelExtractor() + assert.Nil(t, err, "failed to create extractor") + + err = source.AddExtractor(extractor) + assert.Nil(t, err, "failed to add extractor") + + err = source.AddExtractor(extractor) + assert.NotNil(t, err, "expected to fail to add the same extractor twice") + + extractors := source.Extractors() + assert.Len(t, extractors, 1) + assert.Equal(t, extractor.TypedName().String(), extractors[0]) + + err = datalayer.RegisterSource(source) + assert.Nil(t, err, "failed to register") + + ctx := context.Background() + factory := datalayer.NewEndpointFactory([]fwkdl.DataSource{source}, 100*time.Hour) + pod := &fwkdl.EndpointMetadata{ + NamespacedName: types.NamespacedName{ + Name: "pod1", + Namespace: "default", + }, + Address: "1.2.3.4:5678", + } + endpoint := factory.NewEndpoint(ctx, pod, nil) + assert.NotNil(t, endpoint, "failed to create endpoint") + + err = source.Collect(ctx, endpoint) + assert.NotNil(t, err, "expected to fail to collect metrics") +} diff --git a/pkg/plugins/datalayer/models/extractor.go b/pkg/plugins/datalayer/models/extractor.go new file mode 100644 index 000000000..317240ab2 --- /dev/null +++ b/pkg/plugins/datalayer/models/extractor.go @@ -0,0 +1,102 @@ +package models + +import ( + "context" + "fmt" + "reflect" + "strings" + + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + fwkplugin "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" +) + +const modelsAttributeKey = "/v1/models" + +// ModelInfoCollection defines models' data returned from /v1/models API +type ModelInfoCollection []ModelInfo + +// ModelInfo defines model's data returned from /v1/models API +type ModelInfo struct { + ID string `json:"id"` + Parent string `json:"parent,omitempty"` +} + +// String returns a string representation of the model info +func (m *ModelInfo) String() string { + return fmt.Sprintf("%+v", *m) +} + +// Clone returns a full copy of the object +func (m ModelInfoCollection) Clone() fwkdl.Cloneable { + if m == nil { + return nil + } + clone := make([]ModelInfo, len(m)) + copy(clone, m) + return (*ModelInfoCollection)(&clone) +} + +func (m ModelInfoCollection) String() string { + if m == nil { + return "[]" + } + parts := make([]string, len(m)) + for i, p := range m { + parts[i] = p.String() + } + return "[" + strings.Join(parts, ", ") + "]" +} + +// ModelResponse is the response from /v1/models API +type ModelResponse struct { + Object string `json:"object"` + Data []ModelInfo `json:"data"` +} + +// ModelsResponseType is the type of models response +var ( + ModelsResponseType = reflect.TypeOf(ModelResponse{}) +) + +// ModelExtractor implements the models extraction. +type ModelExtractor struct { + typedName fwkplugin.TypedName +} + +// NewModelExtractor returns a new model extractor. +func NewModelExtractor() (*ModelExtractor, error) { + return &ModelExtractor{ + typedName: fwkplugin.TypedName{ + Type: ModelsExtractorType, + Name: ModelsExtractorType, + }, + }, nil +} + +// TypedName returns the type and name of the ModelExtractor. +func (me *ModelExtractor) TypedName() fwkplugin.TypedName { + return me.typedName +} + +// WithName sets the name of the extractor. +func (me *ModelExtractor) WithName(name string) *ModelExtractor { + me.typedName.Name = name + return me +} + +// ExpectedInputType defines the type expected by ModelExtractor. +func (me *ModelExtractor) ExpectedInputType() reflect.Type { + return ModelsResponseType +} + +// Extract transforms the data source output into a concrete attribute that +// is stored on the given endpoint. +func (me *ModelExtractor) Extract(_ context.Context, data any, ep fwkdl.Endpoint) error { + models, ok := data.(*ModelResponse) + if !ok { + return fmt.Errorf("unexpected input in Extract: %T", data) + } + + ep.GetAttributes().Put(modelsAttributeKey, ModelInfoCollection(models.Data)) + return nil +} diff --git a/pkg/plugins/datalayer/models/extractor_test.go b/pkg/plugins/datalayer/models/extractor_test.go new file mode 100644 index 000000000..4f075f3d0 --- /dev/null +++ b/pkg/plugins/datalayer/models/extractor_test.go @@ -0,0 +1,113 @@ +package models + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" +) + +func TestExtractorExtract(t *testing.T) { + ctx := context.Background() + + extractor, err := NewModelExtractor() + if err != nil { + t.Fatalf("failed to create extractor: %v", err) + } + + if exType := extractor.TypedName().Type; exType == "" { + t.Error("empty extractor type") + } + + if exName := extractor.TypedName().Name; exName == "" { + t.Error("empty extractor name") + } + + if inputType := extractor.ExpectedInputType(); inputType != ModelsResponseType { + t.Errorf("incorrect expected input type: %v", inputType) + } + + ep := fwkdl.NewEndpoint(nil, nil) + if ep == nil { + t.Fatal("expected non-nil endpoint") + } + + model := "food-review" + + tests := []struct { + name string + data any + wantErr bool + updated bool // whether metrics are expected to change + }{ + { + name: "nil data", + data: nil, + wantErr: true, + updated: false, + }, + { + name: "empty ModelsResponse", + data: &ModelResponse{}, + wantErr: false, + updated: false, + }, + { + name: "valid models response", + data: &ModelResponse{ + Object: "list", + Data: []ModelInfo{ + { + ID: model, + }, + { + ID: "lora1", + Parent: model, + }, + }, + }, + wantErr: false, + updated: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("Extract panicked: %v", r) + } + }() + + attr := ep.GetAttributes() + before, ok := attr.Get(modelsAttributeKey) + if ok && before != nil { + t.Error("expected empty attributes") + } + err := extractor.Extract(ctx, tt.data, ep) + after, ok := attr.Get(modelsAttributeKey) + if !ok && tt.updated { + t.Error("expected updated attributes") + } + + if tt.wantErr && err == nil { + t.Errorf("expected error but got nil") + } + if !tt.wantErr && err != nil { + t.Errorf("unexpected error: %v", err) + } + + if tt.updated { + if diff := cmp.Diff(before, after); diff == "" { + t.Errorf("expected models to be updated, but no change detected") + } + } else { + if diff := cmp.Diff(before, after); diff != "" { + t.Errorf("expected no models update, but got changes:\n%s", diff) + } + } + }) + } +} diff --git a/pkg/plugins/datalayer/models/factories.go b/pkg/plugins/datalayer/models/factories.go new file mode 100644 index 000000000..49cd4edad --- /dev/null +++ b/pkg/plugins/datalayer/models/factories.go @@ -0,0 +1,69 @@ +package models + +import ( + "encoding/json" + "fmt" + "io" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" +) + +const ( + // ModelsDataSourceType is models data source type + ModelsDataSourceType = "models-data-source" + // ModelsExtractorType is models extractor type + ModelsExtractorType = "model-server-protocol-models" +) + +// Configuration parameters for models data source. +type modelsDatasourceParams struct { + // Scheme defines the protocol scheme used in models retrieval (e.g., "http"). + Scheme string `json:"scheme"` + // Path defines the URL path used in models retrieval (e.g., "/v1/models"). + Path string `json:"path"` + // InsecureSkipVerify defines whether model server certificate should be verified or not. + InsecureSkipVerify bool `json:"insecureSkipVerify"` +} + +// ModelDataSourceFactory is a factory function used to instantiate data layer's +// models data source plugins specified in a configuration. +func ModelDataSourceFactory(name string, parameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { + cfg := defaultDataSourceConfigParams() + if parameters != nil { // overlay the defaults with configured values + if err := json.Unmarshal(parameters, cfg); err != nil { + return nil, err + } + } + if cfg.Scheme != "http" && cfg.Scheme != "https" { + return nil, fmt.Errorf("unsupported scheme: %s", cfg.Scheme) + } + + ds := http.NewHTTPDataSource(cfg.Scheme, cfg.Path, cfg.InsecureSkipVerify, ModelsDataSourceType, + name, parseModels, ModelsResponseType) + return ds, nil +} + +// ModelServerExtractorFactory is a factory function used to instantiate data layer's models +// Extractor plugins specified in a configuration. +func ModelServerExtractorFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { + extractor, err := NewModelExtractor() + if err != nil { + return nil, err + } + return extractor.WithName(name), nil +} + +func defaultDataSourceConfigParams() *modelsDatasourceParams { + return &modelsDatasourceParams{Scheme: "http", Path: "/v1/models", InsecureSkipVerify: true} +} + +func parseModels(data io.Reader) (any, error) { + body, err := io.ReadAll(data) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %v", err) + } + var modelsResponse ModelResponse + err = json.Unmarshal(body, &modelsResponse) + return &modelsResponse, err +} diff --git a/pkg/plugins/register.go b/pkg/plugins/register.go index 08bb7b0b3..e978de60c 100644 --- a/pkg/plugins/register.go +++ b/pkg/plugins/register.go @@ -1,6 +1,7 @@ package plugins import ( + "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/datalayer/models" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter" prerequest "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/pre-request" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile" @@ -22,4 +23,6 @@ func RegisterAllPlugins() { plugin.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory) plugin.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory) plugin.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory) + plugin.Register(models.ModelsDataSourceType, models.ModelDataSourceFactory) + plugin.Register(models.ModelsExtractorType, models.ModelServerExtractorFactory) }