Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ catalog/pkg/openapi/model_catalog_model_artifact.go linguist-generated=true
catalog/pkg/openapi/model_catalog_model_list.go linguist-generated=true
catalog/pkg/openapi/model_catalog_source.go linguist-generated=true
catalog/pkg/openapi/model_catalog_source_list.go linguist-generated=true
catalog/pkg/openapi/model_catalog_source_preview_response.go linguist-generated=true
catalog/pkg/openapi/model_catalog_source_preview_response_all_of_summary.go linguist-generated=true
catalog/pkg/openapi/model_error.go linguist-generated=true
catalog/pkg/openapi/model_filter_option.go linguist-generated=true
catalog/pkg/openapi/model_filter_option_range.go linguist-generated=true
Expand All @@ -39,6 +41,7 @@ catalog/pkg/openapi/model_metadata_proto_value.go linguist-generated=true
catalog/pkg/openapi/model_metadata_string_value.go linguist-generated=true
catalog/pkg/openapi/model_metadata_struct_value.go linguist-generated=true
catalog/pkg/openapi/model_metadata_value.go linguist-generated=true
catalog/pkg/openapi/model_model_preview_result.go linguist-generated=true
catalog/pkg/openapi/model_order_by_field.go linguist-generated=true
catalog/pkg/openapi/model_sort_order.go linguist-generated=true
catalog/pkg/openapi/response.go linguist-generated=true
Expand Down
340 changes: 331 additions & 9 deletions api/openapi/catalog.yaml

Large diffs are not rendered by default.

342 changes: 333 additions & 9 deletions api/openapi/src/catalog.yaml

Large diffs are not rendered by default.

318 changes: 318 additions & 0 deletions catalog/internal/catalog/hf_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ const (
defaultAPIKeyEnvVar = "HF_API_KEY"
urlKey = "url"
apiKeyEnvVarKey = "apiKeyEnvVar"
maxModelsKey = "maxModels"

// defaultMaxModels is the default limit for models fetched PER PATTERN.
// This limit is applied independently to each pattern in includedModels
// (e.g., "ibm-granite/*", "meta-llama/*") to prevent overloading the
// HuggingFace API with too many requests and to respect rate limiting.
//
// Example: with maxModels=100 and 3 patterns, up to 300 models total may be fetched.
// Set to 0 to disable the limit (not recommended for large organizations).
defaultMaxModels = 500
)

// gatedString is a custom type that can unmarshal both boolean and string values from JSON
Expand Down Expand Up @@ -71,6 +81,10 @@ type hfModelProvider struct {
baseURL string
includedModels []string
filter *ModelFilter
// maxModels limits how many models to fetch PER PATTERN (e.g., per "org/*").
// This is applied independently to each pattern to respect HuggingFace API rate limits.
// A value of 0 means no limit.
maxModels int
}

// hfModelInfo represents the structure of HuggingFace API model information
Expand Down Expand Up @@ -660,3 +674,307 @@ func init() {
panic(err)
}
}

// NewHFPreviewProvider creates an hfModelProvider for preview use.
// It initializes the provider from a PreviewConfig without starting the full model loading.
func NewHFPreviewProvider(config *PreviewConfig) (*hfModelProvider, error) {
p := &hfModelProvider{
client: &http.Client{Timeout: 30 * time.Second},
baseURL: defaultHuggingFaceURL,
maxModels: defaultMaxModels,
}

// Parse API key from environment variable
apiKeyEnvVar := defaultAPIKeyEnvVar
if envVar, ok := config.Properties[apiKeyEnvVarKey].(string); ok && envVar != "" {
apiKeyEnvVar = envVar
}
apiKey := os.Getenv(apiKeyEnvVar)
if apiKey == "" {
return nil, fmt.Errorf("missing %s environment variable for HuggingFace preview", apiKeyEnvVar)
}
p.apiKey = apiKey

// Parse base URL (optional, defaults to huggingface.co)
if url, ok := config.Properties[urlKey].(string); ok && url != "" {
p.baseURL = strings.TrimSuffix(url, "/")
}

// Parse maxModels limit (optional, defaults to 500)
// This limit is applied PER PATTERN (e.g., each "org/*" pattern gets its own limit)
// to prevent overloading the HuggingFace API and respect rate limiting.
// Set to 0 to disable the limit.
if maxModels, ok := config.Properties[maxModelsKey]; ok {
switch v := maxModels.(type) {
case int:
p.maxModels = v
case int64:
p.maxModels = int(v)
case float64:
p.maxModels = int(v)
}
}

return p, nil
}

