|
| 1 | +//! Prometheus metrics collector and HTTP handler. |
| 2 | +
|
| 3 | +use std::{fmt::Write, sync::Arc}; |
| 4 | + |
| 5 | +use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Router}; |
| 6 | +use prometheus_client::{ |
| 7 | + encoding::text::encode, |
| 8 | + metrics::{counter::Counter, gauge::Gauge}, |
| 9 | + registry::Registry, |
| 10 | +}; |
| 11 | + |
| 12 | +use crate::{ |
| 13 | + histogram::LogHistogram, |
| 14 | + stat::{CommonStat, HttpStat, RxStat, SocketStat, TxStat}, |
| 15 | +}; |
| 16 | + |
| 17 | +/// Prometheus histogram buckets for latency (in seconds). |
| 18 | +/// Range: 5us to 10s with logarithmic distribution. |
| 19 | +const LATENCY_BUCKETS: [f64; 20] = [ |
| 20 | + 0.000_005, // 5us |
| 21 | + 0.000_010, // 10us |
| 22 | + 0.000_025, // 25us |
| 23 | + 0.000_050, // 50us |
| 24 | + 0.000_100, // 100us |
| 25 | + 0.000_250, // 250us |
| 26 | + 0.000_500, // 500us |
| 27 | + 0.001, // 1ms |
| 28 | + 0.002_5, // 2.5ms |
| 29 | + 0.005, // 5ms |
| 30 | + 0.010, // 10ms |
| 31 | + 0.025, // 25ms |
| 32 | + 0.050, // 50ms |
| 33 | + 0.100, // 100ms |
| 34 | + 0.250, // 250ms |
| 35 | + 0.500, // 500ms |
| 36 | + 1.0, // 1s |
| 37 | + 2.5, // 2.5s |
| 38 | + 5.0, // 5s |
| 39 | + 10.0, // 10s |
| 40 | +]; |
| 41 | + |
| 42 | +/// Collector that gathers metrics from stat sources and exports them to |
| 43 | +/// Prometheus. |
| 44 | +pub struct MetricsCollector { |
| 45 | + registry: Registry, |
| 46 | + generator_rps: Gauge, |
| 47 | + requests_total: Counter, |
| 48 | + responses_total: Counter, |
| 49 | + timeouts_total: Counter, |
| 50 | + bytes_tx_total: Counter, |
| 51 | + bytes_rx_total: Counter, |
| 52 | + http_2xx_total: Counter, |
| 53 | + http_3xx_total: Counter, |
| 54 | + http_4xx_total: Counter, |
| 55 | + http_5xx_total: Counter, |
| 56 | + sockets_created_total: Counter, |
| 57 | + socket_errors_total: Counter, |
| 58 | + retransmits_total: Counter, |
| 59 | +} |
| 60 | + |
| 61 | +impl MetricsCollector { |
| 62 | + /// Creates a new metrics collector and registers all metrics. |
| 63 | + pub fn new() -> Self { |
| 64 | + let mut registry = Registry::default(); |
| 65 | + |
| 66 | + let generator_rps = Gauge::default(); |
| 67 | + registry.register( |
| 68 | + "dwd_generator_rps", |
| 69 | + "Target RPS from the generator", |
| 70 | + generator_rps.clone(), |
| 71 | + ); |
| 72 | + |
| 73 | + let requests_total = Counter::default(); |
| 74 | + registry.register( |
| 75 | + "dwd_requests_total", |
| 76 | + "Total number of requests sent", |
| 77 | + requests_total.clone(), |
| 78 | + ); |
| 79 | + |
| 80 | + let responses_total = Counter::default(); |
| 81 | + registry.register( |
| 82 | + "dwd_responses_total", |
| 83 | + "Total number of responses received", |
| 84 | + responses_total.clone(), |
| 85 | + ); |
| 86 | + |
| 87 | + let timeouts_total = Counter::default(); |
| 88 | + registry.register( |
| 89 | + "dwd_timeouts_total", |
| 90 | + "Total number of request timeouts", |
| 91 | + timeouts_total.clone(), |
| 92 | + ); |
| 93 | + |
| 94 | + let bytes_tx_total = Counter::default(); |
| 95 | + registry.register("dwd_bytes_tx_total", "Total bytes transmitted", bytes_tx_total.clone()); |
| 96 | + |
| 97 | + let bytes_rx_total = Counter::default(); |
| 98 | + registry.register("dwd_bytes_rx_total", "Total bytes received", bytes_rx_total.clone()); |
| 99 | + |
| 100 | + let http_2xx_total = Counter::default(); |
| 101 | + registry.register("dwd_http_2xx_total", "Total HTTP 2xx responses", http_2xx_total.clone()); |
| 102 | + |
| 103 | + let http_3xx_total = Counter::default(); |
| 104 | + registry.register("dwd_http_3xx_total", "Total HTTP 3xx responses", http_3xx_total.clone()); |
| 105 | + |
| 106 | + let http_4xx_total = Counter::default(); |
| 107 | + registry.register("dwd_http_4xx_total", "Total HTTP 4xx responses", http_4xx_total.clone()); |
| 108 | + |
| 109 | + let http_5xx_total = Counter::default(); |
| 110 | + registry.register("dwd_http_5xx_total", "Total HTTP 5xx responses", http_5xx_total.clone()); |
| 111 | + |
| 112 | + let sockets_created_total = Counter::default(); |
| 113 | + registry.register( |
| 114 | + "dwd_sockets_created_total", |
| 115 | + "Total number of sockets created", |
| 116 | + sockets_created_total.clone(), |
| 117 | + ); |
| 118 | + |
| 119 | + let socket_errors_total = Counter::default(); |
| 120 | + registry.register( |
| 121 | + "dwd_socket_errors_total", |
| 122 | + "Total number of socket errors", |
| 123 | + socket_errors_total.clone(), |
| 124 | + ); |
| 125 | + |
| 126 | + let retransmits_total = Counter::default(); |
| 127 | + registry.register( |
| 128 | + "dwd_retransmits_total", |
| 129 | + "Total number of TCP retransmits", |
| 130 | + retransmits_total.clone(), |
| 131 | + ); |
| 132 | + |
| 133 | + Self { |
| 134 | + registry, |
| 135 | + generator_rps, |
| 136 | + requests_total, |
| 137 | + responses_total, |
| 138 | + timeouts_total, |
| 139 | + bytes_tx_total, |
| 140 | + bytes_rx_total, |
| 141 | + http_2xx_total, |
| 142 | + http_3xx_total, |
| 143 | + http_4xx_total, |
| 144 | + http_5xx_total, |
| 145 | + sockets_created_total, |
| 146 | + socket_errors_total, |
| 147 | + retransmits_total, |
| 148 | + } |
| 149 | + } |
| 150 | + |
| 151 | + /// Updates common stats (generator RPS). |
| 152 | + pub fn update_common<S: CommonStat>(&self, stat: &S) { |
| 153 | + self.generator_rps.set(stat.generator() as i64); |
| 154 | + } |
| 155 | + |
| 156 | + /// Updates TX stats. |
| 157 | + pub fn update_tx<S: TxStat>(&self, stat: &S) { |
| 158 | + let requests = stat.num_requests(); |
| 159 | + let bytes_tx = stat.bytes_tx(); |
| 160 | + |
| 161 | + // Calculate delta from current counter value. |
| 162 | + let current_requests = self.requests_total.get(); |
| 163 | + if requests > current_requests { |
| 164 | + self.requests_total.inc_by(requests - current_requests); |
| 165 | + } |
| 166 | + |
| 167 | + let current_bytes_tx = self.bytes_tx_total.get(); |
| 168 | + if bytes_tx > current_bytes_tx { |
| 169 | + self.bytes_tx_total.inc_by(bytes_tx - current_bytes_tx); |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + /// Updates RX stats. |
| 174 | + pub fn update_rx<S: RxStat>(&self, stat: &S) { |
| 175 | + let responses = stat.num_responses(); |
| 176 | + let timeouts = stat.num_timeouts(); |
| 177 | + let bytes_rx = stat.bytes_rx(); |
| 178 | + |
| 179 | + let current_responses = self.responses_total.get(); |
| 180 | + if responses > current_responses { |
| 181 | + self.responses_total.inc_by(responses - current_responses); |
| 182 | + } |
| 183 | + |
| 184 | + let current_timeouts = self.timeouts_total.get(); |
| 185 | + if timeouts > current_timeouts { |
| 186 | + self.timeouts_total.inc_by(timeouts - current_timeouts); |
| 187 | + } |
| 188 | + |
| 189 | + let current_bytes_rx = self.bytes_rx_total.get(); |
| 190 | + if bytes_rx > current_bytes_rx { |
| 191 | + self.bytes_rx_total.inc_by(bytes_rx - current_bytes_rx); |
| 192 | + } |
| 193 | + } |
| 194 | + |
| 195 | + /// Updates HTTP stats. |
| 196 | + pub fn update_http<S: HttpStat>(&self, stat: &S) { |
| 197 | + let num_2xx = stat.num_2xx(); |
| 198 | + let num_3xx = stat.num_3xx(); |
| 199 | + let num_4xx = stat.num_4xx(); |
| 200 | + let num_5xx = stat.num_5xx(); |
| 201 | + |
| 202 | + let current_2xx = self.http_2xx_total.get(); |
| 203 | + if num_2xx > current_2xx { |
| 204 | + self.http_2xx_total.inc_by(num_2xx - current_2xx); |
| 205 | + } |
| 206 | + |
| 207 | + let current_3xx = self.http_3xx_total.get(); |
| 208 | + if num_3xx > current_3xx { |
| 209 | + self.http_3xx_total.inc_by(num_3xx - current_3xx); |
| 210 | + } |
| 211 | + |
| 212 | + let current_4xx = self.http_4xx_total.get(); |
| 213 | + if num_4xx > current_4xx { |
| 214 | + self.http_4xx_total.inc_by(num_4xx - current_4xx); |
| 215 | + } |
| 216 | + |
| 217 | + let current_5xx = self.http_5xx_total.get(); |
| 218 | + if num_5xx > current_5xx { |
| 219 | + self.http_5xx_total.inc_by(num_5xx - current_5xx); |
| 220 | + } |
| 221 | + } |
| 222 | + |
| 223 | + /// Updates socket stats. |
| 224 | + pub fn update_socket<S: SocketStat>(&self, stat: &S) { |
| 225 | + let created = stat.num_sock_created(); |
| 226 | + let errors = stat.num_sock_errors(); |
| 227 | + let retransmits = stat.num_retransmits(); |
| 228 | + |
| 229 | + let current_created = self.sockets_created_total.get(); |
| 230 | + if created > current_created { |
| 231 | + self.sockets_created_total.inc_by(created - current_created); |
| 232 | + } |
| 233 | + |
| 234 | + let current_errors = self.socket_errors_total.get(); |
| 235 | + if errors > current_errors { |
| 236 | + self.socket_errors_total.inc_by(errors - current_errors); |
| 237 | + } |
| 238 | + |
| 239 | + let current_retransmits = self.retransmits_total.get(); |
| 240 | + if retransmits > current_retransmits { |
| 241 | + self.retransmits_total.inc_by(retransmits - current_retransmits); |
| 242 | + } |
| 243 | + } |
| 244 | + |
| 245 | + /// Encodes all metrics to Prometheus text format. |
| 246 | + pub fn encode(&self) -> String { |
| 247 | + let mut buffer = String::new(); |
| 248 | + encode(&mut buffer, &self.registry).expect("encoding should not fail"); |
| 249 | + buffer |
| 250 | + } |
| 251 | +} |
| 252 | + |
| 253 | +impl Default for MetricsCollector { |
| 254 | + fn default() -> Self { |
| 255 | + Self::new() |
| 256 | + } |
| 257 | +} |
| 258 | + |
| 259 | +/// Encodes a LogHistogram to Prometheus histogram format. |
| 260 | +/// |
| 261 | +/// The histogram is encoded manually because prometheus-client's Histogram |
| 262 | +/// uses observe() which accumulates values, but we need to export absolute |
| 263 | +/// cumulative bucket counts from the log-histogram. |
| 264 | +fn encode_histogram(hist: &LogHistogram) -> String { |
| 265 | + let snapshot = hist.snapshot(); |
| 266 | + let factor = LogHistogram::factor(); |
| 267 | + |
| 268 | + // Calculate cumulative counts for prometheus buckets. |
| 269 | + // For each prometheus bucket with upper bound B (in seconds), |
| 270 | + // we sum all log-bucket counts where the upper bound <= B. |
| 271 | + let mut bucket_counts = vec![0u64; LATENCY_BUCKETS.len()]; |
| 272 | + let mut total_count = 0u64; |
| 273 | + let mut total_sum = 0.0f64; |
| 274 | + |
| 275 | + for (idx, &count) in snapshot.iter().enumerate() { |
| 276 | + if count == 0 { |
| 277 | + continue; |
| 278 | + } |
| 279 | + |
| 280 | + total_count += count; |
| 281 | + |
| 282 | + // Upper bound of this log-bucket in microseconds. |
| 283 | + let upper_us = factor.powi(idx as i32); |
| 284 | + // Lower bound for sum calculation. |
| 285 | + let lower_us = if idx == 0 { 0.0 } else { factor.powi(idx as i32 - 1) }; |
| 286 | + // Midpoint in seconds for sum calculation. |
| 287 | + let midpoint_sec = (lower_us + upper_us) / 2.0 / 1_000_000.0; |
| 288 | + total_sum += midpoint_sec * count as f64; |
| 289 | + |
| 290 | + // Upper bound in seconds. |
| 291 | + let upper_sec = upper_us / 1_000_000.0; |
| 292 | + |
| 293 | + // Add count to all prometheus buckets whose upper bound >= upper_sec. |
| 294 | + for (bucket_idx, &bucket_bound) in LATENCY_BUCKETS.iter().enumerate() { |
| 295 | + if bucket_bound >= upper_sec { |
| 296 | + bucket_counts[bucket_idx] += count; |
| 297 | + } |
| 298 | + } |
| 299 | + } |
| 300 | + |
| 301 | + let mut output = String::new(); |
| 302 | + writeln!( |
| 303 | + output, |
| 304 | + "# HELP dwd_latency_seconds Response latency histogram in seconds" |
| 305 | + ) |
| 306 | + .unwrap(); |
| 307 | + writeln!(output, "# TYPE dwd_latency_seconds histogram").unwrap(); |
| 308 | + |
| 309 | + // Buckets must be cumulative. |
| 310 | + let mut cumulative = 0u64; |
| 311 | + for (idx, &bound) in LATENCY_BUCKETS.iter().enumerate() { |
| 312 | + cumulative += bucket_counts[idx]; |
| 313 | + writeln!( |
| 314 | + output, |
| 315 | + "dwd_latency_seconds_bucket{{le=\"{:.6}\"}} {}", |
| 316 | + bound, cumulative |
| 317 | + ) |
| 318 | + .unwrap(); |
| 319 | + } |
| 320 | + writeln!(output, "dwd_latency_seconds_bucket{{le=\"+Inf\"}} {}", total_count).unwrap(); |
| 321 | + writeln!(output, "dwd_latency_seconds_sum {:.6}", total_sum).unwrap(); |
| 322 | + writeln!(output, "dwd_latency_seconds_count {}", total_count).unwrap(); |
| 323 | + |
| 324 | + output |
| 325 | +} |
| 326 | + |
| 327 | +/// Trait for stat sources that can be collected. |
| 328 | +pub trait StatSource: Send + Sync { |
| 329 | + /// Updates the metrics collector with current stats. |
| 330 | + fn collect(&self, collector: &MetricsCollector); |
| 331 | + |
| 332 | + /// Returns the latency histogram if available. |
| 333 | + fn histogram(&self) -> Option<LogHistogram> { |
| 334 | + None |
| 335 | + } |
| 336 | +} |
| 337 | + |
| 338 | +/// Shared state for the metrics handler. |
| 339 | +pub struct MetricsState { |
| 340 | + collector: MetricsCollector, |
| 341 | + stat_source: Arc<dyn StatSource>, |
| 342 | +} |
| 343 | + |
| 344 | +impl MetricsState { |
| 345 | + /// Creates a new metrics state. |
| 346 | + pub fn new(stat_source: Arc<dyn StatSource>) -> Self { |
| 347 | + Self { |
| 348 | + collector: MetricsCollector::new(), |
| 349 | + stat_source, |
| 350 | + } |
| 351 | + } |
| 352 | +} |
| 353 | + |
| 354 | +/// Creates a router for metrics endpoints. |
| 355 | +pub fn router(state: Arc<MetricsState>) -> Router { |
| 356 | + Router::new() |
| 357 | + .route("/api/v1/metrics", get(metrics_handler)) |
| 358 | + .with_state(state) |
| 359 | +} |
| 360 | + |
| 361 | +async fn metrics_handler(State(state): State<Arc<MetricsState>>) -> impl IntoResponse { |
| 362 | + // Update metrics from stat source. |
| 363 | + state.stat_source.collect(&state.collector); |
| 364 | + |
| 365 | + // Encode to prometheus format. |
| 366 | + let mut body = state.collector.encode(); |
| 367 | + |
| 368 | + // Add histogram if available (encoded separately due to its special nature). |
| 369 | + if let Some(hist) = state.stat_source.histogram() { |
| 370 | + body.push_str(&encode_histogram(&hist)); |
| 371 | + } |
| 372 | + |
| 373 | + ( |
| 374 | + StatusCode::OK, |
| 375 | + [( |
| 376 | + axum::http::header::CONTENT_TYPE, |
| 377 | + "text/plain; version=0.0.4; charset=utf-8", |
| 378 | + )], |
| 379 | + body, |
| 380 | + ) |
| 381 | +} |
0 commit comments