@@ -47,6 +47,17 @@ def robust_center(values):
4747 return statistics .median (filtered )
4848
4949
50+ def collect_once (rust_env ):
51+ torch_records = run_json_command ([sys .executable , "demo_visual/demo_benchmark.py" ])
52+ rust_records = run_json_command (
53+ ["cargo" , "run" , "-p" , "demo_visual" , "--bin" , "bench_alignment_ci" , "--release" ],
54+ env = rust_env ,
55+ )
56+ torch_finish = latest_finish (torch_records , "PyTorch" )
57+ rust_finish = latest_finish (rust_records , "RusTorch" )
58+ return torch_finish , rust_finish
59+
60+
5061def main ():
5162 repeat = max (int (os .environ .get ("RUST_TORCH_REPEAT" , "3" )), 1 )
5263 torch_losses = []
@@ -66,19 +77,32 @@ def main():
6677 rust_env .setdefault ("RUSTORCH_GRAD_PATH" , "tensor" )
6778
6879 for _ in range (repeat ):
69- torch_records = run_json_command ([sys .executable , "demo_visual/demo_benchmark.py" ])
70- rust_records = run_json_command (
71- ["cargo" , "run" , "-p" , "demo_visual" , "--bin" , "bench_alignment_ci" , "--release" ],
72- env = rust_env ,
73- )
74- torch_finish = latest_finish (torch_records , "PyTorch" )
75- rust_finish = latest_finish (rust_records , "RusTorch" )
76- torch_losses .append (max (float (torch_finish .get ("final_loss" , 0.0 )), 1e-12 ))
77- rust_losses .append (max (float (rust_finish .get ("final_loss" , 0.0 )), 1e-12 ))
78- torch_accs .append (float (torch_finish .get ("final_accuracy" , 0.0 )))
79- rust_accs .append (float (rust_finish .get ("final_accuracy" , 0.0 )))
80- torch_speeds .append (max (float (torch_finish .get ("avg_speed" , 0.0 )), 1e-12 ))
81- rust_speeds .append (max (float (rust_finish .get ("avg_speed" , 0.0 )), 1e-12 ))
80+ collected = False
81+ for attempt in range (2 ):
82+ active_env = rust_env .copy ()
83+ if attempt == 1 :
84+ active_env ["RUSTORCH_LINEAR_FUSED" ] = "0"
85+ active_env ["RUSTORCH_CPU_MATMUL_STRATEGY" ] = "parallel"
86+ active_env ["RUSTORCH_FUSED_PIPELINE_STRATEGY" ] = "off"
87+ active_env ["RUSTORCH_GRAD_PATH" ] = "tensor"
88+ try :
89+ torch_finish , rust_finish = collect_once (active_env )
90+ torch_losses .append (max (float (torch_finish .get ("final_loss" , 0.0 )), 1e-12 ))
91+ rust_losses .append (max (float (rust_finish .get ("final_loss" , 0.0 )), 1e-12 ))
92+ torch_accs .append (float (torch_finish .get ("final_accuracy" , 0.0 )))
93+ rust_accs .append (float (rust_finish .get ("final_accuracy" , 0.0 )))
94+ torch_speeds .append (max (float (torch_finish .get ("avg_speed" , 0.0 )), 1e-12 ))
95+ rust_speeds .append (max (float (rust_finish .get ("avg_speed" , 0.0 )), 1e-12 ))
96+ collected = True
97+ break
98+ except Exception as e :
99+ print (f"WARN: alignment sample failed on attempt { attempt + 1 } : { e } " , file = sys .stderr )
100+ if not collected :
101+ print ("WARN: skip one alignment sample due to runtime instability" , file = sys .stderr )
102+
103+ if not torch_losses or not rust_losses :
104+ print ("FAIL: no valid alignment samples collected" , file = sys .stderr )
105+ sys .exit (1 )
82106
83107 torch_loss = robust_center (torch_losses )
84108 rust_loss = robust_center (rust_losses )
0 commit comments