Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ catalog/pkg/openapi/model_catalog_model_artifact.go linguist-generated=true
catalog/pkg/openapi/model_catalog_model_list.go linguist-generated=true
catalog/pkg/openapi/model_catalog_source.go linguist-generated=true
catalog/pkg/openapi/model_catalog_source_list.go linguist-generated=true
catalog/pkg/openapi/model_catalog_source_preview_response.go linguist-generated=true
catalog/pkg/openapi/model_catalog_source_preview_response_all_of_summary.go linguist-generated=true
catalog/pkg/openapi/model_error.go linguist-generated=true
catalog/pkg/openapi/model_filter_option.go linguist-generated=true
catalog/pkg/openapi/model_filter_option_range.go linguist-generated=true
Expand All @@ -39,6 +41,7 @@ catalog/pkg/openapi/model_metadata_proto_value.go linguist-generated=true
catalog/pkg/openapi/model_metadata_string_value.go linguist-generated=true
catalog/pkg/openapi/model_metadata_struct_value.go linguist-generated=true
catalog/pkg/openapi/model_metadata_value.go linguist-generated=true
catalog/pkg/openapi/model_model_preview_result.go linguist-generated=true
catalog/pkg/openapi/model_order_by_field.go linguist-generated=true
catalog/pkg/openapi/model_sort_order.go linguist-generated=true
catalog/pkg/openapi/response.go linguist-generated=true
Expand Down
415 changes: 406 additions & 9 deletions api/openapi/catalog.yaml

Large diffs are not rendered by default.

417 changes: 408 additions & 9 deletions api/openapi/src/catalog.yaml

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion catalog/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ The HuggingFace catalog source allows you to discover and import models from the

#### 1. Set Your API Key

The HuggingFace provider requires an API key for authentication. By default, the service reads the API key from the `HF_API_KEY` environment variable:
Setting a Hugging Face API key is optional. Hugging Face requires an API key for authentication for full access to data of models that are private and/or gated. If an API key is NOT set, private models will be entirely unavailable and gated models will have limited metadata. By default, the service reads the API key from the `HF_API_KEY` environment variable:

