|
54 | 54 |
|
55 | 55 | logger = logging.getLogger(__name__) |
56 | 56 | _BACKWARD_FLOW_METRICS_KEY = "_backward_flow" |
57 | | -_BACKWARD_FLOW_BASELINE_DURATION_EMA_ALPHA = 0.1 |
58 | 57 | _BACKWARD_FLOW_DEFAULT_INTERVAL = 50 |
59 | 58 |
|
60 | 59 |
|
@@ -540,7 +539,6 @@ def _init_state(model_rng): |
540 | 539 |
|
541 | 540 | last_loss: float | jax.Array = 0.0 |
542 | 541 | last_step_duration = 0.0 |
543 | | - non_backward_flow_step_duration_ema: float | None = None |
544 | 542 | backward_flow_graph: BackwardFlowGraph | None = None |
545 | 543 |
|
546 | 544 | # Main optimization loop. |
@@ -575,22 +573,6 @@ def _init_state(model_rng): |
575 | 573 | backward_flow_timing_metrics = None |
576 | 574 | if compute_backward_flow: |
577 | 575 | backward_flow_timing_metrics = {"backward_flow/compute_step_duration": duration} |
578 | | - if non_backward_flow_step_duration_ema is not None: |
579 | | - baseline_duration = non_backward_flow_step_duration_ema |
580 | | - backward_flow_timing_metrics.update( |
581 | | - { |
582 | | - "backward_flow/baseline_step_duration_ema": baseline_duration, |
583 | | - "backward_flow/estimated_compute_overhead": duration - baseline_duration, |
584 | | - "backward_flow/estimated_compute_overhead_ratio": duration / baseline_duration, |
585 | | - } |
586 | | - ) |
587 | | - elif non_backward_flow_step_duration_ema is None: |
588 | | - non_backward_flow_step_duration_ema = duration |
589 | | - else: |
590 | | - alpha = _BACKWARD_FLOW_BASELINE_DURATION_EMA_ALPHA |
591 | | - non_backward_flow_step_duration_ema = (1.0 - alpha) * non_backward_flow_step_duration_ema + ( |
592 | | - alpha * duration |
593 | | - ) |
594 | 576 |
|
595 | 577 | hook_start = time.perf_counter() |
596 | 578 | with jax.profiler.TraceAnnotation("callbacks"): |
|
0 commit comments