diff --git a/catalog/README.md b/catalog/README.md index bf196f2da1..0b5b765573 100644 --- a/catalog/README.md +++ b/catalog/README.md @@ -13,6 +13,7 @@ The catalog service operates as a **metadata aggregation layer** that: ### Supported Catalog Sources - **YAML Catalog** - Static YAML files containing model metadata +- **HuggingFace Hub** - Discover models from HuggingFace's model repository ## REST API @@ -75,6 +76,76 @@ catalogs: path: "./models" ``` +### HuggingFace Source Configuration + +The HuggingFace catalog source allows you to discover and import models from the HuggingFace Hub. To configure a HuggingFace source: + +#### 1. Set Your API Key + +The HuggingFace provider requires an API key for authentication. By default, the service reads the API key from the `HF_API_KEY` environment variable: + +```bash +export HF_API_KEY="your-huggingface-api-key-here" +``` + +**Getting a HuggingFace API Key:** +1. Sign up or log in to [HuggingFace](https://huggingface.co) +2. Go to your [Settings > Access Tokens](https://huggingface.co/settings/tokens) +3. Create a new token with "Read" permissions +4. Copy the token and set it as an environment variable + +**For Kubernetes deployments:** +- Store the API key in a Kubernetes Secret +- Reference it in your deployment configuration +- The catalog service will read it from the configured environment variable (defaults to `HF_API_KEY`) + +**Custom Environment Variable Name:** +You can configure a custom environment variable name per source by setting the `apiKeyEnvVar` property in your source configuration (see below). This is useful when you need different API keys for different sources. + +**Important Notes:** +- **Private Models**: For private models, the API key must belong to an account that has been granted access to the model. Without proper access, the catalog service will not be able to retrieve model information. +- **Gated Models**: For gated models (models with usage restrictions), you must accept the model's terms of service on HuggingFace before the catalog service can access all available model information. Visit the model's page on HuggingFace and accept the terms to ensure full metadata is available. + +#### 2. Configure the Source + +Add a HuggingFace source to your `catalog-sources.yaml`: + +```yaml +catalogs: + - name: "HuggingFace Hub" + id: "huggingface" + type: "hf" + enabled: true + # Required: List of model identifiers to include + # Format: "organization/model-name" or "username/model-name" + includedModels: + - "meta-llama/Llama-3.1-8B-Instruct" + - "ibm-granite/granite-4.0-h-small" + - "microsoft/phi-2" + + # Optional: Exclude specific models or patterns + # Supports exact matches or patterns ending with "*" + excludedModels: + - "some-org/unwanted-model" + - "another-org/test-*" # Excludes all models starting with "test-" + + # Optional: Configure a custom environment variable name for the API key + # Defaults to "HF_API_KEY" if not specified + properties: + apiKeyEnvVar: "MY_CUSTOM_API_KEY_VAR" +``` + +#### Model Filtering + +Both `includedModels` and `excludedModels` are top-level properties (not nested under `properties`): + +- **`includedModels`** (required): List of model identifiers to fetch from HuggingFace. Format: `"organization/model-name"` or `"username/model-name"` +- **`excludedModels`** (optional): List of models or patterns to exclude from the results + +The `excludedModels` property supports: +- **Exact matches**: `"meta-llama/Llama-3.1-8B-Instruct"` - excludes this specific model +- **Pattern matching**: `"test-*"` - excludes all models starting with "test-" + ## Development ### Prerequisites diff --git a/catalog/internal/catalog/assets/catalog_logo.svg b/catalog/internal/catalog/assets/catalog_logo.svg new file mode 100644 index 0000000000..d1fb11bd47 --- /dev/null +++ b/catalog/internal/catalog/assets/catalog_logo.svg @@ -0,0 +1,2 @@ + + diff --git a/catalog/internal/catalog/hf_catalog.go b/catalog/internal/catalog/hf_catalog.go index 0ae837b749..00b459b338 100644 --- a/catalog/internal/catalog/hf_catalog.go +++ b/catalog/internal/catalog/hf_catalog.go @@ -2,78 +2,587 @@ package catalog import ( "context" + _ "embed" + "encoding/base64" + "encoding/json" "fmt" + "io" "net/http" + "os" + "strconv" "strings" "time" "github.com/golang/glog" - "github.com/kubeflow/model-registry/catalog/pkg/openapi" - model "github.com/kubeflow/model-registry/catalog/pkg/openapi" + dbmodels "github.com/kubeflow/model-registry/catalog/internal/db/models" + apimodels "github.com/kubeflow/model-registry/catalog/pkg/openapi" + "github.com/kubeflow/model-registry/internal/db/models" ) -type hfCatalogImpl struct { - client *http.Client - apiKey string - baseURL string +const ( + defaultHuggingFaceURL = "https://huggingface.co" + defaultAPIKeyEnvVar = "HF_API_KEY" + urlKey = "url" + apiKeyEnvVarKey = "apiKeyEnvVar" +) + +// gatedString is a custom type that can unmarshal both boolean and string values from JSON +// It converts booleans to strings (false -> "false", true -> "true") +type gatedString string + +// UnmarshalJSON implements json.Unmarshaler to handle both boolean and string values +func (g *gatedString) UnmarshalJSON(data []byte) error { + // Handle null/empty + if len(data) == 0 || string(data) == "null" { + *g = gatedString("") + return nil + } + + // Try to unmarshal as boolean first (handles true/false) + var b bool + if err := json.Unmarshal(data, &b); err == nil { + *g = gatedString(strconv.FormatBool(b)) + return nil + } + + // If not a boolean, try as string (handles quoted strings) + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("gated field must be boolean or string, got: %s", string(data)) + } + *g = gatedString(s) + return nil } -var _ APIProvider = &hfCatalogImpl{} +// String returns the string value +func (g gatedString) String() string { + return string(g) +} -const ( - defaultHuggingFaceURL = "https://huggingface.co" +// hfModel implements apimodels.CatalogModel and populates it from HuggingFace API data +type hfModel struct { + apimodels.CatalogModel +} + +type hfModelProvider struct { + client *http.Client + sourceId string + apiKey string + baseURL string + includedModels []string + filter *ModelFilter +} + +// hfModelInfo represents the structure of HuggingFace API model information +type hfModelInfo struct { + ID string `json:"id"` + Author string `json:"author,omitempty"` + Sha string `json:"sha,omitempty"` + CreatedAt string `json:"createdAt,omitempty"` + UpdatedAt string `json:"updatedAt,omitempty"` + Private bool `json:"private,omitempty"` + Gated gatedString `json:"gated,omitempty"` + Downloads int `json:"downloads,omitempty"` + Tags []string `json:"tags,omitempty"` + PipelineTag string `json:"pipeline_tag,omitempty"` + LibraryName string `json:"library_name,omitempty"` + ModelID string `json:"modelId,omitempty"` + Task string `json:"task,omitempty"` + Siblings []hfFile `json:"siblings,omitempty"` + Config *hfConfig `json:"config,omitempty"` + CardData *hfCard `json:"cardData,omitempty"` +} + +type hfFile struct { + RFileName string `json:"rfilename"` +} + +type hfConfig struct { + Architectures []string `json:"architectures,omitempty"` + ModelType string `json:"model_type,omitempty"` +} + +type hfCard struct { + Data map[string]interface{} `json:"data,omitempty"` +} + +//go:embed assets/catalog_logo.svg +var catalogLogoSVG []byte + +var ( + catalogModelLogo = "data:image/svg+xml;base64," + base64.StdEncoding.EncodeToString(catalogLogoSVG) ) -func (h *hfCatalogImpl) GetModel(ctx context.Context, modelName string, sourceID string) (*openapi.CatalogModel, error) { - // TODO: Implement HuggingFace model retrieval - return nil, fmt.Errorf("HuggingFace model retrieval not yet implemented") +// populateFromHFInfo populates the hfModel's CatalogModel fields from HuggingFace API data +func (hfm *hfModel) populateFromHFInfo(ctx context.Context, provider *hfModelProvider, hfInfo *hfModelInfo, sourceId string, originalModelName string) { + // Set model name + modelName := hfInfo.ID + if modelName == "" { + modelName = hfInfo.ModelID + } + if modelName == "" { + modelName = originalModelName + } + hfm.Name = modelName + + // Set ExternalId + if hfInfo.ID != "" { + hfm.ExternalId = &hfInfo.ID + } + + // Set SourceId + if sourceId != "" { + hfm.SourceId = &sourceId + } + + // Convert timestamps + if hfInfo.CreatedAt != "" { + if createTime, err := parseHFTime(hfInfo.CreatedAt); err == nil { + createTimeStr := strconv.FormatInt(createTime, 10) + hfm.CreateTimeSinceEpoch = &createTimeStr + } + } + if hfInfo.UpdatedAt != "" { + if updateTime, err := parseHFTime(hfInfo.UpdatedAt); err == nil { + updateTimeStr := strconv.FormatInt(updateTime, 10) + hfm.LastUpdateTimeSinceEpoch = &updateTimeStr + } + } + + // Extract license from tags + // Skip license tags in custom properties to avoid duplication + var filteredTags []string + if len(hfInfo.Tags) > 0 { + filteredTags = make([]string, 0, len(hfInfo.Tags)) + for _, tag := range hfInfo.Tags { + if strings.HasPrefix(tag, "license:") { + // Extract license (only first one) + if hfm.License == nil { + license := strings.TrimPrefix(tag, "license:") + if license != "" { + hfm.License = &license + } + } + } else { + filteredTags = append(filteredTags, tag) + } + } + } + + // Extract README from sibling files first (preferred source) + // Check for common README filenames + readmeFilenames := []string{"README.md", "readme.md", "Readme.md", "README", "readme"} + + for _, sibling := range hfInfo.Siblings { + for _, readmeFilename := range readmeFilenames { + if sibling.RFileName == readmeFilename { + if readmeContent, err := provider.fetchFileContent(ctx, modelName, readmeFilename); err == nil { + hfm.Readme = &readmeContent + break + } else { + glog.V(2).Infof("Failed to fetch README from sibling file %s for model %s: %v", readmeFilename, modelName, err) + } + } + } + if hfm.Readme != nil { + break + } + } + + // Extract description from cardData if available + if hfInfo.CardData != nil && hfInfo.CardData.Data != nil { + // Extract description from cardData if available + if desc, ok := hfInfo.CardData.Data["description"].(string); ok && desc != "" { + hfm.Description = &desc + } + + // Extract language from cardData if available + if langData, ok := hfInfo.CardData.Data["language"].([]interface{}); ok && len(langData) > 0 { + languages := make([]string, 0, len(langData)) + for _, lang := range langData { + if langStr, ok := lang.(string); ok && langStr != "" { + languages = append(languages, langStr) + } + } + if len(languages) > 0 { + hfm.Language = languages + } + } + + // Extract license link from cardData if available + // Check common field names for license link/URL + if hfm.LicenseLink == nil { + licenseLinkFields := []string{"license_link", "licenseLink", "license_url", "licenseUrl", "license"} + for _, field := range licenseLinkFields { + if link, ok := hfInfo.CardData.Data[field].(string); ok && link != "" { + if strings.HasPrefix(link, "http://") || strings.HasPrefix(link, "https://") { + hfm.LicenseLink = &link + break + } + } + } + } + + } + + // Set provider from author + if hfInfo.Author != "" { + hfm.Provider = &hfInfo.Author + } + + // Set library name + if hfInfo.LibraryName != "" { + hfm.LibraryName = &hfInfo.LibraryName + } + + // Set logo + hfm.Logo = &catalogModelLogo + + // Convert tasks + var tasks []string + if hfInfo.Task != "" { + tasks = append(tasks, hfInfo.Task) + } + if hfInfo.PipelineTag != "" && hfInfo.PipelineTag != hfInfo.Task { + tasks = append(tasks, hfInfo.PipelineTag) + } + if len(tasks) > 0 { + hfm.Tasks = tasks + } + + // Convert tags and other metadata to custom properties + customProps := make(map[string]apimodels.MetadataValue) + + customProps["hf_private"] = apimodels.MetadataValue{ + MetadataStringValue: &apimodels.MetadataStringValue{ + StringValue: strconv.FormatBool(hfInfo.Private), + }, + } + + customProps["hf_gated"] = apimodels.MetadataValue{ + MetadataStringValue: &apimodels.MetadataStringValue{ + StringValue: hfInfo.Gated.String(), + }, + } + + if len(filteredTags) > 0 { + if tagsJSON, err := json.Marshal(filteredTags); err == nil { + customProps["hf_tags"] = apimodels.MetadataValue{ + MetadataStringValue: &apimodels.MetadataStringValue{ + StringValue: string(tagsJSON), + }, + } + } + } + + if hfInfo.Config != nil { + if len(hfInfo.Config.Architectures) > 0 { + if archJSON, err := json.Marshal(hfInfo.Config.Architectures); err == nil { + customProps["hf_architectures"] = apimodels.MetadataValue{ + MetadataStringValue: &apimodels.MetadataStringValue{ + StringValue: string(archJSON), + }, + } + } + } + if hfInfo.Config.ModelType != "" { + customProps["hf_model_type"] = apimodels.MetadataValue{ + MetadataStringValue: &apimodels.MetadataStringValue{ + StringValue: hfInfo.Config.ModelType, + }, + } + } + } + + if len(customProps) > 0 { + hfm.SetCustomProperties(customProps) + } } -func (h *hfCatalogImpl) ListModels(ctx context.Context, params ListModelsParams) (model.CatalogModelList, error) { - // TODO: Implement HuggingFace model listing - // For now, return empty list to satisfy interface - return model.CatalogModelList{ - Items: []model.CatalogModel{}, - PageSize: 0, - Size: 0, - }, nil +func (p *hfModelProvider) Models(ctx context.Context) (<-chan ModelProviderRecord, error) { + // Read the catalog and report errors + catalog, err := p.getModelsFromHF(ctx) + if err != nil { + return nil, err + } + + ch := make(chan ModelProviderRecord) + go func() { + defer close(ch) + + // Send the initial list right away. + p.emit(ctx, catalog, ch) + }() + + return ch, nil } -func (h *hfCatalogImpl) GetArtifacts(ctx context.Context, modelName string, sourceID string, params ListArtifactsParams) (openapi.CatalogArtifactList, error) { - // TODO: Implement HuggingFace model artifacts retrieval - // For now, return empty list to satisfy interface - return openapi.CatalogArtifactList{ - Items: []openapi.CatalogArtifact{}, - PageSize: 0, - Size: 0, - }, nil +func (p *hfModelProvider) getModelsFromHF(ctx context.Context) ([]ModelProviderRecord, error) { + var records []ModelProviderRecord + + for _, modelName := range p.includedModels { + // Skip if excluded - check before fetching to avoid unnecessary API calls + if !p.filter.Allows(modelName) { + glog.V(2).Infof("Skipping excluded model: %s", modelName) + continue + } + + modelInfo, err := p.fetchModelInfo(ctx, modelName) + if err != nil { + glog.Errorf("Failed to fetch model info for %s: %v", modelName, err) + continue + } + + record := p.convertHFModelToRecord(ctx, modelInfo, modelName) + + // Additional safety check: verify the final model name is not excluded + // (in case the model name changed during conversion, e.g., from hfInfo.ID) + if record.Model.GetAttributes() != nil && record.Model.GetAttributes().Name != nil { + finalModelName := *record.Model.GetAttributes().Name + if !p.filter.Allows(finalModelName) { + glog.V(2).Infof("Skipping excluded model (after conversion): %s", finalModelName) + continue + } + } + + records = append(records, record) + } + + return records, nil } -func (h *hfCatalogImpl) GetFilterOptions(ctx context.Context) (*openapi.FilterOptionsList, error) { - // TODO: Implement HuggingFace filter options retrieval - // For now, return empty options to satisfy interface - emptyFilters := make(map[string]openapi.FilterOption) - return &openapi.FilterOptionsList{ - Filters: &emptyFilters, - }, nil +func (p *hfModelProvider) fetchModelInfo(ctx context.Context, modelName string) (*hfModelInfo, error) { + // The HF API requires the full model identifier: org/model-name (aka repo/model-name) + + // Normalize the model name (remove any leading/trailing slashes) + modelName = strings.Trim(modelName, "/") + + // Construct the API URL with the full model identifier + apiURL := fmt.Sprintf("%s/api/models/%s", p.baseURL, modelName) + + glog.V(2).Infof("Fetching HuggingFace model info from: %s", apiURL) + + req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set User-Agent header (HuggingFace API expects this) + req.Header.Set("User-Agent", "model-registry-catalog") + + if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + resp, err := p.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch model info for %s: %w", modelName, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HuggingFace API returned status %d for model %s: %s", resp.StatusCode, modelName, string(bodyBytes)) + } + + var modelInfo hfModelInfo + if err := json.NewDecoder(resp.Body).Decode(&modelInfo); err != nil { + return nil, fmt.Errorf("failed to decode model info for %s: %w", modelName, err) + } + + // Ensure ID is set from modelName if not present in API response + if modelInfo.ID == "" { + modelInfo.ID = modelName + } + + return &modelInfo, nil +} + +// fetchFileContent fetches the content of a file from HuggingFace repository +func (p *hfModelProvider) fetchFileContent(ctx context.Context, modelName string, filename string) (string, error) { + // Normalize the model name (remove any leading/trailing slashes) + modelName = strings.Trim(modelName, "/") + + // Construct the API URL for raw file content + // HuggingFace API endpoint: {baseURL}/{model_id}/raw/main/{filename} + apiURL := fmt.Sprintf("%s/%s/raw/main/%s", p.baseURL, modelName, filename) + + req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + // Set User-Agent header + req.Header.Set("User-Agent", "model-registry-catalog") + + if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + resp, err := p.client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to fetch file %s for model %s: %w", filename, modelName, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("HuggingFace API returned status %d for file %s in model %s: %s", resp.StatusCode, filename, modelName, string(bodyBytes)) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read file content for %s in model %s: %w", filename, modelName, err) + } + + return string(bodyBytes), nil +} + +func (p *hfModelProvider) convertHFModelToRecord(ctx context.Context, hfInfo *hfModelInfo, originalModelName string) ModelProviderRecord { + // Create hfModel and populate it from HF API data + hfm := &hfModel{} + hfm.populateFromHFInfo(ctx, p, hfInfo, p.sourceId, originalModelName) + + // Convert to database model + model := dbmodels.CatalogModelImpl{} + + // Convert model attributes + modelName := hfm.Name + attrs := &dbmodels.CatalogModelAttributes{ + Name: &modelName, + ExternalID: hfm.ExternalId, + } + + // Convert timestamps if available + if hfm.CreateTimeSinceEpoch != nil { + if createTime, err := strconv.ParseInt(*hfm.CreateTimeSinceEpoch, 10, 64); err == nil { + attrs.CreateTimeSinceEpoch = &createTime + } + } + if hfm.LastUpdateTimeSinceEpoch != nil { + if updateTime, err := strconv.ParseInt(*hfm.LastUpdateTimeSinceEpoch, 10, 64); err == nil { + attrs.LastUpdateTimeSinceEpoch = &updateTime + } + } + + model.Attributes = attrs + + // Convert model properties + properties, customProperties := convertHFModelProperties(&hfm.CatalogModel) + if len(properties) > 0 { + model.Properties = &properties + } + if len(customProperties) > 0 { + model.CustomProperties = &customProperties + } + + return ModelProviderRecord{ + Model: &model, + Artifacts: []dbmodels.CatalogArtifact{}, // HF models don't have artifacts from the API + } +} + +// convertHFModelProperties converts CatalogModel properties to database format +func convertHFModelProperties(catalogModel *apimodels.CatalogModel) ([]models.Properties, []models.Properties) { + var properties []models.Properties + var customProperties []models.Properties + + // Regular properties + if catalogModel.Description != nil { + properties = append(properties, models.NewStringProperty("description", *catalogModel.Description, false)) + } + if catalogModel.Readme != nil { + properties = append(properties, models.NewStringProperty("readme", *catalogModel.Readme, false)) + } + if catalogModel.Provider != nil { + properties = append(properties, models.NewStringProperty("provider", *catalogModel.Provider, false)) + } + if catalogModel.License != nil { + properties = append(properties, models.NewStringProperty("license", *catalogModel.License, false)) + } + if catalogModel.LicenseLink != nil { + properties = append(properties, models.NewStringProperty("license_link", *catalogModel.LicenseLink, false)) + } + if catalogModel.LibraryName != nil { + properties = append(properties, models.NewStringProperty("library_name", *catalogModel.LibraryName, false)) + } + if catalogModel.Logo != nil { + properties = append(properties, models.NewStringProperty("logo", *catalogModel.Logo, false)) + } + if catalogModel.SourceId != nil { + properties = append(properties, models.NewStringProperty("source_id", *catalogModel.SourceId, false)) + } + + // Convert array properties + if len(catalogModel.Tasks) > 0 { + if tasksJSON, err := json.Marshal(catalogModel.Tasks); err == nil { + properties = append(properties, models.NewStringProperty("tasks", string(tasksJSON), false)) + } + } + if len(catalogModel.Language) > 0 { + if languageJSON, err := json.Marshal(catalogModel.Language); err == nil { + properties = append(properties, models.NewStringProperty("language", string(languageJSON), false)) + } + } + + // Convert custom properties from the CatalogModel + if catalogModel.CustomProperties != nil { + for key, value := range catalogModel.GetCustomProperties() { + prop := convertMetadataValueToProperty(key, value) + customProperties = append(customProperties, prop) + } + } + + return properties, customProperties } -// validateCredentials checks if the HuggingFace API credentials are valid -func (h *hfCatalogImpl) validateCredentials(ctx context.Context) error { +// parseHFTime parses HuggingFace timestamp format (ISO 8601) +func parseHFTime(timeStr string) (int64, error) { + t, err := time.Parse(time.RFC3339, timeStr) + if err != nil { + return 0, err + } + return t.UnixMilli(), nil +} + +func (p *hfModelProvider) emit(ctx context.Context, models []ModelProviderRecord, out chan<- ModelProviderRecord) { + done := ctx.Done() + for _, model := range models { + // Check if model should be excluded by name + if model.Model.GetAttributes() != nil && model.Model.GetAttributes().Name != nil { + modelName := *model.Model.GetAttributes().Name + if !p.filter.Allows(modelName) { + glog.V(2).Infof("Skipping excluded model in emit: %s", modelName) + continue + } + } + + select { + case out <- model: + case <-done: + return + } + } +} + +// validateCredentials checks if the HuggingFace API key credentials are valid +func (p *hfModelProvider) validateCredentials(ctx context.Context) error { glog.Infof("Validating HuggingFace API credentials") // Make a simple API call to validate credentials - apiURL := h.baseURL + "/api/whoami-v2" + apiURL := p.baseURL + "/api/whoami-v2" req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) if err != nil { return fmt.Errorf("failed to create validation request: %w", err) } - if h.apiKey != "" { - req.Header.Set("Authorization", "Bearer "+h.apiKey) + req.Header.Set("User-Agent", "model-registry-catalog") + + if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) } - resp, err := h.client.Do(req) + resp, err := p.client.Do(req) if err != nil { return fmt.Errorf("failed to validate HuggingFace credentials: %w", err) } @@ -83,46 +592,71 @@ func (h *hfCatalogImpl) validateCredentials(ctx context.Context) error { return fmt.Errorf("invalid HuggingFace API credentials") } if resp.StatusCode != http.StatusOK { - return fmt.Errorf("HuggingFace API validation failed with status: %d", resp.StatusCode) + bodyBytes, _ := io.ReadAll(resp.Body) + return fmt.Errorf("HuggingFace API validation failed with status: %d: %s", resp.StatusCode, string(bodyBytes)) } glog.Infof("HuggingFace credentials validated successfully") return nil } -// newHfCatalog creates a new HuggingFace catalog source -func newHfCatalog(source *Source, reldir string) (APIProvider, error) { - apiKey, ok := source.Properties["apiKey"].(string) - if !ok || apiKey == "" { - return nil, fmt.Errorf("missing or invalid 'apiKey' property for HuggingFace catalog") - } +func newHFModelProvider(ctx context.Context, source *Source, reldir string) (<-chan ModelProviderRecord, error) { + p := &hfModelProvider{} + p.client = &http.Client{Timeout: 30 * time.Second} - baseURL := defaultHuggingFaceURL - if url, ok := source.Properties["url"].(string); ok && url != "" { - baseURL = strings.TrimSuffix(url, "/") + // Parse Source ID + sourceId := source.GetId() + if sourceId == "" { + return nil, fmt.Errorf("missing source ID for HuggingFace catalog") } + p.sourceId = sourceId - // Optional model limit for future implementation - modelLimit := 100 - if limit, ok := source.Properties["modelLimit"].(int); ok && limit > 0 { - modelLimit = limit + // Parse API key from environment variable + // Allow the environment variable name to be configured via properties, defaulting to HF_API_KEY + apiKeyEnvVar := defaultAPIKeyEnvVar + if envVar, ok := source.Properties[apiKeyEnvVarKey].(string); ok && envVar != "" { + apiKeyEnvVar = envVar } + apiKey := os.Getenv(apiKeyEnvVar) + if apiKey == "" { + return nil, fmt.Errorf("missing %s environment variable for HuggingFace catalog", apiKeyEnvVar) + } + p.apiKey = apiKey - glog.Infof("Configuring HuggingFace catalog with URL: %s, modelLimit: %d", baseURL, modelLimit) - - h := &hfCatalogImpl{ - client: &http.Client{Timeout: 30 * time.Second}, - apiKey: apiKey, - baseURL: baseURL, + // Parse base URL (optional, defaults to huggingface.co) + // This allows tests to use mock servers by providing a custom URL + p.baseURL = defaultHuggingFaceURL + if url, ok := source.Properties[urlKey].(string); ok && url != "" { + p.baseURL = strings.TrimSuffix(url, "/") // Remove trailing slash if present } - // Validate credentials during initialization (as required by Jira ticket) - ctx := context.Background() - if err := h.validateCredentials(ctx); err != nil { + // Validate credentials before proceeding + if err := p.validateCredentials(ctx); err != nil { glog.Errorf("HuggingFace catalog credential validation failed: %v", err) return nil, fmt.Errorf("failed to validate HuggingFace catalog credentials: %w", err) } - glog.Infof("HuggingFace catalog source configured successfully") - return h, nil + // Use top-level IncludedModels from Source as the list of models to fetch + // These can be specific model names (required for HF API) or patterns + if len(source.IncludedModels) == 0 { + return nil, fmt.Errorf("includedModels cannot be empty for HuggingFace catalog") + } + + p.includedModels = source.IncludedModels + + // Create ModelFilter from source configuration (handles IncludedModels/ExcludedModels from Source) + // Note: IncludedModels are used both for fetching and filtering + filter, err := NewModelFilterFromSource(source, nil, nil) + if err != nil { + return nil, err + } + p.filter = filter + + return p.Models(ctx) +} + +func init() { + if err := RegisterModelProvider("hf", newHFModelProvider); err != nil { + panic(err) + } } diff --git a/catalog/internal/catalog/hf_catalog_test.go b/catalog/internal/catalog/hf_catalog_test.go index e21b384ed7..7cf032f22b 100644 --- a/catalog/internal/catalog/hf_catalog_test.go +++ b/catalog/internal/catalog/hf_catalog_test.go @@ -3,172 +3,556 @@ package catalog import ( "context" "net/http" - "net/http/httptest" - "strings" "testing" - "github.com/kubeflow/model-registry/catalog/pkg/openapi" + apimodels "github.com/kubeflow/model-registry/catalog/pkg/openapi" + "github.com/kubeflow/model-registry/internal/db/models" ) -func TestNewHfCatalog_MissingAPIKey(t *testing.T) { - source := &Source{ - CatalogSource: openapi.CatalogSource{ - Id: "test_hf", - Name: "Test HF", +func TestPopulateFromHFInfo(t *testing.T) { + tests := []struct { + name string + hfInfo *hfModelInfo + sourceId string + originalModelName string + expectedName string + expectedExternalID *string + expectedSourceID *string + expectedProvider *string + expectedLicense *string + expectedLibrary *string + hasReadme bool + hasDescription bool + hasTasks bool + hasCustomProps bool + }{ + { + name: "complete model info", + hfInfo: &hfModelInfo{ + ID: "test-org/test-model", + Author: "test-author", + Sha: "abc123", + CreatedAt: "2023-01-01T00:00:00Z", + UpdatedAt: "2023-01-02T00:00:00Z", + Downloads: 1000, + Tags: []string{"license:mit", "transformers", "pytorch"}, + PipelineTag: "text-generation", + Task: "text-generation", + LibraryName: "transformers", + Config: &hfConfig{ + Architectures: []string{"GPT2LMHeadModel"}, + ModelType: "gpt2", + }, + CardData: &hfCard{ + Data: map[string]interface{}{ + "description": "A test model description", + }, + }, + }, + sourceId: "test-source-id", + originalModelName: "test-org/test-model", + expectedName: "test-org/test-model", + expectedProvider: stringPtr("test-author"), + expectedLicense: stringPtr("mit"), + expectedLibrary: stringPtr("transformers"), + hasTasks: true, + hasCustomProps: true, + hasReadme: false, // No README fetching in unit tests }, - Type: "hf", - Properties: map[string]any{ - "url": "https://huggingface.co", + { + name: "model with ModelID fallback", + hfInfo: &hfModelInfo{ + ModelID: "fallback-model-id", + Author: "another-author", + }, + sourceId: "source-2", + originalModelName: "original-name", + expectedName: "fallback-model-id", + expectedProvider: stringPtr("another-author"), + }, + { + name: "model with original name fallback", + hfInfo: &hfModelInfo{ + Author: "author-3", + }, + sourceId: "source-3", + originalModelName: "fallback-original-name", + expectedName: "fallback-original-name", + expectedProvider: stringPtr("author-3"), + }, + { + name: "model with license in tags", + hfInfo: &hfModelInfo{ + ID: "test/licensed-model", + Tags: []string{"license:apache-2.0", "other-tag"}, + }, + sourceId: "source-4", + originalModelName: "test/licensed-model", + expectedName: "test/licensed-model", + expectedLicense: stringPtr("apache-2.0"), + hasCustomProps: true, + }, + { + name: "model with tasks", + hfInfo: &hfModelInfo{ + ID: "test/task-model", + Task: "text-classification", + PipelineTag: "sentiment-analysis", + }, + sourceId: "source-5", + originalModelName: "test/task-model", + expectedName: "test/task-model", + hasTasks: true, + }, + { + name: "model with description in cardData", + hfInfo: &hfModelInfo{ + ID: "test/desc-model", + CardData: &hfCard{ + Data: map[string]interface{}{ + "description": "This is a test description", + }, + }, + }, + sourceId: "source-6", + originalModelName: "test/desc-model", + expectedName: "test/desc-model", + hasDescription: true, + }, + { + name: "minimal model info", + hfInfo: &hfModelInfo{ + ID: "minimal/model", + }, + sourceId: "source-7", + originalModelName: "minimal/model", + expectedName: "minimal/model", }, } - _, err := newHfCatalog(source, "") - if err == nil { - t.Fatal("Expected error for missing API key, got nil") - } - if err.Error() != "missing or invalid 'apiKey' property for HuggingFace catalog" { - t.Fatalf("Expected specific error message, got: %s", err.Error()) - } -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a mock provider with HTTP client to avoid nil pointer + // Note: README fetching will fail, but that's expected in unit tests + provider := &hfModelProvider{ + sourceId: tt.sourceId, + client: &http.Client{}, + } -func TestNewHfCatalog_WithValidCredentials(t *testing.T) { - // Create mock server that returns valid response for credential validation - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check for authorization header - auth := r.Header.Get("Authorization") - if auth != "Bearer test-api-key" { - w.WriteHeader(http.StatusUnauthorized) - return - } - - switch r.URL.Path { - case "/api/whoami-v2": - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"name": "test-user", "type": "user"}`)) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - - source := &Source{ - CatalogSource: openapi.CatalogSource{ - Id: "test_hf", - Name: "Test HF", - }, - Type: "hf", - Properties: map[string]any{ - "apiKey": "test-api-key", - "url": server.URL, - "modelLimit": 10, - }, - } + // Create hfModel and populate it + hfm := &hfModel{} + ctx := context.Background() + hfm.populateFromHFInfo(ctx, provider, tt.hfInfo, tt.sourceId, tt.originalModelName) - catalog, err := newHfCatalog(source, "") - if err != nil { - t.Fatalf("Failed to create HF catalog: %v", err) - } + // Verify name + if hfm.Name != tt.expectedName { + t.Errorf("Name = %v, want %v", hfm.Name, tt.expectedName) + } - hfCatalog := catalog.(*hfCatalogImpl) + // Verify ExternalID + if tt.expectedExternalID != nil { + if hfm.ExternalId == nil || *hfm.ExternalId != *tt.expectedExternalID { + t.Errorf("ExternalId = %v, want %v", hfm.ExternalId, tt.expectedExternalID) + } + } else if tt.hfInfo.ID != "" { + // If hfInfo has ID, ExternalId should be set + if hfm.ExternalId == nil || *hfm.ExternalId != tt.hfInfo.ID { + t.Errorf("ExternalId = %v, want %v", hfm.ExternalId, tt.hfInfo.ID) + } + } - // Test that methods return appropriate responses for stub implementation - ctx := context.Background() + // Verify SourceID + if tt.expectedSourceID != nil { + if hfm.SourceId == nil || *hfm.SourceId != *tt.expectedSourceID { + t.Errorf("SourceId = %v, want %v", hfm.SourceId, tt.expectedSourceID) + } + } else if tt.sourceId != "" { + if hfm.SourceId == nil || *hfm.SourceId != tt.sourceId { + t.Errorf("SourceId = %v, want %v", hfm.SourceId, tt.sourceId) + } + } - // Test GetModel - should return not implemented error - model, err := hfCatalog.GetModel(ctx, "test-model", "") - if err == nil { - t.Fatal("Expected not implemented error, got nil") - } - if model != nil { - t.Fatal("Expected nil model, got non-nil") - } + // Verify Provider + if tt.expectedProvider != nil { + if hfm.Provider == nil || *hfm.Provider != *tt.expectedProvider { + t.Errorf("Provider = %v, want %v", hfm.Provider, tt.expectedProvider) + } + } - // Test ListModels - should return empty list - listParams := ListModelsParams{ - Query: "", - OrderBy: openapi.ORDERBYFIELD_NAME, - SortOrder: openapi.SORTORDER_ASC, - } - modelList, err := hfCatalog.ListModels(ctx, listParams) - if err != nil { - t.Fatalf("Failed to list models: %v", err) + // Verify License + if tt.expectedLicense != nil { + if hfm.License == nil || *hfm.License != *tt.expectedLicense { + t.Errorf("License = %v, want %v", hfm.License, tt.expectedLicense) + } + } + + // Verify LibraryName + if tt.expectedLibrary != nil { + if hfm.LibraryName == nil || *hfm.LibraryName != *tt.expectedLibrary { + t.Errorf("LibraryName = %v, want %v", hfm.LibraryName, tt.expectedLibrary) + } + } + + // Verify Tasks + if tt.hasTasks { + if len(hfm.Tasks) == 0 { + t.Error("Expected tasks to be set, but got empty slice") + } + } + + // Verify Description + if tt.hasDescription { + if hfm.Description == nil { + t.Error("Expected description to be set, but got nil") + } + } + + // Verify CustomProperties + if tt.hasCustomProps { + if hfm.GetCustomProperties() == nil || len(hfm.GetCustomProperties()) == 0 { + t.Error("Expected custom properties to be set, but got nil or empty") + } + } + + // Verify timestamps if present + if tt.hfInfo.CreatedAt != "" { + if hfm.CreateTimeSinceEpoch == nil { + t.Error("Expected CreateTimeSinceEpoch to be set") + } + } + if tt.hfInfo.UpdatedAt != "" { + if hfm.LastUpdateTimeSinceEpoch == nil { + t.Error("Expected LastUpdateTimeSinceEpoch to be set") + } + } + }) } - if len(modelList.Items) != 0 { - t.Fatalf("Expected 0 models, got %d", len(modelList.Items)) +} + +func TestConvertHFModelToRecord(t *testing.T) { + tests := []struct { + name string + hfInfo *hfModelInfo + originalModelName string + sourceId string + verifyFunc func(t *testing.T, record ModelProviderRecord) + }{ + { + name: "complete model conversion", + hfInfo: &hfModelInfo{ + ID: "test-org/complete-model", + Author: "test-author", + CreatedAt: "2023-01-01T00:00:00Z", + UpdatedAt: "2023-01-02T00:00:00Z", + Tags: []string{"license:mit"}, + LibraryName: "transformers", + Task: "text-generation", + CardData: &hfCard{ + Data: map[string]interface{}{ + "description": "A complete test model", + }, + }, + }, + originalModelName: "test-org/complete-model", + sourceId: "test-source", + verifyFunc: func(t *testing.T, record ModelProviderRecord) { + if record.Model == nil { + t.Fatal("Model should not be nil") + } + attrs := record.Model.GetAttributes() + if attrs == nil { + t.Fatal("Attributes should not be nil") + } + if attrs.Name == nil || *attrs.Name != "test-org/complete-model" { + t.Errorf("Name = %v, want 'test-org/complete-model'", attrs.Name) + } + if attrs.ExternalID == nil || *attrs.ExternalID != "test-org/complete-model" { + t.Errorf("ExternalID = %v, want 'test-org/complete-model'", attrs.ExternalID) + } + if attrs.CreateTimeSinceEpoch == nil { + t.Error("CreateTimeSinceEpoch should be set") + } + if attrs.LastUpdateTimeSinceEpoch == nil { + t.Error("LastUpdateTimeSinceEpoch should be set") + } + if record.Model.GetProperties() == nil || len(*record.Model.GetProperties()) == 0 { + t.Error("Properties should be set") + } + if len(record.Artifacts) != 0 { + t.Errorf("Artifacts should be empty, got %d", len(record.Artifacts)) + } + }, + }, + { + name: "minimal model conversion", + hfInfo: &hfModelInfo{ + ID: "minimal/model", + }, + originalModelName: "minimal/model", + sourceId: "source-1", + verifyFunc: func(t *testing.T, record ModelProviderRecord) { + if record.Model == nil { + t.Fatal("Model should not be nil") + } + attrs := record.Model.GetAttributes() + if attrs == nil || attrs.Name == nil { + t.Fatal("Attributes and Name should not be nil") + } + if *attrs.Name != "minimal/model" { + t.Errorf("Name = %v, want 'minimal/model'", attrs.Name) + } + }, + }, } - // Test GetArtifacts - should return empty list - artifacts, err := hfCatalog.GetArtifacts(ctx, "test-model", "", ListArtifactsParams{}) - if err != nil { - t.Fatalf("Failed to get artifacts: %v", err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider := &hfModelProvider{ + sourceId: tt.sourceId, + } + ctx := context.Background() + record := provider.convertHFModelToRecord(ctx, tt.hfInfo, tt.originalModelName) + tt.verifyFunc(t, record) + }) } - if artifacts.Items == nil { - t.Fatal("Expected artifacts list, got nil") +} + +func TestConvertHFModelProperties(t *testing.T) { + tests := []struct { + name string + catalogModel *apimodels.CatalogModel + wantProps bool + wantCustom bool + verifyFunc func(t *testing.T, props []models.Properties, customProps []models.Properties) + }{ + { + name: "model with all properties", + catalogModel: &apimodels.CatalogModel{ + Name: "test-model", + Description: stringPtr("Test description"), + Readme: stringPtr("# Test README"), + Provider: stringPtr("test-provider"), + License: stringPtr("mit"), + LibraryName: stringPtr("transformers"), + SourceId: stringPtr("source-1"), + Tasks: []string{"text-generation"}, + }, + wantProps: true, + wantCustom: false, + verifyFunc: func(t *testing.T, props []models.Properties, customProps []models.Properties) { + if len(props) == 0 { + t.Error("Expected properties to be set") + } + }, + }, + { + name: "model with custom properties", + catalogModel: func() *apimodels.CatalogModel { + model := &apimodels.CatalogModel{ + Name: "test-model", + } + customProps := map[string]apimodels.MetadataValue{ + "hf_tags": { + MetadataStringValue: &apimodels.MetadataStringValue{ + StringValue: `["tag1","tag2"]`, + }, + }, + } + model.SetCustomProperties(customProps) + return model + }(), + wantProps: false, + wantCustom: true, + verifyFunc: func(t *testing.T, props []models.Properties, customProps []models.Properties) { + if len(customProps) == 0 { + t.Error("Expected custom properties to be set") + } + }, + }, + { + name: "model with minimal properties", + catalogModel: &apimodels.CatalogModel{ + Name: "minimal-model", + }, + wantProps: false, + wantCustom: false, + verifyFunc: func(t *testing.T, props []models.Properties, customProps []models.Properties) { + if len(props) != 0 { + t.Errorf("Expected no properties, got %d", len(props)) + } + if len(customProps) != 0 { + t.Errorf("Expected no custom properties, got %d", len(customProps)) + } + }, + }, + { + name: "model with tasks", + catalogModel: &apimodels.CatalogModel{ + Name: "task-model", + Tasks: []string{"classification", "generation"}, + }, + wantProps: true, + verifyFunc: func(t *testing.T, props []models.Properties, customProps []models.Properties) { + // Should have tasks property + foundTasks := false + for _, prop := range props { + if prop.Name == "tasks" { + foundTasks = true + break + } + } + if !foundTasks { + t.Error("Expected tasks property to be present") + } + }, + }, } - if len(artifacts.Items) != 0 { - t.Fatalf("Expected 0 artifacts, got %d", len(artifacts.Items)) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + props, customProps := convertHFModelProperties(tt.catalogModel) + if (len(props) > 0) != tt.wantProps { + t.Errorf("Properties presence = %v, want %v", len(props) > 0, tt.wantProps) + } + if (len(customProps) > 0) != tt.wantCustom { + t.Errorf("Custom properties presence = %v, want %v", len(customProps) > 0, tt.wantCustom) + } + if tt.verifyFunc != nil { + tt.verifyFunc(t, props, customProps) + } + }) } } -func TestNewHfCatalog_InvalidCredentials(t *testing.T) { - // Create mock server that returns 401 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - })) - defer server.Close() - - source := &Source{ - CatalogSource: openapi.CatalogSource{ - Id: "test_hf", - Name: "Test HF", +func TestHFModelProviderWithModelFilter(t *testing.T) { + tests := []struct { + name string + includedModels []string + excludedModels []string + modelName string + wantAllowed bool + description string + }{ + { + name: "model matches included pattern", + includedModels: []string{"ibm-granite/*"}, + excludedModels: nil, + modelName: "ibm-granite/granite-4.0-h-small", + wantAllowed: true, + description: "Model matching included pattern should be allowed", }, - Type: "hf", - Properties: map[string]any{ - "apiKey": "invalid-key", - "url": server.URL, + { + name: "model does not match included pattern", + includedModels: []string{"ibm-granite/*"}, + excludedModels: nil, + modelName: "meta-llama/Llama-3.2-1B", + wantAllowed: false, + description: "Model not matching included pattern should be excluded", + }, + { + name: "model matches excluded pattern", + includedModels: []string{"ibm-granite/*"}, + excludedModels: []string{"*-beta"}, + modelName: "ibm-granite/granite-4.0-h-beta", + wantAllowed: false, + description: "Model matching excluded pattern should be excluded even if it matches included", + }, + { + name: "model matches included but not excluded", + includedModels: []string{"ibm-granite/*"}, + excludedModels: []string{"*-beta"}, + modelName: "ibm-granite/granite-4.0-h-small", + wantAllowed: true, + description: "Model matching included but not excluded should be allowed", + }, + { + name: "case insensitive matching", + includedModels: []string{"IBM-Granite/*"}, + excludedModels: nil, + modelName: "ibm-granite/granite-4.0-h-small", + wantAllowed: true, + description: "Filtering should be case-insensitive", + }, + { + name: "no included patterns allows all", + includedModels: nil, + excludedModels: []string{"*-beta"}, + modelName: "test/model", + wantAllowed: true, + description: "No included patterns means all models are allowed (unless excluded)", }, } - _, err := newHfCatalog(source, "") - if err == nil { - t.Fatal("Expected error for invalid credentials, got nil") - } - if !strings.Contains(err.Error(), "invalid HuggingFace API credentials") { - t.Fatalf("Expected credential validation error, got: %s", err.Error()) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a filter with the test patterns + filter, err := NewModelFilter(tt.includedModels, tt.excludedModels) + if err != nil { + t.Fatalf("NewModelFilter() error = %v, want nil", err) + } + + // Create a provider with the filter + provider := &hfModelProvider{ + filter: filter, + } + + // Test that the filter works correctly + got := provider.filter.Allows(tt.modelName) + if got != tt.wantAllowed { + t.Errorf("ModelFilter.Allows(%q) = %v, want %v. %s", tt.modelName, got, tt.wantAllowed, tt.description) + } + }) } } -func TestNewHfCatalog_DefaultConfiguration(t *testing.T) { - // Create mock server for default HuggingFace URL - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"name": "test-user"}`)) - })) - defer server.Close() - - source := &Source{ - CatalogSource: openapi.CatalogSource{ - Id: "test_hf", - Name: "Test HF", - }, - Type: "hf", - Properties: map[string]any{ - "apiKey": "test-key", - "url": server.URL, // Override default for testing +func TestPopulateFromHFInfoWithCustomProperties(t *testing.T) { + hfInfo := &hfModelInfo{ + ID: "test/custom-props-model", + Sha: "sha123", + Downloads: 5000, + Tags: []string{"tag1", "tag2", "license:apache-2.0"}, + Config: &hfConfig{ + Architectures: []string{"BertModel", "BertForSequenceClassification"}, + ModelType: "bert", }, } - catalog, err := newHfCatalog(source, "") - if err != nil { - t.Fatalf("Failed to create HF catalog with defaults: %v", err) + provider := &hfModelProvider{ + sourceId: "test-source", + } + + hfm := &hfModel{} + ctx := context.Background() + hfm.populateFromHFInfo(ctx, provider, hfInfo, "test-source", "test/custom-props-model") + + customProps := hfm.GetCustomProperties() + if customProps == nil { + t.Fatal("Custom properties should not be nil") + } + + // Verify hf_tags + if tagsVal, ok := customProps["hf_tags"]; !ok { + t.Error("Expected hf_tags in custom properties") + } else if tagsVal.MetadataStringValue == nil { + t.Error("hf_tags should be a string value") } - hfCatalog := catalog.(*hfCatalogImpl) - if hfCatalog.apiKey != "test-key" { - t.Fatalf("Expected apiKey 'test-key', got '%s'", hfCatalog.apiKey) + // Verify hf_architectures + if archVal, ok := customProps["hf_architectures"]; !ok { + t.Error("Expected hf_architectures in custom properties") + } else if archVal.MetadataStringValue == nil { + t.Error("hf_architectures should be a string value") } - if hfCatalog.baseURL != server.URL { - t.Fatalf("Expected baseURL '%s', got '%s'", server.URL, hfCatalog.baseURL) + + // Verify hf_model_type + if modelTypeVal, ok := customProps["hf_model_type"]; !ok { + t.Error("Expected hf_model_type in custom properties") + } else if modelTypeVal.MetadataStringValue == nil || modelTypeVal.MetadataStringValue.StringValue != "bert" { + t.Errorf("hf_model_type = %v, want 'bert'", modelTypeVal.MetadataStringValue) } } + +// Helper function to create string pointers +func stringPtr(s string) *string { + return &s +} diff --git a/catalog/internal/catalog/testdata/test-hf-catalog-sources.yaml b/catalog/internal/catalog/testdata/test-hf-catalog-sources.yaml index 164b3482f1..03ca2b6882 100644 --- a/catalog/internal/catalog/testdata/test-hf-catalog-sources.yaml +++ b/catalog/internal/catalog/testdata/test-hf-catalog-sources.yaml @@ -5,13 +5,21 @@ catalogs: enabled: true properties: apiKey: "hf_test_api_key_here" - url: "https://huggingface.co" modelLimit: 50 + includedModels: + - "ibm-granite/granite-4.0-h-small" + - "ibm-granite/granite-4.0-h-tiny" + excludedModels: + - "ibm-granite/granite-4.0-h-small" - name: "HuggingFace Invalid Credentials" id: hf_invalid type: hf enabled: false # disabled so it doesn't cause startup failures in tests properties: apiKey: "invalid_key" - url: "https://huggingface.co" modelLimit: 10 + includedModels: + - "ibm-granite/granite-4.0-h-small" + - "ibm-granite/granite-4.0-h-tiny" + excludedModels: + - "ibm-granite/granite-4.0-h-small" diff --git a/manifests/kustomize/options/catalog/base/deployment.yaml b/manifests/kustomize/options/catalog/base/deployment.yaml index 3051f24fe9..ce3fa3cc71 100644 --- a/manifests/kustomize/options/catalog/base/deployment.yaml +++ b/manifests/kustomize/options/catalog/base/deployment.yaml @@ -44,6 +44,12 @@ spec: secretKeyRef: name: model-catalog-postgres key: POSTGRES_PASSWORD + - name: HF_API_KEY + valueFrom: + secretKeyRef: + name: model-catalog-hf-api-key + key: HF_API_KEY + optional: true command: - /model-registry - catalog diff --git a/manifests/kustomize/options/catalog/base/hf-api-key.env b/manifests/kustomize/options/catalog/base/hf-api-key.env new file mode 100644 index 0000000000..19f7f26480 --- /dev/null +++ b/manifests/kustomize/options/catalog/base/hf-api-key.env @@ -0,0 +1,2 @@ +HF_API_KEY=your-huggingface-api-key-here + diff --git a/manifests/kustomize/options/catalog/base/kustomization.yaml b/manifests/kustomize/options/catalog/base/kustomization.yaml index 36a465931b..d07eda9d07 100644 --- a/manifests/kustomize/options/catalog/base/kustomization.yaml +++ b/manifests/kustomize/options/catalog/base/kustomization.yaml @@ -25,3 +25,8 @@ secretGenerator: - postgres.env options: disableNameSuffixHash: true +- name: model-catalog-hf-api-key + envs: + - hf-api-key.env + options: + disableNameSuffixHash: true diff --git a/manifests/kustomize/options/catalog/base/sources.yaml b/manifests/kustomize/options/catalog/base/sources.yaml index 49a00160f3..7560017fff 100644 --- a/manifests/kustomize/options/catalog/base/sources.yaml +++ b/manifests/kustomize/options/catalog/base/sources.yaml @@ -5,3 +5,4 @@ catalogs: enabled: true properties: yamlCatalogPath: sample-catalog.yaml +