// hfListResponse represents a single model in the HuggingFace list API response.
type hfListModel struct {
ID string `json:"id"`
ModelID string `json:"modelId,omitempty"`
Author string `json:"author,omitempty"`
Private bool `json:"private,omitempty"`
Downloads int `json:"downloads,omitempty"`
}

// PatternType indicates how to handle an includedModels pattern.
type PatternType int

const (
PatternExact PatternType = iota // e.g., "org/model-name" - direct fetch
PatternOrgAll // e.g., "org/*" - list all from org
PatternOrgPrefix // e.g., "org/prefix*" - list from org with search
PatternInvalid // e.g., "*", "*/*" - not supported
)

// parseModelPattern analyzes a model identifier to determine how to fetch it.
// Returns: patternType, org, searchPrefix
func parseModelPattern(pattern string) (PatternType, string, string) {
pattern = strings.TrimSpace(pattern)

// Reject unsupported wildcard patterns that would try to list all HuggingFace models
// HuggingFace has millions of models, so we require a specific organization
if pattern == "*" || pattern == "*/*" {
return PatternInvalid, "", ""
}

// Reject patterns like "*/something" where org is a wildcard
if strings.HasPrefix(pattern, "*/") {
return PatternInvalid, "", ""
}

// Check if it's an org/* pattern
if strings.HasSuffix(pattern, "/*") {
org := strings.TrimSuffix(pattern, "/*")
// Ensure org is not empty or just whitespace
if org == "" || strings.TrimSpace(org) == "" {
return PatternInvalid, "", ""
}
return PatternOrgAll, org, ""
}

// Check if it has a wildcard after org/prefix
if strings.Contains(pattern, "/") && strings.HasSuffix(pattern, "*") {
parts := strings.SplitN(pattern, "/", 2)
if len(parts) == 2 {
org := parts[0]
// Ensure org is not empty or a wildcard
if org == "" || org == "*" {
return PatternInvalid, "", ""
}
prefix := strings.TrimSuffix(parts[1], "*")
if prefix != "" {
return PatternOrgPrefix, org, prefix
}
}
}

// Exact model name
return PatternExact, "", ""
}

// listModelsByAuthor fetches all models from an organization using the HuggingFace list API with pagination.
// If searchPrefix is provided, it filters models that start with that prefix.
func (p *hfModelProvider) listModelsByAuthor(ctx context.Context, author string, searchPrefix string) ([]string, error) {
var allModels []string
limit := 100 // Max allowed by HF API
cursor := ""

for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

// Check if we've reached the maxModels limit for this pattern
// (maxModels is applied per-pattern to respect HF API rate limits)
if p.maxModels > 0 && len(allModels) >= p.maxModels {
glog.Warningf("Reached maxModels limit (%d) for pattern author=%s, stopping pagination", p.maxModels, author)
break
}

// Build API URL
apiURL := fmt.Sprintf("%s/api/models?author=%s&limit=%d", p.baseURL, author, limit)
if searchPrefix != "" {
apiURL += "&search=" + searchPrefix
}
if cursor != "" {
apiURL += "&cursor=" + cursor
}

glog.V(2).Infof("Fetching HuggingFace models list: %s", apiURL)

req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create list request: %w", err)
}

req.Header.Set("User-Agent", "model-registry-catalog")
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}

resp, err := p.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to list models for author %s: %w", author, err)
}

if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, fmt.Errorf("HuggingFace API returned status %d for author %s: %s", resp.StatusCode, author, string(bodyBytes))
}

