Skip to content

Commit f33d223

Browse files
committed
feat: add hf maxModels safeguard
Signed-off-by: Alessio Pragliola <seth.pro@gmail.com>
1 parent 2afab7f commit f33d223

2 files changed

Lines changed: 108 additions & 3 deletions

File tree

catalog/internal/catalog/hf_catalog.go

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ const (
2424
defaultAPIKeyEnvVar = "HF_API_KEY"
2525
urlKey = "url"
2626
apiKeyEnvVarKey = "apiKeyEnvVar"
27+
maxModelsKey = "maxModels"
28+
29+
// defaultMaxModels is the default limit for models fetched PER PATTERN.
30+
// This limit is applied independently to each pattern in includedModels
31+
// (e.g., "ibm-granite/*", "meta-llama/*") to prevent overloading the
32+
// HuggingFace API with too many requests and to respect rate limiting.
33+
//
34+
// Example: with maxModels=100 and 3 patterns, up to 300 models total may be fetched.
35+
// Set to 0 to disable the limit (not recommended for large organizations).
36+
defaultMaxModels = 500
2737
)
2838

2939
// gatedString is a custom type that can unmarshal both boolean and string values from JSON
@@ -71,6 +81,10 @@ type hfModelProvider struct {
7181
baseURL string
7282
includedModels []string
7383
filter *ModelFilter
84+
// maxModels limits how many models to fetch PER PATTERN (e.g., per "org/*").
85+
// This is applied independently to each pattern to respect HuggingFace API rate limits.
86+
// A value of 0 means no limit.
87+
maxModels int
7488
}
7589

7690
// hfModelInfo represents the structure of HuggingFace API model information
@@ -665,8 +679,9 @@ func init() {
665679
// It initializes the provider from a PreviewConfig without starting the full model loading.
666680
func NewHFPreviewProvider(config *PreviewConfig) (*hfModelProvider, error) {
667681
p := &hfModelProvider{
668-
client: &http.Client{Timeout: 30 * time.Second},
669-
baseURL: defaultHuggingFaceURL,
682+
client: &http.Client{Timeout: 30 * time.Second},
683+
baseURL: defaultHuggingFaceURL,
684+
maxModels: defaultMaxModels,
670685
}
671686

672687
// Parse API key from environment variable
@@ -685,6 +700,21 @@ func NewHFPreviewProvider(config *PreviewConfig) (*hfModelProvider, error) {
685700
p.baseURL = strings.TrimSuffix(url, "/")
686701
}
687702

703+
// Parse maxModels limit (optional, defaults to 500)
704+
// This limit is applied PER PATTERN (e.g., each "org/*" pattern gets its own limit)
705+
// to prevent overloading the HuggingFace API and respect rate limiting.
706+
// Set to 0 to disable the limit.
707+
if maxModels, ok := config.Properties[maxModelsKey]; ok {
708+
switch v := maxModels.(type) {
709+
case int:
710+
p.maxModels = v
711+
case int64:
712+
p.maxModels = int(v)
713+
case float64:
714+
p.maxModels = int(v)
715+
}
716+
}
717+
688718
return p, nil
689719
}
690720

@@ -767,6 +797,13 @@ func (p *hfModelProvider) listModelsByAuthor(ctx context.Context, author string,
767797
default:
768798
}
769799

800+
// Check if we've reached the maxModels limit for this pattern
801+
// (maxModels is applied per-pattern to respect HF API rate limits)
802+
if p.maxModels > 0 && len(allModels) >= p.maxModels {
803+
glog.Warningf("Reached maxModels limit (%d) for pattern author=%s, stopping pagination", p.maxModels, author)
804+
break
805+
}
806+
770807
// Build API URL
771808
apiURL := fmt.Sprintf("%s/api/models?author=%s&limit=%d", p.baseURL, author, limit)
772809
if searchPrefix != "" {
@@ -808,6 +845,11 @@ func (p *hfModelProvider) listModelsByAuthor(ctx context.Context, author string,
808845

809846
// Extract model IDs
810847
for _, m := range models {
848+
// Check limit before adding each model
849+
if p.maxModels > 0 && len(allModels) >= p.maxModels {
850+
break
851+
}
852+
811853
modelID := m.ID
812854
if modelID == "" {
813855
modelID = m.ModelID
@@ -842,7 +884,7 @@ func (p *hfModelProvider) listModelsByAuthor(ctx context.Context, author string,
842884
cursor = nextCursor
843885
}
844886

845-
glog.Infof("Listed %d models from author %s", len(allModels), author)
887+
glog.Infof("Listed %d models from author %s (maxModels: %d)", len(allModels), author, p.maxModels)
846888
return allModels, nil
847889
}
848890

catalog/internal/catalog/hf_catalog_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,69 @@ func TestListModelsByAuthor(t *testing.T) {
747747
// "other-model" should be filtered out
748748
assert.NotContains(t, models, "search-org/other-model")
749749
})
750+
751+
t.Run("respects maxModels limit", func(t *testing.T) {
752+
callCount = 0
753+
config := &PreviewConfig{
754+
Type: "hf",
755+
Properties: map[string]any{
756+
"url": server.URL,
757+
"maxModels": 50, // Limit to 50 models
758+
},
759+
}
760+
761+
provider, err := NewHFPreviewProvider(config)
762+
require.NoError(t, err)
763+
assert.Equal(t, 50, provider.maxModels)
764+
765+
models, err := provider.listModelsByAuthor(context.Background(), "test-org", "")
766+
require.NoError(t, err)
767+
768+
// Should stop at 50 models (first page has 100, but we limit to 50)
769+
assert.Len(t, models, 50)
770+
771+
// Should have only made 1 API call (stopped before second page)
772+
assert.Equal(t, 1, callCount)
773+
})
774+
775+
t.Run("uses default maxModels when not specified", func(t *testing.T) {
776+
config := &PreviewConfig{
777+
Type: "hf",
778+
Properties: map[string]any{
779+
"url": server.URL,
780+
},
781+
}
782+
783+
provider, err := NewHFPreviewProvider(config)
784+
require.NoError(t, err)
785+
786+
// Should use default (500)
787+
assert.Equal(t, 500, provider.maxModels)
788+
})
789+
790+
t.Run("maxModels 0 means no limit", func(t *testing.T) {
791+
callCount = 0
792+
config := &PreviewConfig{
793+
Type: "hf",
794+
Properties: map[string]any{
795+
"url": server.URL,
796+
"maxModels": 0, // No limit
797+
},
798+
}
799+
800+
provider, err := NewHFPreviewProvider(config)
801+
require.NoError(t, err)
802+
assert.Equal(t, 0, provider.maxModels)
803+
804+
models, err := provider.listModelsByAuthor(context.Background(), "test-org", "")
805+
require.NoError(t, err)
806+
807+
// Should get all 102 models (100 from page 1 + 2 from page 2)
808+
assert.Len(t, models, 102)
809+
810+
// Should have made 2 API calls
811+
assert.Equal(t, 2, callCount)
812+
})
750813
}
751814

752815
func TestFetchModelNamesForPreviewWithPatterns(t *testing.T) {

0 commit comments

Comments
 (0)