|
32 | 32 | from levanter.data.text.examples import GrugLmExample |
33 | 33 | from levanter.distributed import DistributedConfig |
34 | 34 | from levanter.grug.attention import AttentionMask as GrugAttentionMask |
35 | | -from levanter.analysis.backward_flow import BackwardFlowConfig |
36 | 35 | from levanter.tracker.json_logger import JsonLoggerConfig |
37 | 36 | from levanter.trainer import TrainerConfig |
38 | 37 |
|
@@ -230,7 +229,7 @@ def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path): |
230 | 229 | trainer=train_module.GrugTrainerConfig( |
231 | 230 | trainer=trainer_config, |
232 | 231 | log_every=1, |
233 | | - backward_flow=BackwardFlowConfig(interval=1), |
| 232 | + backward_flow=train_module.BackwardFlowConfig(interval=0), |
234 | 233 | ), |
235 | 234 | eval=train_module.GrugEvalConfig( |
236 | 235 | eval_batch_size=1, |
@@ -266,49 +265,3 @@ def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path): |
266 | 265 | ] |
267 | 266 | for key in required_keys: |
268 | 267 | assert key in summary |
269 | | - assert any(key.startswith("backward_flow/") for key in summary) |
270 | | - |
271 | | - required_backward_flow_keys = [ |
272 | | - "backward_flow/Transformer/token_embed/out_gradient_rms", |
273 | | - "backward_flow/Transformer/block_0/Block/resid_in/out_gradient_rms", |
274 | | - "backward_flow/Transformer/block_0/Block/CausalSelfAttention/in_gradient_rms", |
275 | | - "backward_flow/Transformer/block_0/Block/CausalSelfAttention/out_gradient_rms", |
276 | | - "backward_flow/Transformer/block_0/Block/resid_post_attn/out_gradient_rms", |
277 | | - "backward_flow/Transformer/block_0/Block/MLP/in_gradient_rms", |
278 | | - "backward_flow/Transformer/block_0/Block/MLP/out_gradient_rms", |
279 | | - "backward_flow/Transformer/block_0/Block/resid_out/out_gradient_rms", |
280 | | - "backward_flow/Transformer/final_norm/out_gradient_rms", |
281 | | - "backward_flow/Transformer/token_embed/out_gradient_rms_scaled", |
282 | | - "backward_flow/Transformer/block_0/Block/resid_in/out_gradient_rms_scaled", |
283 | | - "backward_flow/Transformer/block_0/Block/CausalSelfAttention/in_gradient_rms_scaled", |
284 | | - "backward_flow/Transformer/block_0/Block/CausalSelfAttention/in_gradient_max_abs_scaled", |
285 | | - "backward_flow/Transformer/block_0/Block/MLP/in_gradient_rms_scaled", |
286 | | - "backward_flow/compute_step_duration", |
287 | | - "backward_flow/artifact_write_duration", |
288 | | - ] |
289 | | - for key in required_backward_flow_keys: |
290 | | - assert key in summary |
291 | | - removed_backward_flow_keys = [ |
292 | | - "backward_flow/baseline_step_duration_ema", |
293 | | - "backward_flow/estimated_compute_overhead", |
294 | | - "backward_flow/estimated_compute_overhead_ratio", |
295 | | - ] |
296 | | - for key in removed_backward_flow_keys: |
297 | | - assert key not in summary |
298 | | - |
299 | | - backward_flow_artifact = variant_tmp / "logs/test-grug-base-metrics/artifacts/backward_flow/step_0000000.html" |
300 | | - artifact_html = backward_flow_artifact.read_text(encoding="utf-8") |
301 | | - assert "class='flow-plate'" in artifact_html |
302 | | - assert "Transformer/block_0" in artifact_html |
303 | | - assert "resid_in" in artifact_html |
304 | | - assert "resid_out" in artifact_html |
305 | | - assert "<table" not in artifact_html |
306 | | - assert "scaled grad RMS" in artifact_html |
307 | | - assert "scaled max abs grad" in artifact_html |
308 | | - assert "max abs grad" in artifact_html |
309 | | - resid_skip_edge = "data-source='Transformer/block_0/Block/resid_post_attn'" |
310 | | - resid_skip_edge += " data-target='Transformer/block_0/Block/resid_out'" |
311 | | - assert resid_skip_edge in artifact_html |
312 | | - final_edge = "data-source='Transformer/block_1/Block/resid_out'" |
313 | | - final_edge += " data-target='Transformer/final_norm'" |
314 | | - assert final_edge in artifact_html |
0 commit comments