Skip to content

Commit b545943

Browse files
committed
v0.1.1
1 parent 001d279 commit b545943

File tree

5 files changed

+55
-57
lines changed

5 files changed

+55
-57
lines changed

.github/workflows/release.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ jobs:
5050
python-version:
5151
- "3.13"
5252

53+
env:
54+
PYTHONIOENCODING: "utf-8"
55+
PYTHONUTF8: "1"
56+
5357
steps:
5458
- name: Checkout
5559
uses: actions/checkout@v4

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "poptrie"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
edition = "2021"
55

66
[lib]

build_bin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
class BinBuilder:
6-
NODE_SIZE = 68 # 32字节Child_BM + 32字节Leaf_BM + 4字节Base_Offset
6+
NODE_SIZE = 72 # 32字节Child_BM + 32字节Leaf_BM + 4字节Base_Offset + 4字节Padding
77

88
def __init__(self):
99
# 根节点:{ 'children': {byte: node_dict}, 'is_leaf': bool }
@@ -111,12 +111,12 @@ def save(self, output_path):
111111

112112
if nodes_with_children:
113113
# 记录跳转到下一层这些子节点的起始偏移
114-
final_data.extend(struct.pack("<I", next_layer_start_offset))
114+
final_data.extend(struct.pack("<II", next_layer_start_offset, 0))
115115
next_layer.extend(nodes_with_children)
116116
next_layer_start_offset += len(nodes_with_children) * self.NODE_SIZE
117117
else:
118118
# 没有子节点可跳转
119-
final_data.extend(struct.pack("<I", 0))
119+
final_data.extend(struct.pack("<II", 0, 0))
120120

121121
current_layer = next_layer
122122

src/lib.rs

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use memmap2::Mmap;
22
use pyo3::prelude::*;
3-
use pyo3::types::PyBytes;
43
use rayon::prelude::*;
54
use std::fs::File;
5+
use std::net::IpAddr;
66

77
#[pyclass]
88
struct IpSearcher {
@@ -15,113 +15,111 @@ impl IpSearcher {
1515
fn new(path: String) -> PyResult<Self> {
1616
let file = File::open(path)?;
1717
let mmap = unsafe { Mmap::map(&file)? };
18+
if mmap.len() % Self::NODE_SIZE != 0 {
19+
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
20+
"Invalid bin file: alignment mismatch (expected 72).",
21+
));
22+
}
1823
Ok(IpSearcher { mmap })
1924
}
2025

