22import os
33import json
44import glob
5- import sys
5+ from numpy . testing import assert_allclose
66import test_utils
77from statistics import mean
88
9+ LOSS_RTOL = 0.10
910STEP_TIME_MULT = 0.95
1011E2E_TIME_MULT = 0.95
1112test_dir = os .path .dirname (os .path .abspath (__file__ ))
@@ -22,9 +23,21 @@ def test_loss(baseline_filename):
2223 event_file = os .path .join (results_dir , test_config , "logdir/tensorboard/logdir/events*" )
2324 event_file = glob .glob (event_file )[0 ]
2425 with open (baseline_filepath , "r" ) as baseline_file :
25- end_step = json .load (baseline_file )["end_step" ]
26+ baseline_data = json .load (baseline_file )
27+ loss_expected_values = baseline_data ["loss_values" ]
28+ start_step = baseline_data ["start_step" ]
29+ end_step = baseline_data ["end_step" ]
30+ interval = baseline_data ["step_interval" ]
31+ loss_expected = {step : loss_expected_values [i ] for i , step in enumerate (
32+ range (start_step , end_step + 1 , interval ))}
2633 loss_actual = test_utils .read_maxtext_tb_tag (event_file , loss_summary_name )
27- assert 0 <= loss_actual [end_step ] < 1.8e-3 , f"Loss at final step: { loss_actual [end_step ]} , Expected 0 <= loss < 1.8e-3"
34+ assert loss_expected .keys () == loss_actual .keys (), \
35+ f"Steps at which loss was emitted for run do not match baseline. \
36+ Actual steps: { loss_actual .keys ()} , Baseline steps: { loss_expected .keys ()} "
37+ assert_allclose (list (loss_actual .values ()), list (loss_expected .values ()),
38+ rtol = LOSS_RTOL ,
39+ err_msg = f"Run loss values: { loss_actual .values ()} , \
40+ Baseline loss values: { loss_expected .values ()} " )
2841
2942
3043@pytest .mark .parametrize ("baseline_filename" , os .listdir (baselines_dir ))
0 commit comments