Skip to content

Commit 032ce8c

Browse files
committed
[Change] Adds validation functions for count and timeout parameters
Implements new validation functions to ensure that the count parameter is a positive integer and that the timeout parameter meets specified conditions relative to the interval. This change enhances error handling and input validation in the ping functionality, improving robustness and user feedback.
1 parent b0e4df8 commit 032ce8c

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,3 @@ strip = false
5151

5252
[target.'cfg(windows)'.dependencies]
5353
winping = { version = "0.10", features = ["async"] }
54-
once_cell = { version = "1.21" }

src/lib.rs

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,40 @@ fn validate_interval_ms(value: i64, param_name: &str) -> PyResult<u64> {
9696
Ok(value as u64)
9797
}
9898

99+
/// 验证 count 参数并转换为 usize
100+
fn validate_count(count: i32, param_name: &str) -> PyResult<usize> {
101+
if count <= 0 {
102+
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
103+
"{} ({}) must be a positive integer",
104+
param_name, count
105+
)));
106+
}
107+
Ok(count as usize)
108+
}
109+
110+
/// 验证 timeout_ms 参数并转换为 Duration
111+
///
112+
/// 如果 timeout_ms 为 None,返回 None
113+
/// 否则验证 timeout_ms 必须大于等于 interval_ms
114+
fn validate_timeout_ms(timeout_ms: Option<i64>, interval_ms: u64, param_name: &str) -> PyResult<Option<Duration>> {
115+
match timeout_ms {
116+
Some(timeout) => {
117+
let timeout_ms_u64 = validate_interval_ms(timeout, param_name)?;
118+
119+
// 确保 timeout_ms 大于 interval_ms
120+
if timeout_ms_u64 < interval_ms {
121+
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
122+
"{} ({} ms) must be greater than or equal to interval_ms ({} ms)",
123+
param_name, timeout_ms_u64, interval_ms
124+
)));
125+
}
126+
127+
Ok(Some(Duration::from_millis(timeout_ms_u64)))
128+
}
129+
None => Ok(None),
130+
}
131+
}
132+
99133
/// 从 Python 对象中提取 IP 地址字符串
100134
fn extract_target(target: &Bound<PyAny>) -> PyResult<String> {
101135
// 首先尝试直接提取为 IpAddr(包含 IPv4 和 IPv6)
@@ -380,21 +414,10 @@ impl Pinger {
380414
#[pyo3(signature = (count=4, timeout_ms=None))]
381415
fn ping_multiple(&self, count: i32, timeout_ms: Option<i64>) -> PyResult<Vec<PingResult>> {
382416
// 验证 count 参数
383-
if count <= 0 {
384-
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
385-
"count ({}) must be a positive integer",
386-
count
387-
)));
388-
}
389-
let count = count as usize;
417+
let count = validate_count(count, "count")?;
390418

391419
// 验证 timeout_ms 参数
392-
let timeout = if let Some(timeout) = timeout_ms {
393-
let timeout_ms_u64 = validate_interval_ms(timeout, "timeout_ms")?;
394-
Some(Duration::from_millis(timeout_ms_u64))
395-
} else {
396-
None
397-
};
420+
let timeout = validate_timeout_ms(timeout_ms, self.interval_ms, "timeout_ms")?;
398421

399422
let options = create_ping_options(
400423
&self.target,
@@ -450,21 +473,10 @@ impl Pinger {
450473
timeout_ms: Option<i64>,
451474
) -> PyResult<Bound<'py, PyAny>> {
452475
// 验证 count 参数
453-
if count <= 0 {
454-
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
455-
"count ({}) must be a positive integer",
456-
count
457-
)));
458-
}
459-
let count = count as usize;
476+
let count = validate_count(count, "count")?;
460477

461478
// 验证 timeout_ms 参数
462-
let timeout = if let Some(timeout) = timeout_ms {
463-
let timeout_ms_u64 = validate_interval_ms(timeout, "timeout_ms")?;
464-
Some(Duration::from_millis(timeout_ms_u64))
465-
} else {
466-
None
467-
};
479+
let timeout = validate_timeout_ms(timeout_ms, self.interval_ms, "timeout_ms")?;
468480

469481
let target = self.target.clone();
470482
let interval_ms = self.interval_ms;
@@ -570,6 +582,11 @@ impl PingStream {
570582
// 验证 interval_ms 参数
571583
let interval_ms_u64 = validate_interval_ms(interval_ms, "interval_ms")?;
572584

585+
// 验证 max_count 如果有的话
586+
if let Some(count) = max_count {
587+
validate_count(count.try_into().unwrap(), "max_count")?;
588+
}
589+
573590
// 创建 ping 选项
574591
let options = create_ping_options(&target_str, interval_ms_u64, interface, ipv4, ipv6);
575592

@@ -758,6 +775,11 @@ impl AsyncPingStream {
758775
// 验证 interval_ms 参数
759776
let interval_ms_u64 = validate_interval_ms(interval_ms, "interval_ms")?;
760777

778+
// 验证 max_count 如果有的话
779+
if let Some(count) = max_count {
780+
validate_count(count.try_into().unwrap(), "max_count")?;
781+
}
782+
761783
// 创建 ping 选项
762784
let options = create_ping_options(&target_str, interval_ms_u64, interface, ipv4, ipv6);
763785

0 commit comments

Comments
 (0)