Skip to content

Commit ba63c3a

Browse files
authored
feat: prometheus metrics export (#21)
1 parent d83433e commit ba63c3a

File tree

11 files changed

+845
-88
lines changed

11 files changed

+845
-88
lines changed

Cargo.lock

Lines changed: 311 additions & 82 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dwd/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ pnet = "0.35"
4646
libc = "0.2"
4747
jemallocator = "0.5"
4848
rand = "0.9"
49+
axum = "0.8"
50+
prometheus-client = "0.24"
4951

5052
[target.'cfg(target_os = "linux")'.dependencies]
5153
netlink-packet-core = { version = "0.7" }

dwd/src/api/metrics.rs

Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
//! Prometheus metrics collector and HTTP handler.
2+
3+
use std::{fmt::Write, sync::Arc};
4+
5+
use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Router};
6+
use prometheus_client::{
7+
encoding::text::encode,
8+
metrics::{counter::Counter, gauge::Gauge},
9+
registry::Registry,
10+
};
11+
12+
use crate::{
13+
histogram::LogHistogram,
14+
stat::{CommonStat, HttpStat, RxStat, SocketStat, TxStat},
15+
};
16+
17+
/// Prometheus histogram buckets for latency (in seconds).
18+
/// Range: 5us to 10s with logarithmic distribution.
19+
const LATENCY_BUCKETS: [f64; 20] = [
20+
0.000_005, // 5us
21+
0.000_010, // 10us
22+
0.000_025, // 25us
23+
0.000_050, // 50us
24+
0.000_100, // 100us
25+
0.000_250, // 250us
26+
0.000_500, // 500us
27+
0.001, // 1ms
28+
0.002_5, // 2.5ms
29+
0.005, // 5ms
30+
0.010, // 10ms
31+
0.025, // 25ms
32+
0.050, // 50ms
33+
0.100, // 100ms
34+
0.250, // 250ms
35+
0.500, // 500ms
36+
1.0, // 1s
37+
2.5, // 2.5s
38+
5.0, // 5s
39+
10.0, // 10s
40+
];
41+
42+
/// Collector that gathers metrics from stat sources and exports them to
43+
/// Prometheus.
44+
pub struct MetricsCollector {
45+
registry: Registry,
46+
generator_rps: Gauge,
47+
requests_total: Counter,
48+
responses_total: Counter,
49+
timeouts_total: Counter,
50+
bytes_tx_total: Counter,
51+
bytes_rx_total: Counter,
52+
http_2xx_total: Counter,
53+
http_3xx_total: Counter,
54+
http_4xx_total: Counter,
55+
http_5xx_total: Counter,
56+
sockets_created_total: Counter,
57+
socket_errors_total: Counter,
58+
retransmits_total: Counter,
59+
}
60+
61+
impl MetricsCollector {
62+
/// Creates a new metrics collector and registers all metrics.
63+
pub fn new() -> Self {
64+
let mut registry = Registry::default();
65+
66+
let generator_rps = Gauge::default();
67+
registry.register(
68+
"dwd_generator_rps",
69+
"Target RPS from the generator",
70+
generator_rps.clone(),
71+
);
72+
73+
let requests_total = Counter::default();
74+
registry.register(
75+
"dwd_requests_total",
76+
"Total number of requests sent",
77+
requests_total.clone(),
78+
);
79+
80+
let responses_total = Counter::default();
81+
registry.register(
82+
"dwd_responses_total",
83+
"Total number of responses received",
84+
responses_total.clone(),
85+
);
86+
87+
let timeouts_total = Counter::default();
88+
registry.register(
89+
"dwd_timeouts_total",
90+
"Total number of request timeouts",
91+
timeouts_total.clone(),
92+
);
93+
94+
let bytes_tx_total = Counter::default();
95+
registry.register("dwd_bytes_tx_total", "Total bytes transmitted", bytes_tx_total.clone());
96+
97+
let bytes_rx_total = Counter::default();
98+
registry.register("dwd_bytes_rx_total", "Total bytes received", bytes_rx_total.clone());
99+
100+
let http_2xx_total = Counter::default();
101+
registry.register("dwd_http_2xx_total", "Total HTTP 2xx responses", http_2xx_total.clone());
102+
103+
let http_3xx_total = Counter::default();
104+
registry.register("dwd_http_3xx_total", "Total HTTP 3xx responses", http_3xx_total.clone());
105+
106+
let http_4xx_total = Counter::default();
107+
registry.register("dwd_http_4xx_total", "Total HTTP 4xx responses", http_4xx_total.clone());
108+
109+
let http_5xx_total = Counter::default();
110+
registry.register("dwd_http_5xx_total", "Total HTTP 5xx responses", http_5xx_total.clone());
111+
112+
let sockets_created_total = Counter::default();
113+
registry.register(
114+
"dwd_sockets_created_total",
115+
"Total number of sockets created",
116+
sockets_created_total.clone(),
117+
);
118+
119+
let socket_errors_total = Counter::default();
120+
registry.register(
121+
"dwd_socket_errors_total",
122+
"Total number of socket errors",
123+
socket_errors_total.clone(),
124+
);
125+
126+
let retransmits_total = Counter::default();
127+
registry.register(
128+
"dwd_retransmits_total",
129+
"Total number of TCP retransmits",
130+
retransmits_total.clone(),
131+
);
132+
133+
Self {
134+
registry,
135+
generator_rps,
136+
requests_total,
137+
responses_total,
138+
timeouts_total,
139+
bytes_tx_total,
140+
bytes_rx_total,
141+
http_2xx_total,
142+
http_3xx_total,
143+
http_4xx_total,
144+
http_5xx_total,
145+
sockets_created_total,
146+
socket_errors_total,
147+
retransmits_total,
148+
}
149+
}
150+
151+
/// Updates common stats (generator RPS).
152+
pub fn update_common<S: CommonStat>(&self, stat: &S) {
153+
self.generator_rps.set(stat.generator() as i64);
154+
}
155+
156+
/// Updates TX stats.
157+
pub fn update_tx<S: TxStat>(&self, stat: &S) {
158+
let requests = stat.num_requests();
159+
let bytes_tx = stat.bytes_tx();
160+
161+
// Calculate delta from current counter value.
162+
let current_requests = self.requests_total.get();
163+
if requests > current_requests {
164+
self.requests_total.inc_by(requests - current_requests);
165+
}
166+
167+
let current_bytes_tx = self.bytes_tx_total.get();
168+
if bytes_tx > current_bytes_tx {
169+
self.bytes_tx_total.inc_by(bytes_tx - current_bytes_tx);
170+
}
171+
}
172+
173+
/// Updates RX stats.
174+
pub fn update_rx<S: RxStat>(&self, stat: &S) {
175+
let responses = stat.num_responses();
176+
let timeouts = stat.num_timeouts();
177+
let bytes_rx = stat.bytes_rx();
178+
179+
let current_responses = self.responses_total.get();
180+
if responses > current_responses {
181+
self.responses_total.inc_by(responses - current_responses);
182+
}
183+
184+
let current_timeouts = self.timeouts_total.get();
185+
if timeouts > current_timeouts {
186+
self.timeouts_total.inc_by(timeouts - current_timeouts);
187+
}
188+
189+
let current_bytes_rx = self.bytes_rx_total.get();
190+
if bytes_rx > current_bytes_rx {
191+
self.bytes_rx_total.inc_by(bytes_rx - current_bytes_rx);
192+
}
193+
}
194+
195+
/// Updates HTTP stats.
196+
pub fn update_http<S: HttpStat>(&self, stat: &S) {
197+
let num_2xx = stat.num_2xx();
198+
let num_3xx = stat.num_3xx();
199+
let num_4xx = stat.num_4xx();
200+
let num_5xx = stat.num_5xx();
201+
202+
let current_2xx = self.http_2xx_total.get();
203+
if num_2xx > current_2xx {
204+
self.http_2xx_total.inc_by(num_2xx - current_2xx);
205+
}
206+
207+
let current_3xx = self.http_3xx_total.get();
208+
if num_3xx > current_3xx {
209+
self.http_3xx_total.inc_by(num_3xx - current_3xx);
210+
}
211+
212+
let current_4xx = self.http_4xx_total.get();
213+
if num_4xx > current_4xx {
214+
self.http_4xx_total.inc_by(num_4xx - current_4xx);
215+
}
216+
217+
let current_5xx = self.http_5xx_total.get();
218+
if num_5xx > current_5xx {
219+
self.http_5xx_total.inc_by(num_5xx - current_5xx);
220+
}
221+
}
222+
223+
/// Updates socket stats.
224+
pub fn update_socket<S: SocketStat>(&self, stat: &S) {
225+
let created = stat.num_sock_created();
226+
let errors = stat.num_sock_errors();
227+
let retransmits = stat.num_retransmits();
228+
229+
let current_created = self.sockets_created_total.get();
230+
if created > current_created {
231+
self.sockets_created_total.inc_by(created - current_created);
232+
}
233+
234+
let current_errors = self.socket_errors_total.get();
235+
if errors > current_errors {
236+
self.socket_errors_total.inc_by(errors - current_errors);
237+
}
238+
239+
let current_retransmits = self.retransmits_total.get();
240+
if retransmits > current_retransmits {
241+
self.retransmits_total.inc_by(retransmits - current_retransmits);
242+
}
243+
}
244+
245+
/// Encodes all metrics to Prometheus text format.
246+
pub fn encode(&self) -> String {
247+
let mut buffer = String::new();
248+
encode(&mut buffer, &self.registry).expect("encoding should not fail");
249+
buffer
250+
}
251+
}
252+
253+
impl Default for MetricsCollector {
254+
fn default() -> Self {
255+
Self::new()
256+
}
257+
}
258+
259+
/// Encodes a LogHistogram to Prometheus histogram format.
260+
///
261+
/// The histogram is encoded manually because prometheus-client's Histogram
262+
/// uses observe() which accumulates values, but we need to export absolute
263+
/// cumulative bucket counts from the log-histogram.
264+
fn encode_histogram(hist: &LogHistogram) -> String {
265+
let snapshot = hist.snapshot();
266+
let factor = LogHistogram::factor();
267+
268+
// Calculate cumulative counts for prometheus buckets.
269+
// For each prometheus bucket with upper bound B (in seconds),
270+
// we sum all log-bucket counts where the upper bound <= B.
271+
let mut bucket_counts = vec![0u64; LATENCY_BUCKETS.len()];
272+
let mut total_count = 0u64;
273+
let mut total_sum = 0.0f64;
274+
275+
for (idx, &count) in snapshot.iter().enumerate() {
276+
if count == 0 {
277+
continue;
278+
}
279+
280+
total_count += count;
281+
282+
// Upper bound of this log-bucket in microseconds.
283+
let upper_us = factor.powi(idx as i32);
284+
// Lower bound for sum calculation.
285+
let lower_us = if idx == 0 { 0.0 } else { factor.powi(idx as i32 - 1) };
286+
// Midpoint in seconds for sum calculation.
287+
let midpoint_sec = (lower_us + upper_us) / 2.0 / 1_000_000.0;
288+
total_sum += midpoint_sec * count as f64;
289+
290+
// Upper bound in seconds.
291+
let upper_sec = upper_us / 1_000_000.0;
292+
293+
// Add count to all prometheus buckets whose upper bound >= upper_sec.
294+
for (bucket_idx, &bucket_bound) in LATENCY_BUCKETS.iter().enumerate() {
295+
if bucket_bound >= upper_sec {
296+
bucket_counts[bucket_idx] += count;
297+
}
298+
}
299+
}
300+
301+
let mut output = String::new();
302+
writeln!(
303+
output,
304+
"# HELP dwd_latency_seconds Response latency histogram in seconds"
305+
)
306+
.unwrap();
307+
writeln!(output, "# TYPE dwd_latency_seconds histogram").unwrap();
308+
309+
// Buckets must be cumulative.
310+
let mut cumulative = 0u64;
311+
for (idx, &bound) in LATENCY_BUCKETS.iter().enumerate() {
312+
cumulative += bucket_counts[idx];
313+
writeln!(
314+
output,
315+
"dwd_latency_seconds_bucket{{le=\"{:.6}\"}} {}",
316+
bound, cumulative
317+
)
318+
.unwrap();
319+
}
320+
writeln!(output, "dwd_latency_seconds_bucket{{le=\"+Inf\"}} {}", total_count).unwrap();
321+
writeln!(output, "dwd_latency_seconds_sum {:.6}", total_sum).unwrap();
322+
writeln!(output, "dwd_latency_seconds_count {}", total_count).unwrap();
323+
324+
output
325+
}
326+
327+
/// Trait for stat sources that can be collected.
328+
pub trait StatSource: Send + Sync {
329+
/// Updates the metrics collector with current stats.
330+
fn collect(&self, collector: &MetricsCollector);
331+
332+
/// Returns the latency histogram if available.
333+
fn histogram(&self) -> Option<LogHistogram> {
334+
None
335+
}
336+
}
337+
338+
/// Shared state for the metrics handler.
339+
pub struct MetricsState {
340+
collector: MetricsCollector,
341+
stat_source: Arc<dyn StatSource>,
342+
}
343+
344+
impl MetricsState {
345+
/// Creates a new metrics state.
346+
pub fn new(stat_source: Arc<dyn StatSource>) -> Self {
347+
Self {
348+
collector: MetricsCollector::new(),
349+
stat_source,
350+
}
351+
}
352+
}
353+
354+
/// Creates a router for metrics endpoints.
355+
pub fn router(state: Arc<MetricsState>) -> Router {
356+
Router::new()
357+
.route("/api/v1/metrics", get(metrics_handler))
358+
.with_state(state)
359+
}
360+
361+
async fn metrics_handler(State(state): State<Arc<MetricsState>>) -> impl IntoResponse {
362+
// Update metrics from stat source.
363+
state.stat_source.collect(&state.collector);
364+
365+
// Encode to prometheus format.
366+
let mut body = state.collector.encode();
367+
368+
// Add histogram if available (encoded separately due to its special nature).
369+
if let Some(hist) = state.stat_source.histogram() {
370+
body.push_str(&encode_histogram(&hist));
371+
}
372+
373+
(
374+
StatusCode::OK,
375+
[(
376+
axum::http::header::CONTENT_TYPE,
377+
"text/plain; version=0.0.4; charset=utf-8",
378+
)],
379+
body,
380+
)
381+
}

0 commit comments

Comments
 (0)