Skip to content

Commit a054eea

Browse files
authored
feat(catalog): sort by model accuracy (kubeflow#1814)
* feat(catalog): sort by model accuracy Add a new sort option called ACCURACY which sorts models by the value of the `overall_average` property of associated artifacts. Signed-off-by: Paul Boyd <paul@pboyd.io> * fix(catalog): fix bugs with nextPageToken and the accuracy sort Signed-off-by: Paul Boyd <paul@pboyd.io> * chore(catalog): add a note about the ACCURACY sort option Signed-off-by: Paul Boyd <paul@pboyd.io> --------- Signed-off-by: Paul Boyd <paul@pboyd.io>
1 parent 44bbf62 commit a054eea

6 files changed

Lines changed: 610 additions & 45 deletions

File tree

api/openapi/catalog.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,12 +585,17 @@ components:
585585
string_value: my_value
586586
metadataType: MetadataStringValue
587587
OrderByField:
588-
description: Supported fields for ordering result entities.
588+
description: |-
589+
Supported fields for ordering result entities.
590+
591+
The `ACCURACY` sort only applies to catalog models, and will sort by
592+
the `overall_average` property in any linked metrics artifact.
589593
enum:
590594
- CREATE_TIME
591595
- LAST_UPDATE_TIME
592596
- ID
593597
- NAME
598+
- ACCURACY
594599
type: string
595600
SortOrder:
596601
description: Supported sort direction for ordering result entities.

api/openapi/src/catalog.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,17 @@ components:
348348
additionalProperties:
349349
$ref: "#/components/schemas/FilterOption"
350350
OrderByField:
351-
description: Supported fields for ordering result entities.
351+
description: |-
352+
Supported fields for ordering result entities.
353+
354+
The `ACCURACY` sort only applies to catalog models, and will sort by
355+
the `overall_average` property in any linked metrics artifact.
352356
enum:
353357
- CREATE_TIME
354358
- LAST_UPDATE_TIME
355359
- ID
356360
- NAME
361+
- ACCURACY
357362
type: string
358363

359364
responses:

catalog/internal/db/service/catalog_model.go

Lines changed: 231 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package service
22

33
import (
4+
"encoding/base64"
45
"errors"
56
"fmt"
7+
"strconv"
68
"strings"
79

10+
"github.com/golang/glog"
811
"github.com/kubeflow/model-registry/catalog/internal/db/filter"
912
"github.com/kubeflow/model-registry/catalog/internal/db/models"
1013
dbmodels "github.com/kubeflow/model-registry/internal/db/models"
@@ -15,28 +18,34 @@ import (
1518
"gorm.io/gorm"
1619
)
1720

21+
// accuracyProperty is the property of a metrics artifact to use when sorting by accuracy.
22+
const accuracyProperty = "overall_average"
23+
1824
var ErrCatalogModelNotFound = errors.New("catalog model by id not found")
1925

2026
type CatalogModelRepositoryImpl struct {
2127
*service.GenericRepository[models.CatalogModel, schema.Context, schema.ContextProperty, *models.CatalogModelListOptions]
28+
metricsArtifactTypeID int32
2229
}
2330

2431
func NewCatalogModelRepository(db *gorm.DB, typeID int32) models.CatalogModelRepository {
2532
r := &CatalogModelRepositoryImpl{}
2633

2734
r.GenericRepository = service.NewGenericRepository(service.GenericRepositoryConfig[models.CatalogModel, schema.Context, schema.ContextProperty, *models.CatalogModelListOptions]{
28-
DB: db,
29-
TypeID: typeID,
30-
EntityToSchema: mapCatalogModelToContext,
31-
SchemaToEntity: mapDataLayerToCatalogModel,
32-
EntityToProperties: mapCatalogModelToContextProperties,
33-
NotFoundError: ErrCatalogModelNotFound,
34-
EntityName: "catalog model",
35-
PropertyFieldName: "context_id",
36-
ApplyListFilters: applyCatalogModelListFilters,
37-
IsNewEntity: func(entity models.CatalogModel) bool { return entity.GetID() == nil },
38-
HasCustomProperties: func(entity models.CatalogModel) bool { return entity.GetCustomProperties() != nil },
39-
EntityMappingFuncs: filter.NewCatalogEntityMappings(),
35+
DB: db,
36+
TypeID: typeID,
37+
EntityToSchema: mapCatalogModelToContext,
38+
SchemaToEntity: mapDataLayerToCatalogModel,
39+
EntityToProperties: mapCatalogModelToContextProperties,
40+
NotFoundError: ErrCatalogModelNotFound,
41+
EntityName: "catalog model",
42+
PropertyFieldName: "context_id",
43+
ApplyListFilters: applyCatalogModelListFilters,
44+
CreatePaginationToken: r.createPaginationToken,
45+
ApplyCustomOrdering: r.applyAccuracyOrdering,
46+
IsNewEntity: func(entity models.CatalogModel) bool { return entity.GetID() == nil },
47+
HasCustomProperties: func(entity models.CatalogModel) bool { return entity.GetCustomProperties() != nil },
48+
EntityMappingFuncs: filter.NewCatalogEntityMappings(),
4049
})
4150

4251
return r
@@ -298,3 +307,213 @@ func (r *CatalogModelRepositoryImpl) GetFilterableProperties(maxLength int) (map
298307

299308
return result, nil
300309
}
310+
311+
// getMetricsArtifactTypeID looks up the type ID for CatalogMetricsArtifact dynamically
312+
func (r *CatalogModelRepositoryImpl) getMetricsArtifactTypeID() (int32, error) {
313+
if r.metricsArtifactTypeID != 0 {
314+
return r.metricsArtifactTypeID, nil
315+
}
316+
317+
// Look up the type ID dynamically from the database
318+
var typeRecord struct {
319+
ID int32 `gorm:"column:id"`
320+
}
321+
322+
err := r.GetConfig().DB.
323+
Table("\"Type\"").
324+
Select("id").
325+
Where("name = ?", CatalogMetricsArtifactTypeName).
326+
First(&typeRecord).Error
327+
328+
if err != nil {
329+
return 0, fmt.Errorf("failed to lookup CatalogMetricsArtifact type ID: %w", err)
330+
}
331+
332+
// Cache the type ID for future use
333+
r.metricsArtifactTypeID = typeRecord.ID
334+
return typeRecord.ID, nil
335+
}
336+
337+
// applyAccuracyOrdering applies custom ordering logic for ACCURACY orderBy field
338+
func (r *CatalogModelRepositoryImpl) applyAccuracyOrdering(query *gorm.DB, listOptions *models.CatalogModelListOptions) *gorm.DB {
339+
orderBy := listOptions.GetOrderBy()
340+
341+
// Only apply custom ordering for ACCURACY orderBy
342+
if orderBy != "ACCURACY" {
343+
// Fall back to standard pagination for non-ACCURACY ordering
344+
return r.ApplyStandardPagination(query, listOptions, []models.CatalogModel{})
345+
}
346+
347+
// Get the metrics artifact type ID
348+
metricsTypeID, err := r.getMetricsArtifactTypeID()
349+
if err != nil {
350+
// Fall back to standard pagination if we can't get the type ID
351+
return r.ApplyStandardPagination(query, listOptions, []models.CatalogModel{})
352+
}
353+
354+
db := r.GetConfig().DB
355+
contextTable := utils.GetTableName(db, &schema.Context{})
356+
attributionTable := utils.GetTableName(db, &schema.Attribution{})
357+
artifactTable := utils.GetTableName(db, &schema.Artifact{})
358+
propertyTable := utils.GetTableName(db, &schema.ArtifactProperty{})
359+
360+
sortOrder := listOptions.GetSortOrder()
361+
pageSize := listOptions.GetPageSize()
362+
363+
// Build the accuracy subquery
364+
// This gets the accuracy score for each model from its AccuracyMetric artifacts
365+
accuracySubquery := db.
366+
Select(fmt.Sprintf("%s.id, max(%s.double_value) AS accuracy", contextTable, propertyTable)).
367+
Table(contextTable).
368+
Joins(fmt.Sprintf("LEFT JOIN %s ON %s.id=%s.context_id", attributionTable, contextTable, attributionTable)).
369+
Joins(fmt.Sprintf("LEFT JOIN %s ON %s.artifact_id=%s.id AND %s.type_id=?", artifactTable, attributionTable, artifactTable, artifactTable), metricsTypeID).
370+
Joins(fmt.Sprintf("LEFT JOIN %s ON %s.id=%s.artifact_id AND %s.name=?", propertyTable, artifactTable, propertyTable, propertyTable), accuracyProperty).
371+
Where(contextTable+".type_id=?", r.GetConfig().TypeID).
372+
Group(contextTable + ".id")
373+
374+
// Join the main query with the accuracy subquery
375+
query = query.
376+
Joins("LEFT JOIN (?) accuracy ON \"Context\".id=accuracy.id", accuracySubquery)
377+
378+
// Apply sorting order
379+
if sortOrder != "ASC" {
380+
sortOrder = "DESC"
381+
}
382+
query = query.Order(fmt.Sprintf("accuracy %s NULLS LAST, %s.id", sortOrder, contextTable))
383+
384+
// Handle cursor-based pagination with nextPageToken
385+
nextPageToken := listOptions.GetNextPageToken()
386+
if nextPageToken != "" {
387+
// Parse the cursor from the token
388+
if cursor, err := r.parseNextPageToken(nextPageToken); err == nil {
389+
// Apply WHERE clause for cursor-based pagination with ACCURACY
390+
query = r.applyCursorPagination(query, cursor, sortOrder)
391+
}
392+
// If token parsing fails, fall back to no cursor (first page)
393+
}
394+
395+
// Apply pagination limit
396+
if pageSize > 0 {
397+
query = query.Limit(int(pageSize) + 1) // +1 to detect if there are more pages
398+
}
399+
400+
return query
401+
}
402+
403+
// cursor represents a pagination cursor with ID and accuracy value
404+
type accuracyCursor struct {
405+
ID int32
406+
Accuracy *float64
407+
}
408+
409+
// parseNextPageToken decodes a nextPageToken and extracts the cursor information
410+
func (r *CatalogModelRepositoryImpl) parseNextPageToken(token string) (*accuracyCursor, error) {
411+
// Sanity check the length before decoding
412+
if len(token) > 64 {
413+
return nil, fmt.Errorf("invalid nextPageToken")
414+
}
415+
416+
decoded, err := base64.StdEncoding.DecodeString(token)
417+
if err != nil {
418+
return nil, fmt.Errorf("failed to decode token: %w", err)
419+
}
420+
421+
parts := strings.Split(string(decoded), ":")
422+
if len(parts) != 2 {
423+
return nil, fmt.Errorf("invalid cursor format, expected 'ID:Value'")
424+
}
425+
426+
id, err := strconv.ParseInt(parts[0], 10, 32)
427+
if err != nil {
428+
return nil, fmt.Errorf("invalid ID in cursor: %w", err)
429+
}
430+
431+
cursor := accuracyCursor{ID: int32(id)}
432+
433+
// Parse accuracy value from cursor
434+
accuracy, err := strconv.ParseFloat(parts[1], 64)
435+
if err == nil {
436+
cursor.Accuracy = &accuracy
437+
}
438+
439+
return &cursor, nil
440+
}
441+
442+
// applyCursorPagination applies WHERE clause for cursor-based pagination with ACCURACY sorting
443+
func (r *CatalogModelRepositoryImpl) applyCursorPagination(query *gorm.DB, cursor *accuracyCursor, sortOrder string) *gorm.DB {
444+
contextTable := utils.GetTableName(query, &schema.Context{})
445+
446+
// Handle NULL accuracy values in cursor
447+
if cursor.Accuracy == nil {
448+
// For models without accuracy, just use ID-based pagination
449+
return query.Where(fmt.Sprintf("accuracy IS NULL AND %s.id > ?", contextTable), cursor.ID)
450+
}
451+
452+
accuracyValue := *cursor.Accuracy
453+
454+
// Apply cursor pagination logic for ACCURACY sorting
455+
if sortOrder == "ASC" {
456+
// For ASC: get records where (accuracy > cursor_accuracy) OR (accuracy = cursor_accuracy AND id > cursor_id)
457+
// Also include NULL values at the end
458+
return query.Where("(accuracy > ? OR (accuracy = ? AND "+contextTable+".id > ?) OR accuracy IS NULL)",
459+
accuracyValue, accuracyValue, cursor.ID)
460+
} else {
461+
// For DESC: get records where (accuracy < cursor_accuracy) OR (accuracy = cursor_accuracy AND id > cursor_id)
462+
return query.Where("(accuracy < ? OR (accuracy = ? AND "+contextTable+".id > ?) OR accuracy IS NULL)",
463+
accuracyValue, accuracyValue, cursor.ID)
464+
}
465+
}
466+
467+
func (r *CatalogModelRepositoryImpl) createPaginationToken(lastItem schema.Context, listOptions *models.CatalogModelListOptions) string {
468+
if listOptions.GetOrderBy() == "ACCURACY" {
469+
// The accuracy metric is not available from the context table,
470+
// so we'll need another query to get it.
471+
472+
db := r.GetConfig().DB
473+
contextTable := utils.GetTableName(db, &schema.Context{})
474+
attributionTable := utils.GetTableName(db, &schema.Attribution{})
475+
artifactTable := utils.GetTableName(db, &schema.Artifact{})
476+
propertyTable := utils.GetTableName(db, &schema.ArtifactProperty{})
477+
metricsTypeID, err := r.getMetricsArtifactTypeID()
478+
if err != nil {
479+
glog.Warningf("Failed to get metrics artifact type ID: %v", err)
480+
return r.CreateDefaultPaginationToken(lastItem, listOptions)
481+
}
482+
483+
query := db.
484+
Select("MAX(double_value) AS accuracy").
485+
Table(contextTable).
486+
Joins(fmt.Sprintf("LEFT JOIN %s ON %s.id=%s.context_id", attributionTable, contextTable, attributionTable)).
487+
Joins(fmt.Sprintf("LEFT JOIN %s ON %s.artifact_id=%s.id", artifactTable, attributionTable, artifactTable)).
488+
Joins(fmt.Sprintf("LEFT JOIN %s ON %s.id=%s.artifact_id", propertyTable, artifactTable, propertyTable)).
489+
Where(artifactTable+".type_id=?", metricsTypeID).
490+
Where(propertyTable+".name=?", accuracyProperty).
491+
Where(contextTable+".id=?", lastItem.ID)
492+
493+
var result struct {
494+
Accuracy *float64 `gorm:"accuracy"`
495+
}
496+
err = query.Scan(&result).Error
497+
if err != nil {
498+
glog.Warningf("Failed to get accuracy score: %v", err)
499+
return r.CreateDefaultPaginationToken(lastItem, listOptions)
500+
}
501+
502+
return createAccuracyPaginationToken(lastItem.ID, result.Accuracy)
503+
}
504+
505+
return r.CreateDefaultPaginationToken(lastItem, listOptions)
506+
}
507+
508+
// createAccuracyPaginationToken creates a pagination token for ACCURACY sorting
509+
func createAccuracyPaginationToken(entityID int32, accuracyValue *float64) string {
510+
var valueStr string
511+
if accuracyValue != nil {
512+
valueStr = fmt.Sprintf("%.15f", *accuracyValue)
513+
} else {
514+
valueStr = "" // Represents NULL
515+
}
516+
517+
cursor := fmt.Sprintf("%d:%s", entityID, valueStr)
518+
return base64.StdEncoding.EncodeToString([]byte(cursor))
519+
}

0 commit comments

Comments
 (0)