11use memmap2:: Mmap ;
22use pyo3:: prelude:: * ;
3- use pyo3:: types:: PyBytes ;
43use rayon:: prelude:: * ;
54use std:: fs:: File ;
5+ use std:: net:: IpAddr ;
66
77#[ pyclass]
88struct IpSearcher {
@@ -15,113 +15,111 @@ impl IpSearcher {
1515 fn new ( path : String ) -> PyResult < Self > {
1616 let file = File :: open ( path) ?;
1717 let mmap = unsafe { Mmap :: map ( & file) ? } ;
18+ if mmap. len ( ) % Self :: NODE_SIZE != 0 {
19+ return Err ( PyErr :: new :: < pyo3:: exceptions:: PyValueError , _ > (
20+ "Invalid bin file: alignment mismatch (expected 72)." ,
21+ ) ) ;
22+ }
1823 Ok ( IpSearcher { mmap } )
1924 }
2025
2126 /// 核心查询逻辑:支持 IPv4 (4字节) 和 IPv6 (16字节)
2227 fn is_china_ip ( & self , ip_bytes : & [ u8 ] ) -> bool {
23- let mut curr_ptr: usize = 0 ;
24- let node_size: usize = 68 ;
28+ let mut cursor: usize = 0 ;
2529
2630 for & byte in ip_bytes {
2731 // 节点布局: 0-31 ChildBitmap, 32-63 LeafBitmap, 64-67 BaseOffset
28- let child_bm = & self . mmap [ curr_ptr..curr_ptr + 32 ] ;
29- let leaf_bm = & self . mmap [ curr_ptr + 32 ..curr_ptr + 64 ] ;
32+ let child_bitmap = & self . mmap [ cursor..cursor + 32 ] ;
33+ let leaf_bitmap = & self . mmap [ cursor + 32 ..cursor + 64 ] ;
3034
3135 // 1. 检查当前步长是否匹配 (Leaf)
32- if self . check_bit ( leaf_bm , byte) {
36+ if self . check_bit ( leaf_bitmap , byte) {
3337 return true ;
3438 }
3539
3640 // 2. 检查是否有子节点
37- if !self . check_bit ( child_bm , byte) {
41+ if !self . check_bit ( child_bitmap , byte) {
3842 return false ;
3943 }
4044
4145 // 3. 计算跳转偏移 (Popcount)
4246 // 读取 4 字节的 BaseOffset (小端序)
4347 let base_offset =
44- u32:: from_le_bytes ( self . mmap [ curr_ptr + 64 ..curr_ptr + 68 ] . try_into ( ) . unwrap ( ) )
48+ u32:: from_le_bytes ( self . mmap [ cursor + 64 ..cursor + 68 ] . try_into ( ) . unwrap ( ) )
4549 as usize ;
4650
4751 // 获取当前字节之前的 '1' 的数量,确定子节点索引
48- let index = self . get_popcount ( child_bm , byte) ;
49- curr_ptr = base_offset + ( index * node_size ) ;
52+ let index = self . get_popcount ( child_bitmap , byte) ;
53+ cursor = base_offset + ( index * Self :: NODE_SIZE ) ;
5054 }
5155 false
5256 }
5357
54- // 这里的 Bound<'_, PyBytes> 允许我们直接访问 Python 的内存
55- fn batch_check ( & self , ip_list : Vec < Bound < ' _ , PyBytes > > ) -> Vec < bool > {
56- ip_list
57- . into_iter ( )
58- . map ( |py_bytes| {
59- // as_bytes() 返回 &[u8],不需要拷贝数据
60- self . is_china_ip ( py_bytes. as_bytes ( ) )
61- } )
62- . collect ( )
63- }
64-
65- /// 极致性能版:接收一个扁平化的字节流(每 4 或 16 字节代表一个 IP)
58+ /// 极致性能版:接收扁平化字节流(每 4 或 16 字节代表一个 IP)
6659 fn batch_check_packed ( & self , packed_ips : & [ u8 ] , is_v6 : bool ) -> Vec < bool > {
67- let stride = if is_v6 { 16 } else { 4 } ;
60+ let ip_stride = if is_v6 { 16 } else { 4 } ;
6861
6962 // 使用 chunks_exact 确保每次切出固定长度的 IP 字节块
7063 // 这是极致性能的关键:内存完全连续,没有 Python 对象开销
7164 packed_ips
72- . chunks_exact ( stride )
65+ . chunks_exact ( ip_stride )
7366 . map ( |ip_chunk| self . is_china_ip ( ip_chunk) )
7467 . collect ( )
7568 }
7669
77- fn batch_check_packed_parallel (
78- & self ,
79- py : Python < ' _ > ,
80- packed_ips : & [ u8 ] ,
81- is_v6 : bool ,
82- ) -> Vec < bool > {
83- let stride = if is_v6 { 16 } else { 4 } ;
84-
85- // par_chunks_exact 是 Rayon 提供的并行切片方法
70+ fn batch_check_strings ( & self , py : Python < ' _ > , ips : Vec < String > ) -> Vec < bool > {
8671 py. allow_threads ( || {
87- packed_ips
88- . par_chunks_exact ( stride)
89- . map ( |ip_chunk| self . is_china_ip ( ip_chunk) )
72+ ips. into_par_iter ( )
73+ . map ( |ip_str| match ip_str. parse :: < IpAddr > ( ) {
74+ Ok ( IpAddr :: V4 ( v4) ) => self . is_china_ip ( & v4. octets ( ) ) ,
75+ Ok ( IpAddr :: V6 ( v6) ) => self . is_china_ip ( & v6. octets ( ) ) ,
76+ Err ( _) => false ,
77+ } )
9078 . collect ( )
9179 } )
9280 }
9381}
9482
9583impl IpSearcher {
84+ const NODE_SIZE : usize = 72 ;
85+
9686 #[ inline]
9787 fn check_bit ( & self , bitmap : & [ u8 ] , byte : u8 ) -> bool {
98- let idx = ( byte >> 3 ) as usize ;
99- let bit = 7 - ( byte % 8 ) ; // 对应 Python 的 (1 << (7 - (k % 8)))
100- ( bitmap[ idx ] >> bit ) & 1 == 1
88+ let byte_index = ( byte >> 3 ) as usize ;
89+ let bit_index = 7 - ( byte % 8 ) ; // 对应 Python 的 (1 << (7 - (k % 8)))
90+ ( bitmap[ byte_index ] >> bit_index ) & 1 == 1
10191 }
10292
10393 #[ inline]
10494 fn get_popcount ( & self , bitmap : & [ u8 ] , byte : u8 ) -> usize {
105- let byte_idx = ( byte >> 3 ) as usize ;
106- let bit_in_byte = 7 - ( byte % 8 ) ;
107- let mut count = 0 ;
108-
109- // 1. 累加之前所有字节中 1 的个数
110- for i in 0 ..byte_idx {
95+ let byte_index = ( byte >> 3 ) as usize ;
96+ let bit_index = 7 - ( byte % 8 ) ;
97+ let mut count: usize = 0 ;
98+
99+ // 1. 以 u64 为单位统计,减少循环次数
100+ let full_byte_count = byte_index;
101+ let chunk_count = full_byte_count / 8 ;
102+ for i in 0 ..chunk_count {
103+ let start = i * 8 ;
104+ let value =
105+ u64:: from_le_bytes ( bitmap[ start..start + 8 ] . try_into ( ) . unwrap ( ) ) ;
106+ count += value. count_ones ( ) as usize ;
107+ }
108+ for i in ( chunk_count * 8 ) ..full_byte_count {
111109 count += bitmap[ i] . count_ones ( ) as usize ;
112110 }
113111
114112 // 2. 累加当前字节中,目标位“左侧”所有 1 的个数
115113 // 我们需要一个掩码来保留比 bit_in_byte 更高的位
116114 // 例如:如果 bit_in_byte 是 5 (二进制 00100000),
117115 // 我们需要掩码 11000000 来计算它之前的 1
118- let mask = if bit_in_byte == 7 {
116+ let mask = if bit_index == 7 {
119117 0
120118 } else {
121- 0xFF << ( bit_in_byte + 1 )
119+ 0xFF << ( bit_index + 1 )
122120 } ;
123121
124- count += ( bitmap[ byte_idx ] & mask) . count_ones ( ) as usize ;
122+ count += ( bitmap[ byte_index ] & mask) . count_ones ( ) as usize ;
125123
126124 // 返回值即为该子节点在子节点数组中的索引(从 0 开始)
127125 count
0 commit comments