@@ -96,6 +96,40 @@ fn validate_interval_ms(value: i64, param_name: &str) -> PyResult<u64> {
9696 Ok ( value as u64 )
9797}
9898
99+ /// 验证 count 参数并转换为 usize
100+ fn validate_count ( count : i32 , param_name : & str ) -> PyResult < usize > {
101+ if count <= 0 {
102+ return Err ( PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > ( format ! (
103+ "{} ({}) must be a positive integer" ,
104+ param_name, count
105+ ) ) ) ;
106+ }
107+ Ok ( count as usize )
108+ }
109+
110+ /// 验证 timeout_ms 参数并转换为 Duration
111+ ///
112+ /// 如果 timeout_ms 为 None,返回 None
113+ /// 否则验证 timeout_ms 必须大于等于 interval_ms
114+ fn validate_timeout_ms ( timeout_ms : Option < i64 > , interval_ms : u64 , param_name : & str ) -> PyResult < Option < Duration > > {
115+ match timeout_ms {
116+ Some ( timeout) => {
117+ let timeout_ms_u64 = validate_interval_ms ( timeout, param_name) ?;
118+
119+ // 确保 timeout_ms 大于 interval_ms
120+ if timeout_ms_u64 < interval_ms {
121+ return Err ( PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > ( format ! (
122+ "{} ({} ms) must be greater than or equal to interval_ms ({} ms)" ,
123+ param_name, timeout_ms_u64, interval_ms
124+ ) ) ) ;
125+ }
126+
127+ Ok ( Some ( Duration :: from_millis ( timeout_ms_u64) ) )
128+ }
129+ None => Ok ( None ) ,
130+ }
131+ }
132+
99133/// 从 Python 对象中提取 IP 地址字符串
100134fn extract_target ( target : & Bound < PyAny > ) -> PyResult < String > {
101135 // 首先尝试直接提取为 IpAddr(包含 IPv4 和 IPv6)
@@ -380,21 +414,10 @@ impl Pinger {
380414 #[ pyo3( signature = ( count=4 , timeout_ms=None ) ) ]
381415 fn ping_multiple ( & self , count : i32 , timeout_ms : Option < i64 > ) -> PyResult < Vec < PingResult > > {
382416 // 验证 count 参数
383- if count <= 0 {
384- return Err ( PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > ( format ! (
385- "count ({}) must be a positive integer" ,
386- count
387- ) ) ) ;
388- }
389- let count = count as usize ;
417+ let count = validate_count ( count, "count" ) ?;
390418
391419 // 验证 timeout_ms 参数
392- let timeout = if let Some ( timeout) = timeout_ms {
393- let timeout_ms_u64 = validate_interval_ms ( timeout, "timeout_ms" ) ?;
394- Some ( Duration :: from_millis ( timeout_ms_u64) )
395- } else {
396- None
397- } ;
420+ let timeout = validate_timeout_ms ( timeout_ms, self . interval_ms , "timeout_ms" ) ?;
398421
399422 let options = create_ping_options (
400423 & self . target ,
@@ -450,21 +473,10 @@ impl Pinger {
450473 timeout_ms : Option < i64 > ,
451474 ) -> PyResult < Bound < ' py , PyAny > > {
452475 // 验证 count 参数
453- if count <= 0 {
454- return Err ( PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > ( format ! (
455- "count ({}) must be a positive integer" ,
456- count
457- ) ) ) ;
458- }
459- let count = count as usize ;
476+ let count = validate_count ( count, "count" ) ?;
460477
461478 // 验证 timeout_ms 参数
462- let timeout = if let Some ( timeout) = timeout_ms {
463- let timeout_ms_u64 = validate_interval_ms ( timeout, "timeout_ms" ) ?;
464- Some ( Duration :: from_millis ( timeout_ms_u64) )
465- } else {
466- None
467- } ;
479+ let timeout = validate_timeout_ms ( timeout_ms, self . interval_ms , "timeout_ms" ) ?;
468480
469481 let target = self . target . clone ( ) ;
470482 let interval_ms = self . interval_ms ;
@@ -570,6 +582,11 @@ impl PingStream {
570582 // 验证 interval_ms 参数
571583 let interval_ms_u64 = validate_interval_ms ( interval_ms, "interval_ms" ) ?;
572584
585+ // 验证 max_count 如果有的话
586+ if let Some ( count) = max_count {
587+ validate_count ( count. try_into ( ) . unwrap ( ) , "max_count" ) ?;
588+ }
589+
573590 // 创建 ping 选项
574591 let options = create_ping_options ( & target_str, interval_ms_u64, interface, ipv4, ipv6) ;
575592
@@ -758,6 +775,11 @@ impl AsyncPingStream {
758775 // 验证 interval_ms 参数
759776 let interval_ms_u64 = validate_interval_ms ( interval_ms, "interval_ms" ) ?;
760777
778+ // 验证 max_count 如果有的话
779+ if let Some ( count) = max_count {
780+ validate_count ( count. try_into ( ) . unwrap ( ) , "max_count" ) ?;
781+ }
782+
761783 // 创建 ping 选项
762784 let options = create_ping_options ( & target_str, interval_ms_u64, interface, ipv4, ipv6) ;
763785
0 commit comments