|
| 1 | +use std::{ |
| 2 | + collections::HashMap, |
| 3 | + sync::{Arc, Mutex}, |
| 4 | + time::{Duration, Instant}, |
| 5 | +}; |
| 6 | + |
| 7 | +use tracing::info; |
| 8 | + |
| 9 | +use crate::{ |
| 10 | + error::Result, |
| 11 | + ir::{IRContext, Operator}, |
| 12 | +}; |
| 13 | + |
| 14 | +use super::PassExtension; |
| 15 | + |
| 16 | +/// Pass extension that emits wall-clock timing metrics for each pass run. |
| 17 | +/// |
| 18 | +/// Metrics are logged through `tracing`, so callers can enable or suppress |
| 19 | +/// output with their subscriber configuration. Without an explicit target, the |
| 20 | +/// events use this module path as their target. |
| 21 | +#[derive(Default)] |
| 22 | +pub struct PassProfilingExtension { |
| 23 | + /// Mutable timing state shared across pass invocations. |
| 24 | + state: Mutex<ProfilingPassState>, |
| 25 | +} |
| 26 | + |
| 27 | +/// Internal state accumulated by [`PassProfilingExtension`]. |
| 28 | +#[derive(Default)] |
| 29 | +struct ProfilingPassState { |
| 30 | + /// Passes that have started but not yet finished. |
| 31 | + active_passes: Vec<ActivePass>, |
| 32 | + /// Aggregated per-pass timing totals keyed by pass name. |
| 33 | + totals: HashMap<&'static str, PassProfileMetric>, |
| 34 | +} |
| 35 | + |
| 36 | +/// Timing information for a pass that is currently executing. |
| 37 | +struct ActivePass { |
| 38 | + /// Stable name of the in-flight pass. |
| 39 | + name: &'static str, |
| 40 | + /// Timestamp captured immediately before the pass started. |
| 41 | + started_at: Instant, |
| 42 | +} |
| 43 | + |
| 44 | +/// Aggregated metrics for one pass name. |
| 45 | +#[derive(Default)] |
| 46 | +struct PassProfileMetric { |
| 47 | + /// Number of times the pass has been executed. |
| 48 | + invocations: usize, |
| 49 | + /// Total wall-clock time spent running the pass. |
| 50 | + total_duration: Duration, |
| 51 | +} |
| 52 | + |
| 53 | +/// Formats a duration in milliseconds using 3 significant digits. |
| 54 | +fn format_duration_ms(duration: Duration) -> String { |
| 55 | + let value = duration.as_secs_f64() * 1_000.0; |
| 56 | + if value == 0.0 { |
| 57 | + return "0".to_string(); |
| 58 | + } |
| 59 | + |
| 60 | + let exponent = value.abs().log10().floor() as i32; |
| 61 | + let scale = 10f64.powi(2 - exponent); |
| 62 | + let rounded = (value * scale).round() / scale; |
| 63 | + let decimals = (2 - exponent).max(0) as usize; |
| 64 | + format!("{rounded:.decimals$}") |
| 65 | +} |
| 66 | + |
| 67 | +impl PassExtension for PassProfilingExtension { |
| 68 | + fn before_pass( |
| 69 | + &self, |
| 70 | + pass_name: &'static str, |
| 71 | + _root: &Arc<Operator>, |
| 72 | + _ctx: &IRContext, |
| 73 | + ) -> Result<()> { |
| 74 | + self.state.lock().unwrap().active_passes.push(ActivePass { |
| 75 | + name: pass_name, |
| 76 | + started_at: Instant::now(), |
| 77 | + }); |
| 78 | + Ok(()) |
| 79 | + } |
| 80 | + |
| 81 | + fn after_pass( |
| 82 | + &self, |
| 83 | + pass_name: &'static str, |
| 84 | + before: &Arc<Operator>, |
| 85 | + after: &Arc<Operator>, |
| 86 | + _ctx: &IRContext, |
| 87 | + ) -> Result<()> { |
| 88 | + let mut state = self.state.lock().unwrap(); |
| 89 | + let active = state |
| 90 | + .active_passes |
| 91 | + .pop() |
| 92 | + .expect("after_pass called without a matching before_pass"); |
| 93 | + assert_eq!( |
| 94 | + active.name, pass_name, |
| 95 | + "after_pass order mismatch: expected {pass_name}, found {}", |
| 96 | + active.name |
| 97 | + ); |
| 98 | + let elapsed = active.started_at.elapsed(); |
| 99 | + let metric = state.totals.entry(pass_name).or_default(); |
| 100 | + metric.invocations += 1; |
| 101 | + metric.total_duration += elapsed; |
| 102 | + info!( |
| 103 | + pass = pass_name, |
| 104 | + changed = before != after, |
| 105 | + elapsed_ms = %format_duration_ms(elapsed), |
| 106 | + total_ms = %format_duration_ms(metric.total_duration), |
| 107 | + invocations = metric.invocations, |
| 108 | + "optimizer pass profile", |
| 109 | + ); |
| 110 | + Ok(()) |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +#[cfg(test)] |
| 115 | +mod tests { |
| 116 | + use std::time::Duration; |
| 117 | + |
| 118 | + #[test] |
| 119 | + fn format_duration_ms_uses_three_significant_digits() { |
| 120 | + assert_eq!(super::format_duration_ms(Duration::from_nanos(0)), "0"); |
| 121 | + assert_eq!(super::format_duration_ms(Duration::from_micros(456)), "0.456"); |
| 122 | + assert_eq!(super::format_duration_ms(Duration::from_micros(12_340)), "12.3"); |
| 123 | + assert_eq!(super::format_duration_ms(Duration::from_micros(123_400)), "123"); |
| 124 | + assert_eq!(super::format_duration_ms(Duration::from_micros(1_234_000)), "1230"); |
| 125 | + } |
| 126 | +} |
0 commit comments