|
| 1 | +//! # End-to-end Hawkes process example with real orderbook data |
| 2 | +//! |
| 3 | +//! Pipeline: |
| 4 | +//! |
| 5 | +//! 1. Load a Bybit SOLUSDT orderbook parquet file. |
| 6 | +//! 2. Extract timestamps → treat each snapshot as an "arrival". |
| 7 | +//! 3. Convert to milliseconds, compute interarrival times. |
| 8 | +//! 4. Hold out the last 10 arrivals as a test set. |
| 9 | +//! 5. Print descriptive statistics and gap diagnostics on the training set. |
| 10 | +//! 6. Fit a univariate Hawkes process (MLE) on the training interarrivals. |
| 11 | +//! 7. Use the fitted model to forecast the next 10 arrival times. |
| 12 | +//! 8. Compare forecast vs. actual, print the diff. |
| 13 | +//! |
| 14 | +//! ```text |
| 15 | +//! cargo run -p atelier_quant --example eg_hawkes_ob_arrivals |
| 16 | +//! ``` |
| 17 | +
|
| 18 | +use std::path::Path; |
| 19 | + |
| 20 | +use atelier_data::orderbooks::io::ob_parquet::load_parquet_to_ob; |
| 21 | +use atelier_data::temporal::{self, TimeResolution}; |
| 22 | + |
| 23 | +use atelier_quant::arrivals::extract::extract_orderbook_timestamps; |
| 24 | +use atelier_quant::arrivals::inter::{compute_interarrivals, descriptive_stats}; |
| 25 | +use atelier_quant::hawkes::estimation::{ |
| 26 | + compensator, estimate_hawkes_mle, time_rescaling_residuals, HawkesEstimationConfig, |
| 27 | +}; |
| 28 | +use atelier_quant::hawkes::HawkesProcess; |
| 29 | + |
| 30 | +// ── Helpers ───────────────────────────────────────────────────────── |
| 31 | + |
| 32 | +/// Pretty-print a horizontal separator. |
| 33 | +fn separator(label: &str) { |
| 34 | + println!("\n{:═^72}", format!(" {} ", label)); |
| 35 | +} |
| 36 | + |
| 37 | +/// Print a small table row. |
| 38 | +fn row(label: &str, value: impl std::fmt::Display) { |
| 39 | + println!(" {:<30} {}", label, value); |
| 40 | +} |
| 41 | + |
| 42 | +// ── Main ──────────────────────────────────────────────────────────── |
| 43 | + |
| 44 | +fn main() { |
| 45 | + // ── 1. Load orderbook parquet ─────────────────────────────────── |
| 46 | + separator("1. Load Parquet"); |
| 47 | + |
| 48 | + let parquet_path = Path::new( |
| 49 | + "datasets/collected/bybit/SOLUSDT/orderbooks/ob_bybit_20260217_003728.819.parquet", |
| 50 | + ); |
| 51 | + |
| 52 | + println!(" File: {}", parquet_path.display()); |
| 53 | + |
| 54 | + let orderbooks = load_parquet_to_ob(parquet_path).unwrap_or_else(|e| { |
| 55 | + eprintln!(" ERROR: Failed to load parquet: {}", e); |
| 56 | + std::process::exit(1); |
| 57 | + }); |
| 58 | + |
| 59 | + println!(" Loaded {} orderbook snapshots", orderbooks.len()); |
| 60 | + |
| 61 | + if orderbooks.len() < 12 { |
| 62 | + eprintln!(" ERROR: Need at least 12 snapshots (10 test + 2 train), got {}", orderbooks.len()); |
| 63 | + std::process::exit(1); |
| 64 | + } |
| 65 | + |
| 66 | + // ── 2. Extract timestamps (nanoseconds) ───────────────────────── |
| 67 | + separator("2. Extract Timestamps"); |
| 68 | + |
| 69 | + let timestamps_ns = extract_orderbook_timestamps(&orderbooks); |
| 70 | + println!(" Total arrivals: {}", timestamps_ns.len()); |
| 71 | + println!( |
| 72 | + " First ts (ns): {}", |
| 73 | + timestamps_ns.first().unwrap_or(&0) |
| 74 | + ); |
| 75 | + println!( |
| 76 | + " Last ts (ns): {}", |
| 77 | + timestamps_ns.last().unwrap_or(&0) |
| 78 | + ); |
| 79 | + |
| 80 | + // Convert to milliseconds for display |
| 81 | + let first_ms = temporal::from_nanos(*timestamps_ns.first().unwrap(), TimeResolution::Milliseconds); |
| 82 | + let last_ms = temporal::from_nanos(*timestamps_ns.last().unwrap(), TimeResolution::Milliseconds); |
| 83 | + let span_s = (last_ms - first_ms) / 1000.0; |
| 84 | + println!(" Observation window: {:.3} seconds ({:.1} ms)", span_s, last_ms - first_ms); |
| 85 | + |
| 86 | + // ── 3. Validate and detect gaps ───────────────────────────────── |
| 87 | + separator("3. Validation & Gap Detection"); |
| 88 | + |
| 89 | + // Validate monotonicity |
| 90 | + let mut is_monotonic = true; |
| 91 | + for i in 1..timestamps_ns.len() { |
| 92 | + if timestamps_ns[i] <= timestamps_ns[i - 1] { |
| 93 | + eprintln!( |
| 94 | + " ✗ Monotonicity violation at index {}: {} <= {}", |
| 95 | + i, timestamps_ns[i], timestamps_ns[i - 1] |
| 96 | + ); |
| 97 | + is_monotonic = false; |
| 98 | + break; |
| 99 | + } |
| 100 | + } |
| 101 | + if is_monotonic { |
| 102 | + println!("Timestamps are strictly monotonic"); |
| 103 | + } else { |
| 104 | + std::process::exit(1); |
| 105 | + } |
| 106 | + |
| 107 | + // Detect gaps > 5 seconds (5_000_000_000 ns) — likely feed disconnects |
| 108 | + let gap_threshold_ns = 5_000_000_000_u64; |
| 109 | + let mut n_gaps = 0_usize; |
| 110 | + for i in 1..timestamps_ns.len() { |
| 111 | + let gap = timestamps_ns[i] - timestamps_ns[i - 1]; |
| 112 | + if gap > gap_threshold_ns { |
| 113 | + if n_gaps == 0 { |
| 114 | + println!(" ⚠ Gaps exceeding {:.1}s:", gap_threshold_ns as f64 / 1e9); |
| 115 | + } |
| 116 | + println!(" index {}: gap = {:.3} ms", i - 1, gap as f64 / 1e6); |
| 117 | + n_gaps += 1; |
| 118 | + } |
| 119 | + } |
| 120 | + if n_gaps == 0 { |
| 121 | + println!(" ✓ No gaps exceeding {:.1}s detected", gap_threshold_ns as f64 / 1e9); |
| 122 | + } else { |
| 123 | + println!(" Total large gaps: {}", n_gaps); |
| 124 | + } |
| 125 | + |
| 126 | + // ── 4. Train/test split ───────────────────────────────────────── |
| 127 | + separator("4. Train / Test Split"); |
| 128 | + |
| 129 | + let n_test = 10; |
| 130 | + let n_total = timestamps_ns.len(); |
| 131 | + let n_train = n_total - n_test; |
| 132 | + |
| 133 | + let train_ts = ×tamps_ns[..n_train]; |
| 134 | + let test_ts = ×tamps_ns[n_train..]; |
| 135 | + |
| 136 | + println!(" Training set: {} arrivals", train_ts.len()); |
| 137 | + println!(" Test set: {} arrivals", test_ts.len()); |
| 138 | + |
| 139 | + // ── 5. Compute interarrivals + stats (training set) ───────────── |
| 140 | + separator("5. Interarrival Statistics (Training)"); |
| 141 | + |
| 142 | + let ia_result = |
| 143 | + compute_interarrivals(train_ts, TimeResolution::Milliseconds).unwrap_or_else(|e| { |
| 144 | + eprintln!(" ERROR: {}", e); |
| 145 | + std::process::exit(1); |
| 146 | + }); |
| 147 | + |
| 148 | + let stats = descriptive_stats(&ia_result.deltas_f64).unwrap(); |
| 149 | + |
| 150 | + row("Count (gaps)", format!("{}", stats.count)); |
| 151 | + row("Mean (ms)", format!("{:.6}", stats.mean)); |
| 152 | + row("Std dev (ms)", format!("{:.6}", stats.std_dev)); |
| 153 | + row("Variance (ms²)", format!("{:.6}", stats.variance)); |
| 154 | + row("Min (ms)", format!("{:.6}", stats.min)); |
| 155 | + row("Max (ms)", format!("{:.6}", stats.max)); |
| 156 | + row("Skewness", format!("{:.4}", stats.skewness)); |
| 157 | + row("Excess kurtosis", format!("{:.4}", stats.kurtosis)); |
| 158 | + row("CV (σ/μ)", format!("{:.4}", stats.covariance)); |
| 159 | + |
| 160 | + if stats.covariance > 1.0 { |
| 161 | + println!("\n → CV > 1 indicates clustering (super-Poisson), consistent with Hawkes excitation."); |
| 162 | + } else if (stats.covariance - 1.0).abs() < 0.15 { |
| 163 | + println!("\n → CV ≈ 1 suggests near-Poisson (memoryless) arrivals."); |
| 164 | + } else { |
| 165 | + println!("\n → CV < 1 indicates regularity (sub-Poisson), less common for LOB data."); |
| 166 | + } |
| 167 | + |
| 168 | + // ── 6. Fit Hawkes MLE ─────────────────────────────────────────── |
| 169 | + separator("6. Hawkes MLE Estimation"); |
| 170 | + |
| 171 | + // Build event times in milliseconds relative to the first arrival. |
| 172 | + // This keeps numbers in a reasonable range for the optimizer. |
| 173 | + let t0_ns = train_ts[0]; |
| 174 | + let train_events_ms: Vec<f64> = train_ts |
| 175 | + .iter() |
| 176 | + .map(|&t| temporal::from_nanos(t - t0_ns, TimeResolution::Milliseconds)) |
| 177 | + .collect(); |
| 178 | + |
| 179 | + let config = HawkesEstimationConfig { |
| 180 | + max_iter: 10_000, |
| 181 | + tol: 1e-4, |
| 182 | + learning_rate: 1e-3, |
| 183 | + initial_params: None, |
| 184 | + }; |
| 185 | + |
| 186 | + let mle = estimate_hawkes_mle(&train_events_ms, &config).unwrap_or_else(|e| { |
| 187 | + eprintln!(" ERROR: MLE failed: {}", e); |
| 188 | + std::process::exit(1); |
| 189 | + }); |
| 190 | + |
| 191 | + row("μ̂ (events/ms)", format!("{:.8}", mle.mu)); |
| 192 | + row("α̂ (excitation)", format!("{:.8}", mle.alpha)); |
| 193 | + row("β̂ (decay 1/ms)", format!("{:.8}", mle.beta)); |
| 194 | + row("Branching ratio α̂/β̂", format!("{:.6}", mle.branching_ratio)); |
| 195 | + row("Log-likelihood", format!("{:.4}", mle.log_likelihood)); |
| 196 | + row("AIC", format!("{:.4}", mle.aic)); |
| 197 | + row("BIC", format!("{:.4}", mle.bic)); |
| 198 | + row("Iterations", format!("{}", mle.iterations)); |
| 199 | + row("Converged", format!("{}", mle.converged)); |
| 200 | + |
| 201 | + let theoretical_rate = mle.mu / (1.0 - mle.branching_ratio); |
| 202 | + row("Stationary rate (ev/ms)", format!("{:.8}", theoretical_rate)); |
| 203 | + row( |
| 204 | + "Stationary mean gap (ms)", |
| 205 | + format!("{:.6}", 1.0 / theoretical_rate), |
| 206 | + ); |
| 207 | + |
| 208 | + // ── 7. Goodness-of-fit: time-rescaling residuals ──────────────── |
| 209 | + separator("7. Goodness-of-Fit (Time-Rescaling)"); |
| 210 | + |
| 211 | + let residuals = |
| 212 | + time_rescaling_residuals(mle.mu, mle.alpha, mle.beta, &train_events_ms); |
| 213 | + |
| 214 | + let res_stats = descriptive_stats(&residuals); |
| 215 | + if let Some(rs) = &res_stats { |
| 216 | + row("Residuals count", format!("{}", rs.count)); |
| 217 | + row("Residuals mean", format!("{:.6}", rs.mean)); |
| 218 | + row("Residuals std dev", format!("{:.6}", rs.std_dev)); |
| 219 | + println!( |
| 220 | + "\n Under correct specification, residuals ~ Exp(1): mean ≈ 1.0, std ≈ 1.0" |
| 221 | + ); |
| 222 | + if (rs.mean - 1.0).abs() < 0.3 { |
| 223 | + println!(" → Mean {:.3} is within 30% of 1.0: reasonable fit.", rs.mean); |
| 224 | + } else { |
| 225 | + println!(" → Mean {:.3} deviates from 1.0: model may be mis-specified.", rs.mean); |
| 226 | + } |
| 227 | + } |
| 228 | + |
| 229 | + // ── 8. Forecast next 10 arrivals ──────────────────────────────── |
| 230 | + separator("8. Forecast Next 10 Arrivals"); |
| 231 | + |
| 232 | + // Strategy: use the compensator to convert from Hawkes time to |
| 233 | + // calendar time. We simulate from the fitted model starting at the |
| 234 | + // last training event. |
| 235 | + let last_train_ms = *train_events_ms.last().unwrap(); |
| 236 | + |
| 237 | + // Build a HawkesProcess with the fitted parameters and simulate |
| 238 | + let hp = HawkesProcess::new(mle.mu, mle.alpha, mle.beta).unwrap_or_else(|e| { |
| 239 | + eprintln!(" ERROR: Could not create HawkesProcess: {:?}", e); |
| 240 | + std::process::exit(1); |
| 241 | + }); |
| 242 | + |
| 243 | + // We simulate events continuing from the last training time. |
| 244 | + // The intensity at the boundary depends on the full training history, |
| 245 | + // so we pass the last training time as the start. |
| 246 | + let forecasted_events_ms = hp.generate_values(last_train_ms, n_test); |
| 247 | + |
| 248 | + // Convert forecasted absolute times (ms relative to t0) back to |
| 249 | + // nanosecond timestamps: ms × 1_000_000 = ns, then add origin. |
| 250 | + let _forecasted_ts_ns: Vec<u64> = forecasted_events_ms |
| 251 | + .iter() |
| 252 | + .map(|&t_ms| t0_ns + (t_ms * 1_000_000.0) as u64) |
| 253 | + .collect(); |
| 254 | + |
| 255 | + // ── 9. Compare forecast vs actual ─────────────────────────────── |
| 256 | + separator("9. Forecast vs Actual Comparison"); |
| 257 | + |
| 258 | + // Compute interarrival gaps for both series |
| 259 | + // For actual: gaps between last train event and each subsequent test event |
| 260 | + let actual_arrivals_ms: Vec<f64> = test_ts |
| 261 | + .iter() |
| 262 | + .map(|&t| temporal::from_nanos(t - t0_ns, TimeResolution::Milliseconds)) |
| 263 | + .collect(); |
| 264 | + |
| 265 | + let forecast_arrivals_ms: Vec<f64> = forecasted_events_ms.clone(); |
| 266 | + |
| 267 | + // Compute cumulative interarrivals from the last training point |
| 268 | + let actual_gaps: Vec<f64> = actual_arrivals_ms |
| 269 | + .iter() |
| 270 | + .map(|&t| t - last_train_ms) |
| 271 | + .collect(); |
| 272 | + |
| 273 | + let forecast_gaps: Vec<f64> = forecast_arrivals_ms |
| 274 | + .iter() |
| 275 | + .map(|&t| t - last_train_ms) |
| 276 | + .collect(); |
| 277 | + |
| 278 | + println!( |
| 279 | + " {:>4} {:>16} {:>16} {:>16}", |
| 280 | + "i", "Actual Δt (ms)", "Forecast Δt (ms)", "Diff (ms)" |
| 281 | + ); |
| 282 | + println!(" {}", "─".repeat(68)); |
| 283 | + |
| 284 | + for i in 0..n_test { |
| 285 | + let actual = actual_gaps[i]; |
| 286 | + let forecast = if i < forecast_gaps.len() { |
| 287 | + forecast_gaps[i] |
| 288 | + } else { |
| 289 | + f64::NAN |
| 290 | + }; |
| 291 | + let diff = actual - forecast; |
| 292 | + println!( |
| 293 | + " {:>4} {:>16.4} {:>16.4} {:>16.4}", |
| 294 | + i + 1, |
| 295 | + actual, |
| 296 | + forecast, |
| 297 | + diff |
| 298 | + ); |
| 299 | + } |
| 300 | + |
| 301 | + // ── 10. Summary error metrics ─────────────────────────────────── |
| 302 | + separator("10. Forecast Error Metrics"); |
| 303 | + |
| 304 | + let n_compare = n_test.min(forecast_gaps.len()); |
| 305 | + let mut sum_abs_err = 0.0_f64; |
| 306 | + let mut sum_sq_err = 0.0_f64; |
| 307 | + |
| 308 | + for i in 0..n_compare { |
| 309 | + let err = actual_gaps[i] - forecast_gaps[i]; |
| 310 | + sum_abs_err += err.abs(); |
| 311 | + sum_sq_err += err * err; |
| 312 | + } |
| 313 | + |
| 314 | + let mae = sum_abs_err / n_compare as f64; |
| 315 | + let rmse = (sum_sq_err / n_compare as f64).sqrt(); |
| 316 | + |
| 317 | + row("MAE (ms)", format!("{:.6}", mae)); |
| 318 | + row("RMSE (ms)", format!("{:.6}", rmse)); |
| 319 | + row("Mean actual gap (ms)", format!("{:.6}", actual_gaps.iter().sum::<f64>() / n_compare as f64)); |
| 320 | + row("Mean forecast gap (ms)", format!("{:.6}", forecast_gaps.iter().sum::<f64>() / n_compare as f64)); |
| 321 | + |
| 322 | + // ── 11. Compensator at boundary (diagnostic) ──────────────────── |
| 323 | + separator("11. Compensator Diagnostic"); |
| 324 | + |
| 325 | + let t_end = *train_events_ms.last().unwrap(); |
| 326 | + let comp_end = compensator(mle.mu, mle.alpha, mle.beta, &train_events_ms, t_end); |
| 327 | + let expected_comp = (train_events_ms.len() - 1) as f64; // under correct model, Λ(T) ≈ n-1 |
| 328 | + |
| 329 | + row("Λ(T) at last train event", format!("{:.4}", comp_end)); |
| 330 | + row("Expected (n-1)", format!("{}", train_events_ms.len() - 1)); |
| 331 | + row( |
| 332 | + "Ratio Λ(T)/(n-1)", |
| 333 | + format!("{:.4}", comp_end / expected_comp), |
| 334 | + ); |
| 335 | + |
| 336 | + println!("\n Under correct specification, Λ(T)/(n−1) ≈ 1.0"); |
| 337 | + |
| 338 | + separator("Done"); |
| 339 | + println!(" Example completed successfully.\n"); |
| 340 | +} |
0 commit comments