2126
/// 核心查询逻辑:支持 IPv4 (4字节) 和 IPv6 (16字节)
2227
fn is_china_ip(&self, ip_bytes: &[u8]) -> bool {
23-
let mut curr_ptr: usize = 0;
24-
let node_size: usize = 68;
28+
let mut cursor: usize = 0;
2529

2630
for &byte in ip_bytes {
2731
// 节点布局: 0-31 ChildBitmap, 32-63 LeafBitmap, 64-67 BaseOffset
28-
let child_bm = &self.mmap[curr_ptr..curr_ptr + 32];
29-
let leaf_bm = &self.mmap[curr_ptr + 32..curr_ptr + 64];
32+
let child_bitmap = &self.mmap[cursor..cursor + 32];
33+
let leaf_bitmap = &self.mmap[cursor + 32..cursor + 64];
3034

3135
// 1. 检查当前步长是否匹配 (Leaf)
32-
if self.check_bit(leaf_bm, byte) {
36+
if self.check_bit(leaf_bitmap, byte) {
3337
return true;
3438
}
3539

3640
// 2. 检查是否有子节点
37-
if !self.check_bit(child_bm, byte) {
41+
if !self.check_bit(child_bitmap, byte) {
3842
return false;
3943
}
4044

4145
// 3. 计算跳转偏移 (Popcount)
4246
// 读取 4 字节的 BaseOffset (小端序)
4347
let base_offset =
44-
u32::from_le_bytes(self.mmap[curr_ptr + 64..curr_ptr + 68].try_into().unwrap())
48+
u32::from_le_bytes(self.mmap[cursor + 64..cursor + 68].try_into().unwrap())
4549
as usize;
4650

4751
// 获取当前字节之前的 '1' 的数量,确定子节点索引
48-
let index = self.get_popcount(child_bm, byte);
49-
curr_ptr = base_offset + (index * node_size);
52+
let index = self.get_popcount(child_bitmap, byte);
53+
cursor = base_offset + (index * Self::NODE_SIZE);
5054
}
5155
false
5256
}
5357

54-
// 这里的 Bound<'_, PyBytes> 允许我们直接访问 Python 的内存
55-
fn batch_check(&self, ip_list: Vec<Bound<'_, PyBytes>>) -> Vec<bool> {
56-
ip_list
57-
.into_iter()
58-
.map(|py_bytes| {
59-
// as_bytes() 返回 &[u8],不需要拷贝数据
60-
self.is_china_ip(py_bytes.as_bytes())
61-
})
62-
.collect()
63-
}
64-
65-
/// 极致性能版:接收一个扁平化的字节流(每 4 或 16 字节代表一个 IP)
58+
/// 极致性能版:接收扁平化字节流(每 4 或 16 字节代表一个 IP)
6659
fn batch_check_packed(&self, packed_ips: &[u8], is_v6: bool) -> Vec<bool> {
67-
let stride = if is_v6 { 16 } else { 4 };
60+
let ip_stride = if is_v6 { 16 } else { 4 };
6861

6962
// 使用 chunks_exact 确保每次切出固定长度的 IP 字节块
7063
// 这是极致性能的关键:内存完全连续,没有 Python 对象开销
7164
packed_ips
72-
.chunks_exact(stride)
65+
.chunks_exact(ip_stride)
7366
.map(|ip_chunk| self.is_china_ip(ip_chunk))
7467
.collect()
7568
}
7669

77-
fn batch_check_packed_parallel(
78-
&self,
79-
py: Python<'_>,
80-
packed_ips: &[u8],
81-
is_v6: bool,
82-
) -> Vec<bool> {
83-
let stride = if is_v6 { 16 } else { 4 };
84-
85-
// par_chunks_exact 是 Rayon 提供的并行切片方法
70+
fn batch_check_strings(&self, py: Python<'_>, ips: Vec<String>) -> Vec<bool> {
8671
py.allow_threads(|| {
87-
packed_ips
88-
.par_chunks_exact(stride)
89-
.map(|ip_chunk| self.is_china_ip(ip_chunk))
72+
ips.into_par_iter()
73+
.map(|ip_str| match ip_str.parse::<IpAddr>() {
74+
Ok(IpAddr::V4(v4)) => self.is_china_ip(&v4.octets()),
75+
Ok(IpAddr::V6(v6)) => self.is_china_ip(&v6.octets()),
76+
Err(_) => false,
77+
})
9078
.collect()
9179
})
9280
}
9381
}
9482

9583
impl IpSearcher {
84+
const NODE_SIZE: usize = 72;
85+
9686
#[inline]
9787
fn check_bit(&self, bitmap: &[u8], byte: u8) -> bool {
98-
let idx = (byte >> 3) as usize;
99-
let bit = 7 - (byte % 8); // 对应 Python 的 (1 << (7 - (k % 8)))
100-
(bitmap[idx] >> bit) & 1 == 1
88+
let byte_index = (byte >> 3) as usize;
89+
let bit_index = 7 - (byte % 8); // 对应 Python 的 (1 << (7 - (k % 8)))
90+
(bitmap[byte_index] >> bit_index) & 1 == 1
10191
}
10292

10393
#[inline]
10494
fn get_popcount(&self, bitmap: &[u8], byte: u8) -> usize {
105-
let byte_idx = (byte >> 3) as usize;
106-
let bit_in_byte = 7 - (byte % 8);
107-
let mut count = 0;
108-
109-
// 1. 累加之前所有字节中 1 的个数
110-
for i in 0..byte_idx {
95+
let byte_index = (byte >> 3) as usize;
96+
let bit_index = 7 - (byte % 8);
97+
let mut count: usize = 0;
98+
99+
// 1. 以 u64 为单位统计,减少循环次数
100+
let full_byte_count = byte_index;
101+
let chunk_count = full_byte_count / 8;
102+
for i in 0..chunk_count {
103+
let start = i * 8;
104+
let value =
105+
u64::from_le_bytes(bitmap[start..start + 8].try_into().unwrap());
106+
count += value.count_ones() as usize;
107+
}
108+
for i in (chunk_count * 8)..full_byte_count {
111109
count += bitmap[i].count_ones() as usize;
112110
}
113111

114112
// 2. 累加当前字节中,目标位“左侧”所有 1 的个数
115113
// 我们需要一个掩码来保留比 bit_in_byte 更高的位
116114
// 例如:如果 bit_in_byte 是 5 (二进制 00100000),
117115
// 我们需要掩码 11000000 来计算它之前的 1
118-
let mask = if bit_in_byte == 7 {
116+
let mask = if bit_index == 7 {
119117
0
120118
} else {
121-
0xFF << (bit_in_byte + 1)
119+
0xFF << (bit_index + 1)
122120
};
123121

124-
count += (bitmap[byte_idx] & mask).count_ones() as usize;
122+
count += (bitmap[byte_index] & mask).count_ones() as usize;
125123

126124
// 返回值即为该子节点在子节点数组中的索引(从 0 开始)
127125
count

tests/test_poptrie.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,7 @@ def test_ipv6_boundary(self):
5555

5656
def test_batch_check(self):
5757
ips = ["1.0.1.1", "8.8.8.8", "240e::1", "2001:db8::"]
58-
ip_bytes_list = [
59-
socket.inet_pton(socket.AF_INET6 if ':' in ip else socket.AF_INET, ip)
60-
for ip in ips
61-
]
62-
results = self.searcher.batch_check(ip_bytes_list)
58+
results = self.searcher.batch_check_strings(ips)
6359
self.assertEqual(results, [True, False, True, False])
6460

6561
def test_batch_check_packed(self):

0 commit comments

Comments
 (0)