11package service
22
33import (
4+ "encoding/base64"
45 "errors"
56 "fmt"
7+ "strconv"
68 "strings"
79
10+ "github.com/golang/glog"
811 "github.com/kubeflow/model-registry/catalog/internal/db/filter"
912 "github.com/kubeflow/model-registry/catalog/internal/db/models"
1013 dbmodels "github.com/kubeflow/model-registry/internal/db/models"
@@ -15,28 +18,34 @@ import (
1518 "gorm.io/gorm"
1619)
1720
21+ // accuracyProperty is the property of a metrics artifact to use when sorting by accuracy.
22+ const accuracyProperty = "overall_average"
23+
1824var ErrCatalogModelNotFound = errors .New ("catalog model by id not found" )
1925
2026type CatalogModelRepositoryImpl struct {
2127 * service.GenericRepository [models.CatalogModel , schema.Context , schema.ContextProperty , * models.CatalogModelListOptions ]
28+ metricsArtifactTypeID int32
2229}
2330
2431func NewCatalogModelRepository (db * gorm.DB , typeID int32 ) models.CatalogModelRepository {
2532 r := & CatalogModelRepositoryImpl {}
2633
2734 r .GenericRepository = service .NewGenericRepository (service.GenericRepositoryConfig [models.CatalogModel , schema.Context , schema.ContextProperty , * models.CatalogModelListOptions ]{
28- DB : db ,
29- TypeID : typeID ,
30- EntityToSchema : mapCatalogModelToContext ,
31- SchemaToEntity : mapDataLayerToCatalogModel ,
32- EntityToProperties : mapCatalogModelToContextProperties ,
33- NotFoundError : ErrCatalogModelNotFound ,
34- EntityName : "catalog model" ,
35- PropertyFieldName : "context_id" ,
36- ApplyListFilters : applyCatalogModelListFilters ,
37- IsNewEntity : func (entity models.CatalogModel ) bool { return entity .GetID () == nil },
38- HasCustomProperties : func (entity models.CatalogModel ) bool { return entity .GetCustomProperties () != nil },
39- EntityMappingFuncs : filter .NewCatalogEntityMappings (),
35+ DB : db ,
36+ TypeID : typeID ,
37+ EntityToSchema : mapCatalogModelToContext ,
38+ SchemaToEntity : mapDataLayerToCatalogModel ,
39+ EntityToProperties : mapCatalogModelToContextProperties ,
40+ NotFoundError : ErrCatalogModelNotFound ,
41+ EntityName : "catalog model" ,
42+ PropertyFieldName : "context_id" ,
43+ ApplyListFilters : applyCatalogModelListFilters ,
44+ CreatePaginationToken : r .createPaginationToken ,
45+ ApplyCustomOrdering : r .applyAccuracyOrdering ,
46+ IsNewEntity : func (entity models.CatalogModel ) bool { return entity .GetID () == nil },
47+ HasCustomProperties : func (entity models.CatalogModel ) bool { return entity .GetCustomProperties () != nil },
48+ EntityMappingFuncs : filter .NewCatalogEntityMappings (),
4049 })
4150
4251 return r
@@ -298,3 +307,213 @@ func (r *CatalogModelRepositoryImpl) GetFilterableProperties(maxLength int) (map
298307
299308 return result , nil
300309}
310+
311+ // getMetricsArtifactTypeID looks up the type ID for CatalogMetricsArtifact dynamically
312+ func (r * CatalogModelRepositoryImpl ) getMetricsArtifactTypeID () (int32 , error ) {
313+ if r .metricsArtifactTypeID != 0 {
314+ return r .metricsArtifactTypeID , nil
315+ }
316+
317+ // Look up the type ID dynamically from the database
318+ var typeRecord struct {
319+ ID int32 `gorm:"column:id"`
320+ }
321+
322+ err := r .GetConfig ().DB .
323+ Table ("\" Type\" " ).
324+ Select ("id" ).
325+ Where ("name = ?" , CatalogMetricsArtifactTypeName ).
326+ First (& typeRecord ).Error
327+
328+ if err != nil {
329+ return 0 , fmt .Errorf ("failed to lookup CatalogMetricsArtifact type ID: %w" , err )
330+ }
331+
332+ // Cache the type ID for future use
333+ r .metricsArtifactTypeID = typeRecord .ID
334+ return typeRecord .ID , nil
335+ }
336+
337+ // applyAccuracyOrdering applies custom ordering logic for ACCURACY orderBy field
338+ func (r * CatalogModelRepositoryImpl ) applyAccuracyOrdering (query * gorm.DB , listOptions * models.CatalogModelListOptions ) * gorm.DB {
339+ orderBy := listOptions .GetOrderBy ()
340+
341+ // Only apply custom ordering for ACCURACY orderBy
342+ if orderBy != "ACCURACY" {
343+ // Fall back to standard pagination for non-ACCURACY ordering
344+ return r .ApplyStandardPagination (query , listOptions , []models.CatalogModel {})
345+ }
346+
347+ // Get the metrics artifact type ID
348+ metricsTypeID , err := r .getMetricsArtifactTypeID ()
349+ if err != nil {
350+ // Fall back to standard pagination if we can't get the type ID
351+ return r .ApplyStandardPagination (query , listOptions , []models.CatalogModel {})
352+ }
353+
354+ db := r .GetConfig ().DB
355+ contextTable := utils .GetTableName (db , & schema.Context {})
356+ attributionTable := utils .GetTableName (db , & schema.Attribution {})
357+ artifactTable := utils .GetTableName (db , & schema.Artifact {})
358+ propertyTable := utils .GetTableName (db , & schema.ArtifactProperty {})
359+
360+ sortOrder := listOptions .GetSortOrder ()
361+ pageSize := listOptions .GetPageSize ()
362+
363+ // Build the accuracy subquery
364+ // This gets the accuracy score for each model from its AccuracyMetric artifacts
365+ accuracySubquery := db .
366+ Select (fmt .Sprintf ("%s.id, max(%s.double_value) AS accuracy" , contextTable , propertyTable )).
367+ Table (contextTable ).
368+ Joins (fmt .Sprintf ("LEFT JOIN %s ON %s.id=%s.context_id" , attributionTable , contextTable , attributionTable )).
369+ Joins (fmt .Sprintf ("LEFT JOIN %s ON %s.artifact_id=%s.id AND %s.type_id=?" , artifactTable , attributionTable , artifactTable , artifactTable ), metricsTypeID ).
370+ Joins (fmt .Sprintf ("LEFT JOIN %s ON %s.id=%s.artifact_id AND %s.name=?" , propertyTable , artifactTable , propertyTable , propertyTable ), accuracyProperty ).
371+ Where (contextTable + ".type_id=?" , r .GetConfig ().TypeID ).
372+ Group (contextTable + ".id" )
373+
374+ // Join the main query with the accuracy subquery
375+ query = query .
376+ Joins ("LEFT JOIN (?) accuracy ON \" Context\" .id=accuracy.id" , accuracySubquery )
377+
378+ // Apply sorting order
379+ if sortOrder != "ASC" {
380+ sortOrder = "DESC"
381+ }
382+ query = query .Order (fmt .Sprintf ("accuracy %s NULLS LAST, %s.id" , sortOrder , contextTable ))
383+
384+ // Handle cursor-based pagination with nextPageToken
385+ nextPageToken := listOptions .GetNextPageToken ()
386+ if nextPageToken != "" {
387+ // Parse the cursor from the token
388+ if cursor , err := r .parseNextPageToken (nextPageToken ); err == nil {
389+ // Apply WHERE clause for cursor-based pagination with ACCURACY
390+ query = r .applyCursorPagination (query , cursor , sortOrder )
391+ }
392+ // If token parsing fails, fall back to no cursor (first page)
393+ }
394+
395+ // Apply pagination limit
396+ if pageSize > 0 {
397+ query = query .Limit (int (pageSize ) + 1 ) // +1 to detect if there are more pages
398+ }
399+
400+ return query
401+ }
402+
403+ // cursor represents a pagination cursor with ID and accuracy value
404+ type accuracyCursor struct {
405+ ID int32
406+ Accuracy * float64
407+ }
408+
409+ // parseNextPageToken decodes a nextPageToken and extracts the cursor information
410+ func (r * CatalogModelRepositoryImpl ) parseNextPageToken (token string ) (* accuracyCursor , error ) {
411+ // Sanity check the length before decoding
412+ if len (token ) > 64 {
413+ return nil , fmt .Errorf ("invalid nextPageToken" )
414+ }
415+
416+ decoded , err := base64 .StdEncoding .DecodeString (token )
417+ if err != nil {
418+ return nil , fmt .Errorf ("failed to decode token: %w" , err )
419+ }
420+
421+ parts := strings .Split (string (decoded ), ":" )
422+ if len (parts ) != 2 {
423+ return nil , fmt .Errorf ("invalid cursor format, expected 'ID:Value'" )
424+ }
425+
426+ id , err := strconv .ParseInt (parts [0 ], 10 , 32 )
427+ if err != nil {
428+ return nil , fmt .Errorf ("invalid ID in cursor: %w" , err )
429+ }
430+
431+ cursor := accuracyCursor {ID : int32 (id )}
432+
433+ // Parse accuracy value from cursor
434+ accuracy , err := strconv .ParseFloat (parts [1 ], 64 )
435+ if err == nil {
436+ cursor .Accuracy = & accuracy
437+ }
438+
439+ return & cursor , nil
440+ }
441+
442+ // applyCursorPagination applies WHERE clause for cursor-based pagination with ACCURACY sorting
443+ func (r * CatalogModelRepositoryImpl ) applyCursorPagination (query * gorm.DB , cursor * accuracyCursor , sortOrder string ) * gorm.DB {
444+ contextTable := utils .GetTableName (query , & schema.Context {})
445+
446+ // Handle NULL accuracy values in cursor
447+ if cursor .Accuracy == nil {
448+ // For models without accuracy, just use ID-based pagination
449+ return query .Where (fmt .Sprintf ("accuracy IS NULL AND %s.id > ?" , contextTable ), cursor .ID )
450+ }
451+
452+ accuracyValue := * cursor .Accuracy
453+
454+ // Apply cursor pagination logic for ACCURACY sorting
455+ if sortOrder == "ASC" {
456+ // For ASC: get records where (accuracy > cursor_accuracy) OR (accuracy = cursor_accuracy AND id > cursor_id)
457+ // Also include NULL values at the end
458+ return query .Where ("(accuracy > ? OR (accuracy = ? AND " + contextTable + ".id > ?) OR accuracy IS NULL)" ,
459+ accuracyValue , accuracyValue , cursor .ID )
460+ } else {
461+ // For DESC: get records where (accuracy < cursor_accuracy) OR (accuracy = cursor_accuracy AND id > cursor_id)
462+ return query .Where ("(accuracy < ? OR (accuracy = ? AND " + contextTable + ".id > ?) OR accuracy IS NULL)" ,
463+ accuracyValue , accuracyValue , cursor .ID )
464+ }
465+ }
466+
467+ func (r * CatalogModelRepositoryImpl ) createPaginationToken (lastItem schema.Context , listOptions * models.CatalogModelListOptions ) string {
468+ if listOptions .GetOrderBy () == "ACCURACY" {
469+ // The accuracy metric is not available from the context table,
470+ // so we'll need another query to get it.
471+
472+ db := r .GetConfig ().DB
473+ contextTable := utils .GetTableName (db , & schema.Context {})
474+ attributionTable := utils .GetTableName (db , & schema.Attribution {})
475+ artifactTable := utils .GetTableName (db , & schema.Artifact {})
476+ propertyTable := utils .GetTableName (db , & schema.ArtifactProperty {})
477+ metricsTypeID , err := r .getMetricsArtifactTypeID ()
478+ if err != nil {
479+ glog .Warningf ("Failed to get metrics artifact type ID: %v" , err )
480+ return r .CreateDefaultPaginationToken (lastItem , listOptions )
481+ }
482+
483+ query := db .
484+ Select ("MAX(double_value) AS accuracy" ).
485+ Table (contextTable ).
486+ Joins (fmt .Sprintf ("LEFT JOIN %s ON %s.id=%s.context_id" , attributionTable , contextTable , attributionTable )).
487+ Joins (fmt .Sprintf ("LEFT JOIN %s ON %s.artifact_id=%s.id" , artifactTable , attributionTable , artifactTable )).
488+ Joins (fmt .Sprintf ("LEFT JOIN %s ON %s.id=%s.artifact_id" , propertyTable , artifactTable , propertyTable )).
489+ Where (artifactTable + ".type_id=?" , metricsTypeID ).
490+ Where (propertyTable + ".name=?" , accuracyProperty ).
491+ Where (contextTable + ".id=?" , lastItem .ID )
492+
493+ var result struct {
494+ Accuracy * float64 `gorm:"accuracy"`
495+ }
496+ err = query .Scan (& result ).Error
497+ if err != nil {
498+ glog .Warningf ("Failed to get accuracy score: %v" , err )
499+ return r .CreateDefaultPaginationToken (lastItem , listOptions )
500+ }
501+
502+ return createAccuracyPaginationToken (lastItem .ID , result .Accuracy )
503+ }
504+
505+ return r .CreateDefaultPaginationToken (lastItem , listOptions )
506+ }
507+
508+ // createAccuracyPaginationToken creates a pagination token for ACCURACY sorting
509+ func createAccuracyPaginationToken (entityID int32 , accuracyValue * float64 ) string {
510+ var valueStr string
511+ if accuracyValue != nil {
512+ valueStr = fmt .Sprintf ("%.15f" , * accuracyValue )
513+ } else {
514+ valueStr = "" // Represents NULL
515+ }
516+
517+ cursor := fmt .Sprintf ("%d:%s" , entityID , valueStr )
518+ return base64 .StdEncoding .EncodeToString ([]byte (cursor ))
519+ }
0 commit comments