Skip to content

Commit 1913341

Browse files
danpithsPratham-Mishra04
authored andcommitted
refactor: replaced ModelMatcher interface with ModelCatalog struct to avoid having to do to nil checks using reflect
1 parent 871d512 commit 1913341

File tree

4 files changed

+37
-62
lines changed

4 files changed

+37
-62
lines changed

framework/modelcatalog/main.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,3 +634,17 @@ func (mc *ModelCatalog) Cleanup() error {
634634

635635
return nil
636636
}
637+
638+
// NewTestCatalog creates a minimal ModelCatalog for testing purposes.
639+
// It does not start background sync workers or connect to external services.
640+
func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog {
641+
if baseModelIndex == nil {
642+
baseModelIndex = make(map[string]string)
643+
}
644+
return &ModelCatalog{
645+
modelPool: make(map[schemas.ModelProvider][]string),
646+
baseModelIndex: baseModelIndex,
647+
pricingData: make(map[string]configstoreTables.TableModelPricing),
648+
done: make(chan struct{}),
649+
}
650+
}

plugins/governance/model_provider_governance_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,7 +1947,7 @@ func TestStore_CheckModelBudget_CrossProviderModelMatch(t *testing.T) {
19471947
budget := buildBudgetWithUsage("budget1", 100.0, 100.0, "1h") // At limit
19481948
modelConfig := buildModelConfig("mc1", "gpt-4o", nil, budget, nil)
19491949

1950-
mc := newMockModelMatcher(t)
1950+
mc := newTestModelCatalog(t)
19511951
store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{
19521952
ModelConfigs: []configstoreTables.TableModelConfig{*modelConfig},
19531953
Budgets: []configstoreTables.TableBudget{*budget},
@@ -1967,7 +1967,7 @@ func TestStore_CheckModelBudget_CrossProviderModelMatch_WithinLimit(t *testing.T
19671967
budget := buildBudget("budget1", 100.0, "1h")
19681968
modelConfig := buildModelConfig("mc1", "gpt-4o", nil, budget, nil)
19691969

1970-
mc := newMockModelMatcher(t)
1970+
mc := newTestModelCatalog(t)
19711971
store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{
19721972
ModelConfigs: []configstoreTables.TableModelConfig{*modelConfig},
19731973
Budgets: []configstoreTables.TableBudget{*budget},
@@ -1985,7 +1985,7 @@ func TestStore_CheckModelRateLimit_CrossProviderModelMatch(t *testing.T) {
19851985
rateLimit := buildRateLimitWithUsage("rl1", 10000, 10000, 1000, 0) // Token limit at max
19861986
modelConfig := buildModelConfig("mc1", "gpt-4o", nil, nil, rateLimit)
19871987

1988-
mc := newMockModelMatcher(t)
1988+
mc := newTestModelCatalog(t)
19891989
store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{
19901990
ModelConfigs: []configstoreTables.TableModelConfig{*modelConfig},
19911991
RateLimits: []configstoreTables.TableRateLimit{*rateLimit},
@@ -2005,7 +2005,7 @@ func TestStore_UpdateModelBudgetUsage_CrossProviderModelMatch(t *testing.T) {
20052005
budget := buildBudget("budget1", 100.0, "1h")
20062006
modelConfig := buildModelConfig("mc1", "gpt-4o", nil, budget, nil)
20072007

2008-
mc := newMockModelMatcher(t)
2008+
mc := newTestModelCatalog(t)
20092009
store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{
20102010
ModelConfigs: []configstoreTables.TableModelConfig{*modelConfig},
20112011
Budgets: []configstoreTables.TableBudget{*budget},
@@ -2033,7 +2033,7 @@ func TestStore_UpdateModelRateLimitUsage_CrossProviderModelMatch(t *testing.T) {
20332033
rateLimit := buildRateLimitWithUsage("rl1", 100, 0, 1000, 0) // Low token limit
20342034
modelConfig := buildModelConfig("mc1", "gpt-4o", nil, nil, rateLimit)
20352035

2036-
mc := newMockModelMatcher(t)
2036+
mc := newTestModelCatalog(t)
20372037
store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{
20382038
ModelConfigs: []configstoreTables.TableModelConfig{*modelConfig},
20392039
RateLimits: []configstoreTables.TableRateLimit{*rateLimit},
@@ -2059,7 +2059,7 @@ func TestStore_CheckModelBudget_ModelWithProvider_ExactMatchOnly(t *testing.T) {
20592059
providerStr := "openai"
20602060
modelConfig := buildModelConfig("mc1", "gpt-4o", &providerStr, budget, nil)
20612061

2062-
mc := newMockModelMatcher(t)
2062+
mc := newTestModelCatalog(t)
20632063
store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{
20642064
ModelConfigs: []configstoreTables.TableModelConfig{*modelConfig},
20652065
Budgets: []configstoreTables.TableBudget{*budget},

plugins/governance/store.go

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,10 @@ import (
1313
"github.com/maximhq/bifrost/core/schemas"
1414
"github.com/maximhq/bifrost/framework/configstore"
1515
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
16+
"github.com/maximhq/bifrost/framework/modelcatalog"
1617
"gorm.io/gorm"
1718
)
1819

19-
// ModelMatcher provides cross-provider model name matching.
20-
// This is satisfied by *modelcatalog.ModelCatalog.
21-
type ModelMatcher interface {
22-
GetBaseModelName(model string) string
23-
IsSameModel(model1, model2 string) bool
24-
}
25-
2620
// LocalGovernanceStore provides in-memory cache for governance data with fast, non-blocking access
2721
type LocalGovernanceStore struct {
2822
// Core data maps using sync.Map for lock-free reads
@@ -50,8 +44,8 @@ type LocalGovernanceStore struct {
5044
// Config store for refresh operations
5145
configStore configstore.ConfigStore
5246

53-
// Model matcher for cross-provider model matching (optional)
54-
modelMatcher ModelMatcher
47+
// Model catalog for cross-provider model matching (optional)
48+
modelCatalog *modelcatalog.ModelCatalog
5549

5650
// Logger
5751
logger schemas.Logger
@@ -141,9 +135,9 @@ type GovernanceStore interface {
141135
}
142136

143137
// NewLocalGovernanceStore creates a new in-memory governance store
144-
// The modelMatcher parameter is optional (can be nil) and enables cross-provider model matching
138+
// The modelCatalog parameter is optional (can be nil) and enables cross-provider model matching
145139
// for governance lookups (e.g., "openai/gpt-4o" matching config for "gpt-4o").
146-
func NewLocalGovernanceStore(ctx context.Context, logger schemas.Logger, configStore configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig, modelMatcher ModelMatcher) (*LocalGovernanceStore, error) {
140+
func NewLocalGovernanceStore(ctx context.Context, logger schemas.Logger, configStore configstore.ConfigStore, governanceConfig *configstore.GovernanceConfig, modelCatalog *modelcatalog.ModelCatalog) (*LocalGovernanceStore, error) {
147141
// Create singleton CEL environment once for all routing rule compilations
148142
env, err := createCELEnvironment()
149143
if err != nil {
@@ -154,7 +148,7 @@ func NewLocalGovernanceStore(ctx context.Context, logger schemas.Logger, configS
154148
configStore: configStore,
155149
logger: logger,
156150
routingCELEnv: env,
157-
modelMatcher: modelMatcher,
151+
modelCatalog: modelCatalog,
158152
LastDBUsagesBudgets: make(map[string]float64),
159153
LastDBUsagesRequestsRateLimits: make(map[string]int64),
160154
LastDBUsagesTokensRateLimits: make(map[string]int64),
@@ -599,8 +593,8 @@ func (gs *LocalGovernanceStore) CheckProviderRateLimit(ctx context.Context, requ
599593
// Returns the matching config and the display name for error messages.
600594
func (gs *LocalGovernanceStore) findModelOnlyConfig(model string) (*configstoreTables.TableModelConfig, string) {
601595
// If modelMatcher is available, try normalized base model name first (cross-provider matching)
602-
if gs.modelMatcher != nil {
603-
baseName := gs.modelMatcher.GetBaseModelName(model)
596+
if gs.modelCatalog != nil {
597+
baseName := gs.modelCatalog.GetBaseModelName(model)
604598
if baseName != model {
605599
if value, exists := gs.modelConfigs.Load(baseName); exists && value != nil {
606600
if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil {
@@ -1787,8 +1781,8 @@ func (gs *LocalGovernanceStore) rebuildInMemoryStructures(ctx context.Context, c
17871781
} else {
17881782
// Global config (applies to all providers) - store under normalized model name
17891783
key := mc.ModelName
1790-
if gs.modelMatcher != nil {
1791-
key = gs.modelMatcher.GetBaseModelName(mc.ModelName)
1784+
if gs.modelCatalog != nil {
1785+
key = gs.modelCatalog.GetBaseModelName(mc.ModelName)
17921786
}
17931787
gs.modelConfigs.Store(key, mc)
17941788
}
@@ -2406,8 +2400,8 @@ func (gs *LocalGovernanceStore) UpdateModelConfigInMemory(mc *configstoreTables.
24062400
gs.modelConfigs.Store(key, &clone)
24072401
} else {
24082402
key := clone.ModelName
2409-
if gs.modelMatcher != nil {
2410-
key = gs.modelMatcher.GetBaseModelName(clone.ModelName)
2403+
if gs.modelCatalog != nil {
2404+
key = gs.modelCatalog.GetBaseModelName(clone.ModelName)
24112405
}
24122406
gs.modelConfigs.Store(key, &clone)
24132407
}

plugins/governance/test_utils.go

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -325,46 +325,13 @@ func fetchDatasheetBaseIndex() {
325325
datasheetBaseIndex = index
326326
}
327327

328-
// mockModelMatcher implements ModelMatcher for testing.
329-
// It fetches the default datasheet to build a base model index,
330-
// matching the production ModelCatalog.GetBaseModelName behavior:
331-
// 1. Direct lookup in base model index
332-
// 2. Strip provider prefix and retry lookup
333-
// 3. Fallback to algorithmic date/version stripping
334-
type mockModelMatcher struct {
335-
baseModelIndex map[string]string
336-
}
337-
338-
func (m *mockModelMatcher) GetBaseModelName(model string) string {
339-
// Step 1: Direct lookup in base model index
340-
if base, ok := m.baseModelIndex[model]; ok {
341-
return base
342-
}
343-
344-
// Step 2: Strip provider prefix and try again
345-
_, baseName := schemas.ParseModelString(model, "")
346-
if baseName != model {
347-
if base, ok := m.baseModelIndex[baseName]; ok {
348-
return base
349-
}
350-
}
351-
352-
// Step 3: Fallback to algorithmic date/version stripping
353-
return schemas.BaseModelName(baseName)
354-
}
355-
356-
func (m *mockModelMatcher) IsSameModel(model1, model2 string) bool {
357-
if model1 == model2 {
358-
return true
359-
}
360-
return m.GetBaseModelName(model1) == m.GetBaseModelName(model2)
361-
}
362-
363-
func newMockModelMatcher(t *testing.T) ModelMatcher {
328+
// newTestModelCatalog creates a test ModelCatalog using the fetched datasheet base model index.
329+
// This provides proper nil-pointer semantics (unlike an interface wrapper).
330+
func newTestModelCatalog(t *testing.T) *modelcatalog.ModelCatalog {
364331
t.Helper()
365332
datasheetOnce.Do(fetchDatasheetBaseIndex)
366333
if datasheetErr != nil {
367-
t.Skipf("skipping: failed to fetch datasheet for mock model matcher: %v", datasheetErr)
334+
t.Skipf("skipping: failed to fetch datasheet for test model catalog: %v", datasheetErr)
368335
}
369-
return &mockModelMatcher{baseModelIndex: datasheetBaseIndex}
336+
return modelcatalog.NewTestCatalog(datasheetBaseIndex)
370337
}

0 commit comments

Comments
 (0)