var models []hfListModel
if err := json.NewDecoder(resp.Body).Decode(&models); err != nil {
resp.Body.Close()
return nil, fmt.Errorf("failed to decode models list for author %s: %w", author, err)
}
resp.Body.Close()

// Extract model IDs
for _, m := range models {
// Check limit before adding each model
if p.maxModels > 0 && len(allModels) >= p.maxModels {
break
}

modelID := m.ID
if modelID == "" {
modelID = m.ModelID
}
if modelID == "" {
continue
}

// If we have a search prefix, double-check it matches
// (HF search is fuzzy, so we need to verify)
if searchPrefix != "" {
// Extract the model name part (after org/)
parts := strings.SplitN(modelID, "/", 2)
if len(parts) == 2 {
modelName := parts[1]
if !strings.HasPrefix(strings.ToLower(modelName), strings.ToLower(searchPrefix)) {
continue
}
}
}

allModels = append(allModels, modelID)
}

// Check for next page via Link header
linkHeader := resp.Header.Get("Link")
nextCursor := parseNextCursor(linkHeader)
if nextCursor == "" || len(models) < limit {
// No more pages
break
}
cursor = nextCursor
}

glog.Infof("Listed %d models from author %s (maxModels: %d)", len(allModels), author, p.maxModels)
return allModels, nil
}

// parseNextCursor extracts the cursor for the next page from the Link header.
// Link header format: <url>; rel="next"
func parseNextCursor(linkHeader string) string {
if linkHeader == "" {
return ""
}

// Parse Link header for rel="next"
for _, link := range strings.Split(linkHeader, ",") {
link = strings.TrimSpace(link)
if strings.Contains(link, `rel="next"`) {
// Extract URL between < and >
start := strings.Index(link, "<")
end := strings.Index(link, ">")
if start >= 0 && end > start {
nextURL := link[start+1 : end]
// Extract cursor parameter from URL
if idx := strings.Index(nextURL, "cursor="); idx >= 0 {
cursor := nextURL[idx+7:]
// Handle if there are more parameters after cursor
if ampIdx := strings.Index(cursor, "&"); ampIdx >= 0 {
cursor = cursor[:ampIdx]
}
return cursor
}
}
}
}
return ""
}

// FetchModelNamesForPreview fetches model info from HuggingFace API for the given model identifiers
// and returns the actual model names. This is used for preview functionality.
// Supports patterns like "org/*" and "org/prefix*" which use the paginated list API.
func (p *hfModelProvider) FetchModelNamesForPreview(ctx context.Context, modelIdentifiers []string) ([]string, error) {
if len(modelIdentifiers) == 0 {
return nil, fmt.Errorf("includedModels is required for HuggingFace source preview")
}

// Validate credentials first
if err := p.validateCredentials(ctx); err != nil {
return nil, fmt.Errorf("failed to validate HuggingFace credentials: %w", err)
}

names := make([]string, 0)

for _, pattern := range modelIdentifiers {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

patternType, org, searchPrefix := parseModelPattern(pattern)

switch patternType {
case PatternInvalid:
// Reject unsupported wildcard patterns
return nil, fmt.Errorf("wildcard pattern %q is not supported - HuggingFace requires a specific organization (e.g., 'ibm-granite/*' or 'meta-llama/Llama-2-*')", pattern)

case PatternOrgAll, PatternOrgPrefix:
// Use paginated list API
glog.Infof("Using HuggingFace list API for pattern: %s (org=%s, prefix=%s)", pattern, org, searchPrefix)
models, err := p.listModelsByAuthor(ctx, org, searchPrefix)
if err != nil {
glog.Warningf("Failed to list models for pattern %s: %v", pattern, err)
// Don't fail completely, just skip this pattern
continue
}
names = append(names, models...)

case PatternExact:
// Direct fetch for exact model name
modelInfo, err := p.fetchModelInfo(ctx, pattern)
if err != nil {
glog.Warningf("Failed to fetch model info for preview: %s: %v", pattern, err)
names = append(names, pattern)
continue
}

actualName := modelInfo.ID
if actualName == "" {
actualName = pattern
}
names = append(names, actualName)
}
}

return names, nil
}
Loading
Loading