Skip to content

Commit f01fd62

Browse files
committed
move profile extension to a separate module
Signed-off-by: Yuchen Liang <yuchenl3@andrew.cmu.edu>
1 parent 87ec4c7 commit f01fd62

3 files changed

Lines changed: 130 additions & 82 deletions

File tree

optd/core/src/rules/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod logical_join_inner_commute;
66
mod logical_select_join_transpose;
77
mod logical_select_simplify;
88
mod pass;
9+
mod profile;
910
mod simplification;
1011

1112
pub use decorrelation::*;
@@ -15,5 +16,6 @@ pub use logical_join_inner_assoc::LogicalJoinInnerAssocRule;
1516
pub use logical_join_inner_commute::LogicalJoinInnerCommuteRule;
1617
pub use logical_select_join_transpose::LogicalSelectJoinTransposeRule;
1718
pub use logical_select_simplify::LogicalSelectSimplifyRule;
18-
pub use pass::{PassExtension, PassManager, PassManagerBuilder, PassProfilingExtension, PlanPass};
19+
pub use pass::{PassExtension, PassManager, PassManagerBuilder, PlanPass};
20+
pub use profile::PassProfilingExtension;
1921
pub use simplification::SimplificationPass;

optd/core/src/rules/pass.rs

Lines changed: 1 addition & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
use std::{
2-
collections::HashMap,
3-
sync::{Arc, Mutex},
4-
time::{Duration, Instant},
5-
};
1+
use std::sync::Arc;
62

73
use crate::{
84
error::Result,
@@ -130,82 +126,6 @@ impl PassManagerBuilder {
130126
}
131127
}
132128

