Skip to content

Commit 57265a4

Browse files
committed
inter-arrival model
1 parent 3d4ea6f commit 57265a4

File tree

16 files changed

+1866
-29
lines changed

16 files changed

+1866
-29
lines changed

atelier-quant/Cargo.toml

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,32 @@ path = "src/lib.rs"
2525

2626
[[example]]
2727
name = "eg_inter_ats"
28-
path = "examples/eg_intera_ats.rs"
28+
path = "examples/eg_inter_ats.rs"
29+
30+
[[example]]
31+
name = "eg_hawkes_ob_arrivals"
32+
path = "examples/eg_hawkes_ob_arrivals.rs"
33+
34+
[[test]]
35+
name = "test_extracts"
36+
path = "tests/arrivals/test_extracts.rs"
37+
38+
[[test]]
39+
name = "test_inter"
40+
path = "tests/arrivals/test_inter.rs"
2941

3042
[[test]]
31-
name = "test_inter-ats"
32-
path = "tests/base/test_inter_ats.rs"
43+
name = "test_inter_stats"
44+
path = "tests/arrivals/test_inter_stats.rs"
45+
46+
# --- Model: Hawkes --- #
47+
48+
[[test]]
49+
name = "test_model_estimation_hawkes"
50+
path = "tests/hawkes/test_estimation.rs"
3351

3452
[dependencies]
3553
approx = { version = "0.5.1" }
36-
atelier_data = { version = "0.0.13", features = ["parquet"] }
3754
chrono = { version = "0.4", features = ["serde"] }
3855
futures = { version = "0.3" }
3956
lazy_static = { version = "1.4" }
@@ -47,6 +64,11 @@ thiserror = { version = "1.0.64" }
4764
tokio = { version = "1", features = ["full"] }
4865
toml = { version = "0.8" }
4966

