Skip to content

Commit 0f22576

Browse files
sutaakarclaude
andcommitted
Sanitize NaN/Infinity in metrics JSON before parsing
Python's json.dumps() allows NaN and Infinity by default, but these are not valid JSON per the spec. When training metrics contain these values (e.g., grad_norm when loss=0, eval_loss with bf16 on ROCm), Go's json.Unmarshal fails, causing the operator to never capture training progress. Add sanitizeJSON() to replace NaN/Infinity with null before parsing in both PollTrainingProgress and CaptureMetricsFromTerminationMessage. Ref: RHOAIENG-56898 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7d53779 commit 0f22576

2 files changed

Lines changed: 148 additions & 2 deletions

File tree

pkg/rhai/progression/progression.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"fmt"
2323
"io"
2424
"net/http"
25+
"regexp"
2526
"strconv"
2627
"sync"
2728
"time"
@@ -283,7 +284,7 @@ func PollTrainingProgress(ctx context.Context, pod *corev1.Pod, metricsPort stri
283284
}
284285

285286
var status TrainerStatus
286-
if err := json.Unmarshal(body, &status); err != nil {
287+
if err := json.Unmarshal(sanitizeJSON(body), &status); err != nil {
287288
return nil, fmt.Errorf("failed to parse metrics JSON: %w", err)
288289
}
289290

@@ -292,6 +293,18 @@ func PollTrainingProgress(ctx context.Context, pod *corev1.Pod, metricsPort stri
292293
return &status, nil
293294
}
294295

296+
// sanitizeJSON replaces NaN and Infinity values with null in JSON strings.
297+
// Python's json.dumps() allows these by default, but they are not valid JSON,
298+
// causing Go's json.Unmarshal to fail.
299+
var nanPattern = regexp.MustCompile(`:\s*NaN\b`)
300+
var infPattern = regexp.MustCompile(`:\s*-?Infinity\b`)
301+
302+
func sanitizeJSON(data []byte) []byte {
303+
s := nanPattern.ReplaceAll(data, []byte(": null"))
304+
s = infPattern.ReplaceAll(s, []byte(": null"))
305+
return s
306+
}
307+
295308
// cleanInvalidMetrics removes invalid values while keeping valid fields.
296309
// Defense against custom implementations, malformed requests, or edge cases.
297310
func cleanInvalidMetrics(m *TrainerStatus) {
@@ -495,7 +508,7 @@ func CaptureMetricsFromTerminationMessage(ctx context.Context, pod *corev1.Pod)
495508

496509
// Parse JSON from termination message
497510
var status AnnotationStatus
498-
if err := json.Unmarshal([]byte(message), &status); err != nil {
511+
if err := json.Unmarshal(sanitizeJSON([]byte(message)), &status); err != nil {
499512
return nil, fmt.Errorf("failed to parse termination message JSON: %w", err)
500513
}
501514

pkg/rhai/progression/progression_test.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,53 @@ func TestPollTrainingProgress(t *testing.T) {
271271
wantStatus: nil,
272272
wantErr: true,
273273
},
274+
{
275+
name: "NaN and Infinity values are sanitized to null",
276+
responseBody: `{
277+
"progressPercentage": 100,
278+
"estimatedRemainingSeconds": 0,
279+
"currentStep": 50,
280+
"totalSteps": 50,
281+
"currentEpoch": 5,
282+
"totalEpochs": 5,
283+
"trainMetrics": {"loss": 0.0, "grad_norm": NaN, "learning_rate": 1e-06},
284+
"evalMetrics": {"eval_loss": NaN, "eval_runtime": 0.04}
285+
}`,
286+
responseStatus: http.StatusOK,
287+
wantStatus: &TrainerStatus{
288+
ProgressPercentage: ptrInt(100),
289+
EstimatedRemainingSeconds: ptrInt(0),
290+
CurrentStep: ptrInt(50),
291+
TotalSteps: ptrInt(50),
292+
CurrentEpoch: ptrFloat64(5),
293+
TotalEpochs: ptrInt(5),
294+
TrainMetrics: map[string]interface{}{
295+
"loss": 0.0,
296+
"grad_norm": nil,
297+
"learning_rate": 1e-06,
298+
},
299+
EvalMetrics: map[string]interface{}{
300+
"eval_loss": nil,
301+
"eval_runtime": 0.04,
302+
},
303+
},
304+
wantErr: false,
305+
},
306+
{
307+
name: "negative Infinity is sanitized to null",
308+
responseBody: `{
309+
"progressPercentage": 50,
310+
"trainMetrics": {"loss": -Infinity}
311+
}`,
312+
responseStatus: http.StatusOK,
313+
wantStatus: &TrainerStatus{
314+
ProgressPercentage: ptrInt(50),
315+
TrainMetrics: map[string]interface{}{
316+
"loss": nil,
317+
},
318+
},
319+
wantErr: false,
320+
},
274321
}
275322

276323
for _, tt := range tests {
@@ -537,6 +584,59 @@ func TestGetPrimaryPod(t *testing.T) {
537584
}
538585
}
539586

587+
func TestSanitizeJSON(t *testing.T) {
588+
tests := []struct {
589+
name string
590+
input string
591+
want string
592+
}{
593+
{
594+
name: "NaN replaced with null",
595+
input: `{"grad_norm": NaN, "loss": 0.5}`,
596+
want: `{"grad_norm": null, "loss": 0.5}`,
597+
},
598+
{
599+
name: "Infinity replaced with null",
600+
input: `{"loss": Infinity}`,
601+
want: `{"loss": null}`,
602+
},
603+
{
604+
name: "negative Infinity replaced with null",
605+
input: `{"loss": -Infinity}`,
606+
want: `{"loss": null}`,
607+
},
608+
{
609+
name: "multiple NaN values",
610+
input: `{"grad_norm": NaN, "eval_loss": NaN, "loss": 0.1}`,
611+
want: `{"grad_norm": null, "eval_loss": null, "loss": 0.1}`,
612+
},
613+
{
614+
name: "no special values unchanged",
615+
input: `{"loss": 0.5, "step": 100}`,
616+
want: `{"loss": 0.5, "step": 100}`,
617+
},
618+
{
619+
name: "NaN in string value not replaced",
620+
input: `{"name": "NaNcy", "loss": NaN}`,
621+
want: `{"name": "NaNcy", "loss": null}`,
622+
},
623+
}
624+
625+
for _, tt := range tests {
626+
t.Run(tt.name, func(t *testing.T) {
627+
got := string(sanitizeJSON([]byte(tt.input)))
628+
if got != tt.want {
629+
t.Errorf("sanitizeJSON() = %q, want %q", got, tt.want)
630+
}
631+
// Verify result is valid JSON
632+
var m map[string]interface{}
633+
if err := json.Unmarshal([]byte(got), &m); err != nil {
634+
t.Errorf("sanitizeJSON() produced invalid JSON: %v", err)
635+
}
636+
})
637+
}
638+
}
639+
540640
func TestCleanInvalidMetrics(t *testing.T) {
541641
tests := []struct {
542642
name string
@@ -1267,6 +1367,39 @@ func TestCaptureMetricsFromTerminationMessage(t *testing.T) {
12671367
wantErr: true,
12681368
wantNil: true,
12691369
},
1370+
{
1371+
name: "NaN and Infinity in termination message are sanitized",
1372+
pod: &corev1.Pod{
1373+
Status: corev1.PodStatus{
1374+
ContainerStatuses: []corev1.ContainerStatus{
1375+
{
1376+
Name: "node",
1377+
State: corev1.ContainerState{
1378+
Terminated: &corev1.ContainerStateTerminated{
1379+
Message: `{"progressPercentage": 100, "estimatedRemainingSeconds": 0, "currentStep": 50, "totalSteps": 50, "trainMetrics": {"loss": 0.0, "grad_norm": NaN}, "evalMetrics": {"eval_loss": NaN, "eval_runtime": 0.04}}`,
1380+
},
1381+
},
1382+
},
1383+
},
1384+
},
1385+
},
1386+
wantErr: false,
1387+
wantNil: false,
1388+
checkFunc: func(t *testing.T, status *AnnotationStatus) {
1389+
if status.ProgressPercentage == nil || *status.ProgressPercentage != 100 {
1390+
t.Errorf("ProgressPercentage = %v, want 100", status.ProgressPercentage)
1391+
}
1392+
if status.TrainMetrics == nil {
1393+
t.Fatal("TrainMetrics is nil")
1394+
}
1395+
if status.TrainMetrics["grad_norm"] != nil {
1396+
t.Errorf("grad_norm should be nil (sanitized from NaN), got %v", status.TrainMetrics["grad_norm"])
1397+
}
1398+
if status.EvalMetrics["eval_loss"] != nil {
1399+
t.Errorf("eval_loss should be nil (sanitized from NaN), got %v", status.EvalMetrics["eval_loss"])
1400+
}
1401+
},
1402+
},
12701403
{
12711404
name: "invalid JSON in termination message",
12721405
pod: &corev1.Pod{

0 commit comments

Comments
 (0)