Skip to content

Commit f1ec0a5

Browse files
committed
Remove backward flow duration EMA
1 parent 6ceac44 commit f1ec0a5

2 files changed

Lines changed: 7 additions & 18 deletions

File tree

experiments/grug/base/train.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454

5555
logger = logging.getLogger(__name__)
5656
_BACKWARD_FLOW_METRICS_KEY = "_backward_flow"
57-
_BACKWARD_FLOW_BASELINE_DURATION_EMA_ALPHA = 0.1
5857
_BACKWARD_FLOW_DEFAULT_INTERVAL = 50
5958

6059

@@ -540,7 +539,6 @@ def _init_state(model_rng):
540539

541540
last_loss: float | jax.Array = 0.0
542541
last_step_duration = 0.0
543-
non_backward_flow_step_duration_ema: float | None = None
544542
backward_flow_graph: BackwardFlowGraph | None = None
545543

546544
# Main optimization loop.
@@ -575,22 +573,6 @@ def _init_state(model_rng):
575573
backward_flow_timing_metrics = None
576574
if compute_backward_flow:
577575
backward_flow_timing_metrics = {"backward_flow/compute_step_duration": duration}
578-
if non_backward_flow_step_duration_ema is not None:
579-
baseline_duration = non_backward_flow_step_duration_ema
580-
backward_flow_timing_metrics.update(
581-
{
582-
"backward_flow/baseline_step_duration_ema": baseline_duration,
583-
"backward_flow/estimated_compute_overhead": duration - baseline_duration,
584-
"backward_flow/estimated_compute_overhead_ratio": duration / baseline_duration,
585-
}
586-
)
587-
elif non_backward_flow_step_duration_ema is None:
588-
non_backward_flow_step_duration_ema = duration
589-
else:
590-
alpha = _BACKWARD_FLOW_BASELINE_DURATION_EMA_ALPHA
591-
non_backward_flow_step_duration_ema = (1.0 - alpha) * non_backward_flow_step_duration_ema + (
592-
alpha * duration
593-
)
594576

595577
hook_start = time.perf_counter()
596578
with jax.profiler.TraceAnnotation("callbacks"):

tests/test_grug_variant_contracts.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,13 @@ def test_grug_base_run_emits_expected_metrics_with_json_tracker(tmp_path: Path):
288288
]
289289
for key in required_backward_flow_keys:
290290
assert key in summary
291+
removed_backward_flow_keys = [
292+
"backward_flow/baseline_step_duration_ema",
293+
"backward_flow/estimated_compute_overhead",
294+
"backward_flow/estimated_compute_overhead_ratio",
295+
]
296+
for key in removed_backward_flow_keys:
297+
assert key not in summary
291298

292299
backward_flow_artifact = variant_tmp / "logs/test-grug-base-metrics/artifacts/backward_flow/step_0000000.html"
293300
artifact_html = backward_flow_artifact.read_text(encoding="utf-8")

0 commit comments

Comments
 (0)