diff --git a/.gitattributes b/.gitattributes index 6c1891dd04..f81833b940 100644 --- a/.gitattributes +++ b/.gitattributes @@ -28,6 +28,8 @@ catalog/pkg/openapi/model_catalog_model_artifact.go linguist-generated=true catalog/pkg/openapi/model_catalog_model_list.go linguist-generated=true catalog/pkg/openapi/model_catalog_source.go linguist-generated=true catalog/pkg/openapi/model_catalog_source_list.go linguist-generated=true +catalog/pkg/openapi/model_catalog_source_preview_response.go linguist-generated=true +catalog/pkg/openapi/model_catalog_source_preview_response_all_of_summary.go linguist-generated=true catalog/pkg/openapi/model_error.go linguist-generated=true catalog/pkg/openapi/model_filter_option.go linguist-generated=true catalog/pkg/openapi/model_filter_option_range.go linguist-generated=true @@ -39,6 +41,7 @@ catalog/pkg/openapi/model_metadata_proto_value.go linguist-generated=true catalog/pkg/openapi/model_metadata_string_value.go linguist-generated=true catalog/pkg/openapi/model_metadata_struct_value.go linguist-generated=true catalog/pkg/openapi/model_metadata_value.go linguist-generated=true +catalog/pkg/openapi/model_model_preview_result.go linguist-generated=true catalog/pkg/openapi/model_order_by_field.go linguist-generated=true catalog/pkg/openapi/model_sort_order.go linguist-generated=true catalog/pkg/openapi/response.go linguist-generated=true diff --git a/api/openapi/catalog.yaml b/api/openapi/catalog.yaml index 46e98a2949..4c4103df76 100644 --- a/api/openapi/catalog.yaml +++ b/api/openapi/catalog.yaml @@ -160,6 +160,156 @@ paths: $ref: "#/components/responses/InternalServerError" operationId: findSources description: Gets a list of all `CatalogSource` entities. + /api/model_catalog/v1alpha1/sources/preview: + description: >- + The REST endpoint/path used to preview a catalog source configuration. This endpoint accepts a catalog source definition and returns a list of models with their inclusion/exclusion status based on the configured filters. + post: + summary: Preview catalog source configuration + description: |- + Accepts a catalog source configuration and returns a list of models showing + which would be included or excluded based on the configured filters. This allows + users to test and validate their source configurations before applying them. + + **Two modes of operation:** + + 1. **Stateless mode (recommended for new sources):** Upload both `config` and + `catalogData` files via multipart form. The models are read directly from + the uploaded `catalogData`, enabling preview of new sources before saving + anything to the server. This is ideal for testing configurations. + + 2. **Path mode (for existing sources):** Upload only `config` with a `yamlCatalogPath` + property. The models are read from the specified file path on the server. + Use this for previewing changes to existing saved sources. + tags: + - ModelCatalogService + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - config + properties: + config: + type: string + format: binary + description: |- + YAML file containing the catalog source configuration. + The file should contain a source definition with type and properties + fields, including optional includedModels and excludedModels filters. + + Model filter patterns support the `*` wildcard only and are case-insensitive. + Patterns match the entire model name (e.g., `ibm-granite/*` matches all + models starting with "ibm-granite/"). + catalogData: + type: string + format: binary + description: |- + Optional YAML file containing the catalog data (models). + + This field enables stateless preview of new sources before saving them. + When provided, the catalog data is read directly from this file instead of + from the `yamlCatalogPath` property in the config. + + **Two modes of operation:** + 1. **Stateless mode (recommended for new sources):** Upload both `config` and + `catalogData` files. The models are read from `catalogData`, allowing preview + without saving anything to the server. + 2. **Path mode (for existing sources):** Upload only `config` with a `yamlCatalogPath` + property pointing to a catalog file on the server filesystem. + + If both `catalogData` and `yamlCatalogPath` are provided, `catalogData` takes precedence. + examples: + statelessPreview: + summary: Stateless preview with uploaded catalog data + description: |- + Upload both config and catalogData files to preview a new source + before saving. This is the recommended approach for testing new configurations. + value: | + # config file content: + type: yaml + includedModels: + - "ibm-granite/*" + - "meta-llama/*" + excludedModels: + - "*-draft" + - "*-experimental" + + # catalogData file content (separate file): + models: + - name: ibm-granite/granite-3.0-8b-instruct + description: Granite 8B Instruct model + - name: ibm-granite/granite-3.0-2b-draft + description: Draft version + - name: meta-llama/Llama-2-7b-hf + description: Llama 2 7B + pathBasedPreview: + summary: Path-based preview using server-side catalog + description: |- + Upload only config file with yamlCatalogPath pointing to an + existing catalog file on the server. Use this for previewing + changes to existing saved sources. + value: | + type: yaml + includedModels: + - "ibm-granite/*" + - "meta-llama/*" + - "mistralai/*" + excludedModels: + - "*-draft" + - "*-experimental" + properties: + yamlCatalogPath: "models-catalog.yaml" + huggingfaceSource: + summary: HuggingFace catalog source + description: |- + Upload configuration for HuggingFace source with API credentials. + The API key is passed per-request and not persisted anywhere. + value: | + type: hf + includedModels: + - "microsoft/*" + - "google/*" + excludedModels: + - "*-gguf" + properties: + apiKey: "your-huggingface-api-key-here" # notsecret + modelLimit: 100 + parameters: + - $ref: "#/components/parameters/pageSize" + - $ref: "#/components/parameters/nextPageToken" + - name: filterStatus + description: |- + Filter the response to show specific model statuses: + - `all` (default): Show all models regardless of inclusion status + - `included`: Show only models that pass the configured filters + - `excluded`: Show only models that are filtered out + schema: + type: string + enum: + - all + - included + - excluded + default: all + in: query + required: false + responses: + "200": + $ref: "#/components/responses/CatalogSourcePreviewResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "422": + description: Unprocessable Entity - Invalid source configuration + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: previewCatalogSource /api/model_catalog/v1alpha1/sources/{source_id}/models/{model_name+}: description: >- The REST endpoint/path used to get a `CatalogModel`. @@ -577,22 +727,47 @@ components: type: string includedModels: description: |- - Optional allow-list of models that are eligible for this source. Entries can be - exact model names or patterns that use `*` as a wildcard. When provided, only - models matching at least one pattern are considered. + Optional list of glob patterns for models to include. If specified, only models matching + at least one pattern will be included. If omitted, all models are considered for inclusion. + + Pattern Syntax: + - Only the `*` wildcard is supported (matches zero or more characters) + - Patterns are case-insensitive (e.g., `Granite/*` matches `granite/model` and `GRANITE/model`) + - Patterns match the entire model name (anchored at start and end) + - Wildcards can appear anywhere: `Granite/*`, `*-beta`, `*deprecated*`, `*/old*` - Pattern matching is case-insensitive, so `Granite/*` will match `granite/model`, - `Granite/model`, and `GRANITE/model`. + Examples: + - `ibm-granite/*` - matches all models starting with "ibm-granite/" + - `meta-llama/*` - matches all models in the meta-llama namespace + - `*` - matches all models + + Constraints: + - Patterns cannot be empty or whitespace-only + - A pattern cannot appear in both includedModels and excludedModels type: array items: type: string excludedModels: description: |- - Optional block-list of models that should be removed from the catalog even if - they match `includedModels`. Patterns support the `*` wildcard. + Optional list of glob patterns for models to exclude. Models matching any pattern + will be excluded even if they match an includedModels pattern. Exclusions take + precedence over inclusions. + + Pattern Syntax: + - Only the `*` wildcard is supported (matches zero or more characters) + - Patterns are case-insensitive + - Patterns match the entire model name (anchored at start and end) + - Wildcards can appear anywhere in the pattern - Pattern matching is case-insensitive, so `*-beta` will match `Model-Beta`, - `model-beta`, and `MODEL-BETA`. + Examples: + - `*-draft` - excludes all models ending with "-draft" + - `*-experimental` - excludes experimental models + - `*deprecated*` - excludes models with "deprecated" anywhere in the name + - `*/beta-*` - excludes models with "/beta-" in the path + + Constraints: + - Patterns cannot be empty or whitespace-only + - A pattern cannot appear in both includedModels and excludedModels type: array items: type: string @@ -607,6 +782,40 @@ components: items: $ref: "#/components/schemas/CatalogSource" - $ref: "#/components/schemas/BaseResourceList" + CatalogSourcePreviewResponse: + description: Response containing models and their inclusion/exclusion status. + allOf: + - type: object + properties: + items: + description: Array of model preview results. + type: array + items: + $ref: "#/components/schemas/ModelPreviewResult" + summary: + description: Summary of the preview results + type: object + properties: + totalModels: + type: integer + description: Total number of models evaluated + example: 1500 + includedModels: + type: integer + description: Number of models that would be included + example: 850 + excludedModels: + type: integer + description: Number of models that would be excluded + example: 650 + required: + - totalModels + - includedModels + - excludedModels + required: + - items + - summary + - $ref: "#/components/schemas/BaseResourceList" Error: description: Error code and message. required: @@ -763,6 +972,22 @@ components: example: string_value: my_value metadataType: MetadataStringValue + ModelPreviewResult: + description: |- + A model with its inclusion/exclusion status based on the + configured catalog source filters. + type: object + required: + - name + - included + properties: + name: + type: string + description: Name of the model + example: microsoft/DialoGPT-medium + included: + type: boolean + description: Whether this model would be included based on the source configuration OrderByField: description: |- Supported fields for ordering result entities. @@ -815,6 +1040,103 @@ components: schema: $ref: "#/components/schemas/CatalogSourceList" description: A response containing a list of CatalogSource entities. + CatalogSourcePreviewResponse: + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogSourcePreviewResponse" + examples: + allModels: + summary: All models with inclusion status + description: Response showing all models when filterStatus=all + value: + nextPageToken: "" + pageSize: 10 + size: 5 + items: + - name: "ibm-granite/granite-3.0-8b-instruct" + included: true + - name: "ibm-granite/granite-3.0-2b-instruct" + included: true + - name: "meta-llama/Llama-2-7b-hf" + included: true + - name: "mistralai/Mistral-7B-v0.1-draft" + included: false + - name: "microsoft/phi-2-experimental" + included: false + summary: + totalModels: 5 + includedModels: 3 + excludedModels: 2 + includedOnly: + summary: Only included models + description: Response when filterStatus=included + value: + nextPageToken: "" + pageSize: 10 + size: 3 + items: + - name: "ibm-granite/granite-3.0-8b-instruct" + included: true + - name: "ibm-granite/granite-3.0-2b-instruct" + included: true + - name: "meta-llama/Llama-2-7b-hf" + included: true + summary: + totalModels: 5 + includedModels: 3 + excludedModels: 2 + excludedOnly: + summary: Only excluded models + description: Response when filterStatus=excluded + value: + nextPageToken: "" + pageSize: 10 + size: 2 + items: + - name: "mistralai/Mistral-7B-v0.1-draft" + included: false + - name: "microsoft/phi-2-experimental" + included: false + summary: + totalModels: 5 + includedModels: 3 + excludedModels: 2 + withPagination: + summary: Paginated response + description: Response with pagination when pageSize is smaller than total + value: + nextPageToken: "eyJvZmZzZXQiOjEwfQ==" # notsecret + pageSize: 10 + size: 10 + items: + - name: "ibm-granite/granite-3.0-8b-instruct" + included: true + - name: "ibm-granite/granite-3.0-2b-instruct" + included: true + - name: "meta-llama/Llama-2-7b-hf" + included: true + - name: "meta-llama/Llama-2-13b-hf" + included: true + - name: "mistralai/Mistral-7B-Instruct-v0.3" + included: true + - name: "mistralai/Mistral-7B-v0.1" + included: true + - name: "microsoft/phi-2" + included: true + - name: "google/gemma-7b" + included: true + - name: "mistralai/Mistral-7B-v0.1-draft" + included: false + - name: "microsoft/phi-2-experimental" + included: false + summary: + totalModels: 150 + includedModels: 85 + excludedModels: 65 + description: |- + A response containing a list of models with their inclusion/exclusion + status based on the provided catalog source configuration. CatalogSourceResponse: content: application/json: diff --git a/api/openapi/src/catalog.yaml b/api/openapi/src/catalog.yaml index 3c26ff38b5..3dc95fd7a2 100644 --- a/api/openapi/src/catalog.yaml +++ b/api/openapi/src/catalog.yaml @@ -300,6 +300,158 @@ paths: - $ref: "#/components/parameters/artifactOrderBy" - $ref: "#/components/parameters/sortOrder" - $ref: "#/components/parameters/nextPageToken" + /api/model_catalog/v1alpha1/sources/preview: + description: >- + The REST endpoint/path used to preview a catalog source configuration. + This endpoint accepts a catalog source definition and returns a list of + models with their inclusion/exclusion status based on the configured filters. + post: + summary: Preview catalog source configuration + description: |- + Accepts a catalog source configuration and returns a list of models showing + which would be included or excluded based on the configured filters. This allows + users to test and validate their source configurations before applying them. + + **Two modes of operation:** + + 1. **Stateless mode (recommended for new sources):** Upload both `config` and + `catalogData` files via multipart form. The models are read directly from + the uploaded `catalogData`, enabling preview of new sources before saving + anything to the server. This is ideal for testing configurations. + + 2. **Path mode (for existing sources):** Upload only `config` with a `yamlCatalogPath` + property. The models are read from the specified file path on the server. + Use this for previewing changes to existing saved sources. + tags: + - ModelCatalogService + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - config + properties: + config: + type: string + format: binary + description: |- + YAML file containing the catalog source configuration. + The file should contain a source definition with type and properties + fields, including optional includedModels and excludedModels filters. + + Model filter patterns support the `*` wildcard only and are case-insensitive. + Patterns match the entire model name (e.g., `ibm-granite/*` matches all + models starting with "ibm-granite/"). + catalogData: + type: string + format: binary + description: |- + Optional YAML file containing the catalog data (models). + + This field enables stateless preview of new sources before saving them. + When provided, the catalog data is read directly from this file instead of + from the `yamlCatalogPath` property in the config. + + **Two modes of operation:** + 1. **Stateless mode (recommended for new sources):** Upload both `config` and + `catalogData` files. The models are read from `catalogData`, allowing preview + without saving anything to the server. + 2. **Path mode (for existing sources):** Upload only `config` with a `yamlCatalogPath` + property pointing to a catalog file on the server filesystem. + + If both `catalogData` and `yamlCatalogPath` are provided, `catalogData` takes precedence. + examples: + statelessPreview: + summary: Stateless preview with uploaded catalog data + description: |- + Upload both config and catalogData files to preview a new source + before saving. This is the recommended approach for testing new configurations. + value: | + # config file content: + type: yaml + includedModels: + - "ibm-granite/*" + - "meta-llama/*" + excludedModels: + - "*-draft" + - "*-experimental" + + # catalogData file content (separate file): + models: + - name: ibm-granite/granite-3.0-8b-instruct + description: Granite 8B Instruct model + - name: ibm-granite/granite-3.0-2b-draft + description: Draft version + - name: meta-llama/Llama-2-7b-hf + description: Llama 2 7B + pathBasedPreview: + summary: Path-based preview using server-side catalog + description: |- + Upload only config file with yamlCatalogPath pointing to an + existing catalog file on the server. Use this for previewing + changes to existing saved sources. + value: | + type: yaml + includedModels: + - "ibm-granite/*" + - "meta-llama/*" + - "mistralai/*" + excludedModels: + - "*-draft" + - "*-experimental" + properties: + yamlCatalogPath: "models-catalog.yaml" + huggingfaceSource: + summary: HuggingFace catalog source + description: |- + Upload configuration for HuggingFace source with API credentials. + The API key is passed per-request and not persisted anywhere. + value: | + type: hf + includedModels: + - "microsoft/*" + - "google/*" + excludedModels: + - "*-gguf" + properties: + apiKey: "your-huggingface-api-key-here" # notsecret + modelLimit: 100 + parameters: + - $ref: "#/components/parameters/pageSize" + - $ref: "#/components/parameters/nextPageToken" + - name: filterStatus + description: |- + Filter the response to show specific model statuses: + - `all` (default): Show all models regardless of inclusion status + - `included`: Show only models that pass the configured filters + - `excluded`: Show only models that are filtered out + schema: + type: string + enum: + - all + - included + - excluded + default: all + in: query + required: false + responses: + "200": + $ref: "#/components/responses/CatalogSourcePreviewResponse" + "400": + $ref: "#/components/responses/BadRequest" + "401": + $ref: "#/components/responses/Unauthorized" + "422": + description: Unprocessable Entity - Invalid source configuration + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + "500": + $ref: "#/components/responses/InternalServerError" + operationId: previewCatalogSource components: schemas: CatalogArtifact: @@ -460,22 +612,47 @@ components: type: string includedModels: description: |- - Optional allow-list of models that are eligible for this source. Entries can be - exact model names or patterns that use `*` as a wildcard. When provided, only - models matching at least one pattern are considered. + Optional list of glob patterns for models to include. If specified, only models matching + at least one pattern will be included. If omitted, all models are considered for inclusion. + + Pattern Syntax: + - Only the `*` wildcard is supported (matches zero or more characters) + - Patterns are case-insensitive (e.g., `Granite/*` matches `granite/model` and `GRANITE/model`) + - Patterns match the entire model name (anchored at start and end) + - Wildcards can appear anywhere: `Granite/*`, `*-beta`, `*deprecated*`, `*/old*` - Pattern matching is case-insensitive, so `Granite/*` will match `granite/model`, - `Granite/model`, and `GRANITE/model`. + Examples: + - `ibm-granite/*` - matches all models starting with "ibm-granite/" + - `meta-llama/*` - matches all models in the meta-llama namespace + - `*` - matches all models + + Constraints: + - Patterns cannot be empty or whitespace-only + - A pattern cannot appear in both includedModels and excludedModels type: array items: type: string excludedModels: description: |- - Optional block-list of models that should be removed from the catalog even if - they match `includedModels`. Patterns support the `*` wildcard. + Optional list of glob patterns for models to exclude. Models matching any pattern + will be excluded even if they match an includedModels pattern. Exclusions take + precedence over inclusions. + + Pattern Syntax: + - Only the `*` wildcard is supported (matches zero or more characters) + - Patterns are case-insensitive + - Patterns match the entire model name (anchored at start and end) + - Wildcards can appear anywhere in the pattern - Pattern matching is case-insensitive, so `*-beta` will match `Model-Beta`, - `model-beta`, and `MODEL-BETA`. + Examples: + - `*-draft` - excludes all models ending with "-draft" + - `*-experimental` - excludes experimental models + - `*deprecated*` - excludes models with "deprecated" anywhere in the name + - `*/beta-*` - excludes models with "/beta-" in the path + + Constraints: + - Patterns cannot be empty or whitespace-only + - A pattern cannot appear in both includedModels and excludedModels type: array items: type: string @@ -490,6 +667,56 @@ components: items: $ref: "#/components/schemas/CatalogSource" - $ref: "#/components/schemas/BaseResourceList" + CatalogSourcePreviewResponse: + description: Response containing models and their inclusion/exclusion status. + allOf: + - type: object + properties: + items: + description: Array of model preview results. + type: array + items: + $ref: "#/components/schemas/ModelPreviewResult" + summary: + description: Summary of the preview results + type: object + properties: + totalModels: + type: integer + description: Total number of models evaluated + example: 1500 + includedModels: + type: integer + description: Number of models that would be included + example: 850 + excludedModels: + type: integer + description: Number of models that would be excluded + example: 650 + required: + - totalModels + - includedModels + - excludedModels + required: + - items + - summary + - $ref: "#/components/schemas/BaseResourceList" + ModelPreviewResult: + description: |- + A model with its inclusion/exclusion status based on the + configured catalog source filters. + type: object + required: + - name + - included + properties: + name: + type: string + description: Name of the model + example: microsoft/DialoGPT-medium + included: + type: boolean + description: Whether this model would be included based on the source configuration FilterOption: type: object required: @@ -579,6 +806,103 @@ components: schema: $ref: "#/components/schemas/FilterOptionsList" description: A response containing options for a `filterQuery` parameter. + CatalogSourcePreviewResponse: + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogSourcePreviewResponse" + examples: + allModels: + summary: All models with inclusion status + description: Response showing all models when filterStatus=all + value: + nextPageToken: "" + pageSize: 10 + size: 5 + items: + - name: "ibm-granite/granite-3.0-8b-instruct" + included: true + - name: "ibm-granite/granite-3.0-2b-instruct" + included: true + - name: "meta-llama/Llama-2-7b-hf" + included: true + - name: "mistralai/Mistral-7B-v0.1-draft" + included: false + - name: "microsoft/phi-2-experimental" + included: false + summary: + totalModels: 5 + includedModels: 3 + excludedModels: 2 + includedOnly: + summary: Only included models + description: Response when filterStatus=included + value: + nextPageToken: "" + pageSize: 10 + size: 3 + items: + - name: "ibm-granite/granite-3.0-8b-instruct" + included: true + - name: "ibm-granite/granite-3.0-2b-instruct" + included: true + - name: "meta-llama/Llama-2-7b-hf" + included: true + summary: + totalModels: 5 + includedModels: 3 + excludedModels: 2 + excludedOnly: + summary: Only excluded models + description: Response when filterStatus=excluded + value: + nextPageToken: "" + pageSize: 10 + size: 2 + items: + - name: "mistralai/Mistral-7B-v0.1-draft" + included: false + - name: "microsoft/phi-2-experimental" + included: false + summary: + totalModels: 5 + includedModels: 3 + excludedModels: 2 + withPagination: + summary: Paginated response + description: Response with pagination when pageSize is smaller than total + value: + nextPageToken: "eyJvZmZzZXQiOjEwfQ==" # notsecret + pageSize: 10 + size: 10 + items: + - name: "ibm-granite/granite-3.0-8b-instruct" + included: true + - name: "ibm-granite/granite-3.0-2b-instruct" + included: true + - name: "meta-llama/Llama-2-7b-hf" + included: true + - name: "meta-llama/Llama-2-13b-hf" + included: true + - name: "mistralai/Mistral-7B-Instruct-v0.3" + included: true + - name: "mistralai/Mistral-7B-v0.1" + included: true + - name: "microsoft/phi-2" + included: true + - name: "google/gemma-7b" + included: true + - name: "mistralai/Mistral-7B-v0.1-draft" + included: false + - name: "microsoft/phi-2-experimental" + included: false + summary: + totalModels: 150 + includedModels: 85 + excludedModels: 65 + description: |- + A response containing a list of models with their inclusion/exclusion + status based on the provided catalog source configuration. parameters: filterQuery: diff --git a/catalog/internal/catalog/hf_catalog.go b/catalog/internal/catalog/hf_catalog.go index 00b459b338..d2470740e0 100644 --- a/catalog/internal/catalog/hf_catalog.go +++ b/catalog/internal/catalog/hf_catalog.go @@ -24,6 +24,16 @@ const ( defaultAPIKeyEnvVar = "HF_API_KEY" urlKey = "url" apiKeyEnvVarKey = "apiKeyEnvVar" + maxModelsKey = "maxModels" + + // defaultMaxModels is the default limit for models fetched PER PATTERN. + // This limit is applied independently to each pattern in includedModels + // (e.g., "ibm-granite/*", "meta-llama/*") to prevent overloading the + // HuggingFace API with too many requests and to respect rate limiting. + // + // Example: with maxModels=100 and 3 patterns, up to 300 models total may be fetched. + // Set to 0 to disable the limit (not recommended for large organizations). + defaultMaxModels = 500 ) // gatedString is a custom type that can unmarshal both boolean and string values from JSON @@ -71,6 +81,10 @@ type hfModelProvider struct { baseURL string includedModels []string filter *ModelFilter + // maxModels limits how many models to fetch PER PATTERN (e.g., per "org/*"). + // This is applied independently to each pattern to respect HuggingFace API rate limits. + // A value of 0 means no limit. + maxModels int } // hfModelInfo represents the structure of HuggingFace API model information @@ -660,3 +674,307 @@ func init() { panic(err) } } + +// NewHFPreviewProvider creates an hfModelProvider for preview use. +// It initializes the provider from a PreviewConfig without starting the full model loading. +func NewHFPreviewProvider(config *PreviewConfig) (*hfModelProvider, error) { + p := &hfModelProvider{ + client: &http.Client{Timeout: 30 * time.Second}, + baseURL: defaultHuggingFaceURL, + maxModels: defaultMaxModels, + } + + // Parse API key from environment variable + apiKeyEnvVar := defaultAPIKeyEnvVar + if envVar, ok := config.Properties[apiKeyEnvVarKey].(string); ok && envVar != "" { + apiKeyEnvVar = envVar + } + apiKey := os.Getenv(apiKeyEnvVar) + if apiKey == "" { + return nil, fmt.Errorf("missing %s environment variable for HuggingFace preview", apiKeyEnvVar) + } + p.apiKey = apiKey + + // Parse base URL (optional, defaults to huggingface.co) + if url, ok := config.Properties[urlKey].(string); ok && url != "" { + p.baseURL = strings.TrimSuffix(url, "/") + } + + // Parse maxModels limit (optional, defaults to 500) + // This limit is applied PER PATTERN (e.g., each "org/*" pattern gets its own limit) + // to prevent overloading the HuggingFace API and respect rate limiting. + // Set to 0 to disable the limit. + if maxModels, ok := config.Properties[maxModelsKey]; ok { + switch v := maxModels.(type) { + case int: + p.maxModels = v + case int64: + p.maxModels = int(v) + case float64: + p.maxModels = int(v) + } + } + + return p, nil +} + +// hfListResponse represents a single model in the HuggingFace list API response. +type hfListModel struct { + ID string `json:"id"` + ModelID string `json:"modelId,omitempty"` + Author string `json:"author,omitempty"` + Private bool `json:"private,omitempty"` + Downloads int `json:"downloads,omitempty"` +} + +// PatternType indicates how to handle an includedModels pattern. +type PatternType int + +const ( + PatternExact PatternType = iota // e.g., "org/model-name" - direct fetch + PatternOrgAll // e.g., "org/*" - list all from org + PatternOrgPrefix // e.g., "org/prefix*" - list from org with search + PatternInvalid // e.g., "*", "*/*" - not supported +) + +// parseModelPattern analyzes a model identifier to determine how to fetch it. +// Returns: patternType, org, searchPrefix +func parseModelPattern(pattern string) (PatternType, string, string) { + pattern = strings.TrimSpace(pattern) + + // Reject unsupported wildcard patterns that would try to list all HuggingFace models + // HuggingFace has millions of models, so we require a specific organization + if pattern == "*" || pattern == "*/*" { + return PatternInvalid, "", "" + } + + // Reject patterns like "*/something" where org is a wildcard + if strings.HasPrefix(pattern, "*/") { + return PatternInvalid, "", "" + } + + // Check if it's an org/* pattern + if strings.HasSuffix(pattern, "/*") { + org := strings.TrimSuffix(pattern, "/*") + // Ensure org is not empty or just whitespace + if org == "" || strings.TrimSpace(org) == "" { + return PatternInvalid, "", "" + } + return PatternOrgAll, org, "" + } + + // Check if it has a wildcard after org/prefix + if strings.Contains(pattern, "/") && strings.HasSuffix(pattern, "*") { + parts := strings.SplitN(pattern, "/", 2) + if len(parts) == 2 { + org := parts[0] + // Ensure org is not empty or a wildcard + if org == "" || org == "*" { + return PatternInvalid, "", "" + } + prefix := strings.TrimSuffix(parts[1], "*") + if prefix != "" { + return PatternOrgPrefix, org, prefix + } + } + } + + // Exact model name + return PatternExact, "", "" +} + +// listModelsByAuthor fetches all models from an organization using the HuggingFace list API with pagination. +// If searchPrefix is provided, it filters models that start with that prefix. +func (p *hfModelProvider) listModelsByAuthor(ctx context.Context, author string, searchPrefix string) ([]string, error) { + var allModels []string + limit := 100 // Max allowed by HF API + cursor := "" + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Check if we've reached the maxModels limit for this pattern + // (maxModels is applied per-pattern to respect HF API rate limits) + if p.maxModels > 0 && len(allModels) >= p.maxModels { + glog.Warningf("Reached maxModels limit (%d) for pattern author=%s, stopping pagination", p.maxModels, author) + break + } + + // Build API URL + apiURL := fmt.Sprintf("%s/api/models?author=%s&limit=%d", p.baseURL, author, limit) + if searchPrefix != "" { + apiURL += "&search=" + searchPrefix + } + if cursor != "" { + apiURL += "&cursor=" + cursor + } + + glog.V(2).Infof("Fetching HuggingFace models list: %s", apiURL) + + req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create list request: %w", err) + } + + req.Header.Set("User-Agent", "model-registry-catalog") + if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + resp, err := p.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to list models for author %s: %w", author, err) + } + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, fmt.Errorf("HuggingFace API returned status %d for author %s: %s", resp.StatusCode, author, string(bodyBytes)) + } + + var models []hfListModel + if err := json.NewDecoder(resp.Body).Decode(&models); err != nil { + resp.Body.Close() + return nil, fmt.Errorf("failed to decode models list for author %s: %w", author, err) + } + resp.Body.Close() + + // Extract model IDs + for _, m := range models { + // Check limit before adding each model + if p.maxModels > 0 && len(allModels) >= p.maxModels { + break + } + + modelID := m.ID + if modelID == "" { + modelID = m.ModelID + } + if modelID == "" { + continue + } + + // If we have a search prefix, double-check it matches + // (HF search is fuzzy, so we need to verify) + if searchPrefix != "" { + // Extract the model name part (after org/) + parts := strings.SplitN(modelID, "/", 2) + if len(parts) == 2 { + modelName := parts[1] + if !strings.HasPrefix(strings.ToLower(modelName), strings.ToLower(searchPrefix)) { + continue + } + } + } + + allModels = append(allModels, modelID) + } + + // Check for next page via Link header + linkHeader := resp.Header.Get("Link") + nextCursor := parseNextCursor(linkHeader) + if nextCursor == "" || len(models) < limit { + // No more pages + break + } + cursor = nextCursor + } + + glog.Infof("Listed %d models from author %s (maxModels: %d)", len(allModels), author, p.maxModels) + return allModels, nil +} + +// parseNextCursor extracts the cursor for the next page from the Link header. +// Link header format: ; rel="next" +func parseNextCursor(linkHeader string) string { + if linkHeader == "" { + return "" + } + + // Parse Link header for rel="next" + for _, link := range strings.Split(linkHeader, ",") { + link = strings.TrimSpace(link) + if strings.Contains(link, `rel="next"`) { + // Extract URL between < and > + start := strings.Index(link, "<") + end := strings.Index(link, ">") + if start >= 0 && end > start { + nextURL := link[start+1 : end] + // Extract cursor parameter from URL + if idx := strings.Index(nextURL, "cursor="); idx >= 0 { + cursor := nextURL[idx+7:] + // Handle if there are more parameters after cursor + if ampIdx := strings.Index(cursor, "&"); ampIdx >= 0 { + cursor = cursor[:ampIdx] + } + return cursor + } + } + } + } + return "" +} + +// FetchModelNamesForPreview fetches model info from HuggingFace API for the given model identifiers +// and returns the actual model names. This is used for preview functionality. +// Supports patterns like "org/*" and "org/prefix*" which use the paginated list API. +func (p *hfModelProvider) FetchModelNamesForPreview(ctx context.Context, modelIdentifiers []string) ([]string, error) { + if len(modelIdentifiers) == 0 { + return nil, fmt.Errorf("includedModels is required for HuggingFace source preview") + } + + // Validate credentials first + if err := p.validateCredentials(ctx); err != nil { + return nil, fmt.Errorf("failed to validate HuggingFace credentials: %w", err) + } + + names := make([]string, 0) + + for _, pattern := range modelIdentifiers { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + patternType, org, searchPrefix := parseModelPattern(pattern) + + switch patternType { + case PatternInvalid: + // Reject unsupported wildcard patterns + return nil, fmt.Errorf("wildcard pattern %q is not supported - HuggingFace requires a specific organization (e.g., 'ibm-granite/*' or 'meta-llama/Llama-2-*')", pattern) + + case PatternOrgAll, PatternOrgPrefix: + // Use paginated list API + glog.Infof("Using HuggingFace list API for pattern: %s (org=%s, prefix=%s)", pattern, org, searchPrefix) + models, err := p.listModelsByAuthor(ctx, org, searchPrefix) + if err != nil { + glog.Warningf("Failed to list models for pattern %s: %v", pattern, err) + // Don't fail completely, just skip this pattern + continue + } + names = append(names, models...) + + case PatternExact: + // Direct fetch for exact model name + modelInfo, err := p.fetchModelInfo(ctx, pattern) + if err != nil { + glog.Warningf("Failed to fetch model info for preview: %s: %v", pattern, err) + names = append(names, pattern) + continue + } + + actualName := modelInfo.ID + if actualName == "" { + actualName = pattern + } + names = append(names, actualName) + } + } + + return names, nil +} diff --git a/catalog/internal/catalog/hf_catalog_test.go b/catalog/internal/catalog/hf_catalog_test.go index 7cf032f22b..0b177f0326 100644 --- a/catalog/internal/catalog/hf_catalog_test.go +++ b/catalog/internal/catalog/hf_catalog_test.go @@ -2,9 +2,16 @@ package catalog import ( "context" + "encoding/json" + "fmt" "net/http" + "net/http/httptest" + "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + apimodels "github.com/kubeflow/model-registry/catalog/pkg/openapi" "github.com/kubeflow/model-registry/internal/db/models" ) @@ -556,3 +563,446 @@ func TestPopulateFromHFInfoWithCustomProperties(t *testing.T) { func stringPtr(s string) *string { return &s } + +func TestParseModelPattern(t *testing.T) { + tests := []struct { + pattern string + expectedType PatternType + expectedOrg string + expectedPfx string + }{ + // Exact patterns + {"meta-llama/Llama-2-7b-chat", PatternExact, "", ""}, + {"gpt2", PatternExact, "", ""}, + {"openai-community/gpt2", PatternExact, "", ""}, + + // Org/* patterns + {"ibm-granite/*", PatternOrgAll, "ibm-granite", ""}, + {"meta-llama/*", PatternOrgAll, "meta-llama", ""}, + {"openai/*", PatternOrgAll, "openai", ""}, + + // Org/prefix* patterns + {"meta-llama/Llama-2-*", PatternOrgPrefix, "meta-llama", "Llama-2-"}, + {"ibm-granite/granite-3*", PatternOrgPrefix, "ibm-granite", "granite-3"}, + {"mistralai/Mistral-*", PatternOrgPrefix, "mistralai", "Mistral-"}, + + // Invalid patterns - would try to list all HuggingFace models + {"*", PatternInvalid, "", ""}, + {"*/*", PatternInvalid, "", ""}, + {"*/something", PatternInvalid, "", ""}, + {"*/prefix*", PatternInvalid, "", ""}, + } + + for _, tt := range tests { + t.Run(tt.pattern, func(t *testing.T) { + pType, org, prefix := parseModelPattern(tt.pattern) + assert.Equal(t, tt.expectedType, pType, "pattern type mismatch") + assert.Equal(t, tt.expectedOrg, org, "org mismatch") + assert.Equal(t, tt.expectedPfx, prefix, "prefix mismatch") + }) + } +} + +func TestParseNextCursor(t *testing.T) { + tests := []struct { + name string + header string + expected string + }{ + { + name: "empty header", + header: "", + expected: "", + }, + { + name: "valid next link", + header: `; rel="next"`, + expected: "abc123", + }, + { + name: "next link with other params", + header: `; rel="next"`, + expected: "xyz789", + }, + { + name: "multiple links", + header: `; rel="first", ; rel="next"`, + expected: "page2", + }, + { + name: "no next link", + header: `; rel="first"`, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseNextCursor(tt.header) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestListModelsByAuthor(t *testing.T) { + // Setup mock HF server + mux := http.NewServeMux() + callCount := 0 + + // Mock /api/whoami-v2 for credential validation + mux.HandleFunc("/api/whoami-v2", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"name": "test-user"}) + }) + + // Mock /api/models list endpoint with pagination + mux.HandleFunc("/api/models", func(w http.ResponseWriter, r *http.Request) { + callCount++ + author := r.URL.Query().Get("author") + cursor := r.URL.Query().Get("cursor") + search := r.URL.Query().Get("search") + + if author == "test-org" { + switch cursor { + case "": + // First page - return 100 items to simulate full page (triggers pagination) + models := make([]map[string]interface{}, 100) + for i := range 100 { + models[i] = map[string]interface{}{"id": fmt.Sprintf("test-org/model-%d", i+1)} + } + // Add Link header for next page + w.Header().Set("Link", `; rel="next"`) + _ = json.NewEncoder(w).Encode(models) + case "page2": + // Second page (last) - return fewer than 100 to indicate end + models := []map[string]interface{}{ + {"id": "test-org/model-101"}, + {"id": "test-org/model-102"}, + } + _ = json.NewEncoder(w).Encode(models) + } + } else if author == "search-org" && search != "" { + // Search results + models := []map[string]interface{}{ + {"id": "search-org/" + search + "-match1"}, + {"id": "search-org/" + search + "-match2"}, + {"id": "search-org/other-model"}, // Should be filtered out + } + _ = json.NewEncoder(w).Encode(models) + } else { + _ = json.NewEncoder(w).Encode([]map[string]interface{}{}) + } + }) + + server := httptest.NewServer(mux) + defer server.Close() + + os.Setenv("HF_API_KEY", "test-api-key") + defer os.Unsetenv("HF_API_KEY") + + t.Run("lists all models from org with pagination", func(t *testing.T) { + callCount = 0 + config := &PreviewConfig{ + Type: "hf", + Properties: map[string]any{ + "url": server.URL, + }, + } + + provider, err := NewHFPreviewProvider(config) + require.NoError(t, err) + + models, err := provider.listModelsByAuthor(context.Background(), "test-org", "") + require.NoError(t, err) + + // Should have 100 from first page + 2 from second page = 102 + assert.Len(t, models, 102) + assert.Contains(t, models, "test-org/model-1") + assert.Contains(t, models, "test-org/model-100") + assert.Contains(t, models, "test-org/model-101") + assert.Contains(t, models, "test-org/model-102") + + // Should have made 2 API calls (2 pages) + assert.Equal(t, 2, callCount) + }) + + t.Run("filters by search prefix", func(t *testing.T) { + config := &PreviewConfig{ + Type: "hf", + Properties: map[string]any{ + "url": server.URL, + }, + } + + provider, err := NewHFPreviewProvider(config) + require.NoError(t, err) + + models, err := provider.listModelsByAuthor(context.Background(), "search-org", "prefix") + require.NoError(t, err) + + // Should only include models starting with "prefix" + assert.Len(t, models, 2) + assert.Contains(t, models, "search-org/prefix-match1") + assert.Contains(t, models, "search-org/prefix-match2") + // "other-model" should be filtered out + assert.NotContains(t, models, "search-org/other-model") + }) + + t.Run("respects maxModels limit", func(t *testing.T) { + callCount = 0 + config := &PreviewConfig{ + Type: "hf", + Properties: map[string]any{ + "url": server.URL, + "maxModels": 50, // Limit to 50 models + }, + } + + provider, err := NewHFPreviewProvider(config) + require.NoError(t, err) + assert.Equal(t, 50, provider.maxModels) + + models, err := provider.listModelsByAuthor(context.Background(), "test-org", "") + require.NoError(t, err) + + // Should stop at 50 models (first page has 100, but we limit to 50) + assert.Len(t, models, 50) + + // Should have only made 1 API call (stopped before second page) + assert.Equal(t, 1, callCount) + }) + + t.Run("uses default maxModels when not specified", func(t *testing.T) { + config := &PreviewConfig{ + Type: "hf", + Properties: map[string]any{ + "url": server.URL, + }, + } + + provider, err := NewHFPreviewProvider(config) + require.NoError(t, err) + + // Should use default (500) + assert.Equal(t, 500, provider.maxModels) + }) + + t.Run("maxModels 0 means no limit", func(t *testing.T) { + callCount = 0 + config := &PreviewConfig{ + Type: "hf", + Properties: map[string]any{ + "url": server.URL, + "maxModels": 0, // No limit + }, + } + + provider, err := NewHFPreviewProvider(config) + require.NoError(t, err) + assert.Equal(t, 0, provider.maxModels) + + models, err := provider.listModelsByAuthor(context.Background(), "test-org", "") + require.NoError(t, err) + + // Should get all 102 models (100 from page 1 + 2 from page 2) + assert.Len(t, models, 102) + + // Should have made 2 API calls + assert.Equal(t, 2, callCount) + }) +} + +func TestFetchModelNamesForPreviewWithPatterns(t *testing.T) { + // Setup mock HF server + mux := http.NewServeMux() + + mux.HandleFunc("/api/whoami-v2", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"name": "test-user"}) + }) + + // Mock list API + mux.HandleFunc("/api/models", func(w http.ResponseWriter, r *http.Request) { + author := r.URL.Query().Get("author") + if author == "test-org" { + models := []map[string]interface{}{ + {"id": "test-org/model-a"}, + {"id": "test-org/model-b"}, + } + _ = json.NewEncoder(w).Encode(models) + } else { + _ = json.NewEncoder(w).Encode([]map[string]interface{}{}) + } + }) + + // Mock individual model endpoints + mux.HandleFunc("/api/models/exact-org/exact-model", func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"id": "exact-org/exact-model"}) + }) + + server := httptest.NewServer(mux) + defer server.Close() + + os.Setenv("HF_API_KEY", "test-api-key") + defer os.Unsetenv("HF_API_KEY") + + t.Run("mixed patterns: org/* and exact", func(t *testing.T) { + config := &PreviewConfig{ + Type: "hf", + IncludedModels: []string{ + "test-org/*", // Should use list API + "exact-org/exact-model", // Should use direct fetch + }, + Properties: map[string]any{ + "url": server.URL, + }, + } + + provider, err := NewHFPreviewProvider(config) + require.NoError(t, err) + + names, err := provider.FetchModelNamesForPreview(context.Background(), config.IncludedModels) + require.NoError(t, err) + + assert.Len(t, names, 3) + assert.Contains(t, names, "test-org/model-a") + assert.Contains(t, names, "test-org/model-b") + assert.Contains(t, names, "exact-org/exact-model") + }) + + t.Run("rejects * wildcard pattern", func(t *testing.T) { + config := &PreviewConfig{ + Type: "hf", + IncludedModels: []string{ + "*", + }, + Properties: map[string]any{ + "url": server.URL, + }, + } + + provider, err := NewHFPreviewProvider(config) + require.NoError(t, err) + + _, err = provider.FetchModelNamesForPreview(context.Background(), config.IncludedModels) + require.Error(t, err) + assert.Contains(t, err.Error(), "wildcard pattern") + assert.Contains(t, err.Error(), "not supported") + }) + + t.Run("rejects */* wildcard pattern", func(t *testing.T) { + config := &PreviewConfig{ + Type: "hf", + IncludedModels: []string{ + "*/*", + }, + Properties: map[string]any{ + "url": server.URL, + }, + } + + provider, err := NewHFPreviewProvider(config) + require.NoError(t, err) + + _, err = provider.FetchModelNamesForPreview(context.Background(), config.IncludedModels) + require.Error(t, err) + assert.Contains(t, err.Error(), "wildcard pattern") + assert.Contains(t, err.Error(), "not supported") + }) + + t.Run("rejects */prefix pattern", func(t *testing.T) { + config := &PreviewConfig{ + Type: "hf", + IncludedModels: []string{ + "*/Llama-*", + }, + Properties: map[string]any{ + "url": server.URL, + }, + } + + provider, err := NewHFPreviewProvider(config) + require.NoError(t, err) + + _, err = provider.FetchModelNamesForPreview(context.Background(), config.IncludedModels) + require.Error(t, err) + assert.Contains(t, err.Error(), "wildcard pattern") + assert.Contains(t, err.Error(), "not supported") + }) +} + +func TestPreviewSourceModelsWithHFPatterns(t *testing.T) { + // Setup mock HF server + mux := http.NewServeMux() + + mux.HandleFunc("/api/whoami-v2", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"name": "test-user"}) + }) + + mux.HandleFunc("/api/models", func(w http.ResponseWriter, r *http.Request) { + author := r.URL.Query().Get("author") + if author == "test-org" { + models := []map[string]interface{}{ + {"id": "test-org/model-stable"}, + {"id": "test-org/model-experimental"}, + {"id": "test-org/model-draft"}, + } + _ = json.NewEncoder(w).Encode(models) + } else { + _ = json.NewEncoder(w).Encode([]map[string]interface{}{}) + } + }) + + server := httptest.NewServer(mux) + defer server.Close() + + os.Setenv("HF_API_KEY", "test-api-key") + defer os.Unsetenv("HF_API_KEY") + + t.Run("org/* pattern with excludedModels filter", func(t *testing.T) { + // Note: We test the filtering logic by calling NewHFPreviewProvider directly + // rather than PreviewSourceModels, because PreviewSourceModels has SSRF + // protection that removes custom URLs (which breaks mock server testing). + // The filtering behavior is the same - we're testing the filter logic here. + includedModels := []string{"test-org/*"} + excludedModels := []string{"*-experimental", "*-draft"} + + config := &PreviewConfig{ + Type: "hf", + IncludedModels: includedModels, + ExcludedModels: excludedModels, + Properties: map[string]any{ + "url": server.URL, + }, + } + + // Create provider and fetch model names (bypassing SSRF protection for testing) + provider, err := NewHFPreviewProvider(config) + require.NoError(t, err) + + modelNames, err := provider.FetchModelNamesForPreview(context.Background(), includedModels) + require.NoError(t, err) + require.Len(t, modelNames, 3) + + // Create filter and apply it (same logic as PreviewSourceModels) + filter, err := NewModelFilter(includedModels, excludedModels) + require.NoError(t, err) + + var included, excluded []string + for _, name := range modelNames { + if filter.Allows(name) { + included = append(included, name) + } else { + excluded = append(excluded, name) + } + } + + assert.Len(t, included, 1) + assert.Contains(t, included, "test-org/model-stable") + + assert.Len(t, excluded, 2) + assert.Contains(t, excluded, "test-org/model-experimental") + assert.Contains(t, excluded, "test-org/model-draft") + }) +} diff --git a/catalog/internal/catalog/preview.go b/catalog/internal/catalog/preview.go new file mode 100644 index 0000000000..397f11e3b1 --- /dev/null +++ b/catalog/internal/catalog/preview.go @@ -0,0 +1,161 @@ +package catalog + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/golang/glog" + model "github.com/kubeflow/model-registry/catalog/pkg/openapi" + "k8s.io/apimachinery/pkg/util/yaml" +) + +// PreviewConfig represents the parsed preview request configuration. +type PreviewConfig struct { + Type string `json:"type" yaml:"type"` + IncludedModels []string `json:"includedModels,omitempty" yaml:"includedModels,omitempty"` + ExcludedModels []string `json:"excludedModels,omitempty" yaml:"excludedModels,omitempty"` + Properties map[string]any `json:"properties,omitempty" yaml:"properties,omitempty"` +} + +// ParsePreviewConfig parses the uploaded config bytes into a PreviewConfig. +// Extra fields (like name, id, enabled) are ignored so users can paste +// a full source config entry directly for preview. +func ParsePreviewConfig(configBytes []byte) (*PreviewConfig, error) { + var config PreviewConfig + if err := yaml.Unmarshal(configBytes, &config); err != nil { + return nil, fmt.Errorf("failed to parse config: %w", err) + } + + if config.Type == "" { + return nil, fmt.Errorf("missing required field: type") + } + + // Validate filter patterns early + if err := ValidateSourceFilters(config.IncludedModels, config.ExcludedModels); err != nil { + return nil, err + } + + return &config, nil +} + +// PreviewSourceModels loads models from the source configuration and returns +// preview results showing which models would be included or excluded. +// If catalogDataBytes is provided, it will be used directly instead of reading from yamlCatalogPath. +func PreviewSourceModels(ctx context.Context, config *PreviewConfig, catalogDataBytes []byte) ([]model.ModelPreviewResult, error) { + // Create a ModelFilter from the config + filter, err := NewModelFilter(config.IncludedModels, config.ExcludedModels) + if err != nil { + return nil, fmt.Errorf("invalid filter configuration: %w", err) + } + + // Load all model names from the source (without filtering) + modelNames, err := loadModelNamesFromSource(ctx, config, catalogDataBytes) + if err != nil { + return nil, err + } + + // Create preview results for each model + results := make([]model.ModelPreviewResult, 0, len(modelNames)) + for _, name := range modelNames { + included := filter == nil || filter.Allows(name) + results = append(results, model.ModelPreviewResult{ + Name: name, + Included: included, + }) + } + + return results, nil +} + +// loadModelNamesFromSource loads model names from the specified source type. +// If catalogDataBytes is provided, it takes precedence over reading from file paths. +func loadModelNamesFromSource(ctx context.Context, config *PreviewConfig, catalogDataBytes []byte) ([]string, error) { + switch config.Type { + case "yaml": + return loadYamlModelNames(ctx, config, catalogDataBytes) + case "hf", "huggingface": + return loadHFModelNames(ctx, config) + default: + return nil, fmt.Errorf("unsupported source type: %s", config.Type) + } +} + +// loadHFModelNames fetches model names from the HuggingFace API for preview. +// For HF sources, includedModels specifies which models to fetch from HuggingFace. +// This function calls the HF API to validate models exist and get their actual names. +func loadHFModelNames(ctx context.Context, config *PreviewConfig) ([]string, error) { + if len(config.IncludedModels) == 0 { + return nil, fmt.Errorf("includedModels is required for HuggingFace source preview (specifies which models to fetch from HuggingFace)") + } + + // SECURITY: Override the URL property to prevent SSRF attacks. + // An attacker could otherwise set a custom URL to leak the HF API key + // to an attacker-controlled domain. + // We ensure requests only go to the official HuggingFace API. + if config.Properties == nil { + config.Properties = make(map[string]any) + } + + if customURL, exists := config.Properties["url"]; exists { + glog.Warningf("HuggingFace preview: custom URL %q was ignored for security reasons (SSRF prevention)", customURL) + delete(config.Properties, "url") + } + + // Create HF preview provider (reuses hfModelProvider from hf_catalog.go) + provider, err := NewHFPreviewProvider(config) + if err != nil { + return nil, err + } + + // Fetch model names from HuggingFace API + return provider.FetchModelNamesForPreview(ctx, config.IncludedModels) +} + +// loadYamlModelNames loads model names from a YAML catalog. +// If catalogDataBytes is provided (stateless mode), it is used directly. +// Otherwise, the catalog is read from yamlCatalogPath (path mode). +func loadYamlModelNames(ctx context.Context, config *PreviewConfig, catalogDataBytes []byte) ([]string, error) { + var catalogBytes []byte + + if len(catalogDataBytes) > 0 { + // Stateless mode: use provided catalog data directly + catalogBytes = catalogDataBytes + } else { + // Path mode: read from yamlCatalogPath + path, ok := config.Properties[yamlCatalogPathKey].(string) + if !ok || path == "" { + return nil, fmt.Errorf("missing required property: %s (provide catalogData file or set yamlCatalogPath in config)", yamlCatalogPathKey) + } + + // Resolve relative paths - for preview, we use the current working directory + if !filepath.IsAbs(path) { + cwd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("failed to get working directory: %w", err) + } + path = filepath.Join(cwd, path) + } + + // Read the catalog file + var err error + catalogBytes, err = os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read catalog file %s: %w", path, err) + } + } + + var catalog yamlCatalog + if err := yaml.UnmarshalStrict(catalogBytes, &catalog); err != nil { + return nil, fmt.Errorf("failed to parse catalog file: %w", err) + } + + // Extract model names + names := make([]string, 0, len(catalog.Models)) + for _, m := range catalog.Models { + names = append(names, m.Name) + } + + return names, nil +} diff --git a/catalog/internal/catalog/preview_test.go b/catalog/internal/catalog/preview_test.go new file mode 100644 index 0000000000..45b9e27792 --- /dev/null +++ b/catalog/internal/catalog/preview_test.go @@ -0,0 +1,684 @@ +package catalog + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParsePreviewConfig(t *testing.T) { + tests := []struct { + name string + configYAML string + expectError bool + errorMsg string + validate func(t *testing.T, config *PreviewConfig) + }{ + { + name: "valid config with all fields", + configYAML: ` +type: yaml +includedModels: + - "Granite/*" + - "Llama/*" +excludedModels: + - "*-draft" + - "*-experimental" +properties: + yamlCatalogPath: "/path/to/catalog.yaml" +`, + expectError: false, + validate: func(t *testing.T, config *PreviewConfig) { + assert.Equal(t, "yaml", config.Type) + assert.Equal(t, []string{"Granite/*", "Llama/*"}, config.IncludedModels) + assert.Equal(t, []string{"*-draft", "*-experimental"}, config.ExcludedModels) + assert.Equal(t, "/path/to/catalog.yaml", config.Properties["yamlCatalogPath"]) + }, + }, + { + name: "valid config with only type", + configYAML: ` +type: yaml +properties: + yamlCatalogPath: "/path/to/catalog.yaml" +`, + expectError: false, + validate: func(t *testing.T, config *PreviewConfig) { + assert.Equal(t, "yaml", config.Type) + assert.Empty(t, config.IncludedModels) + assert.Empty(t, config.ExcludedModels) + }, + }, + { + name: "valid huggingface config", + configYAML: ` +type: hf +includedModels: + - "microsoft/*" +properties: + apiKey: "test-key" + modelLimit: 100 +`, + expectError: false, + validate: func(t *testing.T, config *PreviewConfig) { + assert.Equal(t, "hf", config.Type) + assert.Equal(t, []string{"microsoft/*"}, config.IncludedModels) + assert.Equal(t, "test-key", config.Properties["apiKey"]) + }, + }, + { + name: "missing type field", + configYAML: `includedModels: ["Granite/*"]`, + expectError: true, + errorMsg: "missing required field: type", + }, + { + name: "extra fields from full source config are ignored", + configYAML: ` +name: "Community and Custom Models" +id: community_custom_models +type: yaml +enabled: true +includedModels: + - "Granite/*" +properties: + yamlCatalogPath: "/path/to/catalog.yaml" +`, + expectError: false, + validate: func(t *testing.T, config *PreviewConfig) { + // Extra fields (name, id, enabled) should be ignored + assert.Equal(t, "yaml", config.Type) + assert.Equal(t, []string{"Granite/*"}, config.IncludedModels) + assert.Equal(t, "/path/to/catalog.yaml", config.Properties["yamlCatalogPath"]) + }, + }, + { + name: "empty type field", + configYAML: ` +type: "" +includedModels: + - "Granite/*" +`, + expectError: true, + errorMsg: "missing required field: type", + }, + { + name: "invalid YAML syntax", + configYAML: ` +type: yaml +includedModels: [ + - broken +`, + expectError: true, + errorMsg: "failed to parse config", + }, + { + name: "empty pattern in includedModels", + configYAML: ` +type: yaml +includedModels: + - "Granite/*" + - "" +`, + expectError: true, + errorMsg: "pattern cannot be empty", + }, + { + name: "whitespace-only pattern", + configYAML: ` +type: yaml +includedModels: + - " " +`, + expectError: true, + errorMsg: "pattern cannot be empty", + }, + { + name: "conflicting pattern in both include and exclude", + configYAML: ` +type: yaml +includedModels: + - "Granite/*" +excludedModels: + - "Granite/*" +`, + expectError: true, + errorMsg: "defined in both includedModels", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config, err := ParsePreviewConfig([]byte(tt.configYAML)) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + assert.Nil(t, config) + } else { + require.NoError(t, err) + require.NotNil(t, config) + if tt.validate != nil { + tt.validate(t, config) + } + } + }) + } +} + +func TestPreviewSourceModels(t *testing.T) { + // Create a temporary catalog file for testing + tmpDir := t.TempDir() + catalogPath := filepath.Join(tmpDir, "test-catalog.yaml") + + catalogContent := ` +source: Test Source +models: + - name: Granite/3b-instruct + description: Granite model + - name: Granite/8b-code + description: Another Granite model + - name: Llama/7b-chat + description: Llama model + - name: Llama/13b-chat-draft + description: Draft Llama model + - name: Mistral/7b-instruct + description: Mistral model + - name: DeepSeek/coder-v2 + description: DeepSeek model +` + err := os.WriteFile(catalogPath, []byte(catalogContent), 0644) + require.NoError(t, err) + + tests := []struct { + name string + config *PreviewConfig + expectError bool + errorMsg string + expectedTotal int + expectedIncl int + expectedExcl int + validateNames func(t *testing.T, results []string) + }{ + { + name: "no filters - all models included", + config: &PreviewConfig{ + Type: "yaml", + Properties: map[string]any{"yamlCatalogPath": catalogPath}, + }, + expectedTotal: 6, + expectedIncl: 6, + expectedExcl: 0, + }, + { + name: "include only Granite models", + config: &PreviewConfig{ + Type: "yaml", + IncludedModels: []string{"Granite/*"}, + Properties: map[string]any{"yamlCatalogPath": catalogPath}, + }, + expectedTotal: 6, + expectedIncl: 2, + expectedExcl: 4, + validateNames: func(t *testing.T, included []string) { + assert.Contains(t, included, "Granite/3b-instruct") + assert.Contains(t, included, "Granite/8b-code") + }, + }, + { + name: "include Granite and Llama", + config: &PreviewConfig{ + Type: "yaml", + IncludedModels: []string{"Granite/*", "Llama/*"}, + Properties: map[string]any{"yamlCatalogPath": catalogPath}, + }, + expectedTotal: 6, + expectedIncl: 4, + expectedExcl: 2, + }, + { + name: "exclude draft models", + config: &PreviewConfig{ + Type: "yaml", + ExcludedModels: []string{"*-draft"}, + Properties: map[string]any{"yamlCatalogPath": catalogPath}, + }, + expectedTotal: 6, + expectedIncl: 5, + expectedExcl: 1, + validateNames: func(t *testing.T, included []string) { + assert.NotContains(t, included, "Llama/13b-chat-draft") + }, + }, + { + name: "include Llama but exclude drafts", + config: &PreviewConfig{ + Type: "yaml", + IncludedModels: []string{"Llama/*"}, + ExcludedModels: []string{"*-draft"}, + Properties: map[string]any{"yamlCatalogPath": catalogPath}, + }, + expectedTotal: 6, + expectedIncl: 1, // Only Llama/7b-chat + expectedExcl: 5, + validateNames: func(t *testing.T, included []string) { + assert.Equal(t, []string{"Llama/7b-chat"}, included) + }, + }, + { + name: "case insensitive matching", + config: &PreviewConfig{ + Type: "yaml", + IncludedModels: []string{"granite/*"}, // lowercase + Properties: map[string]any{"yamlCatalogPath": catalogPath}, + }, + expectedTotal: 6, + expectedIncl: 2, // Should match Granite/* + expectedExcl: 4, + }, + { + name: "wildcard in middle of pattern", + config: &PreviewConfig{ + Type: "yaml", + IncludedModels: []string{"*/7b-*"}, + Properties: map[string]any{"yamlCatalogPath": catalogPath}, + }, + expectedTotal: 6, + expectedIncl: 2, // Llama/7b-chat and Mistral/7b-instruct + expectedExcl: 4, + }, + { + name: "unsupported source type", + config: &PreviewConfig{ + Type: "unknown", + Properties: map[string]any{}, + }, + expectError: true, + errorMsg: "unsupported source type", + }, + { + name: "huggingface requires includedModels", + config: &PreviewConfig{ + Type: "hf", + Properties: map[string]any{}, + }, + expectError: true, + errorMsg: "includedModels is required for HuggingFace source preview", + }, + { + name: "missing yamlCatalogPath property", + config: &PreviewConfig{ + Type: "yaml", + Properties: map[string]any{}, + }, + expectError: true, + errorMsg: "missing required property", + }, + { + name: "catalog file not found", + config: &PreviewConfig{ + Type: "yaml", + Properties: map[string]any{"yamlCatalogPath": "/nonexistent/path.yaml"}, + }, + expectError: true, + errorMsg: "failed to read catalog file", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + results, err := PreviewSourceModels(ctx, tt.config, nil) // nil = path-based mode + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + return + } + + require.NoError(t, err) + require.NotNil(t, results) + + // Count included and excluded + var includedCount, excludedCount int + var includedNames []string + for _, r := range results { + if r.Included { + includedCount++ + includedNames = append(includedNames, r.Name) + } else { + excludedCount++ + } + } + + assert.Equal(t, tt.expectedTotal, len(results), "total models mismatch") + assert.Equal(t, tt.expectedIncl, includedCount, "included count mismatch") + assert.Equal(t, tt.expectedExcl, excludedCount, "excluded count mismatch") + + if tt.validateNames != nil { + tt.validateNames(t, includedNames) + } + }) + } +} + +func TestLoadYamlModelNames(t *testing.T) { + tmpDir := t.TempDir() + + tests := []struct { + name string + catalogContent string + setupConfig func(path string) *PreviewConfig + expectError bool + errorMsg string + expectedNames []string + }{ + { + name: "valid catalog with multiple models", + catalogContent: ` +source: Test +models: + - name: Model/A + - name: Model/B + - name: Model/C +`, + setupConfig: func(path string) *PreviewConfig { + return &PreviewConfig{ + Type: "yaml", + Properties: map[string]any{"yamlCatalogPath": path}, + } + }, + expectedNames: []string{"Model/A", "Model/B", "Model/C"}, + }, + { + name: "empty models list", + catalogContent: ` +source: Empty +models: [] +`, + setupConfig: func(path string) *PreviewConfig { + return &PreviewConfig{ + Type: "yaml", + Properties: map[string]any{"yamlCatalogPath": path}, + } + }, + expectedNames: []string{}, + }, + { + name: "catalog with artifacts", + catalogContent: ` +source: With Artifacts +models: + - name: Model/WithArtifacts + description: Has artifacts + artifacts: + - uri: oci://test/artifact:v1 + customProperties: + hardware_type: + metadataType: MetadataStringValue + string_value: GPU +`, + setupConfig: func(path string) *PreviewConfig { + return &PreviewConfig{ + Type: "yaml", + Properties: map[string]any{"yamlCatalogPath": path}, + } + }, + expectedNames: []string{"Model/WithArtifacts"}, + }, + { + name: "invalid YAML content", + catalogContent: `not: valid: yaml: [`, + setupConfig: func(path string) *PreviewConfig { + return &PreviewConfig{ + Type: "yaml", + Properties: map[string]any{"yamlCatalogPath": path}, + } + }, + expectError: true, + errorMsg: "failed to parse catalog file", + }, + { + name: "relative path resolution", + catalogContent: `source: Relative +models: + - name: Relative/Model +`, + setupConfig: func(path string) *PreviewConfig { + // Use just the filename (relative path) + return &PreviewConfig{ + Type: "yaml", + Properties: map[string]any{"yamlCatalogPath": filepath.Base(path)}, + } + }, + // This will fail because relative path is resolved from cwd, not tmpDir + expectError: true, + errorMsg: "failed to read catalog file", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Write catalog file + catalogPath := filepath.Join(tmpDir, tt.name+".yaml") + err := os.WriteFile(catalogPath, []byte(tt.catalogContent), 0644) + require.NoError(t, err) + + config := tt.setupConfig(catalogPath) + ctx := context.Background() + + names, err := loadYamlModelNames(ctx, config, nil) // nil = path-based mode + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectedNames, names) + }) + } +} + +func TestPreviewSourceModels_StatelessMode(t *testing.T) { + // Test stateless mode where catalog data is passed directly + catalogData := []byte(` +source: Stateless Test +models: + - name: Granite/3b-instruct + description: Granite model + - name: Llama/7b-chat + description: Llama model + - name: Mistral/7b-draft + description: Draft model +`) + + t.Run("stateless mode with catalog data", func(t *testing.T) { + config := &PreviewConfig{ + Type: "yaml", + IncludedModels: []string{"Granite/*", "Llama/*"}, + ExcludedModels: []string{"*-draft"}, + // No yamlCatalogPath needed in stateless mode + } + + results, err := PreviewSourceModels(context.Background(), config, catalogData) + require.NoError(t, err) + require.Len(t, results, 3) + + var included []string + for _, r := range results { + if r.Included { + included = append(included, r.Name) + } + } + + assert.Len(t, included, 2) + assert.Contains(t, included, "Granite/3b-instruct") + assert.Contains(t, included, "Llama/7b-chat") + }) + + t.Run("stateless mode takes precedence over path", func(t *testing.T) { + // Even with a yamlCatalogPath, catalog data should be used + config := &PreviewConfig{ + Type: "yaml", + Properties: map[string]any{"yamlCatalogPath": "/nonexistent/path.yaml"}, + } + + results, err := PreviewSourceModels(context.Background(), config, catalogData) + require.NoError(t, err) + assert.Len(t, results, 3) + }) + + t.Run("stateless mode with empty catalog data falls back to path", func(t *testing.T) { + config := &PreviewConfig{ + Type: "yaml", + Properties: map[string]any{}, // No path either + } + + _, err := PreviewSourceModels(context.Background(), config, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing required property") + }) + + t.Run("stateless mode with invalid catalog data", func(t *testing.T) { + config := &PreviewConfig{ + Type: "yaml", + } + + invalidData := []byte("not: valid: yaml: [") + _, err := PreviewSourceModels(context.Background(), config, invalidData) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse catalog file") + }) +} + +// Note: HuggingFace preview tests that require API calls are in hf_preview_test.go +// with proper mock HTTP servers. The tests below only test error conditions that +// don't require API calls. + +func TestPreviewSourceModels_HuggingFace_Errors(t *testing.T) { + t.Run("hf preview with empty includedModels returns error", func(t *testing.T) { + config := &PreviewConfig{ + Type: "hf", + IncludedModels: []string{}, + } + + _, err := PreviewSourceModels(context.Background(), config, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "includedModels is required") + }) + + t.Run("hf preview without API key returns error", func(t *testing.T) { + // Ensure HF_API_KEY is not set + oldKey := os.Getenv("HF_API_KEY") + os.Unsetenv("HF_API_KEY") + defer func() { + if oldKey != "" { + os.Setenv("HF_API_KEY", oldKey) + } + }() + + config := &PreviewConfig{ + Type: "hf", + IncludedModels: []string{ + "meta-llama/Llama-2-7b-chat", + }, + } + + _, err := PreviewSourceModels(context.Background(), config, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "HF_API_KEY") + }) +} + +func TestPreviewSourceModels_FilterBehavior(t *testing.T) { + // Create a temporary catalog file + tmpDir := t.TempDir() + catalogPath := filepath.Join(tmpDir, "filter-test.yaml") + + catalogContent := ` +source: Filter Test +models: + - name: ibm-granite/code-base-8b + - name: ibm-granite/code-instruct-8b + - name: ibm-granite/lab-model-experimental + - name: meta-llama/Llama-2-7b + - name: meta-llama/Llama-2-7b-draft + - name: mistralai/Mistral-7B-Instruct-v0.1 +` + err := os.WriteFile(catalogPath, []byte(catalogContent), 0644) + require.NoError(t, err) + + t.Run("exclusions take precedence over inclusions", func(t *testing.T) { + config := &PreviewConfig{ + Type: "yaml", + IncludedModels: []string{"ibm-granite/*"}, + ExcludedModels: []string{"*-experimental"}, + Properties: map[string]any{"yamlCatalogPath": catalogPath}, + } + + results, err := PreviewSourceModels(context.Background(), config, nil) + require.NoError(t, err) + + // Should include ibm-granite models except experimental + var included []string + for _, r := range results { + if r.Included { + included = append(included, r.Name) + } + } + + assert.Len(t, included, 2) + assert.Contains(t, included, "ibm-granite/code-base-8b") + assert.Contains(t, included, "ibm-granite/code-instruct-8b") + assert.NotContains(t, included, "ibm-granite/lab-model-experimental") + }) + + t.Run("multiple include patterns work as OR", func(t *testing.T) { + config := &PreviewConfig{ + Type: "yaml", + IncludedModels: []string{"ibm-granite/*", "meta-llama/*"}, + Properties: map[string]any{"yamlCatalogPath": catalogPath}, + } + + results, err := PreviewSourceModels(context.Background(), config, nil) + require.NoError(t, err) + + var includedCount int + for _, r := range results { + if r.Included { + includedCount++ + } + } + + // 3 ibm-granite + 2 meta-llama = 5 + assert.Equal(t, 5, includedCount) + }) + + t.Run("multiple exclude patterns work as OR", func(t *testing.T) { + config := &PreviewConfig{ + Type: "yaml", + ExcludedModels: []string{"*-experimental", "*-draft"}, + Properties: map[string]any{"yamlCatalogPath": catalogPath}, + } + + results, err := PreviewSourceModels(context.Background(), config, nil) + require.NoError(t, err) + + var excluded []string + for _, r := range results { + if !r.Included { + excluded = append(excluded, r.Name) + } + } + + assert.Len(t, excluded, 2) + assert.Contains(t, excluded, "ibm-granite/lab-model-experimental") + assert.Contains(t, excluded, "meta-llama/Llama-2-7b-draft") + }) +} diff --git a/catalog/internal/server/openapi/.openapi-generator/FILES b/catalog/internal/server/openapi/.openapi-generator/FILES index 344857afc8..a01668ebbf 100644 --- a/catalog/internal/server/openapi/.openapi-generator/FILES +++ b/catalog/internal/server/openapi/.openapi-generator/FILES @@ -19,6 +19,8 @@ model_catalog_model_artifact.go model_catalog_model_list.go model_catalog_source.go model_catalog_source_list.go +model_catalog_source_preview_response.go +model_catalog_source_preview_response_all_of_summary.go model_error.go model_filter_option.go model_filter_option_range.go @@ -30,6 +32,7 @@ model_metadata_proto_value.go model_metadata_string_value.go model_metadata_struct_value.go model_metadata_value.go +model_model_preview_result.go model_order_by_field.go model_sort_order.go routers.go diff --git a/catalog/internal/server/openapi/api.go b/catalog/internal/server/openapi/api.go index 9b6a2e2ec2..20e00c8341 100644 --- a/catalog/internal/server/openapi/api.go +++ b/catalog/internal/server/openapi/api.go @@ -13,6 +13,7 @@ package openapi import ( "context" "net/http" + "os" model "github.com/kubeflow/model-registry/catalog/pkg/openapi" ) @@ -25,6 +26,7 @@ type ModelCatalogServiceAPIRouter interface { FindModels(http.ResponseWriter, *http.Request) FindModelsFilterOptions(http.ResponseWriter, *http.Request) FindSources(http.ResponseWriter, *http.Request) + PreviewCatalogSource(http.ResponseWriter, *http.Request) GetModel(http.ResponseWriter, *http.Request) GetAllModelArtifacts(http.ResponseWriter, *http.Request) GetAllModelPerformanceArtifacts(http.ResponseWriter, *http.Request) @@ -39,6 +41,7 @@ type ModelCatalogServiceAPIServicer interface { FindModels(context.Context, []string, string, []string, string, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) FindModelsFilterOptions(context.Context) (ImplResponse, error) FindSources(context.Context, string, string, model.OrderByField, model.SortOrder, string) (ImplResponse, error) + PreviewCatalogSource(context.Context, *os.File, string, string, string, *os.File) (ImplResponse, error) GetModel(context.Context, string, string) (ImplResponse, error) GetAllModelArtifacts(context.Context, string, string, []model.ArtifactTypeQueryParam, []model.ArtifactTypeQueryParam, string, string, string, model.SortOrder, string) (ImplResponse, error) GetAllModelPerformanceArtifacts(context.Context, string, string, int32, bool, string, string, string, string, string, string, string, model.SortOrder, string) (ImplResponse, error) diff --git a/catalog/internal/server/openapi/api_model_catalog_service.go b/catalog/internal/server/openapi/api_model_catalog_service.go index fbe62ba074..38a2a0171d 100644 --- a/catalog/internal/server/openapi/api_model_catalog_service.go +++ b/catalog/internal/server/openapi/api_model_catalog_service.go @@ -12,6 +12,7 @@ package openapi import ( "net/http" + "os" "strings" "github.com/go-chi/chi/v5" @@ -76,6 +77,12 @@ func (c *ModelCatalogServiceAPIController) Routes() Routes { "/api/model_catalog/v1alpha1/sources", c.FindSources, }, + "PreviewCatalogSource": Route{ + "PreviewCatalogSource", + strings.ToUpper("Post"), + "/api/model_catalog/v1alpha1/sources/preview", + c.PreviewCatalogSource, + }, "GetModel": Route{ "GetModel", strings.ToUpper("Get"), @@ -124,6 +131,12 @@ func (c *ModelCatalogServiceAPIController) OrderedRoutes() []Route { "/api/model_catalog/v1alpha1/sources", c.FindSources, }, + Route{ + "PreviewCatalogSource", + strings.ToUpper("Post"), + "/api/model_catalog/v1alpha1/sources/preview", + c.PreviewCatalogSource, + }, Route{ "GetModel", strings.ToUpper("Get"), @@ -321,6 +334,75 @@ func (c *ModelCatalogServiceAPIController) FindSources(w http.ResponseWriter, r _ = EncodeJSONResponse(result.Body, &result.Code, w) } +// PreviewCatalogSource - Preview catalog source configuration +func (c *ModelCatalogServiceAPIController) PreviewCatalogSource(w http.ResponseWriter, r *http.Request) { + if err := r.ParseMultipartForm(32 << 20); err != nil { + c.errorHandler(w, r, &ParsingError{Err: err}, nil) + return + } + query, err := parseQuery(r.URL.RawQuery) + if err != nil { + c.errorHandler(w, r, &ParsingError{Err: err}, nil) + return + } + var configParam *os.File + { + param, err := ReadFormFileToTempFile(r, "config") + if err != nil { + c.errorHandler(w, r, &ParsingError{Param: "config", Err: err}, nil) + return + } + + configParam = param + } + + var pageSizeParam string + if query.Has("pageSize") { + param := query.Get("pageSize") + + pageSizeParam = param + } else { + } + var nextPageTokenParam string + if query.Has("nextPageToken") { + param := query.Get("nextPageToken") + + nextPageTokenParam = param + } else { + } + var filterStatusParam string + if query.Has("filterStatus") { + param := query.Get("filterStatus") + + filterStatusParam = param + } else { + param := "all" + filterStatusParam = param + } + var catalogDataParam *os.File + { + param, err := ReadFormFileToTempFile(r, "catalogData") + if err != nil { + // Optional file parameter - ignore missing file error + if err != http.ErrMissingFile { + c.errorHandler(w, r, &ParsingError{Param: "catalogData", Err: err}, nil) + return + } + } + + catalogDataParam = param + } + + result, err := c.service.PreviewCatalogSource(r.Context(), configParam, pageSizeParam, nextPageTokenParam, filterStatusParam, catalogDataParam) + // If an error occurred, encode the error with the status code + if err != nil { + c.errorHandler(w, r, err, &result) + return + } + // If no error, encode the body and the result code + _ = EncodeJSONResponse(result.Body, &result.Code, w) +} + // GetModel - Get a `CatalogModel`. func (c *ModelCatalogServiceAPIController) GetModel(w http.ResponseWriter, r *http.Request) { sourceIdParam := chi.URLParam(r, "source_id") diff --git a/catalog/internal/server/openapi/api_model_catalog_service_service.go b/catalog/internal/server/openapi/api_model_catalog_service_service.go index 1f1cc068d6..a36c308157 100644 --- a/catalog/internal/server/openapi/api_model_catalog_service_service.go +++ b/catalog/internal/server/openapi/api_model_catalog_service_service.go @@ -7,6 +7,7 @@ import ( "math" "net/http" "net/url" + "os" "slices" "strconv" "strings" @@ -316,6 +317,113 @@ func (m *ModelCatalogServiceAPIService) FindSources(ctx context.Context, name st return Response(http.StatusOK, res), nil } +func (m *ModelCatalogServiceAPIService) PreviewCatalogSource(ctx context.Context, configParam *os.File, pageSizeParam string, nextPageTokenParam string, filterStatusParam string, catalogDataParam *os.File) (ImplResponse, error) { + // Parse page size + pageSize := int32(10) + if pageSizeParam != "" { + parsed, err := strconv.ParseInt(pageSizeParam, 10, 32) + if err != nil { + return ErrorResponse(http.StatusBadRequest, fmt.Errorf("invalid pageSize: %w", err)), err + } + pageSize = int32(parsed) + } + + // Parse filterStatus (default: "all") + filterStatus := "all" + if filterStatusParam != "" { + filterStatus = strings.ToLower(filterStatusParam) + if filterStatus != "all" && filterStatus != "included" && filterStatus != "excluded" { + err := fmt.Errorf("invalid filterStatus: must be 'all', 'included', or 'excluded'") + return ErrorResponse(http.StatusBadRequest, err), err + } + } + + // Read and parse the uploaded config file + if configParam == nil { + err := errors.New("config file is required") + return ErrorResponse(http.StatusBadRequest, err), err + } + defer configParam.Close() + + configBytes, err := os.ReadFile(configParam.Name()) + if err != nil { + return ErrorResponse(http.StatusBadRequest, fmt.Errorf("failed to read config file: %w", err)), err + } + + // Read catalog data if provided (stateless mode) + var catalogDataBytes []byte + if catalogDataParam != nil { + defer catalogDataParam.Close() + catalogDataBytes, err = os.ReadFile(catalogDataParam.Name()) + if err != nil { + return ErrorResponse(http.StatusBadRequest, fmt.Errorf("failed to read catalogData file: %w", err)), err + } + } + + // Parse the config as a preview request + previewRequest, err := catalog.ParsePreviewConfig(configBytes) + if err != nil { + return ErrorResponse(http.StatusUnprocessableEntity, fmt.Errorf("invalid config: %w", err)), err + } + + // Load models using the preview service + previewResults, err := catalog.PreviewSourceModels(ctx, previewRequest, catalogDataBytes) + if err != nil { + return ErrorResponse(http.StatusUnprocessableEntity, fmt.Errorf("failed to load models: %w", err)), err + } + + // Filter by status + var filteredResults []model.ModelPreviewResult + for _, result := range previewResults { + switch filterStatus { + case "included": + if result.Included { + filteredResults = append(filteredResults, result) + } + case "excluded": + if !result.Included { + filteredResults = append(filteredResults, result) + } + default: // "all" + filteredResults = append(filteredResults, result) + } + } + + // Calculate summary from ALL results (not filtered) + var includedCount, excludedCount int32 + for _, result := range previewResults { + if result.Included { + includedCount++ + } else { + excludedCount++ + } + } + + summary := model.CatalogSourcePreviewResponseAllOfSummary{ + TotalModels: int32(len(previewResults)), + IncludedModels: includedCount, + ExcludedModels: excludedCount, + } + + // Apply pagination + paginator, err := newPaginator[model.ModelPreviewResult](pageSizeParam, model.OrderByField(""), model.SortOrder(""), nextPageTokenParam) + if err != nil { + return ErrorResponse(http.StatusBadRequest, err), err + } + + pagedResults, next := paginator.Paginate(filteredResults) + + response := model.CatalogSourcePreviewResponse{ + PageSize: pageSize, + Size: int32(len(pagedResults)), + NextPageToken: next.Token(), + Items: pagedResults, + Summary: summary, + } + + return Response(http.StatusOK, response), nil +} + func genCatalogCmpFunc(orderBy model.OrderByField, sortOrder model.SortOrder) (func(model.CatalogSource, model.CatalogSource) int, error) { multiplier := 1 switch model.SortOrder(strings.ToUpper(string(sortOrder))) { diff --git a/catalog/internal/server/openapi/error.go b/catalog/internal/server/openapi/error.go index 628fe717e9..23f4b70bae 100644 --- a/catalog/internal/server/openapi/error.go +++ b/catalog/internal/server/openapi/error.go @@ -34,6 +34,11 @@ func (e *ParsingError) Error() string { return e.Err.Error() } + // Provide more helpful error messages for common cases + if errors.Is(e.Err, http.ErrMissingFile) { + return fmt.Sprintf("required file field '%s' is missing from the request. Check that the field name is spelled correctly in your multipart form data.", e.Param) + } + return e.Param + ": " + e.Err.Error() } diff --git a/catalog/internal/server/openapi/pagination.go b/catalog/internal/server/openapi/pagination.go index 26ce4423f7..dfac000d71 100644 --- a/catalog/internal/server/openapi/pagination.go +++ b/catalog/internal/server/openapi/pagination.go @@ -35,6 +35,9 @@ func newPaginator[T model.Sortable](pageSize string, orderBy model.OrderByField, if err != nil { return nil, fmt.Errorf("error converting page size to int32: %w", err) } + if pageSize64 < 1 { + return nil, fmt.Errorf("pageSize must be at least 1, got %d", pageSize64) + } p.PageSize = int32(pageSize64) } diff --git a/catalog/internal/server/openapi/pagination_test.go b/catalog/internal/server/openapi/pagination_test.go index 6cc6eab40a..696cd73d8c 100644 --- a/catalog/internal/server/openapi/pagination_test.go +++ b/catalog/internal/server/openapi/pagination_test.go @@ -149,6 +149,52 @@ func TestPaginateSources(t *testing.T) { } } +func TestNewPaginator_InvalidPageSize(t *testing.T) { + testCases := []struct { + name string + pageSize string + expectError bool + errContains string + }{ + { + name: "pageSize=0 returns error", + pageSize: "0", + expectError: true, + errContains: "pageSize must be at least 1", + }, + { + name: "negative pageSize returns error", + pageSize: "-5", + expectError: true, + errContains: "pageSize must be at least 1", + }, + { + name: "valid pageSize=1 works", + pageSize: "1", + expectError: false, + }, + { + name: "empty pageSize uses default", + pageSize: "", + expectError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + paginator, err := newPaginator[model.CatalogSource](tc.pageSize, "ID", "", "") + if tc.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.errContains) + assert.Nil(t, paginator) + } else { + assert.NoError(t, err) + assert.NotNil(t, paginator) + } + }) + } +} + func TestPaginateSources_NoDuplicates(t *testing.T) { allSources := createCatalogSources(100) pageSize := "10" diff --git a/catalog/internal/server/openapi/type_asserts.go b/catalog/internal/server/openapi/type_asserts.go index 6b866aae0b..f5368f0088 100644 --- a/catalog/internal/server/openapi/type_asserts.go +++ b/catalog/internal/server/openapi/type_asserts.go @@ -277,6 +277,66 @@ func AssertCatalogSourceListRequired(obj model.CatalogSourceList) error { return nil } +// AssertCatalogSourcePreviewResponseAllOfSummaryConstraints checks if the values respects the defined constraints +func AssertCatalogSourcePreviewResponseAllOfSummaryConstraints(obj model.CatalogSourcePreviewResponseAllOfSummary) error { + return nil +} + +// AssertCatalogSourcePreviewResponseAllOfSummaryRequired checks if the required fields are not zero-ed +func AssertCatalogSourcePreviewResponseAllOfSummaryRequired(obj model.CatalogSourcePreviewResponseAllOfSummary) error { + elements := map[string]interface{}{ + "totalModels": obj.TotalModels, + "includedModels": obj.IncludedModels, + "excludedModels": obj.ExcludedModels, + } + for name, el := range elements { + if isZero := IsZeroValue(el); isZero { + return &RequiredError{Field: name} + } + } + + return nil +} + +// AssertCatalogSourcePreviewResponseConstraints checks if the values respects the defined constraints +func AssertCatalogSourcePreviewResponseConstraints(obj model.CatalogSourcePreviewResponse) error { + for _, el := range obj.Items { + if err := AssertModelPreviewResultConstraints(el); err != nil { + return err + } + } + if err := AssertCatalogSourcePreviewResponseAllOfSummaryConstraints(obj.Summary); err != nil { + return err + } + return nil +} + +// AssertCatalogSourcePreviewResponseRequired checks if the required fields are not zero-ed +func AssertCatalogSourcePreviewResponseRequired(obj model.CatalogSourcePreviewResponse) error { + elements := map[string]interface{}{ + "nextPageToken": obj.NextPageToken, + "pageSize": obj.PageSize, + "size": obj.Size, + "items": obj.Items, + "summary": obj.Summary, + } + for name, el := range elements { + if isZero := IsZeroValue(el); isZero { + return &RequiredError{Field: name} + } + } + + for _, el := range obj.Items { + if err := AssertModelPreviewResultRequired(el); err != nil { + return err + } + } + if err := AssertCatalogSourcePreviewResponseAllOfSummaryRequired(obj.Summary); err != nil { + return err + } + return nil +} + // AssertCatalogSourceRequired checks if the required fields are not zero-ed func AssertCatalogSourceRequired(obj model.CatalogSource) error { elements := map[string]interface{}{ @@ -454,6 +514,26 @@ func AssertMetadataStructValueRequired(obj model.MetadataStructValue) error { return nil } +// AssertModelPreviewResultConstraints checks if the values respects the defined constraints +func AssertModelPreviewResultConstraints(obj model.ModelPreviewResult) error { + return nil +} + +// AssertModelPreviewResultRequired checks if the required fields are not zero-ed +func AssertModelPreviewResultRequired(obj model.ModelPreviewResult) error { + elements := map[string]interface{}{ + "name": obj.Name, + "included": obj.Included, + } + for name, el := range elements { + if isZero := IsZeroValue(el); isZero { + return &RequiredError{Field: name} + } + } + + return nil +} + // AssertOrderByFieldConstraints checks if the values respects the defined constraints func AssertOrderByFieldConstraints(obj model.OrderByField) error { return nil diff --git a/catalog/pkg/openapi/.openapi-generator/FILES b/catalog/pkg/openapi/.openapi-generator/FILES index ced04c21e1..219a90503a 100644 --- a/catalog/pkg/openapi/.openapi-generator/FILES +++ b/catalog/pkg/openapi/.openapi-generator/FILES @@ -16,6 +16,8 @@ model_catalog_model_artifact.go model_catalog_model_list.go model_catalog_source.go model_catalog_source_list.go +model_catalog_source_preview_response.go +model_catalog_source_preview_response_all_of_summary.go model_error.go model_filter_option.go model_filter_option_range.go @@ -27,6 +29,7 @@ model_metadata_proto_value.go model_metadata_string_value.go model_metadata_struct_value.go model_metadata_value.go +model_model_preview_result.go model_order_by_field.go model_sort_order.go response.go diff --git a/catalog/pkg/openapi/api_model_catalog_service.go b/catalog/pkg/openapi/api_model_catalog_service.go index 599b083580..7c9b8d4abe 100644 --- a/catalog/pkg/openapi/api_model_catalog_service.go +++ b/catalog/pkg/openapi/api_model_catalog_service.go @@ -16,6 +16,7 @@ import ( "io" "net/http" "net/url" + "os" "reflect" "strings" ) @@ -1396,3 +1397,240 @@ func (a *ModelCatalogServiceAPIService) GetModelExecute(r ApiGetModelRequest) (* return localVarReturnValue, localVarHTTPResponse, nil } + +type ApiPreviewCatalogSourceRequest struct { + ctx context.Context + ApiService *ModelCatalogServiceAPIService + config *os.File + pageSize *string + nextPageToken *string + filterStatus *string + catalogData *os.File +} + +// YAML file containing the catalog source configuration. The file should contain a source definition with type and properties fields, including optional includedModels and excludedModels filters. Model filter patterns support the `*` wildcard only and are case-insensitive. Patterns match the entire model name (e.g., `ibm-granite/_*` matches all models starting with \\\"ibm-granite/\\\"). +func (r ApiPreviewCatalogSourceRequest) Config(config *os.File) ApiPreviewCatalogSourceRequest { + r.config = config + return r +} + +// Number of entities in each page. +func (r ApiPreviewCatalogSourceRequest) PageSize(pageSize string) ApiPreviewCatalogSourceRequest { + r.pageSize = &pageSize + return r +} + +// Token to use to retrieve next page of results. +func (r ApiPreviewCatalogSourceRequest) NextPageToken(nextPageToken string) ApiPreviewCatalogSourceRequest { + r.nextPageToken = &nextPageToken + return r +} + +// Filter the response to show specific model statuses: - `all` (default): Show all models regardless of inclusion status - `included`: Show only models that pass the configured filters - `excluded`: Show only models that are filtered out +func (r ApiPreviewCatalogSourceRequest) FilterStatus(filterStatus string) ApiPreviewCatalogSourceRequest { + r.filterStatus = &filterStatus + return r +} + +// Optional YAML file containing the catalog data (models). This field enables stateless preview of new sources before saving them. When provided, the catalog data is read directly from this file instead of from the `yamlCatalogPath` property in the config. **Two modes of operation:** 1. **Stateless mode (recommended for new sources):** Upload both `config` and `catalogData` files. The models are read from `catalogData`, allowing preview without saving anything to the server. 2. **Path mode (for existing sources):** Upload only `config` with a `yamlCatalogPath` property pointing to a catalog file on the server filesystem. If both `catalogData` and `yamlCatalogPath` are provided, `catalogData` takes precedence. +func (r ApiPreviewCatalogSourceRequest) CatalogData(catalogData *os.File) ApiPreviewCatalogSourceRequest { + r.catalogData = catalogData + return r +} + +func (r ApiPreviewCatalogSourceRequest) Execute() (*CatalogSourcePreviewResponse, *http.Response, error) { + return r.ApiService.PreviewCatalogSourceExecute(r) +} + +/* +PreviewCatalogSource Preview catalog source configuration + +Accepts a catalog source configuration and returns a list of models showing +which would be included or excluded based on the configured filters. This allows +users to test and validate their source configurations before applying them. + +**Two modes of operation:** + + 1. **Stateless mode (recommended for new sources):** Upload both `config` and + `catalogData` files via multipart form. The models are read directly from + the uploaded `catalogData`, enabling preview of new sources before saving + anything to the server. This is ideal for testing configurations. + + 2. **Path mode (for existing sources):** Upload only `config` with a `yamlCatalogPath` + property. The models are read from the specified file path on the server. + Use this for previewing changes to existing saved sources. + + @param ctx context.Context - for authentication, logging, cancellation, deadlines, tracing, etc. Passed from http.Request or context.Background(). + @return ApiPreviewCatalogSourceRequest +*/ +func (a *ModelCatalogServiceAPIService) PreviewCatalogSource(ctx context.Context) ApiPreviewCatalogSourceRequest { + return ApiPreviewCatalogSourceRequest{ + ApiService: a, + ctx: ctx, + } +} + +// Execute executes the request +// +// @return CatalogSourcePreviewResponse +func (a *ModelCatalogServiceAPIService) PreviewCatalogSourceExecute(r ApiPreviewCatalogSourceRequest) (*CatalogSourcePreviewResponse, *http.Response, error) { + var ( + localVarHTTPMethod = http.MethodPost + localVarPostBody interface{} + formFiles []formFile + localVarReturnValue *CatalogSourcePreviewResponse + ) + + localBasePath, err := a.client.cfg.ServerURLWithContext(r.ctx, "ModelCatalogServiceAPIService.PreviewCatalogSource") + if err != nil { + return localVarReturnValue, nil, &GenericOpenAPIError{error: err.Error()} + } + + localVarPath := localBasePath + "/api/model_catalog/v1alpha1/sources/preview" + + localVarHeaderParams := make(map[string]string) + localVarQueryParams := url.Values{} + localVarFormParams := url.Values{} + if r.config == nil { + return localVarReturnValue, nil, reportError("config is required and must be specified") + } + + if r.pageSize != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "pageSize", r.pageSize, "form", "") + } + if r.nextPageToken != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "nextPageToken", r.nextPageToken, "form", "") + } + if r.filterStatus != nil { + parameterAddToHeaderOrQuery(localVarQueryParams, "filterStatus", r.filterStatus, "form", "") + } else { + var defaultValue string = "all" + parameterAddToHeaderOrQuery(localVarQueryParams, "filterStatus", defaultValue, "form", "") + r.filterStatus = &defaultValue + } + // to determine the Content-Type header + localVarHTTPContentTypes := []string{"multipart/form-data"} + + // set Content-Type header + localVarHTTPContentType := selectHeaderContentType(localVarHTTPContentTypes) + if localVarHTTPContentType != "" { + localVarHeaderParams["Content-Type"] = localVarHTTPContentType + } + + // to determine the Accept header + localVarHTTPHeaderAccepts := []string{"application/json"} + + // set Accept header + localVarHTTPHeaderAccept := selectHeaderAccept(localVarHTTPHeaderAccepts) + if localVarHTTPHeaderAccept != "" { + localVarHeaderParams["Accept"] = localVarHTTPHeaderAccept + } + var configLocalVarFormFileName string + var configLocalVarFileName string + var configLocalVarFileBytes []byte + + configLocalVarFormFileName = "config" + configLocalVarFile := r.config + + if configLocalVarFile != nil { + fbs, _ := io.ReadAll(configLocalVarFile) + + configLocalVarFileBytes = fbs + configLocalVarFileName = configLocalVarFile.Name() + configLocalVarFile.Close() + formFiles = append(formFiles, formFile{fileBytes: configLocalVarFileBytes, fileName: configLocalVarFileName, formFileName: configLocalVarFormFileName}) + } + var catalogDataLocalVarFormFileName string + var catalogDataLocalVarFileName string + var catalogDataLocalVarFileBytes []byte + + catalogDataLocalVarFormFileName = "catalogData" + catalogDataLocalVarFile := r.catalogData + + if catalogDataLocalVarFile != nil { + fbs, _ := io.ReadAll(catalogDataLocalVarFile) + + catalogDataLocalVarFileBytes = fbs + catalogDataLocalVarFileName = catalogDataLocalVarFile.Name() + catalogDataLocalVarFile.Close() + formFiles = append(formFiles, formFile{fileBytes: catalogDataLocalVarFileBytes, fileName: catalogDataLocalVarFileName, formFileName: catalogDataLocalVarFormFileName}) + } + req, err := a.client.prepareRequest(r.ctx, localVarPath, localVarHTTPMethod, localVarPostBody, localVarHeaderParams, localVarQueryParams, localVarFormParams, formFiles) + if err != nil { + return localVarReturnValue, nil, err + } + + localVarHTTPResponse, err := a.client.callAPI(req) + if err != nil || localVarHTTPResponse == nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + localVarBody, err := io.ReadAll(localVarHTTPResponse.Body) + localVarHTTPResponse.Body.Close() + localVarHTTPResponse.Body = io.NopCloser(bytes.NewBuffer(localVarBody)) + if err != nil { + return localVarReturnValue, localVarHTTPResponse, err + } + + if localVarHTTPResponse.StatusCode >= 300 { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: localVarHTTPResponse.Status, + } + if localVarHTTPResponse.StatusCode == 400 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 401 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 422 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + return localVarReturnValue, localVarHTTPResponse, newErr + } + if localVarHTTPResponse.StatusCode == 500 { + var v Error + err = a.client.decode(&v, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr.error = err.Error() + return localVarReturnValue, localVarHTTPResponse, newErr + } + newErr.error = formatErrorMessage(localVarHTTPResponse.Status, &v) + newErr.model = v + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + err = a.client.decode(&localVarReturnValue, localVarBody, localVarHTTPResponse.Header.Get("Content-Type")) + if err != nil { + newErr := &GenericOpenAPIError{ + body: localVarBody, + error: err.Error(), + } + return localVarReturnValue, localVarHTTPResponse, newErr + } + + return localVarReturnValue, localVarHTTPResponse, nil +} diff --git a/catalog/pkg/openapi/model_catalog_source.go b/catalog/pkg/openapi/model_catalog_source.go index 2e35fc98db..e51533b4c9 100644 --- a/catalog/pkg/openapi/model_catalog_source.go +++ b/catalog/pkg/openapi/model_catalog_source.go @@ -27,9 +27,9 @@ type CatalogSource struct { Enabled *bool `json:"enabled,omitempty"` // Labels for the catalog source. Labels []string `json:"labels"` - // Optional allow-list of models that are eligible for this source. Entries can be exact model names or patterns that use `*` as a wildcard. When provided, only models matching at least one pattern are considered. Pattern matching is case-insensitive, so `Granite/_*` will match `granite/model`, `Granite/model`, and `GRANITE/model`. + // Optional list of glob patterns for models to include. If specified, only models matching at least one pattern will be included. If omitted, all models are considered for inclusion. Pattern Syntax: - Only the `*` wildcard is supported (matches zero or more characters) - Patterns are case-insensitive (e.g., `Granite/_*` matches `granite/model` and `GRANITE/model`) - Patterns match the entire model name (anchored at start and end) - Wildcards can appear anywhere: `Granite/_*`, `*-beta`, `*deprecated*`, `*_/old*` Examples: - `ibm-granite/_*` - matches all models starting with \"ibm-granite/\" - `meta-llama/_*` - matches all models in the meta-llama namespace - `*` - matches all models Constraints: - Patterns cannot be empty or whitespace-only - A pattern cannot appear in both includedModels and excludedModels IncludedModels []string `json:"includedModels,omitempty"` - // Optional block-list of models that should be removed from the catalog even if they match `includedModels`. Patterns support the `*` wildcard. Pattern matching is case-insensitive, so `*-beta` will match `Model-Beta`, `model-beta`, and `MODEL-BETA`. + // Optional list of glob patterns for models to exclude. Models matching any pattern will be excluded even if they match an includedModels pattern. Exclusions take precedence over inclusions. Pattern Syntax: - Only the `*` wildcard is supported (matches zero or more characters) - Patterns are case-insensitive - Patterns match the entire model name (anchored at start and end) - Wildcards can appear anywhere in the pattern Examples: - `*-draft` - excludes all models ending with \"-draft\" - `*-experimental` - excludes experimental models - `*deprecated*` - excludes models with \"deprecated\" anywhere in the name - `*_/beta-*` - excludes models with \"/beta-\" in the path Constraints: - Patterns cannot be empty or whitespace-only - A pattern cannot appear in both includedModels and excludedModels ExcludedModels []string `json:"excludedModels,omitempty"` } diff --git a/catalog/pkg/openapi/model_catalog_source_preview_response.go b/catalog/pkg/openapi/model_catalog_source_preview_response.go new file mode 100644 index 0000000000..3a7e6645e9 --- /dev/null +++ b/catalog/pkg/openapi/model_catalog_source_preview_response.go @@ -0,0 +1,229 @@ +/* +Model Catalog REST API + +REST API for Model Registry to create and manage ML model metadata + +API version: v1alpha1 +*/ + +// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. + +package openapi + +import ( + "encoding/json" +) + +// checks if the CatalogSourcePreviewResponse type satisfies the MappedNullable interface at compile time +var _ MappedNullable = &CatalogSourcePreviewResponse{} + +// CatalogSourcePreviewResponse Response containing models and their inclusion/exclusion status. +type CatalogSourcePreviewResponse struct { + // Token to use to retrieve next page of results. + NextPageToken string `json:"nextPageToken"` + // Maximum number of resources to return in the result. + PageSize int32 `json:"pageSize"` + // Number of items in result list. + Size int32 `json:"size"` + // Array of model preview results. + Items []ModelPreviewResult `json:"items"` + Summary CatalogSourcePreviewResponseAllOfSummary `json:"summary"` +} + +type _CatalogSourcePreviewResponse CatalogSourcePreviewResponse + +// NewCatalogSourcePreviewResponse instantiates a new CatalogSourcePreviewResponse object +// This constructor will assign default values to properties that have it defined, +// and makes sure properties required by API are set, but the set of arguments +// will change when the set of required properties is changed +func NewCatalogSourcePreviewResponse(nextPageToken string, pageSize int32, size int32, items []ModelPreviewResult, summary CatalogSourcePreviewResponseAllOfSummary) *CatalogSourcePreviewResponse { + this := CatalogSourcePreviewResponse{} + this.NextPageToken = nextPageToken + this.PageSize = pageSize + this.Size = size + this.Items = items + this.Summary = summary + return &this +} + +// NewCatalogSourcePreviewResponseWithDefaults instantiates a new CatalogSourcePreviewResponse object +// This constructor will only assign default values to properties that have it defined, +// but it doesn't guarantee that properties required by API are set +func NewCatalogSourcePreviewResponseWithDefaults() *CatalogSourcePreviewResponse { + this := CatalogSourcePreviewResponse{} + return &this +} + +// GetNextPageToken returns the NextPageToken field value +func (o *CatalogSourcePreviewResponse) GetNextPageToken() string { + if o == nil { + var ret string + return ret + } + + return o.NextPageToken +} + +// GetNextPageTokenOk returns a tuple with the NextPageToken field value +// and a boolean to check if the value has been set. +func (o *CatalogSourcePreviewResponse) GetNextPageTokenOk() (*string, bool) { + if o == nil { + return nil, false + } + return &o.NextPageToken, true +} + +// SetNextPageToken sets field value +func (o *CatalogSourcePreviewResponse) SetNextPageToken(v string) { + o.NextPageToken = v +} + +// GetPageSize returns the PageSize field value +func (o *CatalogSourcePreviewResponse) GetPageSize() int32 { + if o == nil { + var ret int32 + return ret + } + + return o.PageSize +} + +// GetPageSizeOk returns a tuple with the PageSize field value +// and a boolean to check if the value has been set. +func (o *CatalogSourcePreviewResponse) GetPageSizeOk() (*int32, bool) { + if o == nil { + return nil, false + } + return &o.PageSize, true +} + +// SetPageSize sets field value +func (o *CatalogSourcePreviewResponse) SetPageSize(v int32) { + o.PageSize = v +} + +// GetSize returns the Size field value +func (o *CatalogSourcePreviewResponse) GetSize() int32 { + if o == nil { + var ret int32 + return ret + } + + return o.Size +} + +// GetSizeOk returns a tuple with the Size field value +// and a boolean to check if the value has been set. +func (o *CatalogSourcePreviewResponse) GetSizeOk() (*int32, bool) { + if o == nil { + return nil, false + } + return &o.Size, true +} + +// SetSize sets field value +func (o *CatalogSourcePreviewResponse) SetSize(v int32) { + o.Size = v +} + +// GetItems returns the Items field value +func (o *CatalogSourcePreviewResponse) GetItems() []ModelPreviewResult { + if o == nil { + var ret []ModelPreviewResult + return ret + } + + return o.Items +} + +// GetItemsOk returns a tuple with the Items field value +// and a boolean to check if the value has been set. +func (o *CatalogSourcePreviewResponse) GetItemsOk() ([]ModelPreviewResult, bool) { + if o == nil { + return nil, false + } + return o.Items, true +} + +// SetItems sets field value +func (o *CatalogSourcePreviewResponse) SetItems(v []ModelPreviewResult) { + o.Items = v +} + +// GetSummary returns the Summary field value +func (o *CatalogSourcePreviewResponse) GetSummary() CatalogSourcePreviewResponseAllOfSummary { + if o == nil { + var ret CatalogSourcePreviewResponseAllOfSummary + return ret + } + + return o.Summary +} + +// GetSummaryOk returns a tuple with the Summary field value +// and a boolean to check if the value has been set. +func (o *CatalogSourcePreviewResponse) GetSummaryOk() (*CatalogSourcePreviewResponseAllOfSummary, bool) { + if o == nil { + return nil, false + } + return &o.Summary, true +} + +// SetSummary sets field value +func (o *CatalogSourcePreviewResponse) SetSummary(v CatalogSourcePreviewResponseAllOfSummary) { + o.Summary = v +} + +func (o CatalogSourcePreviewResponse) MarshalJSON() ([]byte, error) { + toSerialize, err := o.ToMap() + if err != nil { + return []byte{}, err + } + return json.Marshal(toSerialize) +} + +func (o CatalogSourcePreviewResponse) ToMap() (map[string]interface{}, error) { + toSerialize := map[string]interface{}{} + toSerialize["nextPageToken"] = o.NextPageToken + toSerialize["pageSize"] = o.PageSize + toSerialize["size"] = o.Size + toSerialize["items"] = o.Items + toSerialize["summary"] = o.Summary + return toSerialize, nil +} + +type NullableCatalogSourcePreviewResponse struct { + value *CatalogSourcePreviewResponse + isSet bool +} + +func (v NullableCatalogSourcePreviewResponse) Get() *CatalogSourcePreviewResponse { + return v.value +} + +func (v *NullableCatalogSourcePreviewResponse) Set(val *CatalogSourcePreviewResponse) { + v.value = val + v.isSet = true +} + +func (v NullableCatalogSourcePreviewResponse) IsSet() bool { + return v.isSet +} + +func (v *NullableCatalogSourcePreviewResponse) Unset() { + v.value = nil + v.isSet = false +} + +func NewNullableCatalogSourcePreviewResponse(val *CatalogSourcePreviewResponse) *NullableCatalogSourcePreviewResponse { + return &NullableCatalogSourcePreviewResponse{value: val, isSet: true} +} + +func (v NullableCatalogSourcePreviewResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *NullableCatalogSourcePreviewResponse) UnmarshalJSON(src []byte) error { + v.isSet = true + return json.Unmarshal(src, &v.value) +} diff --git a/catalog/pkg/openapi/model_catalog_source_preview_response_all_of_summary.go b/catalog/pkg/openapi/model_catalog_source_preview_response_all_of_summary.go new file mode 100644 index 0000000000..e903cb82bf --- /dev/null +++ b/catalog/pkg/openapi/model_catalog_source_preview_response_all_of_summary.go @@ -0,0 +1,174 @@ +/* +Model Catalog REST API + +REST API for Model Registry to create and manage ML model metadata + +API version: v1alpha1 +*/ + +// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. + +package openapi + +import ( + "encoding/json" +) + +// checks if the CatalogSourcePreviewResponseAllOfSummary type satisfies the MappedNullable interface at compile time +var _ MappedNullable = &CatalogSourcePreviewResponseAllOfSummary{} + +// CatalogSourcePreviewResponseAllOfSummary Summary of the preview results +type CatalogSourcePreviewResponseAllOfSummary struct { + // Total number of models evaluated + TotalModels int32 `json:"totalModels"` + // Number of models that would be included + IncludedModels int32 `json:"includedModels"` + // Number of models that would be excluded + ExcludedModels int32 `json:"excludedModels"` +} + +type _CatalogSourcePreviewResponseAllOfSummary CatalogSourcePreviewResponseAllOfSummary + +// NewCatalogSourcePreviewResponseAllOfSummary instantiates a new CatalogSourcePreviewResponseAllOfSummary object +// This constructor will assign default values to properties that have it defined, +// and makes sure properties required by API are set, but the set of arguments +// will change when the set of required properties is changed +func NewCatalogSourcePreviewResponseAllOfSummary(totalModels int32, includedModels int32, excludedModels int32) *CatalogSourcePreviewResponseAllOfSummary { + this := CatalogSourcePreviewResponseAllOfSummary{} + this.TotalModels = totalModels + this.IncludedModels = includedModels + this.ExcludedModels = excludedModels + return &this +} + +// NewCatalogSourcePreviewResponseAllOfSummaryWithDefaults instantiates a new CatalogSourcePreviewResponseAllOfSummary object +// This constructor will only assign default values to properties that have it defined, +// but it doesn't guarantee that properties required by API are set +func NewCatalogSourcePreviewResponseAllOfSummaryWithDefaults() *CatalogSourcePreviewResponseAllOfSummary { + this := CatalogSourcePreviewResponseAllOfSummary{} + return &this +} + +// GetTotalModels returns the TotalModels field value +func (o *CatalogSourcePreviewResponseAllOfSummary) GetTotalModels() int32 { + if o == nil { + var ret int32 + return ret + } + + return o.TotalModels +} + +// GetTotalModelsOk returns a tuple with the TotalModels field value +// and a boolean to check if the value has been set. +func (o *CatalogSourcePreviewResponseAllOfSummary) GetTotalModelsOk() (*int32, bool) { + if o == nil { + return nil, false + } + return &o.TotalModels, true +} + +// SetTotalModels sets field value +func (o *CatalogSourcePreviewResponseAllOfSummary) SetTotalModels(v int32) { + o.TotalModels = v +} + +// GetIncludedModels returns the IncludedModels field value +func (o *CatalogSourcePreviewResponseAllOfSummary) GetIncludedModels() int32 { + if o == nil { + var ret int32 + return ret + } + + return o.IncludedModels +} + +// GetIncludedModelsOk returns a tuple with the IncludedModels field value +// and a boolean to check if the value has been set. +func (o *CatalogSourcePreviewResponseAllOfSummary) GetIncludedModelsOk() (*int32, bool) { + if o == nil { + return nil, false + } + return &o.IncludedModels, true +} + +// SetIncludedModels sets field value +func (o *CatalogSourcePreviewResponseAllOfSummary) SetIncludedModels(v int32) { + o.IncludedModels = v +} + +// GetExcludedModels returns the ExcludedModels field value +func (o *CatalogSourcePreviewResponseAllOfSummary) GetExcludedModels() int32 { + if o == nil { + var ret int32 + return ret + } + + return o.ExcludedModels +} + +// GetExcludedModelsOk returns a tuple with the ExcludedModels field value +// and a boolean to check if the value has been set. +func (o *CatalogSourcePreviewResponseAllOfSummary) GetExcludedModelsOk() (*int32, bool) { + if o == nil { + return nil, false + } + return &o.ExcludedModels, true +} + +// SetExcludedModels sets field value +func (o *CatalogSourcePreviewResponseAllOfSummary) SetExcludedModels(v int32) { + o.ExcludedModels = v +} + +func (o CatalogSourcePreviewResponseAllOfSummary) MarshalJSON() ([]byte, error) { + toSerialize, err := o.ToMap() + if err != nil { + return []byte{}, err + } + return json.Marshal(toSerialize) +} + +func (o CatalogSourcePreviewResponseAllOfSummary) ToMap() (map[string]interface{}, error) { + toSerialize := map[string]interface{}{} + toSerialize["totalModels"] = o.TotalModels + toSerialize["includedModels"] = o.IncludedModels + toSerialize["excludedModels"] = o.ExcludedModels + return toSerialize, nil +} + +type NullableCatalogSourcePreviewResponseAllOfSummary struct { + value *CatalogSourcePreviewResponseAllOfSummary + isSet bool +} + +func (v NullableCatalogSourcePreviewResponseAllOfSummary) Get() *CatalogSourcePreviewResponseAllOfSummary { + return v.value +} + +func (v *NullableCatalogSourcePreviewResponseAllOfSummary) Set(val *CatalogSourcePreviewResponseAllOfSummary) { + v.value = val + v.isSet = true +} + +func (v NullableCatalogSourcePreviewResponseAllOfSummary) IsSet() bool { + return v.isSet +} + +func (v *NullableCatalogSourcePreviewResponseAllOfSummary) Unset() { + v.value = nil + v.isSet = false +} + +func NewNullableCatalogSourcePreviewResponseAllOfSummary(val *CatalogSourcePreviewResponseAllOfSummary) *NullableCatalogSourcePreviewResponseAllOfSummary { + return &NullableCatalogSourcePreviewResponseAllOfSummary{value: val, isSet: true} +} + +func (v NullableCatalogSourcePreviewResponseAllOfSummary) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *NullableCatalogSourcePreviewResponseAllOfSummary) UnmarshalJSON(src []byte) error { + v.isSet = true + return json.Unmarshal(src, &v.value) +} diff --git a/catalog/pkg/openapi/model_model_preview_result.go b/catalog/pkg/openapi/model_model_preview_result.go new file mode 100644 index 0000000000..026fc49202 --- /dev/null +++ b/catalog/pkg/openapi/model_model_preview_result.go @@ -0,0 +1,146 @@ +/* +Model Catalog REST API + +REST API for Model Registry to create and manage ML model metadata + +API version: v1alpha1 +*/ + +// Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. + +package openapi + +import ( + "encoding/json" +) + +// checks if the ModelPreviewResult type satisfies the MappedNullable interface at compile time +var _ MappedNullable = &ModelPreviewResult{} + +// ModelPreviewResult A model with its inclusion/exclusion status based on the configured catalog source filters. +type ModelPreviewResult struct { + // Name of the model + Name string `json:"name"` + // Whether this model would be included based on the source configuration + Included bool `json:"included"` +} + +type _ModelPreviewResult ModelPreviewResult + +// NewModelPreviewResult instantiates a new ModelPreviewResult object +// This constructor will assign default values to properties that have it defined, +// and makes sure properties required by API are set, but the set of arguments +// will change when the set of required properties is changed +func NewModelPreviewResult(name string, included bool) *ModelPreviewResult { + this := ModelPreviewResult{} + this.Name = name + this.Included = included + return &this +} + +// NewModelPreviewResultWithDefaults instantiates a new ModelPreviewResult object +// This constructor will only assign default values to properties that have it defined, +// but it doesn't guarantee that properties required by API are set +func NewModelPreviewResultWithDefaults() *ModelPreviewResult { + this := ModelPreviewResult{} + return &this +} + +// GetName returns the Name field value +func (o *ModelPreviewResult) GetName() string { + if o == nil { + var ret string + return ret + } + + return o.Name +} + +// GetNameOk returns a tuple with the Name field value +// and a boolean to check if the value has been set. +func (o *ModelPreviewResult) GetNameOk() (*string, bool) { + if o == nil { + return nil, false + } + return &o.Name, true +} + +// SetName sets field value +func (o *ModelPreviewResult) SetName(v string) { + o.Name = v +} + +// GetIncluded returns the Included field value +func (o *ModelPreviewResult) GetIncluded() bool { + if o == nil { + var ret bool + return ret + } + + return o.Included +} + +// GetIncludedOk returns a tuple with the Included field value +// and a boolean to check if the value has been set. +func (o *ModelPreviewResult) GetIncludedOk() (*bool, bool) { + if o == nil { + return nil, false + } + return &o.Included, true +} + +// SetIncluded sets field value +func (o *ModelPreviewResult) SetIncluded(v bool) { + o.Included = v +} + +func (o ModelPreviewResult) MarshalJSON() ([]byte, error) { + toSerialize, err := o.ToMap() + if err != nil { + return []byte{}, err + } + return json.Marshal(toSerialize) +} + +func (o ModelPreviewResult) ToMap() (map[string]interface{}, error) { + toSerialize := map[string]interface{}{} + toSerialize["name"] = o.Name + toSerialize["included"] = o.Included + return toSerialize, nil +} + +type NullableModelPreviewResult struct { + value *ModelPreviewResult + isSet bool +} + +func (v NullableModelPreviewResult) Get() *ModelPreviewResult { + return v.value +} + +func (v *NullableModelPreviewResult) Set(val *ModelPreviewResult) { + v.value = val + v.isSet = true +} + +func (v NullableModelPreviewResult) IsSet() bool { + return v.isSet +} + +func (v *NullableModelPreviewResult) Unset() { + v.value = nil + v.isSet = false +} + +func NewNullableModelPreviewResult(val *ModelPreviewResult) *NullableModelPreviewResult { + return &NullableModelPreviewResult{value: val, isSet: true} +} + +func (v NullableModelPreviewResult) MarshalJSON() ([]byte, error) { + return json.Marshal(v.value) +} + +func (v *NullableModelPreviewResult) UnmarshalJSON(src []byte) error { + v.isSet = true + return json.Unmarshal(src, &v.value) +} diff --git a/catalog/pkg/openapi/sortable.go b/catalog/pkg/openapi/sortable.go index d14e772ef7..1769f1c070 100644 --- a/catalog/pkg/openapi/sortable.go +++ b/catalog/pkg/openapi/sortable.go @@ -35,3 +35,11 @@ func unrefString(v *string) string { } return *v } + +func (m ModelPreviewResult) SortValue(field OrderByField) string { + switch field { + case ORDERBYFIELD_ID, ORDERBYFIELD_NAME: + return m.Name + } + return "" +} diff --git a/internal/server/openapi/error.go b/internal/server/openapi/error.go index 8dae155879..33ea86a722 100644 --- a/internal/server/openapi/error.go +++ b/internal/server/openapi/error.go @@ -34,6 +34,11 @@ func (e *ParsingError) Error() string { return e.Err.Error() } + // Provide more helpful error messages for common cases + if errors.Is(e.Err, http.ErrMissingFile) { + return fmt.Sprintf("required file field '%s' is missing from the request. Check that the field name is spelled correctly in your multipart form data.", e.Param) + } + return e.Param + ": " + e.Err.Error() } diff --git a/templates/go-server/controller-api.mustache b/templates/go-server/controller-api.mustache index 15dd140e49..7c14511bdb 100644 --- a/templates/go-server/controller-api.mustache +++ b/templates/go-server/controller-api.mustache @@ -564,8 +564,17 @@ func (c *{{classname}}Controller) {{nickname}}(w http.ResponseWriter, r *http.Re param, err := ReadFormFileToTempFile(r, "{{baseName}}") {{/isArray}} if err != nil { + {{^required}} + // Optional file parameter - ignore missing file error + if err != http.ErrMissingFile { + c.errorHandler(w, r, &ParsingError{Param: "{{baseName}}", Err: err}, nil) + return + } + {{/required}} + {{#required}} c.errorHandler(w, r, &ParsingError{Param: "{{baseName}}", Err: err}, nil) return + {{/required}} } {{paramName}}Param = param diff --git a/templates/go-server/error.mustache b/templates/go-server/error.mustache index ace1c0df9d..523caa460e 100644 --- a/templates/go-server/error.mustache +++ b/templates/go-server/error.mustache @@ -25,6 +25,11 @@ func (e *ParsingError) Error() string { return e.Err.Error() } + // Provide more helpful error messages for common cases + if errors.Is(e.Err, http.ErrMissingFile) { + return fmt.Sprintf("required file field '%s' is missing from the request. Check that the field name is spelled correctly in your multipart form data.", e.Param) + } + return e.Param + ": " + e.Err.Error() }