Skip to content

Commit d437078

Browse files
committed
starknet_transaction_prover: add CORS validation unit tests
Covers build_cors_layer (empty/wildcard/allowlist branches), cors_mode labelling, and normalize_cors_allow_origins (scheme/host/userinfo/path/ query rejection, default-port stripping, deduplication of equivalent origins, and wildcard collapsing the rest of the list).
1 parent 63d5c0b commit d437078

2 files changed

Lines changed: 56 additions & 0 deletions

File tree

crates/starknet_transaction_prover/src/server.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ pub mod tls;
3939

4040
pub use health::{HealthLayer, HEALTH_PATH};
4141

42+
#[cfg(test)]
43+
mod cors_test;
4244
#[cfg(test)]
4345
mod rpc_spec_test;
4446

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
use assert_matches::assert_matches;
2+
use rstest::rstest;
3+
4+
use crate::errors::ConfigError;
5+
use crate::server::cors::{build_cors_layer, cors_mode, normalize_cors_allow_origins};
6+
7+
#[rstest]
8+
#[case::empty(&[], false)]
9+
#[case::wildcard(&["*"], true)]
10+
#[case::allowlist(&["http://example.com"], true)]
11+
fn test_build_cors_layer(#[case] origins: &[&str], #[case] expect_layer: bool) {
12+
let origins: Vec<String> = origins.iter().map(|s| s.to_string()).collect();
13+
let layer = build_cors_layer(&origins).unwrap();
14+
assert_eq!(layer.is_some(), expect_layer);
15+
}
16+
17+
#[rstest]
18+
#[case::disabled(&[], "disabled")]
19+
#[case::wildcard(&["*"], "wildcard")]
20+
#[case::allowlist(&["http://example.com"], "allowlist")]
21+
#[case::multiple_origins(&["http://a.com", "http://b.com"], "allowlist")]
22+
fn test_cors_mode_labels(#[case] origins: &[&str], #[case] expected_label: &str) {
23+
let origins: Vec<String> = origins.iter().map(|s| s.to_string()).collect();
24+
assert_eq!(cors_mode(&origins), expected_label);
25+
}
26+
27+
#[rstest]
28+
#[case::ftp_scheme(&["ftp://example.com"])]
29+
#[case::missing_host(&["http://"])]
30+
#[case::userinfo(&["http://user:pass@example.com"])]
31+
#[case::path(&["http://example.com/path"])]
32+
#[case::query(&["http://example.com?q=1"])]
33+
fn test_normalize_rejects_invalid_origin(#[case] origins: &[&str]) {
34+
let origins: Vec<String> = origins.iter().map(|s| s.to_string()).collect();
35+
assert_matches!(normalize_cors_allow_origins(origins), Err(ConfigError::InvalidArgument(_)));
36+
}
37+
38+
#[rstest]
39+
#[case::strip_http_default_port(&["http://example.com:80"], &["http://example.com"])]
40+
#[case::strip_https_default_port(&["https://example.com:443"], &["https://example.com"])]
41+
#[case::preserve_non_default_port(&["http://example.com:8080"], &["http://example.com:8080"])]
42+
#[case::dedup_equivalent_origins(
43+
&["http://example.com", "http://example.com:80"],
44+
&["http://example.com"],
45+
)]
46+
#[case::wildcard_collapses_others(
47+
&["http://example.com", "*", "https://foo.bar"],
48+
&["*"],
49+
)]
50+
fn test_normalize_valid_origin(#[case] input: &[&str], #[case] expected: &[&str]) {
51+
let input: Vec<String> = input.iter().map(|s| s.to_string()).collect();
52+
let expected: Vec<String> = expected.iter().map(|s| s.to_string()).collect();
53+
assert_eq!(normalize_cors_allow_origins(input).unwrap(), expected);
54+
}

0 commit comments

Comments
 (0)