Skip to content

Commit 4a29dc8

Browse files
authored
feat(catalog): remove orphaned models (kubeflow#1962)
* feat(catalog): remove orphaned models Remove models from the database that are no longer referenced. This can happen because a model is removed or excluded from a source, or the source can be deleted or disabled. Signed-off-by: Paul Boyd <paul@pboyd.io> * fix(catalog): race condition causing empty models list Signed-off-by: Paul Boyd <paul@pboyd.io> --------- Signed-off-by: Paul Boyd <paul@pboyd.io>
1 parent 7bbb7e1 commit 4a29dc8

11 files changed

Lines changed: 601 additions & 5 deletions

File tree

catalog/internal/catalog/catalog_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"reflect"
77
"sort"
8+
"sync"
89
"testing"
910
"time"
1011

@@ -376,6 +377,9 @@ func TestLoadCatalogSourcesWithMockRepositories(t *testing.T) {
376377
// Wait a bit for the goroutine to process
377378
time.Sleep(100 * time.Millisecond)
378379

380+
mockModelRepo.mu.RLock()
381+
defer mockModelRepo.mu.RUnlock()
382+
379383
// Verify that the model was saved
380384
if len(mockModelRepo.SavedModels) != 1 {
381385
t.Errorf("Expected 1 model to be saved, got %d", len(mockModelRepo.SavedModels))
@@ -390,11 +394,17 @@ func TestLoadCatalogSourcesWithMockRepositories(t *testing.T) {
390394
}
391395
}
392396

397+
mockModelArtifactRepo.mu.RLock()
398+
defer mockModelArtifactRepo.mu.RUnlock()
399+
393400
// Verify that artifacts were saved
394401
if len(mockModelArtifactRepo.SavedArtifacts) != 1 {
395402
t.Errorf("Expected 1 model artifact to be saved, got %d", len(mockModelArtifactRepo.SavedArtifacts))
396403
}
397404

405+
mockMetricsArtifactRepo.mu.RLock()
406+
defer mockMetricsArtifactRepo.mu.RUnlock()
407+
398408
if len(mockMetricsArtifactRepo.SavedMetrics) != 1 {
399409
t.Errorf("Expected 1 metrics artifact to be saved, got %d", len(mockMetricsArtifactRepo.SavedMetrics))
400410
}
@@ -470,6 +480,9 @@ func TestLoadCatalogSourcesWithRepositoryErrors(t *testing.T) {
470480
// Wait for processing
471481
time.Sleep(100 * time.Millisecond)
472482

483+
mockModelRepo.mu.RLock()
484+
defer mockModelRepo.mu.RUnlock()
485+
473486
// Verify that no models were saved due to the error
474487
if len(mockModelRepo.SavedModels) != 0 {
475488
t.Errorf("Expected 0 models to be saved due to error, got %d", len(mockModelRepo.SavedModels))
@@ -638,11 +651,14 @@ func (m *MockCatalogModelRepositoryWithErrors) Save(model dbmodels.CatalogModel)
638651

639652
// MockCatalogModelRepository mocks the CatalogModelRepository interface.
640653
type MockCatalogModelRepository struct {
654+
mu sync.RWMutex
641655
SavedModels []dbmodels.CatalogModel
642656
NextID int32
643657
}
644658

645659
func (m *MockCatalogModelRepository) GetByID(id int32) (dbmodels.CatalogModel, error) {
660+
m.mu.RLock()
661+
defer m.mu.RUnlock()
646662
for _, model := range m.SavedModels {
647663
if model.GetID() != nil && *model.GetID() == id {
648664
return model, nil
@@ -652,6 +668,8 @@ func (m *MockCatalogModelRepository) GetByID(id int32) (dbmodels.CatalogModel, e
652668
}
653669

654670
func (m *MockCatalogModelRepository) List(listOptions dbmodels.CatalogModelListOptions) (*mrmodels.ListWrapper[dbmodels.CatalogModel], error) {
671+
m.mu.RLock()
672+
defer m.mu.RUnlock()
655673
return &mrmodels.ListWrapper[dbmodels.CatalogModel]{
656674
Items: m.SavedModels,
657675
PageSize: int32(len(m.SavedModels)),
@@ -660,6 +678,8 @@ func (m *MockCatalogModelRepository) List(listOptions dbmodels.CatalogModelListO
660678
}
661679

662680
func (m *MockCatalogModelRepository) GetByName(name string) (dbmodels.CatalogModel, error) {
681+
m.mu.RLock()
682+
defer m.mu.RUnlock()
663683
for _, model := range m.SavedModels {
664684
if model.GetAttributes() != nil && model.GetAttributes().Name != nil && *model.GetAttributes().Name == name {
665685
return model, nil
@@ -669,6 +689,9 @@ func (m *MockCatalogModelRepository) GetByName(name string) (dbmodels.CatalogMod
669689
}
670690

671691
func (m *MockCatalogModelRepository) Save(model dbmodels.CatalogModel) (dbmodels.CatalogModel, error) {
692+
m.mu.Lock()
693+
defer m.mu.Unlock()
694+
672695
m.NextID++
673696
id := m.NextID
674697

@@ -685,13 +708,31 @@ func (m *MockCatalogModelRepository) Save(model dbmodels.CatalogModel) (dbmodels
685708
return savedModel, nil
686709
}
687710

711+
func (m *MockCatalogModelRepository) DeleteBySource(sourceID string) error {
712+
// Mock implementation - no-op for testing
713+
return nil
714+
}
715+
716+
func (m *MockCatalogModelRepository) DeleteByID(id int32) error {
717+
// Mock implementation - no-op for testing
718+
return nil
719+
}
720+
721+
func (m *MockCatalogModelRepository) GetDistinctSourceIDs() ([]string, error) {
722+
// Mock implementation - return empty list by default
723+
return []string{}, nil
724+
}
725+
688726
// MockCatalogModelArtifactRepository mocks the CatalogModelArtifactRepository interface.
689727
type MockCatalogModelArtifactRepository struct {
728+
mu sync.RWMutex
690729
SavedArtifacts []dbmodels.CatalogModelArtifact
691730
NextID int32
692731
}
693732

694733
func (m *MockCatalogModelArtifactRepository) GetByID(id int32) (dbmodels.CatalogModelArtifact, error) {
734+
m.mu.RLock()
735+
defer m.mu.RUnlock()
695736
for _, artifact := range m.SavedArtifacts {
696737
if artifact.GetID() != nil && *artifact.GetID() == id {
697738
return artifact, nil
@@ -701,6 +742,8 @@ func (m *MockCatalogModelArtifactRepository) GetByID(id int32) (dbmodels.Catalog
701742
}
702743

703744
func (m *MockCatalogModelArtifactRepository) List(listOptions dbmodels.CatalogModelArtifactListOptions) (*mrmodels.ListWrapper[dbmodels.CatalogModelArtifact], error) {
745+
m.mu.RLock()
746+
defer m.mu.RUnlock()
704747
return &mrmodels.ListWrapper[dbmodels.CatalogModelArtifact]{
705748
Items: m.SavedArtifacts,
706749
PageSize: int32(len(m.SavedArtifacts)),
@@ -709,6 +752,9 @@ func (m *MockCatalogModelArtifactRepository) List(listOptions dbmodels.CatalogMo
709752
}
710753

711754
func (m *MockCatalogModelArtifactRepository) Save(modelArtifact dbmodels.CatalogModelArtifact, parentResourceID *int32) (dbmodels.CatalogModelArtifact, error) {
755+
m.mu.Lock()
756+
defer m.mu.Unlock()
757+
712758
m.NextID++
713759
id := m.NextID
714760

@@ -727,11 +773,15 @@ func (m *MockCatalogModelArtifactRepository) Save(modelArtifact dbmodels.Catalog
727773

728774
// MockCatalogMetricsArtifactRepository mocks the CatalogMetricsArtifactRepository interface.
729775
type MockCatalogMetricsArtifactRepository struct {
776+
mu sync.RWMutex
730777
SavedMetrics []dbmodels.CatalogMetricsArtifact
731778
NextID int32
732779
}
733780

734781
func (m *MockCatalogMetricsArtifactRepository) GetByID(id int32) (dbmodels.CatalogMetricsArtifact, error) {
782+
m.mu.RLock()
783+
defer m.mu.RUnlock()
784+
735785
for _, metrics := range m.SavedMetrics {
736786
if metrics.GetID() != nil && *metrics.GetID() == id {
737787
return metrics, nil
@@ -741,6 +791,9 @@ func (m *MockCatalogMetricsArtifactRepository) GetByID(id int32) (dbmodels.Catal
741791
}
742792

743793
func (m *MockCatalogMetricsArtifactRepository) List(listOptions dbmodels.CatalogMetricsArtifactListOptions) (*mrmodels.ListWrapper[dbmodels.CatalogMetricsArtifact], error) {
794+
m.mu.RLock()
795+
defer m.mu.RUnlock()
796+
744797
return &mrmodels.ListWrapper[dbmodels.CatalogMetricsArtifact]{
745798
Items: m.SavedMetrics,
746799
PageSize: int32(len(m.SavedMetrics)),
@@ -749,6 +802,9 @@ func (m *MockCatalogMetricsArtifactRepository) List(listOptions dbmodels.Catalog
749802
}
750803

751804
func (m *MockCatalogMetricsArtifactRepository) Save(metricsArtifact dbmodels.CatalogMetricsArtifact, parentResourceID *int32) (dbmodels.CatalogMetricsArtifact, error) {
805+
m.mu.Lock()
806+
defer m.mu.Unlock()
807+
752808
m.NextID++
753809
id := m.NextID
754810

@@ -766,6 +822,9 @@ func (m *MockCatalogMetricsArtifactRepository) Save(metricsArtifact dbmodels.Cat
766822
}
767823

768824
func (m *MockCatalogMetricsArtifactRepository) BatchSave(metricsArtifacts []dbmodels.CatalogMetricsArtifact, parentResourceID *int32) ([]dbmodels.CatalogMetricsArtifact, error) {
825+
m.mu.Lock()
826+
defer m.mu.Unlock()
827+
769828
savedArtifacts := make([]dbmodels.CatalogMetricsArtifact, len(metricsArtifacts))
770829

771830
for i, metricsArtifact := range metricsArtifacts {
@@ -790,11 +849,14 @@ func (m *MockCatalogMetricsArtifactRepository) BatchSave(metricsArtifacts []dbmo
790849

791850
// MockCatalogArtifactRepository mocks the CatalogArtifactRepository interface.
792851
type MockCatalogArtifactRepository struct {
852+
mu sync.RWMutex
793853
SavedArtifacts []dbmodels.CatalogArtifact
794854
NextID int32
795855
}
796856

797857
func (m *MockCatalogArtifactRepository) GetByID(id int32) (dbmodels.CatalogArtifact, error) {
858+
m.mu.RLock()
859+
defer m.mu.RUnlock()
798860
for _, artifact := range m.SavedArtifacts {
799861
// Check both model and metrics artifacts for the ID
800862
if artifact.CatalogModelArtifact != nil && artifact.CatalogModelArtifact.GetID() != nil && *artifact.CatalogModelArtifact.GetID() == id {
@@ -808,6 +870,8 @@ func (m *MockCatalogArtifactRepository) GetByID(id int32) (dbmodels.CatalogArtif
808870
}
809871

810872
func (m *MockCatalogArtifactRepository) List(listOptions dbmodels.CatalogArtifactListOptions) (*mrmodels.ListWrapper[dbmodels.CatalogArtifact], error) {
873+
m.mu.RLock()
874+
defer m.mu.RUnlock()
811875
return &mrmodels.ListWrapper[dbmodels.CatalogArtifact]{
812876
Items: m.SavedArtifacts,
813877
PageSize: int32(len(m.SavedArtifacts)),

catalog/internal/catalog/hf_catalog.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,12 @@ func (p *hfModelProvider) emit(ctx context.Context, models []ModelProviderRecord
577577
return
578578
}
579579
}
580+
581+
// Send an empty record to indicate that we're done with the batch.
582+
select {
583+
case out <- ModelProviderRecord{}:
584+
case <-done:
585+
}
580586
}
581587

582588
// validateCredentials checks if the HuggingFace API key credentials are valid

catalog/internal/catalog/loader.go

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"path/filepath"
99
"sync"
1010

11+
mapset "github.com/deckarep/golang-set/v2"
1112
"github.com/golang/glog"
1213
dbmodels "github.com/kubeflow/model-registry/catalog/internal/db/models"
1314
"github.com/kubeflow/model-registry/catalog/internal/db/service"
@@ -26,6 +27,9 @@ type ModelProviderRecord struct {
2627
// expected to spawn a goroutine and return immediately. The returned channel must
2728
// close when the goroutine ends. The goroutine should end when the context is
2829
// canceled, but may end sooner.
30+
//
31+
// The function may emit a record with a nil Model to indicate that the
32+
// complete set of models has been sent.
2933
type ModelProviderFunc func(ctx context.Context, source *Source, reldir string) (<-chan ModelProviderRecord, error)
3034

3135
var registeredModelProviders = map[string]ModelProviderFunc{}
@@ -120,8 +124,14 @@ func (l *Loader) Start(ctx context.Context) error {
120124
}
121125
}
122126

127+
// Delete models from unknown or disabled sources
128+
err := l.removeModelsFromMissingSources()
129+
if err != nil {
130+
return fmt.Errorf("faied to remove models from missing sources: %w", err)
131+
}
132+
123133
// Phase 2: Load models from merged sources (once, after all merging is complete)
124-
err := l.loadAllModels(ctx)
134+
err = l.loadAllModels(ctx)
125135
if err != nil {
126136
return err
127137
}
@@ -269,6 +279,9 @@ func (l *Loader) updateDatabase(ctx context.Context) error {
269279

270280
go func() {
271281
for record := range records {
282+
if record.Model == nil {
283+
continue
284+
}
272285
attr := record.Model.GetAttributes()
273286
if attr == nil || attr.Name == nil {
274287
continue
@@ -377,7 +390,30 @@ func (l *Loader) readProviderRecords(ctx context.Context) <-chan ModelProviderRe
377390
wg.Add(1)
378391
go func() {
379392
defer wg.Done()
393+
394+
modelNames := []string{}
395+
380396
for r := range records {
397+
if r.Model == nil {
398+
glog.V(2).Infof("%s: trigger cleanup", source.Id)
399+
400+
// Copy the list of model names, then clear it.
401+
modelNameSet := mapset.NewSet(modelNames...)
402+
modelNames = modelNames[:0]
403+
404+
go func() {
405+
err := l.removeOrphanedModelsFromSource(source.Id, modelNameSet)
406+
if err != nil {
407+
glog.Errorf("error removing orphaned models: %v", err)
408+
}
409+
}()
410+
continue
411+
}
412+
413+
if attr := r.Model.GetAttributes(); attr != nil && attr.Name != nil {
414+
modelNames = append(modelNames, *attr.Name)
415+
}
416+
381417
// Set source_id on every returned model.
382418
l.setModelSourceID(r.Model, source.Id)
383419

@@ -424,3 +460,57 @@ func (l *Loader) setModelSourceID(model dbmodels.CatalogModel, sourceID string)
424460

425461
*props = append(*props, mrmodels.NewStringProperty("source_id", sourceID, false))
426462
}
463+
464+
func (l *Loader) removeModelsFromMissingSources() error {
465+
enabledSourceIDs := mapset.NewSet[string]()
466+
for id, source := range l.Sources.AllSources() {
467+
if source.Enabled == nil || *source.Enabled {
468+
enabledSourceIDs.Add(id)
469+
}
470+
}
471+
472+
existingSourceIDs, err := l.services.CatalogModelRepository.GetDistinctSourceIDs()
473+
if err != nil {
474+
return fmt.Errorf("unable to retrieve existing source IDs: %w", err)
475+
}
476+
477+
for oldSource := range mapset.NewSet(existingSourceIDs...).Difference(enabledSourceIDs).Iter() {
478+
glog.Infof("Removing models from source %s", oldSource)
479+
480+
err = l.services.CatalogModelRepository.DeleteBySource(oldSource)
481+
if err != nil {
482+
return fmt.Errorf("unable to remove models from source %q: %w", oldSource, err)
483+
}
484+
}
485+
486+
return nil
487+
}
488+
489+
func (l *Loader) removeOrphanedModelsFromSource(sourceID string, valid mapset.Set[string]) error {
490+
list, err := l.services.CatalogModelRepository.List(dbmodels.CatalogModelListOptions{
491+
SourceIDs: &[]string{sourceID},
492+
})
493+
if err != nil {
494+
return fmt.Errorf("unable to list models from source %q: %w", sourceID, err)
495+
}
496+
497+
for _, model := range list.Items {
498+
attr := model.GetAttributes()
499+
if attr == nil || attr.Name == nil || model.GetID() == nil {
500+
continue
501+
}
502+
503+
if valid.Contains(*attr.Name) {
504+
continue
505+
}
506+
507+
glog.Infof("Removing %s model %s", sourceID, *attr.Name)
508+
509+
err = l.services.CatalogModelRepository.DeleteByID(*model.GetID())
510+
if err != nil {
511+
return fmt.Errorf("unable to remove model %d (%s from source %s): %w", *model.GetID(), *attr.Name, sourceID, err)
512+
}
513+
}
514+
515+
return nil
516+
}

0 commit comments

Comments
 (0)