Skip to content

Commit 9d1462e

Browse files
authored
fix(catalog): include short values from JSON arrays in filter_options (kubeflow#1783)
Long JSON strings were excluded from filter_options output even if the individual values in the array were short. Signed-off-by: Paul Boyd <paul@pboyd.io>
1 parent f0eb5d5 commit 9d1462e

3 files changed

Lines changed: 62 additions & 63 deletions

File tree

catalog/internal/catalog/db_catalog.go

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -205,40 +205,25 @@ func (d *dbCatalogImpl) GetFilterOptions(ctx context.Context) (*apimodels.Filter
205205
continue
206206
}
207207

208-
// Deduplicate values
209-
uniqueValues := make(map[string]bool)
210-
211-
// Parse JSON arrays for fields like language and tasks
212-
for _, value := range values {
213-
var arrayValues []string
214-
if err := json.Unmarshal([]byte(value), &arrayValues); err == nil {
215-
// Successfully parsed as array, add individual values
216-
for _, v := range arrayValues {
217-
uniqueValues[v] = true
218-
}
219-
} else {
220-
// Not a JSON array
221-
uniqueValues[value] = true
222-
}
208+
if len(values) == 0 {
209+
continue
223210
}
224211

225-
if len(uniqueValues) > 0 {
226-
sortedValues := make([]string, 0, len(uniqueValues))
227-
for v := range uniqueValues {
228-
sortedValues = append(sortedValues, v)
229-
}
230-
sort.Strings(sortedValues)
212+
sortedValues := make([]string, 0, len(values))
213+
for _, v := range values {
214+
sortedValues = append(sortedValues, v)
215+
}
216+
sort.Strings(sortedValues)
231217

232-
// Convert to []interface{} (supports future non-string filter types)
233-
expandedValues := make([]interface{}, len(sortedValues))
234-
for i, v := range sortedValues {
235-
expandedValues[i] = v
236-
}
218+
// Convert to []any (supports future non-string filter types)
219+
expandedValues := make([]any, len(sortedValues))
220+
for i, v := range sortedValues {
221+
expandedValues[i] = v
222+
}
237223

238-
options[fieldName] = apimodels.FilterOption{
239-
Type: "string",
240-
Values: expandedValues,
241-
}
224+
options[fieldName] = apimodels.FilterOption{
225+
Type: "string",
226+
Values: expandedValues,
242227
}
243228
}
244229

catalog/internal/db/service/catalog_model.go

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/kubeflow/model-registry/internal/db/schema"
1313
"github.com/kubeflow/model-registry/internal/db/service"
1414
"github.com/kubeflow/model-registry/internal/db/utils"
15+
"github.com/lib/pq"
1516
"gorm.io/gorm"
1617
)
1718

@@ -237,49 +238,63 @@ func mapDataLayerToCatalogModel(modelCtx schema.Context, propertiesCtx []schema.
237238
func (r *CatalogModelRepositoryImpl) GetFilterableProperties(maxLength int) (map[string][]string, error) {
238239
config := r.GetConfig()
239240

241+
if config.DB.Name() != "postgres" {
242+
return nil, fmt.Errorf("GetFilterableProperties is only supported on PostgreSQL")
243+
}
244+
245+
db, err := config.DB.DB()
246+
if err != nil {
247+
return nil, err
248+
}
249+
240250
// Get table names using GORM utilities for database compatibility
241251
contextTable := utils.GetTableName(config.DB, &schema.Context{})
242252
propertyTable := utils.GetTableName(config.DB, &schema.ContextProperty{})
243253

244-
// Simplified query: get distinct property name/value pairs
245254
query := fmt.Sprintf(`
246-
SELECT DISTINCT cp.name, cp.string_value
247-
FROM %s cp
248-
WHERE cp.context_id IN (
249-
SELECT id FROM %s WHERE type_id = ?
250-
)
251-
AND cp.name IN (
252-
SELECT name FROM (
253-
SELECT name, MAX(CHAR_LENGTH(string_value)) as max_len
254-
FROM %s
255-
WHERE context_id IN (
256-
SELECT id FROM %s WHERE type_id = ?
255+
SELECT name, array_agg(string_value) FROM (
256+
SELECT
257+
name,
258+
string_value
259+
FROM %s WHERE
260+
context_id IN (
261+
SELECT id FROM %s WHERE type_id=$1
257262
)
258263
AND string_value IS NOT NULL
259264
AND string_value != ''
260-
GROUP BY name
261-
) AS field_lengths
262-
WHERE max_len <= ?
265+
AND string_value IS NOT JSON ARRAY
266+
267+
UNION
268+
269+
SELECT
270+
name,
271+
json_array_elements_text(string_value::json) AS string_value
272+
FROM %s WHERE
273+
context_id IN (
274+
SELECT id FROM %s WHERE type_id=$1
275+
)
276+
AND string_value IS JSON ARRAY
263277
)
264-
AND cp.string_value IS NOT NULL
265-
AND cp.string_value != ''
266-
ORDER BY cp.name, cp.string_value
278+
GROUP BY name HAVING MAX(CHAR_LENGTH(string_value)) <= $2
267279
`, propertyTable, contextTable, propertyTable, contextTable)
268280

269-
type propertyRow struct {
270-
Name string
271-
StringValue string
272-
}
273-
274-
var rows []propertyRow
275-
if err := config.DB.Raw(query, config.TypeID, config.TypeID, maxLength).Scan(&rows).Error; err != nil {
281+
rows, err := db.Query(query, config.TypeID, maxLength)
282+
if err != nil {
276283
return nil, fmt.Errorf("error querying filterable properties: %w", err)
277284
}
285+
defer rows.Close()
286+
287+
result := map[string][]string{}
288+
for rows.Next() {
289+
var name string
290+
var values pq.StringArray
291+
292+
err = rows.Scan(&name, &values)
293+
if err != nil {
294+
return nil, fmt.Errorf("error scanning filterable property row: %w", err)
295+
}
278296

279-
// Aggregate values by property name in Go
280-
result := make(map[string][]string)
281-
for _, row := range rows {
282-
result[row.Name] = append(result[row.Name], row.StringValue)
297+
result[name] = []string(values)
283298
}
284299

285300
return result, nil

catalog/internal/db/service/catalog_model_test.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ func TestCatalogModelRepository(t *testing.T) {
199199
updateModel := &models.CatalogModelImpl{
200200
ID: saved.GetID(), // Specify the ID for update
201201
Attributes: &models.CatalogModelAttributes{
202-
Name: apiutils.Of("updated-test-model"),
203-
ExternalID: apiutils.Of("updated-ext-456"),
204-
CreateTimeSinceEpoch: saved.GetAttributes().CreateTimeSinceEpoch, // Preserve create time
202+
Name: apiutils.Of("updated-test-model"),
203+
ExternalID: apiutils.Of("updated-ext-456"),
204+
CreateTimeSinceEpoch: saved.GetAttributes().CreateTimeSinceEpoch, // Preserve create time
205205
},
206206
Properties: &[]dbmodels.Properties{
207207
{
@@ -413,7 +413,6 @@ func TestCatalogModelRepository(t *testing.T) {
413413
assert.Contains(t, result, "license")
414414
// Should exclude longer properties
415415
assert.NotContains(t, result, "provider") // "HuggingFace" is > 10 chars
416-
assert.NotContains(t, result, "language")
417416
assert.NotContains(t, result, "tasks")
418417
})
419418
}

0 commit comments

Comments
 (0)