Skip to content

Commit 9e4101b

Browse files
fix: adjust progression e2e tests
Signed-off-by: abhijeet-dhumal <abhijeetdhumal652@gmail.com>
1 parent 40ae060 commit 9e4101b

1 file changed

Lines changed: 42 additions & 42 deletions

File tree

pkg/rhai/progression/progression_test.go

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,9 @@ func TestPollTrainingProgress(t *testing.T) {
241241
wantStatus: &TrainerStatus{
242242
ProgressPercentage: ptrInt(45),
243243
EstimatedRemainingSeconds: ptrInt(9855),
244-
CurrentStep: 4530,
244+
CurrentStep: ptrInt(4530),
245245
TotalSteps: ptrInt(10000),
246-
CurrentEpoch: 2,
246+
CurrentEpoch: ptrFloat64(2),
247247
TotalEpochs: ptrInt(5),
248248
TrainMetrics: map[string]interface{}{
249249
"loss": 0.234,
@@ -334,9 +334,9 @@ func TestUpdateTrainerStatusAnnotation(t *testing.T) {
334334
status: &AnnotationStatus{
335335
ProgressPercentage: ptrInt(10),
336336
EstimatedRemainingSeconds: ptrInt(9000),
337-
CurrentStep: 100,
337+
CurrentStep: ptrInt(100),
338338
TotalSteps: ptrInt(1000),
339-
CurrentEpoch: 1,
339+
CurrentEpoch: ptrFloat64(1),
340340
TotalEpochs: ptrInt(10),
341341
LastUpdatedTime: "2025-11-18T10:00:00Z",
342342
},
@@ -352,9 +352,9 @@ func TestUpdateTrainerStatusAnnotation(t *testing.T) {
352352
},
353353
status: &AnnotationStatus{
354354
ProgressPercentage: ptrInt(50),
355-
CurrentStep: 500,
355+
CurrentStep: ptrInt(500),
356356
TotalSteps: ptrInt(1000),
357-
CurrentEpoch: 5,
357+
CurrentEpoch: ptrFloat64(5),
358358
TrainMetrics: map[string]interface{}{
359359
"loss": 0.5,
360360
"learning_rate": 0.001,
@@ -542,9 +542,9 @@ func TestCleanInvalidMetrics(t *testing.T) {
542542
status: &TrainerStatus{
543543
ProgressPercentage: ptrInt(50),
544544
EstimatedRemainingSeconds: ptrInt(1200),
545-
CurrentStep: 500,
545+
CurrentStep: ptrInt(500),
546546
TotalSteps: ptrInt(1000),
547-
CurrentEpoch: 2,
547+
CurrentEpoch: ptrFloat64(2),
548548
TotalEpochs: ptrInt(5),
549549
},
550550
verify: func(t *testing.T, s *TrainerStatus) {
@@ -554,8 +554,8 @@ func TestCleanInvalidMetrics(t *testing.T) {
554554
if s.EstimatedRemainingSeconds == nil || *s.EstimatedRemainingSeconds != 1200 {
555555
t.Errorf("EstimatedRemainingSeconds should remain 1200, got %v", s.EstimatedRemainingSeconds)
556556
}
557-
if s.CurrentStep != 500 {
558-
t.Errorf("CurrentStep should remain 500, got %d", s.CurrentStep)
557+
if s.CurrentStep == nil || *s.CurrentStep != 500 {
558+
t.Errorf("CurrentStep should remain 500, got %v", s.CurrentStep)
559559
}
560560
if s.TotalSteps == nil || *s.TotalSteps != 1000 {
561561
t.Errorf("TotalSteps should remain 1000, got %v", s.TotalSteps)
@@ -566,31 +566,31 @@ func TestCleanInvalidMetrics(t *testing.T) {
566566
name: "negative progress percentage removed",
567567
status: &TrainerStatus{
568568
ProgressPercentage: ptrInt(-1),
569-
CurrentStep: 500,
569+
CurrentStep: ptrInt(500),
570570
},
571571
verify: func(t *testing.T, s *TrainerStatus) {
572572
if s.ProgressPercentage != nil {
573573
t.Errorf("ProgressPercentage should be nil (removed), got %v", *s.ProgressPercentage)
574574
}
575-
if s.CurrentStep != 500 {
576-
t.Errorf("CurrentStep should remain unchanged, got %d", s.CurrentStep)
575+
if s.CurrentStep == nil || *s.CurrentStep != 500 {
576+
t.Errorf("CurrentStep should remain unchanged, got %v", s.CurrentStep)
577577
}
578578
},
579579
},
580580
{
581581
name: "progress percentage over 100 removed",
582582
status: &TrainerStatus{
583583
ProgressPercentage: ptrInt(150),
584-
CurrentStep: 500,
584+
CurrentStep: ptrInt(500),
585585
TotalSteps: ptrInt(1000),
586586
},
587587
verify: func(t *testing.T, s *TrainerStatus) {
588588
if s.ProgressPercentage != nil {
589589
t.Errorf("ProgressPercentage should be nil (removed), got %v", *s.ProgressPercentage)
590590
}
591591
// Other fields should remain
592-
if s.CurrentStep != 500 {
593-
t.Errorf("CurrentStep should remain 500, got %d", s.CurrentStep)
592+
if s.CurrentStep == nil || *s.CurrentStep != 500 {
593+
t.Errorf("CurrentStep should remain 500, got %v", s.CurrentStep)
594594
}
595595
if s.TotalSteps == nil || *s.TotalSteps != 1000 {
596596
t.Errorf("TotalSteps should remain 1000, got %v", s.TotalSteps)
@@ -601,11 +601,11 @@ func TestCleanInvalidMetrics(t *testing.T) {
601601
name: "negative current step clamped to 0",
602602
status: &TrainerStatus{
603603
ProgressPercentage: ptrInt(50),
604-
CurrentStep: -1,
604+
CurrentStep: ptrInt(-1),
605605
},
606606
verify: func(t *testing.T, s *TrainerStatus) {
607-
if s.CurrentStep != 0 {
608-
t.Errorf("CurrentStep should be clamped to 0, got %d", s.CurrentStep)
607+
if s.CurrentStep == nil || *s.CurrentStep != 0 {
608+
t.Errorf("CurrentStep should be clamped to 0, got %v", s.CurrentStep)
609609
}
610610
// ProgressPercentage should remain valid
611611
if s.ProgressPercentage == nil || *s.ProgressPercentage != 50 {
@@ -617,7 +617,7 @@ func TestCleanInvalidMetrics(t *testing.T) {
617617
name: "negative total steps removed",
618618
status: &TrainerStatus{
619619
ProgressPercentage: ptrInt(50),
620-
CurrentStep: 500,
620+
CurrentStep: ptrInt(500),
621621
TotalSteps: ptrInt(-100),
622622
},
623623
verify: func(t *testing.T, s *TrainerStatus) {
@@ -633,7 +633,7 @@ func TestCleanInvalidMetrics(t *testing.T) {
633633
{
634634
name: "zero total steps preserved (valid for indefinite training)",
635635
status: &TrainerStatus{
636-
CurrentStep: 500,
636+
CurrentStep: ptrInt(500),
637637
TotalSteps: ptrInt(0),
638638
},
639639
verify: func(t *testing.T, s *TrainerStatus) {
@@ -646,27 +646,27 @@ func TestCleanInvalidMetrics(t *testing.T) {
646646
name: "negative current epoch clamped to 0",
647647
status: &TrainerStatus{
648648
ProgressPercentage: ptrInt(50),
649-
CurrentStep: 500,
650-
CurrentEpoch: -5,
649+
CurrentStep: ptrInt(500),
650+
CurrentEpoch: ptrFloat64(-5),
651651
},
652652
verify: func(t *testing.T, s *TrainerStatus) {
653-
if s.CurrentEpoch != 0 {
654-
t.Errorf("CurrentEpoch should be clamped to 0, got %d", s.CurrentEpoch)
653+
if s.CurrentEpoch == nil || *s.CurrentEpoch != 0 {
654+
t.Errorf("CurrentEpoch should be clamped to 0, got %v", s.CurrentEpoch)
655655
}
656656
},
657657
},
658658
{
659659
name: "negative total epochs removed",
660660
status: &TrainerStatus{
661-
CurrentEpoch: 2,
661+
CurrentEpoch: ptrFloat64(2),
662662
TotalEpochs: ptrInt(-3),
663663
},
664664
verify: func(t *testing.T, s *TrainerStatus) {
665665
if s.TotalEpochs != nil {
666666
t.Errorf("TotalEpochs should be nil (removed), got %v", *s.TotalEpochs)
667667
}
668-
if s.CurrentEpoch != 2 {
669-
t.Errorf("CurrentEpoch should remain 2, got %d", s.CurrentEpoch)
668+
if s.CurrentEpoch == nil || *s.CurrentEpoch != 2 {
669+
t.Errorf("CurrentEpoch should remain 2, got %v", s.CurrentEpoch)
670670
}
671671
},
672672
},
@@ -689,8 +689,8 @@ func TestCleanInvalidMetrics(t *testing.T) {
689689
name: "nil progress percentage preserved",
690690
status: &TrainerStatus{
691691
ProgressPercentage: nil,
692-
CurrentStep: 500,
693-
CurrentEpoch: 1,
692+
CurrentStep: ptrInt(500),
693+
CurrentEpoch: ptrFloat64(1),
694694
},
695695
verify: func(t *testing.T, s *TrainerStatus) {
696696
if s.ProgressPercentage != nil {
@@ -703,9 +703,9 @@ func TestCleanInvalidMetrics(t *testing.T) {
703703
status: &TrainerStatus{
704704
ProgressPercentage: ptrInt(150),
705705
EstimatedRemainingSeconds: ptrInt(-50),
706-
CurrentStep: -10,
706+
CurrentStep: ptrInt(-10),
707707
TotalSteps: ptrInt(-100),
708-
CurrentEpoch: -2,
708+
CurrentEpoch: ptrFloat64(-2),
709709
TotalEpochs: ptrInt(5),
710710
},
711711
verify: func(t *testing.T, s *TrainerStatus) {
@@ -715,14 +715,14 @@ func TestCleanInvalidMetrics(t *testing.T) {
715715
if s.EstimatedRemainingSeconds != nil {
716716
t.Errorf("EstimatedRemainingSeconds should be nil, got %v", *s.EstimatedRemainingSeconds)
717717
}
718-
if s.CurrentStep != 0 {
719-
t.Errorf("CurrentStep should be 0, got %d", s.CurrentStep)
718+
if s.CurrentStep == nil || *s.CurrentStep != 0 {
719+
t.Errorf("CurrentStep should be 0, got %v", s.CurrentStep)
720720
}
721721
if s.TotalSteps != nil {
722722
t.Errorf("TotalSteps should be nil, got %v", *s.TotalSteps)
723723
}
724-
if s.CurrentEpoch != 0 {
725-
t.Errorf("CurrentEpoch should be 0, got %d", s.CurrentEpoch)
724+
if s.CurrentEpoch == nil || *s.CurrentEpoch != 0 {
725+
t.Errorf("CurrentEpoch should be 0, got %v", s.CurrentEpoch)
726726
}
727727
// TotalEpochs was valid, should remain
728728
if s.TotalEpochs == nil || *s.TotalEpochs != 5 {
@@ -751,9 +751,9 @@ func TestToAnnotationStatus(t *testing.T) {
751751
input: &TrainerStatus{
752752
ProgressPercentage: ptrInt(45),
753753
EstimatedRemainingSeconds: ptrInt(3665),
754-
CurrentStep: 4500,
754+
CurrentStep: ptrInt(4500),
755755
TotalSteps: ptrInt(10000),
756-
CurrentEpoch: 2,
756+
CurrentEpoch: ptrFloat64(2),
757757
TotalEpochs: ptrInt(5),
758758
TrainMetrics: map[string]interface{}{
759759
"loss": 0.234,
@@ -776,8 +776,8 @@ func TestToAnnotationStatus(t *testing.T) {
776776
if result.LastUpdatedTime == "" {
777777
t.Error("LastUpdatedTime should be set")
778778
}
779-
if result.CurrentStep != 4500 {
780-
t.Errorf("CurrentStep = %d, want 4500", result.CurrentStep)
779+
if result.CurrentStep == nil || *result.CurrentStep != 4500 {
780+
t.Errorf("CurrentStep = %v, want 4500", result.CurrentStep)
781781
}
782782
if len(result.TrainMetrics) != 2 {
783783
t.Errorf("TrainMetrics length = %d, want 2", len(result.TrainMetrics))
@@ -789,7 +789,7 @@ func TestToAnnotationStatus(t *testing.T) {
789789
input: &TrainerStatus{
790790
ProgressPercentage: ptrInt(50),
791791
EstimatedRemainingSeconds: nil,
792-
CurrentStep: 500,
792+
CurrentStep: ptrInt(500),
793793
},
794794
verify: func(t *testing.T, result *AnnotationStatus) {
795795
if result.EstimatedRemainingSeconds != nil {
@@ -805,7 +805,7 @@ func TestToAnnotationStatus(t *testing.T) {
805805
input: &TrainerStatus{
806806
ProgressPercentage: ptrInt(100),
807807
EstimatedRemainingSeconds: ptrInt(0),
808-
CurrentStep: 1000,
808+
CurrentStep: ptrInt(1000),
809809
},
810810
verify: func(t *testing.T, result *AnnotationStatus) {
811811
if result.EstimatedRemainingTimeSummary != "" {

0 commit comments

Comments
 (0)