Skip to content

Commit 4bee1eb

Browse files
committed
Remove backward flow checks from grug contracts
1 parent f2a8f65 commit 4bee1eb

1 file changed

Lines changed: 1 addition & 48 deletions

File tree

tests/test_grug_variant_contracts.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from levanter.data.text.examples import GrugLmExample
3333
from levanter.distributed import DistributedConfig
3434
from levanter.grug.attention import AttentionMask as GrugAttentionMask
35-
from levanter.analysis.backward_flow import BackwardFlowConfig
3635
from levanter.tracker.json_logger import JsonLoggerConfig
3736
from levanter.trainer import TrainerConfig
3837

@@ -230,7 +229,7 @@ def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path):
230229
trainer=train_module.GrugTrainerConfig(
231230
trainer=trainer_config,
232231
log_every=1,
233-
backward_flow=BackwardFlowConfig(interval=1),
232+
backward_flow=train_module.BackwardFlowConfig(interval=0),
234233
),
235234
eval=train_module.GrugEvalConfig(
236235
eval_batch_size=1,
@@ -266,49 +265,3 @@ def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path):
266265
]
267266
for key in required_keys:
268267
assert key in summary
269-
assert any(key.startswith("backward_flow/") for key in summary)
270-
271-
required_backward_flow_keys = [
272-
"backward_flow/Transformer/token_embed/out_gradient_rms",
273-
"backward_flow/Transformer/block_0/Block/resid_in/out_gradient_rms",
274-
"backward_flow/Transformer/block_0/Block/CausalSelfAttention/in_gradient_rms",
275-
"backward_flow/Transformer/block_0/Block/CausalSelfAttention/out_gradient_rms",
276-
"backward_flow/Transformer/block_0/Block/resid_post_attn/out_gradient_rms",
277-
"backward_flow/Transformer/block_0/Block/MLP/in_gradient_rms",
278-
"backward_flow/Transformer/block_0/Block/MLP/out_gradient_rms",
279-
"backward_flow/Transformer/block_0/Block/resid_out/out_gradient_rms",
280-
"backward_flow/Transformer/final_norm/out_gradient_rms",
281-
"backward_flow/Transformer/token_embed/out_gradient_rms_scaled",
282-
"backward_flow/Transformer/block_0/Block/resid_in/out_gradient_rms_scaled",
283-
"backward_flow/Transformer/block_0/Block/CausalSelfAttention/in_gradient_rms_scaled",
284-
"backward_flow/Transformer/block_0/Block/CausalSelfAttention/in_gradient_max_abs_scaled",
285-
"backward_flow/Transformer/block_0/Block/MLP/in_gradient_rms_scaled",
286-
"backward_flow/compute_step_duration",
287-
"backward_flow/artifact_write_duration",
288-
]
289-
for key in required_backward_flow_keys:
290-
assert key in summary
291-
removed_backward_flow_keys = [
292-
"backward_flow/baseline_step_duration_ema",
293-
"backward_flow/estimated_compute_overhead",
294-
"backward_flow/estimated_compute_overhead_ratio",
295-
]
296-
for key in removed_backward_flow_keys:
297-
assert key not in summary
298-
299-
backward_flow_artifact = variant_tmp / "logs/test-grug-base-metrics/artifacts/backward_flow/step_0000000.html"
300-
artifact_html = backward_flow_artifact.read_text(encoding="utf-8")
301-
assert "class='flow-plate'" in artifact_html
302-
assert "Transformer/block_0" in artifact_html
303-
assert "resid_in" in artifact_html
304-
assert "resid_out" in artifact_html
305-
assert "<table" not in artifact_html
306-
assert "scaled grad RMS" in artifact_html
307-
assert "scaled max abs grad" in artifact_html
308-
assert "max abs grad" in artifact_html
309-
resid_skip_edge = "data-source='Transformer/block_0/Block/resid_post_attn'"
310-
resid_skip_edge += " data-target='Transformer/block_0/Block/resid_out'"
311-
assert resid_skip_edge in artifact_html
312-
final_edge = "data-source='Transformer/block_1/Block/resid_out'"
313-
final_edge += " data-target='Transformer/final_norm'"
314-
assert final_edge in artifact_html

0 commit comments

Comments
 (0)