```bash
export HF_API_KEY="your-huggingface-api-key-here"
Expand Down
20 changes: 20 additions & 0 deletions catalog/internal/catalog/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ type ListArtifactsParams struct {
ArtifactTypesFilter []string
}

type ListPerformanceArtifactsParams struct {
FilterQuery string
PageSize int32
OrderBy string
SortOrder model.SortOrder
NextPageToken *string
TargetRPS int32
Recommendations bool
RPSProperty string // configurable "requests_per_second"
LatencyProperty string // configurable "ttft_p90"
HardwareCountProperty string // configurable "hardware_count"
HardwareTypeProperty string // configurable "hardware_type"
}

// APIProvider implements the API endpoints.
type APIProvider interface {
// GetModel returns model metadata for a single model by its name. If
Expand All @@ -43,6 +57,12 @@ type APIProvider interface {
// found, but has no artifacts, an empty list is returned.
GetArtifacts(ctx context.Context, modelName string, sourceID string, params ListArtifactsParams) (model.CatalogArtifactList, error)

// GetPerformanceArtifacts returns all performance-metrics artifacts for a particular model.
// It filters artifacts by metricsType=performance-metrics and calculates custom properties
// for targetRPS when specified. If no model is found with that name, it returns nil.
// If the model is found but has no performance artifacts, an empty list is returned.
GetPerformanceArtifacts(ctx context.Context, modelName string, sourceID string, params ListPerformanceArtifactsParams) (model.CatalogArtifactList, error)

// GetFilterOptions returns all available filter options for models.
// This includes field names, data types, and available values or ranges.
GetFilterOptions(ctx context.Context) (*model.FilterOptionsList, error)
Expand Down
148 changes: 146 additions & 2 deletions catalog/internal/catalog/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
apimodels "github.com/kubeflow/model-registry/catalog/pkg/openapi"
"github.com/kubeflow/model-registry/internal/apiutils"
mrmodels "github.com/kubeflow/model-registry/internal/db/models"
"github.com/stretchr/testify/assert"
)

func TestLoadCatalogSources(t *testing.T) {
Expand All @@ -28,8 +29,7 @@ func TestLoadCatalogSources(t *testing.T) {
{
name: "test-catalog-sources",
args: args{catalogsPath: "testdata/test-catalog-sources.yaml"},
want: []string{"catalog1"},
wantErr: false,
want: []string{"catalog1", "catalog2"},
},
}
for _, tt := range tests {
Expand Down Expand Up @@ -62,6 +62,7 @@ func TestLoadCatalogSources(t *testing.T) {

func TestLoadCatalogSourcesEnabledDisabled(t *testing.T) {
trueValue := true
falseValue := false
type args struct {
catalogsPath string
}
Expand All @@ -81,6 +82,12 @@ func TestLoadCatalogSourcesEnabledDisabled(t *testing.T) {
Enabled: &trueValue,
Labels: []string{},
},
"catalog2": {
Id: "catalog2",
Name: "Catalog 2",
Enabled: &falseValue,
Labels: []string{},
},
},
wantErr: false,
},
Expand Down Expand Up @@ -469,6 +476,87 @@ func TestLoadCatalogSourcesWithRepositoryErrors(t *testing.T) {
}
}

func TestLoadCatalogSourcesWithNilEnabled(t *testing.T) {
// Test that nil Enabled field is treated as enabled (per OpenAPI spec default: true)
mockModelRepo := &MockCatalogModelRepository{}
mockArtifactRepo := &MockCatalogArtifactRepository{}
mockModelArtifactRepo := &MockCatalogModelArtifactRepository{}
mockMetricsArtifactRepo := &MockCatalogMetricsArtifactRepository{}

services := service.NewServices(
mockModelRepo,
mockArtifactRepo,
mockModelArtifactRepo,
mockMetricsArtifactRepo,
&MockPropertyOptionsRepository{},
)

// Register a test provider
testProviderName := "test-nil-enabled-provider"
RegisterModelProvider(testProviderName, func(ctx context.Context, source *Source, reldir string) (<-chan ModelProviderRecord, error) {
ch := make(chan ModelProviderRecord, 1)

modelName := "test-model-nil-enabled"
model := &dbmodels.CatalogModelImpl{
Attributes: &dbmodels.CatalogModelAttributes{
Name: &modelName,
},
}

ch <- ModelProviderRecord{
Model: model,
Artifacts: []dbmodels.CatalogArtifact{},
}
close(ch)

return ch, nil
})

testConfig := &sourceConfig{
Catalogs: []Source{
{
CatalogSource: apimodels.CatalogSource{
Id: "test-catalog-nil-enabled",
Name: "Test Catalog Nil Enabled",
Enabled: nil, // Nil should be treated as enabled
},
Type: testProviderName,
},
},
}

l := NewLoader(services, []string{})
ctx := context.Background()

// First call updateSources to populate the SourceCollection
err := l.updateSources("test-path", testConfig)
if err != nil {
t.Fatalf("updateSources() error = %v", err)
}

err = l.updateDatabase(ctx)
if err != nil {
t.Fatalf("updateDatabase() error = %v", err)
}

// Wait for processing
time.Sleep(100 * time.Millisecond)

// Verify that the model WAS saved (because nil Enabled is treated as enabled)
if len(mockModelRepo.SavedModels) != 1 {
t.Errorf("Expected 1 model to be saved (nil Enabled should be treated as enabled), got %d", len(mockModelRepo.SavedModels))
}

if len(mockModelRepo.SavedModels) > 0 {
savedModel := mockModelRepo.SavedModels[0]
if savedModel.GetAttributes() == nil || savedModel.GetAttributes().Name == nil {
t.Error("Saved model should have attributes with name")
} else if *savedModel.GetAttributes().Name != "test-model-nil-enabled" {
t.Errorf("Expected model name 'test-model-nil-enabled', got '%s'", *savedModel.GetAttributes().Name)
}
}
}

func TestMockRepositoryBehavior(t *testing.T) {
mockRepo := &MockCatalogModelRepository{}

Expand Down Expand Up @@ -791,3 +879,59 @@ func (m *MockPropertyOptionsRepository) SetMockOptions(t dbmodels.PropertyOption
}
m.MockOptions[t][typeID] = options
}

func TestAPIProviderGetPerformanceArtifacts(t *testing.T) {
// This test verifies that the APIProvider interface has GetPerformanceArtifacts method
// The actual implementation is tested in db_catalog_test.go

// Create a mock provider to verify interface compliance
services := service.NewServices(
&MockCatalogModelRepository{},
&MockCatalogArtifactRepository{},
&MockCatalogModelArtifactRepository{},
&MockCatalogMetricsArtifactRepository{},
&MockPropertyOptionsRepository{},
)
provider := NewDBCatalog(services, nil)

// Verify provider implements APIProvider interface with GetPerformanceArtifacts
var _ APIProvider = provider

// Basic test - should return error for non-existent model
ctx := context.Background()
_, err := provider.GetPerformanceArtifacts(ctx, "non-existent-model", "source-1", ListPerformanceArtifactsParams{
TargetRPS: 100,
Recommendations: true,
PageSize: 10,
})

// Should get an error since the model doesn't exist
assert.Error(t, err)
}

// TestAPIProviderInterface verifies that the APIProvider interface supports
// all required fields in ListPerformanceArtifactsParams
func TestAPIProviderInterface(t *testing.T) {
services := service.NewServices(
&MockCatalogModelRepository{},
&MockCatalogArtifactRepository{},
&MockCatalogModelArtifactRepository{},
&MockCatalogMetricsArtifactRepository{},
&MockPropertyOptionsRepository{},
)
var provider APIProvider = NewDBCatalog(services, nil)

params := ListPerformanceArtifactsParams{
TargetRPS: 100,
Recommendations: true,
RPSProperty: "custom_rps",
LatencyProperty: "custom_latency",
HardwareCountProperty: "custom_hw_count",
HardwareTypeProperty: "custom_hw_type",
}

// Should compile without errors and be callable
ctx := context.Background()
_, err := provider.GetPerformanceArtifacts(ctx, "test-model", "source-1", params)
assert.Error(t, err) // Expected error since model doesn't exist
}
65 changes: 64 additions & 1 deletion catalog/internal/catalog/db_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type dbCatalogImpl struct {
catalogModelRepository dbmodels.CatalogModelRepository
catalogArtifactRepository dbmodels.CatalogArtifactRepository
propertyOptionsRepository dbmodels.PropertyOptionsRepository
performanceService *dbmodels.PerformanceArtifactService
sources *SourceCollection
}

Expand All @@ -31,6 +32,7 @@ func NewDBCatalog(services service.Services, sources *SourceCollection) APIProvi
catalogArtifactRepository: services.CatalogArtifactRepository,
catalogModelRepository: services.CatalogModelRepository,
propertyOptionsRepository: services.PropertyOptionsRepository,
performanceService: dbmodels.NewPerformanceArtifactService(services.CatalogArtifactRepository),
sources: sources,
}
}
Expand Down Expand Up @@ -128,7 +130,7 @@ func (d *dbCatalogImpl) ListModels(ctx context.Context, params ListModelsParams)
}

func (d *dbCatalogImpl) GetArtifacts(ctx context.Context, modelName string, sourceID string, params ListArtifactsParams) (apimodels.CatalogArtifactList, error) {
pageSize := int32(params.PageSize)
pageSize := params.PageSize

// Use consistent defaults to match pagination logic
orderBy := string(params.OrderBy)
Expand Down Expand Up @@ -240,6 +242,67 @@ func (d *dbCatalogImpl) GetFilterOptions(ctx context.Context) (*apimodels.Filter
}, nil
}

func (d *dbCatalogImpl) GetPerformanceArtifacts(ctx context.Context, modelName string, sourceID string, params ListPerformanceArtifactsParams) (apimodels.CatalogArtifactList, error) {
// Get the model to validate it exists and get its ID
modelsList, err := d.catalogModelRepository.List(dbmodels.CatalogModelListOptions{
Name: &modelName,
SourceIDs: &[]string{sourceID},
})
if err != nil {
return apimodels.CatalogArtifactList{}, err
}

if len(modelsList.Items) == 0 {
return apimodels.CatalogArtifactList{}, fmt.Errorf("no models found for name=%v: %w", modelName, api.ErrNotFound)
}

if len(modelsList.Items) > 1 {
return apimodels.CatalogArtifactList{}, fmt.Errorf("multiple models found for name=%v: %w", modelName, api.ErrNotFound)
}

model := modelsList.Items[0]

serviceParams := dbmodels.PerformanceArtifactParams{
ModelID: *model.GetID(),
TargetRPS: params.TargetRPS,
Recommendations: params.Recommendations,
FilterQuery: params.FilterQuery,
PageSize: params.PageSize,
OrderBy: params.OrderBy,
SortOrder: string(params.SortOrder),
NextPageToken: params.NextPageToken,
RPSProperty: params.RPSProperty,
LatencyProperty: params.LatencyProperty,
HardwareCountProperty: params.HardwareCountProperty,
HardwareTypeProperty: params.HardwareTypeProperty,
}

artifactsList, err := d.performanceService.GetArtifacts(serviceParams)
if err != nil {
return apimodels.CatalogArtifactList{}, fmt.Errorf("failed to get performance artifacts: %w", err)
}

artifactList := &apimodels.CatalogArtifactList{
Items: make([]apimodels.CatalogArtifact, 0, len(artifactsList.Items)),
}

for _, artifact := range artifactsList.Items {
mappedArtifact, err := mapDBArtifactToAPIArtifact(dbmodels.CatalogArtifact{
CatalogMetricsArtifact: artifact,
})
if err != nil {
return apimodels.CatalogArtifactList{}, err
}
artifactList.Items = append(artifactList.Items, mappedArtifact)
}

artifactList.NextPageToken = artifactsList.NextPageToken
artifactList.PageSize = params.PageSize
artifactList.Size = int32(len(artifactList.Items))

return *artifactList, nil
}

func dbPropToAPIOption(prop dbmodels.PropertyOption) *apimodels.FilterOption {
var option apimodels.FilterOption

Expand Down
Loading
Loading