67+
[dependencies.atelier_data]
68+
git = "https://github.com/IteraLabs/atelier-rs/"
69+
branch = "feature/9-inter-arrival-ts-modeling"
70+
features = ["parquet"]
71+
5072
[lints.rust]
5173
trivial_casts = "warn"
5274
trivial_numeric_casts = "warn"
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
//! # End-to-end Hawkes process example with real orderbook data
2+
//!
3+
//! Pipeline:
4+
//!
5+
//! 1. Load a Bybit SOLUSDT orderbook parquet file.
6+
//! 2. Extract timestamps → treat each snapshot as an "arrival".
7+
//! 3. Convert to milliseconds, compute interarrival times.
8+
//! 4. Hold out the last 10 arrivals as a test set.
9+
//! 5. Print descriptive statistics and gap diagnostics on the training set.
10+
//! 6. Fit a univariate Hawkes process (MLE) on the training interarrivals.
11+
//! 7. Use the fitted model to forecast the next 10 arrival times.
12+
//! 8. Compare forecast vs. actual, print the diff.
13+
//!
14+
//! ```text
15+
//! cargo run -p atelier_quant --example eg_hawkes_ob_arrivals
16+
//! ```
17+
18+
use std::path::Path;
19+
20+
use atelier_data::orderbooks::io::ob_parquet::load_parquet_to_ob;
21+
use atelier_data::temporal::{self, TimeResolution};
22+
23+
use atelier_quant::arrivals::extract::extract_orderbook_timestamps;
24+
use atelier_quant::arrivals::inter::{compute_interarrivals, descriptive_stats};
25+
use atelier_quant::hawkes::estimation::{
26+
compensator, estimate_hawkes_mle, time_rescaling_residuals, HawkesEstimationConfig,
27+
};
28+
use atelier_quant::hawkes::HawkesProcess;
29+
30+
// ── Helpers ─────────────────────────────────────────────────────────
31+
32+
/// Pretty-print a horizontal separator.
33+
fn separator(label: &str) {
34+
println!("\n{:═^72}", format!(" {} ", label));
35+
}
36+
37+
/// Print a small table row.
38+
fn row(label: &str, value: impl std::fmt::Display) {
39+
println!(" {:<30} {}", label, value);
40+
}
41+
42+
// ── Main ────────────────────────────────────────────────────────────
43+
44+
fn main() {
45+
// ── 1. Load orderbook parquet ───────────────────────────────────
46+
separator("1. Load Parquet");
47+
48+
let parquet_path = Path::new(
49+
"datasets/collected/bybit/SOLUSDT/orderbooks/ob_bybit_20260217_003728.819.parquet",
50+
);
51+
52+
println!(" File: {}", parquet_path.display());
53+
54+
let orderbooks = load_parquet_to_ob(parquet_path).unwrap_or_else(|e| {
55+
eprintln!(" ERROR: Failed to load parquet: {}", e);
56+
std::process::exit(1);
57+
});
58+
59+
println!(" Loaded {} orderbook snapshots", orderbooks.len());
60+
61+
if orderbooks.len() < 12 {
62+
eprintln!(" ERROR: Need at least 12 snapshots (10 test + 2 train), got {}", orderbooks.len());
63+
std::process::exit(1);
64+
}
65+
66+
// ── 2. Extract timestamps (nanoseconds) ─────────────────────────
67+
separator("2. Extract Timestamps");
68+
69+
let timestamps_ns = extract_orderbook_timestamps(&orderbooks);
70+
println!(" Total arrivals: {}", timestamps_ns.len());
71+
println!(
72+
" First ts (ns): {}",
73+
timestamps_ns.first().unwrap_or(&0)
74+
);
75+
println!(
76+
" Last ts (ns): {}",
77+
timestamps_ns.last().unwrap_or(&0)
78+
);
79+
80+
// Convert to milliseconds for display
81+
let first_ms = temporal::from_nanos(*timestamps_ns.first().unwrap(), TimeResolution::Milliseconds);
82+
let last_ms = temporal::from_nanos(*timestamps_ns.last().unwrap(), TimeResolution::Milliseconds);
83+
let span_s = (last_ms - first_ms) / 1000.0;
84+
println!(" Observation window: {:.3} seconds ({:.1} ms)", span_s, last_ms - first_ms);
85+
86+
// ── 3. Validate and detect gaps ─────────────────────────────────
87+
separator("3. Validation & Gap Detection");
88+
89+
// Validate monotonicity
90+
let mut is_monotonic = true;
91+
for i in 1..timestamps_ns.len() {
92+
if timestamps_ns[i] <= timestamps_ns[i - 1] {
93+
eprintln!(
94+
" ✗ Monotonicity violation at index {}: {} <= {}",
95+
i, timestamps_ns[i], timestamps_ns[i - 1]
96+
);
97+
is_monotonic = false;
98+
break;
99+
}
100+
}
101+
if is_monotonic {
102+
println!("Timestamps are strictly monotonic");
103+
} else {
104+
std::process::exit(1);
105+
}
106+
107+
// Detect gaps > 5 seconds (5_000_000_000 ns) — likely feed disconnects
108+
let gap_threshold_ns = 5_000_000_000_u64;
109+
let mut n_gaps = 0_usize;
110+
for i in 1..timestamps_ns.len() {
111+
let gap = timestamps_ns[i] - timestamps_ns[i - 1];
112+
if gap > gap_threshold_ns {
113+
if n_gaps == 0 {
114+
println!(" ⚠ Gaps exceeding {:.1}s:", gap_threshold_ns as f64 / 1e9);
115+
}
116+
println!(" index {}: gap = {:.3} ms", i - 1, gap as f64 / 1e6);
117+
n_gaps += 1;
118+
}
119+
}
120+
if n_gaps == 0 {
121+
println!(" ✓ No gaps exceeding {:.1}s detected", gap_threshold_ns as f64 / 1e9);
122+
} else {
123+
println!(" Total large gaps: {}", n_gaps);
124+
}
125+
126+
// ── 4. Train/test split ─────────────────────────────────────────
127+
separator("4. Train / Test Split");
128+
129+
let n_test = 10;
130+
let n_total = timestamps_ns.len();
131+
let n_train = n_total - n_test;
132+
133+
let train_ts = &timestamps_ns[..n_train];
134+
let test_ts = &timestamps_ns[n_train..];
135+
136+
println!(" Training set: {} arrivals", train_ts.len());
137+
println!(" Test set: {} arrivals", test_ts.len());
138+
139+
// ── 5. Compute interarrivals + stats (training set) ─────────────
140+
separator("5. Interarrival Statistics (Training)");
141+
142+
let ia_result =
143+
compute_interarrivals(train_ts, TimeResolution::Milliseconds).unwrap_or_else(|e| {
144+
eprintln!(" ERROR: {}", e);
145+
std::process::exit(1);
146+
});
147+
148+
let stats = descriptive_stats(&ia_result.deltas_f64).unwrap();
149+
150+
row("Count (gaps)", format!("{}", stats.count));
151+
row("Mean (ms)", format!("{:.6}", stats.mean));
152+
row("Std dev (ms)", format!("{:.6}", stats.std_dev));
153+
row("Variance (ms²)", format!("{:.6}", stats.variance));
154+
row("Min (ms)", format!("{:.6}", stats.min));
155+
row("Max (ms)", format!("{:.6}", stats.max));
156+
row("Skewness", format!("{:.4}", stats.skewness));
157+
row("Excess kurtosis", format!("{:.4}", stats.kurtosis));
158+
row("CV (σ/μ)", format!("{:.4}", stats.covariance));
159+
160+
if stats.covariance > 1.0 {
161+
println!("\n → CV > 1 indicates clustering (super-Poisson), consistent with Hawkes excitation.");
162+
} else if (stats.covariance - 1.0).abs() < 0.15 {
163+
println!("\n → CV ≈ 1 suggests near-Poisson (memoryless) arrivals.");
164+
} else {
165+
println!("\n → CV < 1 indicates regularity (sub-Poisson), less common for LOB data.");
166+
}
167+
168+
// ── 6. Fit Hawkes MLE ───────────────────────────────────────────
169+
separator("6. Hawkes MLE Estimation");
170+
171+
// Build event times in milliseconds relative to the first arrival.
172+
// This keeps numbers in a reasonable range for the optimizer.
173+
let t0_ns = train_ts[0];
174+
let train_events_ms: Vec<f64> = train_ts
175+
.iter()
176+
.map(|&t| temporal::from_nanos(t - t0_ns, TimeResolution::Milliseconds))
177+
.collect();
178+
179+
let config = HawkesEstimationConfig {
180+
max_iter: 10_000,
181+
tol: 1e-4,
182+
learning_rate: 1e-3,
183+
initial_params: None,
184+
};
185+
186+
let mle = estimate_hawkes_mle(&train_events_ms, &config).unwrap_or_else(|e| {
187+
eprintln!(" ERROR: MLE failed: {}", e);
188+
std::process::exit(1);
189+
});
190+
191+
row("μ̂ (events/ms)", format!("{:.8}", mle.mu));
192+
row("α̂ (excitation)", format!("{:.8}", mle.alpha));
193+
row("β̂ (decay 1/ms)", format!("{:.8}", mle.beta));
194+
row("Branching ratio α̂/β̂", format!("{:.6}", mle.branching_ratio));
195+
row("Log-likelihood", format!("{:.4}", mle.log_likelihood));
196+
row("AIC", format!("{:.4}", mle.aic));
197+
row("BIC", format!("{:.4}", mle.bic));
198+
row("Iterations", format!("{}", mle.iterations));
199+
row("Converged", format!("{}", mle.converged));
200+
201+
let theoretical_rate = mle.mu / (1.0 - mle.branching_ratio);
202+
row("Stationary rate (ev/ms)", format!("{:.8}", theoretical_rate));
203+
row(
204+
"Stationary mean gap (ms)",
205+
format!("{:.6}", 1.0 / theoretical_rate),
206+
);
207+
208+
// ── 7. Goodness-of-fit: time-rescaling residuals ────────────────
209+
separator("7. Goodness-of-Fit (Time-Rescaling)");
210+
211+
let residuals =
212+
time_rescaling_residuals(mle.mu, mle.alpha, mle.beta, &train_events_ms);
213+
214+
let res_stats = descriptive_stats(&residuals);
215+
if let Some(rs) = &res_stats {
216+
row("Residuals count", format!("{}", rs.count));
217+
row("Residuals mean", format!("{:.6}", rs.mean));
218+
row("Residuals std dev", format!("{:.6}", rs.std_dev));
219+
println!(
220+
"\n Under correct specification, residuals ~ Exp(1): mean ≈ 1.0, std ≈ 1.0"
221+
);
222+
if (rs.mean - 1.0).abs() < 0.3 {
223+
println!(" → Mean {:.3} is within 30% of 1.0: reasonable fit.", rs.mean);
224+
} else {
225+
println!(" → Mean {:.3} deviates from 1.0: model may be mis-specified.", rs.mean);
226+
}
227+
}
228+
229+
// ── 8. Forecast next 10 arrivals ────────────────────────────────
230+
separator("8. Forecast Next 10 Arrivals");
231+
232+
// Strategy: use the compensator to convert from Hawkes time to
233+
// calendar time. We simulate from the fitted model starting at the
234+
// last training event.
235+
let last_train_ms = *train_events_ms.last().unwrap();
236+
237+
// Build a HawkesProcess with the fitted parameters and simulate
238+
let hp = HawkesProcess::new(mle.mu, mle.alpha, mle.beta).unwrap_or_else(|e| {
239+
eprintln!(" ERROR: Could not create HawkesProcess: {:?}", e);
240+
std::process::exit(1);
241+
});
242+
243+
// We simulate events continuing from the last training time.
244+
// The intensity at the boundary depends on the full training history,
245+
// so we pass the last training time as the start.
246+
let forecasted_events_ms = hp.generate_values(last_train_ms, n_test);
247+
248+
// Convert forecasted absolute times (ms relative to t0) back to
249+
// nanosecond timestamps: ms × 1_000_000 = ns, then add origin.
250+
let _forecasted_ts_ns: Vec<u64> = forecasted_events_ms
251+
.iter()
252+
.map(|&t_ms| t0_ns + (t_ms * 1_000_000.0) as u64)
253+
.collect();
254+
255+
// ── 9. Compare forecast vs actual ───────────────────────────────
256+
separator("9. Forecast vs Actual Comparison");
257+
258+
// Compute interarrival gaps for both series
259+
// For actual: gaps between last train event and each subsequent test event
260+
let actual_arrivals_ms: Vec<f64> = test_ts
261+
.iter()
262+
.map(|&t| temporal::from_nanos(t - t0_ns, TimeResolution::Milliseconds))
263+
.collect();
264+
265+
let forecast_arrivals_ms: Vec<f64> = forecasted_events_ms.clone();
266+
267+
// Compute cumulative interarrivals from the last training point
268+
let actual_gaps: Vec<f64> = actual_arrivals_ms
269+
.iter()
270+
.map(|&t| t - last_train_ms)
271+
.collect();
272+
273+
let forecast_gaps: Vec<f64> = forecast_arrivals_ms
274+
.iter()
275+
.map(|&t| t - last_train_ms)
276+
.collect();
277+
278+
println!(
279+
" {:>4} {:>16} {:>16} {:>16}",
280+
"i", "Actual Δt (ms)", "Forecast Δt (ms)", "Diff (ms)"
281+
);
282+
println!(" {}", "─".repeat(68));
283+
284+
for i in 0..n_test {
285+
let actual = actual_gaps[i];
286+
let forecast = if i < forecast_gaps.len() {
287+
forecast_gaps[i]
288+
} else {
289+
f64::NAN
290+
};
291+
let diff = actual - forecast;
292+
println!(
293+
" {:>4} {:>16.4} {:>16.4} {:>16.4}",
294+
i + 1,
295+
actual,
296+
forecast,
297+
diff
298+
);
299+
}
300+
301+
// ── 10. Summary error metrics ───────────────────────────────────
302+
separator("10. Forecast Error Metrics");
303+
304+
let n_compare = n_test.min(forecast_gaps.len());
305+
let mut sum_abs_err = 0.0_f64;
306+
let mut sum_sq_err = 0.0_f64;
307+
308+
for i in 0..n_compare {
309+
let err = actual_gaps[i] - forecast_gaps[i];
310+
sum_abs_err += err.abs();
311+
sum_sq_err += err * err;
312+
}
313+
314+
let mae = sum_abs_err / n_compare as f64;
315+
let rmse = (sum_sq_err / n_compare as f64).sqrt();
316+
317+
row("MAE (ms)", format!("{:.6}", mae));
318+
row("RMSE (ms)", format!("{:.6}", rmse));
319+
row("Mean actual gap (ms)", format!("{:.6}", actual_gaps.iter().sum::<f64>() / n_compare as f64));
320+
row("Mean forecast gap (ms)", format!("{:.6}", forecast_gaps.iter().sum::<f64>() / n_compare as f64));
321+
322+
// ── 11. Compensator at boundary (diagnostic) ────────────────────
323+
separator("11. Compensator Diagnostic");
324+
325+
let t_end = *train_events_ms.last().unwrap();
326+
let comp_end = compensator(mle.mu, mle.alpha, mle.beta, &train_events_ms, t_end);
327+
let expected_comp = (train_events_ms.len() - 1) as f64; // under correct model, Λ(T) ≈ n-1
328+
329+
row("Λ(T) at last train event", format!("{:.4}", comp_end));
330+
row("Expected (n-1)", format!("{}", train_events_ms.len() - 1));
331+
row(
332+
"Ratio Λ(T)/(n-1)",
333+
format!("{:.4}", comp_end / expected_comp),
334+
);
335+
336+
println!("\n Under correct specification, Λ(T)/(n−1) ≈ 1.0");
337+
338+
separator("Done");
339+
println!(" Example completed successfully.\n");
340+
}

0 commit comments

Comments
 (0)