-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlib.rs
More file actions
147 lines (120 loc) · 3.87 KB
/
lib.rs
File metadata and controls
147 lines (120 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
// Copyright 2025 Massimiliano Pippi
// SPDX-License-Identifier: MIT
use axum::{
extract::State,
http::{Request, Uri},
middleware::{self, Next},
response::Response,
routing::post,
Router,
};
use clap::Parser;
use clap_verbosity_flag::Verbosity;
use colored::Colorize;
use std::net::{IpAddr, SocketAddr};
use std::time::Duration;
use tower_http::timeout::TimeoutLayer;
pub mod chat_completions;
pub mod responses;
pub mod server_state;
use crate::server_state::ServerState;
#[derive(Parser, Clone)]
#[command(name = "roy")]
#[command(version = env!("CARGO_PKG_VERSION"))]
#[command(
about = "A HTTP server compatible with the OpenAI platform format that simulates errors and rate limit data"
)]
pub struct Args {
#[command(flatten)]
pub verbosity: Verbosity,
#[arg(long, help = "Port to listen on", default_value = "8000")]
pub port: u16,
#[arg(long, help = "Address to listen on", default_value = "0.0.0.0")]
pub address: IpAddr,
#[arg(
long,
help = "Length of response (fixed number or range like '10:100')",
default_value = "250"
)]
pub response_length: Option<String>,
#[arg(long, help = "HTTP error code to return")]
pub error_code: Option<u16>,
#[arg(long, help = "Error rate percentage (0-100)")]
pub error_rate: Option<u32>,
#[arg(
long,
help = "Maximum number of requests per minute",
default_value = "500"
)]
pub rpm: u32,
#[arg(
long,
help = "Maximum number of tokens per minute",
default_value = "30000"
)]
pub tpm: u32,
#[arg(
long,
help = "Slowdown in milliseconds (fixed number or range like '10:100')"
)]
pub slowdown: Option<String>,
#[arg(long, help = "Timeout in milliseconds")]
pub timeout: Option<u64>,
}
pub async fn not_found(uri: Uri) -> (axum::http::StatusCode, String) {
log::warn!("Path not found: {}", uri.path());
(axum::http::StatusCode::NOT_FOUND, "Not Found".to_string())
}
async fn shutdown_signal() {
let ctrl_c = tokio::signal::ctrl_c();
#[cfg(unix)]
let mut terminate_signal =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install signal handler");
#[cfg(unix)]
let terminate = terminate_signal.recv();
#[cfg(not(unix))]
// On non-Unix platforms, we just make terminate a pending future, only ctrl_c matters.
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
println!();
log::info!("Signal received, starting graceful shutdown");
}
pub async fn run(args: Args) -> anyhow::Result<()> {
let state = ServerState::new(args.clone());
async fn slowdown(
State(state): State<ServerState>,
req: Request<axum::body::Body>,
next: Next,
) -> Response {
let slowdown = state.get_slodown_ms();
log::debug!("Slowing down request by {}ms", slowdown);
tokio::time::sleep(std::time::Duration::from_millis(slowdown)).await;
next.run(req).await
}
let mut app = Router::new()
.route(
"/v1/chat/completions",
post(chat_completions::chat_completions),
)
.route("/v1/responses", post(responses::responses))
.route_layer(middleware::from_fn_with_state(state.clone(), slowdown))
.fallback(not_found)
.with_state(state);
if let Some(timeout) = args.timeout {
app = app.layer(TimeoutLayer::new(Duration::from_millis(timeout)));
}
let addr = SocketAddr::new(args.address, args.port);
let listener = tokio::net::TcpListener::bind(addr).await?;
println!(
"Roy server running on {}",
format!("http://{}", addr).blue()
);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
Ok(())
}