Skip to content

Commit 06fc02d

Browse files
authored
feat: add overall_average custom property when loading perf data (#1930)
* feat: add overall_average custom property when loading perf data Signed-off-by: Alessio Pragliola <[email protected]> * chore: improve test coverage Signed-off-by: Alessio Pragliola <[email protected]> --------- Signed-off-by: Alessio Pragliola <[email protected]>
1 parent 4481151 commit 06fc02d

File tree

2 files changed

+106
-5
lines changed

2 files changed

+106
-5
lines changed

catalog/internal/catalog/performance_metrics.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import (
2121
// metadataJSON represents the minimal structure needed from metadata.json files
2222
// Only the ID field is needed to look up existing models
2323
type metadataJSON struct {
24-
ID string `json:"id"` // Maps to model name for lookup
24+
ID string `json:"id"` // Maps to model name for lookup
25+
OverallAccuracy *float64 `json:"overall_accuracy"` // Overall accuracy score for the model
2526
}
2627

2728
// parseMetadataJSON parses JSON data into metadataJSON struct, extracting only the ID field
@@ -298,12 +299,12 @@ func processModelDirectory(dirPath string, modelRepo dbmodels.CatalogModelReposi
298299
glog.V(2).Infof("Found existing model %s with ID %d, processing metrics", metadata.ID, modelID)
299300

300301
// Use batch processing for all artifacts
301-
return processModelArtifactsBatch(dirPath, modelID, metadata.ID, metricsArtifactRepo, metricsArtifactTypeID)
302+
return processModelArtifactsBatch(dirPath, modelID, metadata.ID, metadata.OverallAccuracy, metricsArtifactRepo, metricsArtifactTypeID)
302303
}
303304

304305
// processModelArtifactsBatch processes all metric artifacts for a model in batch
305306
// This reduces DB overhead by parsing, checking, and inserting in optimized phases
306-
func processModelArtifactsBatch(dirPath string, modelID int32, modelName string, metricsArtifactRepo dbmodels.CatalogMetricsArtifactRepository, metricsArtifactTypeID int32) (int, error) {
307+
func processModelArtifactsBatch(dirPath string, modelID int32, modelName string, overallAccuracy *float64, metricsArtifactRepo dbmodels.CatalogMetricsArtifactRepository, metricsArtifactTypeID int32) (int, error) {
307308
// Parse all metrics files
308309
var evaluationRecords []evaluationRecord
309310
var performanceRecords []performanceRecord
@@ -359,7 +360,7 @@ func processModelArtifactsBatch(dirPath string, modelID int32, modelName string,
359360
if len(evaluationRecords) > 0 {
360361
externalID := fmt.Sprintf("accuracy-metrics-model-%d", modelID)
361362
if !existingArtifactsMap[externalID] {
362-
artifact := createAccuracyMetricsArtifact(evaluationRecords, modelID, metricsArtifactTypeID, nil, nil)
363+
artifact := createAccuracyMetricsArtifact(evaluationRecords, modelID, metricsArtifactTypeID, overallAccuracy, nil, nil)
363364
artifactsToInsert = append(artifactsToInsert, artifact)
364365
} else {
365366
glog.V(2).Infof("Accuracy metrics artifact already exists, skipping")
@@ -463,7 +464,7 @@ func parsePerformanceFile(filePath string) ([]performanceRecord, error) {
463464
}
464465

465466
// createAccuracyMetricsArtifact creates a single metrics artifact from all evaluation records
466-
func createAccuracyMetricsArtifact(evalRecords []evaluationRecord, modelID int32, typeID int32, existingID *int32, existingCreateTime *int64) *dbmodels.CatalogMetricsArtifactImpl {
467+
func createAccuracyMetricsArtifact(evalRecords []evaluationRecord, modelID int32, typeID int32, overallAccuracy *float64, existingID *int32, existingCreateTime *int64) *dbmodels.CatalogMetricsArtifactImpl {
467468
artifactName := fmt.Sprintf("accuracy-metrics-model-%d", modelID)
468469
externalID := fmt.Sprintf("accuracy-metrics-model-%d", modelID)
469470

@@ -506,6 +507,14 @@ func createAccuracyMetricsArtifact(evalRecords []evaluationRecord, modelID int32
506507
}
507508
}
508509

510+
// Add overall_average custom property from metadata.json overall_accuracy field
511+
if overallAccuracy != nil {
512+
customProperties = append(customProperties, models.Properties{
513+
Name: "overall_average",
514+
DoubleValue: overallAccuracy,
515+
})
516+
}
517+
509518
// Create the metrics artifact with metricsType set to accuracy-metrics
510519
metricsArtifact := &dbmodels.CatalogMetricsArtifactImpl{
511520
ID: existingID, // Use existing ID if updating

catalog/internal/catalog/performance_metrics_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,98 @@ func TestParseMetadataJSON_OnlyIDMatters(t *testing.T) {
235235
}
236236
}
237237

238+
func TestOverallAccuracyToOverallAverage(t *testing.T) {
239+
t.Run("parse overall_accuracy from metadata", func(t *testing.T) {
240+
tests := []struct {
241+
name string
242+
jsonData string
243+
wantNil bool
244+
wantValue float64
245+
}{
246+
{
247+
name: "overall_accuracy present",
248+
jsonData: `{"id": "model-1", "overall_accuracy": 85.5}`,
249+
wantNil: false,
250+
wantValue: 85.5,
251+
},
252+
{
253+
name: "overall_accuracy is zero",
254+
jsonData: `{"id": "model-2", "overall_accuracy": 0}`,
255+
wantNil: false,
256+
wantValue: 0.0,
257+
},
258+
{
259+
name: "overall_accuracy is null",
260+
jsonData: `{"id": "model-3", "overall_accuracy": null}`,
261+
wantNil: true,
262+
},
263+
{
264+
name: "overall_accuracy missing",
265+
jsonData: `{"id": "model-4"}`,
266+
wantNil: true,
267+
},
268+
}
269+
270+
for _, tt := range tests {
271+
t.Run(tt.name, func(t *testing.T) {
272+
metadata, err := parseMetadataJSON([]byte(tt.jsonData))
273+
if err != nil {
274+
t.Fatalf("parseMetadataJSON() error = %v", err)
275+
}
276+
277+
if tt.wantNil {
278+
if metadata.OverallAccuracy != nil {
279+
t.Errorf("OverallAccuracy = %v, want nil", *metadata.OverallAccuracy)
280+
}
281+
} else {
282+
if metadata.OverallAccuracy == nil {
283+
t.Errorf("OverallAccuracy = nil, want %v", tt.wantValue)
284+
} else if *metadata.OverallAccuracy != tt.wantValue {
285+
t.Errorf("OverallAccuracy = %v, want %v", *metadata.OverallAccuracy, tt.wantValue)
286+
}
287+
}
288+
})
289+
}
290+
})
291+
292+
t.Run("artifact has overall_average when overall_accuracy provided", func(t *testing.T) {
293+
overallAccuracy := 87.5
294+
evalRecords := []evaluationRecord{
295+
{Benchmark: "mmlu", CustomProperties: map[string]interface{}{"score": 90.0}},
296+
}
297+
298+
artifact := createAccuracyMetricsArtifact(evalRecords, 1, 100, &overallAccuracy, nil, nil)
299+
300+
found := false
301+
for _, prop := range *artifact.CustomProperties {
302+
if prop.Name == "overall_average" && prop.DoubleValue != nil {
303+
if *prop.DoubleValue != overallAccuracy {
304+
t.Errorf("overall_average = %v, want %v", *prop.DoubleValue, overallAccuracy)
305+
}
306+
found = true
307+
break
308+
}
309+
}
310+
if !found {
311+
t.Error("overall_average custom property not found in artifact")
312+
}
313+
})
314+
315+
t.Run("artifact has no overall_average when overall_accuracy is nil", func(t *testing.T) {
316+
evalRecords := []evaluationRecord{
317+
{Benchmark: "mmlu", CustomProperties: map[string]interface{}{"score": 90.0}},
318+
}
319+
320+
artifact := createAccuracyMetricsArtifact(evalRecords, 1, 100, nil, nil, nil)
321+
322+
for _, prop := range *artifact.CustomProperties {
323+
if prop.Name == "overall_average" {
324+
t.Error("overall_average should not exist when overall_accuracy is nil")
325+
}
326+
}
327+
})
328+
}
329+
238330
func TestEvaluationRecordUnmarshalJSON(t *testing.T) {
239331
tests := []struct {
240332
name string

0 commit comments

Comments
 (0)