@@ -271,6 +271,53 @@ func TestPollTrainingProgress(t *testing.T) {
271271 wantStatus : nil ,
272272 wantErr : true ,
273273 },
274+ {
275+ name : "NaN and Infinity values are sanitized to null" ,
276+ responseBody : `{
277+ "progressPercentage": 100,
278+ "estimatedRemainingSeconds": 0,
279+ "currentStep": 50,
280+ "totalSteps": 50,
281+ "currentEpoch": 5,
282+ "totalEpochs": 5,
283+ "trainMetrics": {"loss": 0.0, "grad_norm": NaN, "learning_rate": 1e-06},
284+ "evalMetrics": {"eval_loss": NaN, "eval_runtime": 0.04}
285+ }` ,
286+ responseStatus : http .StatusOK ,
287+ wantStatus : & TrainerStatus {
288+ ProgressPercentage : ptrInt (100 ),
289+ EstimatedRemainingSeconds : ptrInt (0 ),
290+ CurrentStep : ptrInt (50 ),
291+ TotalSteps : ptrInt (50 ),
292+ CurrentEpoch : ptrFloat64 (5 ),
293+ TotalEpochs : ptrInt (5 ),
294+ TrainMetrics : map [string ]interface {}{
295+ "loss" : 0.0 ,
296+ "grad_norm" : nil ,
297+ "learning_rate" : 1e-06 ,
298+ },
299+ EvalMetrics : map [string ]interface {}{
300+ "eval_loss" : nil ,
301+ "eval_runtime" : 0.04 ,
302+ },
303+ },
304+ wantErr : false ,
305+ },
306+ {
307+ name : "negative Infinity is sanitized to null" ,
308+ responseBody : `{
309+ "progressPercentage": 50,
310+ "trainMetrics": {"loss": -Infinity}
311+ }` ,
312+ responseStatus : http .StatusOK ,
313+ wantStatus : & TrainerStatus {
314+ ProgressPercentage : ptrInt (50 ),
315+ TrainMetrics : map [string ]interface {}{
316+ "loss" : nil ,
317+ },
318+ },
319+ wantErr : false ,
320+ },
274321 }
275322
276323 for _ , tt := range tests {
@@ -537,6 +584,59 @@ func TestGetPrimaryPod(t *testing.T) {
537584 }
538585}
539586
587+ func TestSanitizeJSON (t * testing.T ) {
588+ tests := []struct {
589+ name string
590+ input string
591+ want string
592+ }{
593+ {
594+ name : "NaN replaced with null" ,
595+ input : `{"grad_norm": NaN, "loss": 0.5}` ,
596+ want : `{"grad_norm": null, "loss": 0.5}` ,
597+ },
598+ {
599+ name : "Infinity replaced with null" ,
600+ input : `{"loss": Infinity}` ,
601+ want : `{"loss": null}` ,
602+ },
603+ {
604+ name : "negative Infinity replaced with null" ,
605+ input : `{"loss": -Infinity}` ,
606+ want : `{"loss": null}` ,
607+ },
608+ {
609+ name : "multiple NaN values" ,
610+ input : `{"grad_norm": NaN, "eval_loss": NaN, "loss": 0.1}` ,
611+ want : `{"grad_norm": null, "eval_loss": null, "loss": 0.1}` ,
612+ },
613+ {
614+ name : "no special values unchanged" ,
615+ input : `{"loss": 0.5, "step": 100}` ,
616+ want : `{"loss": 0.5, "step": 100}` ,
617+ },
618+ {
619+ name : "NaN in string value not replaced" ,
620+ input : `{"name": "NaNcy", "loss": NaN}` ,
621+ want : `{"name": "NaNcy", "loss": null}` ,
622+ },
623+ }
624+
625+ for _ , tt := range tests {
626+ t .Run (tt .name , func (t * testing.T ) {
627+ got := string (sanitizeJSON ([]byte (tt .input )))
628+ if got != tt .want {
629+ t .Errorf ("sanitizeJSON() = %q, want %q" , got , tt .want )
630+ }
631+ // Verify result is valid JSON
632+ var m map [string ]interface {}
633+ if err := json .Unmarshal ([]byte (got ), & m ); err != nil {
634+ t .Errorf ("sanitizeJSON() produced invalid JSON: %v" , err )
635+ }
636+ })
637+ }
638+ }
639+
540640func TestCleanInvalidMetrics (t * testing.T ) {
541641 tests := []struct {
542642 name string
@@ -1267,6 +1367,39 @@ func TestCaptureMetricsFromTerminationMessage(t *testing.T) {
12671367 wantErr : true ,
12681368 wantNil : true ,
12691369 },
1370+ {
1371+ name : "NaN and Infinity in termination message are sanitized" ,
1372+ pod : & corev1.Pod {
1373+ Status : corev1.PodStatus {
1374+ ContainerStatuses : []corev1.ContainerStatus {
1375+ {
1376+ Name : "node" ,
1377+ State : corev1.ContainerState {
1378+ Terminated : & corev1.ContainerStateTerminated {
1379+ Message : `{"progressPercentage": 100, "estimatedRemainingSeconds": 0, "currentStep": 50, "totalSteps": 50, "trainMetrics": {"loss": 0.0, "grad_norm": NaN}, "evalMetrics": {"eval_loss": NaN, "eval_runtime": 0.04}}` ,
1380+ },
1381+ },
1382+ },
1383+ },
1384+ },
1385+ },
1386+ wantErr : false ,
1387+ wantNil : false ,
1388+ checkFunc : func (t * testing.T , status * AnnotationStatus ) {
1389+ if status .ProgressPercentage == nil || * status .ProgressPercentage != 100 {
1390+ t .Errorf ("ProgressPercentage = %v, want 100" , status .ProgressPercentage )
1391+ }
1392+ if status .TrainMetrics == nil {
1393+ t .Fatal ("TrainMetrics is nil" )
1394+ }
1395+ if status .TrainMetrics ["grad_norm" ] != nil {
1396+ t .Errorf ("grad_norm should be nil (sanitized from NaN), got %v" , status .TrainMetrics ["grad_norm" ])
1397+ }
1398+ if status .EvalMetrics ["eval_loss" ] != nil {
1399+ t .Errorf ("eval_loss should be nil (sanitized from NaN), got %v" , status .EvalMetrics ["eval_loss" ])
1400+ }
1401+ },
1402+ },
12701403 {
12711404 name : "invalid JSON in termination message" ,
12721405 pod : & corev1.Pod {
0 commit comments