133-
/// Pass extension that emits wall-clock timing metrics for each pass run.
134-
#[derive(Default)]
135-
pub struct PassProfilingExtension {
136-
/// Mutable timing state shared across pass invocations.
137-
state: Mutex<ProfilingPassState>,
138-
}
139-
140-
/// Internal state accumulated by [`PassProfilingExtension`].
141-
#[derive(Default)]
142-
struct ProfilingPassState {
143-
/// Passes that have started but not yet finished.
144-
active_passes: Vec<ActivePass>,
145-
/// Aggregated per-pass timing totals keyed by pass name.
146-
totals: HashMap<&'static str, PassProfileMetric>,
147-
}
148-
149-
/// Timing information for a pass that is currently executing.
150-
struct ActivePass {
151-
/// Stable name of the in-flight pass.
152-
name: &'static str,
153-
/// Timestamp captured immediately before the pass started.
154-
started_at: Instant,
155-
}
156-
157-
/// Aggregated metrics for one pass name.
158-
#[derive(Default)]
159-
struct PassProfileMetric {
160-
/// Number of times the pass has been executed.
161-
invocations: usize,
162-
/// Total wall-clock time spent running the pass.
163-
total_duration: Duration,
164-
}
165-
166-
impl PassExtension for PassProfilingExtension {
167-
fn before_pass(
168-
&self,
169-
pass_name: &'static str,
170-
_root: &Arc<Operator>,
171-
_ctx: &IRContext,
172-
) -> Result<()> {
173-
self.state.lock().unwrap().active_passes.push(ActivePass {
174-
name: pass_name,
175-
started_at: Instant::now(),
176-
});
177-
Ok(())
178-
}
179-
180-
fn after_pass(
181-
&self,
182-
pass_name: &'static str,
183-
before: &Arc<Operator>,
184-
after: &Arc<Operator>,
185-
_ctx: &IRContext,
186-
) -> Result<()> {
187-
let mut state = self.state.lock().unwrap();
188-
let active = state
189-
.active_passes
190-
.pop()
191-
.filter(|active| active.name == pass_name);
192-
let elapsed = active
193-
.map(|active| active.started_at.elapsed())
194-
.unwrap_or_default();
195-
let metric = state.totals.entry(pass_name).or_default();
196-
metric.invocations += 1;
197-
metric.total_duration += elapsed;
198-
eprintln!(
199-
"[optd pass profile] pass={pass_name} changed={} elapsed_ms={:.3} total_ms={:.3} invocations={}",
200-
before != after,
201-
elapsed.as_secs_f64() * 1_000.0,
202-
metric.total_duration.as_secs_f64() * 1_000.0,
203-
metric.invocations,
204-
);
205-
Ok(())
206-
}
207-
}
208-
209129
#[cfg(test)]
210130
mod tests {
211131
use super::{PassExtension, PassManager, PlanPass};

optd/core/src/rules/profile.rs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
use std::{
2+
collections::HashMap,
3+
sync::{Arc, Mutex},
4+
time::{Duration, Instant},
5+
};
6+
7+
use tracing::info;
8+
9+
use crate::{
10+
error::Result,
11+
ir::{IRContext, Operator},
12+
};
13+
14+
use super::PassExtension;
15+
16+
/// Pass extension that emits wall-clock timing metrics for each pass run.
17+
///
18+
/// Metrics are logged through `tracing`, so callers can enable or suppress
19+
/// output with their subscriber configuration. Without an explicit target, the
20+
/// events use this module path as their target.
21+
#[derive(Default)]
22+
pub struct PassProfilingExtension {
23+
/// Mutable timing state shared across pass invocations.
24+
state: Mutex<ProfilingPassState>,
25+
}
26+
27+
/// Internal state accumulated by [`PassProfilingExtension`].
28+
#[derive(Default)]
29+
struct ProfilingPassState {
30+
/// Passes that have started but not yet finished.
31+
active_passes: Vec<ActivePass>,
32+
/// Aggregated per-pass timing totals keyed by pass name.
33+
totals: HashMap<&'static str, PassProfileMetric>,
34+
}
35+
36+
/// Timing information for a pass that is currently executing.
37+
struct ActivePass {
38+
/// Stable name of the in-flight pass.
39+
name: &'static str,
40+
/// Timestamp captured immediately before the pass started.
41+
started_at: Instant,
42+
}
43+
44+
/// Aggregated metrics for one pass name.
45+
#[derive(Default)]
46+
struct PassProfileMetric {
47+
/// Number of times the pass has been executed.
48+
invocations: usize,
49+
/// Total wall-clock time spent running the pass.
50+
total_duration: Duration,
51+
}
52+
53+
/// Formats a duration in milliseconds using 3 significant digits.
54+
fn format_duration_ms(duration: Duration) -> String {
55+
let value = duration.as_secs_f64() * 1_000.0;
56+
if value == 0.0 {
57+
return "0".to_string();
58+
}
59+
60+
let exponent = value.abs().log10().floor() as i32;
61+
let scale = 10f64.powi(2 - exponent);
62+
let rounded = (value * scale).round() / scale;
63+
let decimals = (2 - exponent).max(0) as usize;
64+
format!("{rounded:.decimals$}")
65+
}
66+
67+
impl PassExtension for PassProfilingExtension {
68+
fn before_pass(
69+
&self,
70+
pass_name: &'static str,
71+
_root: &Arc<Operator>,
72+
_ctx: &IRContext,
73+
) -> Result<()> {
74+
self.state.lock().unwrap().active_passes.push(ActivePass {
75+
name: pass_name,
76+
started_at: Instant::now(),
77+
});
78+
Ok(())
79+
}
80+
81+
fn after_pass(
82+
&self,
83+
pass_name: &'static str,
84+
before: &Arc<Operator>,
85+
after: &Arc<Operator>,
86+
_ctx: &IRContext,
87+
) -> Result<()> {
88+
let mut state = self.state.lock().unwrap();
89+
let active = state
90+
.active_passes
91+
.pop()
92+
.expect("after_pass called without a matching before_pass");
93+
assert_eq!(
94+
active.name, pass_name,
95+
"after_pass order mismatch: expected {pass_name}, found {}",
96+
active.name
97+
);
98+
let elapsed = active.started_at.elapsed();
99+
let metric = state.totals.entry(pass_name).or_default();
100+
metric.invocations += 1;
101+
metric.total_duration += elapsed;
102+
info!(
103+
pass = pass_name,
104+
changed = before != after,
105+
elapsed_ms = %format_duration_ms(elapsed),
106+
total_ms = %format_duration_ms(metric.total_duration),
107+
invocations = metric.invocations,
108+
"optimizer pass profile",
109+
);
110+
Ok(())
111+
}
112+
}
113+
114+
#[cfg(test)]
115+
mod tests {
116+
use std::time::Duration;
117+
118+
#[test]
119+
fn format_duration_ms_uses_three_significant_digits() {
120+
assert_eq!(super::format_duration_ms(Duration::from_nanos(0)), "0");
121+
assert_eq!(super::format_duration_ms(Duration::from_micros(456)), "0.456");
122+
assert_eq!(super::format_duration_ms(Duration::from_micros(12_340)), "12.3");
123+
assert_eq!(super::format_duration_ms(Duration::from_micros(123_400)), "123");
124+
assert_eq!(super::format_duration_ms(Duration::from_micros(1_234_000)), "1230");
125+
}
126+
}

0 commit comments

Comments
 (0)