Skip to content

Commit fe8750f

Browse files
committed
Speed-up html-escaping using jetscii
```text $ cargo bench --bench escape Before the PR: [3.6464 µs 3.6512 µs 3.6564 µs] Impl. without `jetscii`: [3.4837 µs 3.4899 µs 3.4968 µs] Impl with `jetscii`: [2.0264 µs 2.0335 µs 2.0418 µs] ``` Until portable SIMD gets stabilized, I don't think we can do much for non-X86 platforms. And even after it is stabilized, I guess any optimizations should be implemented upstream in memchr and/or jetscii.
1 parent bf03e44 commit fe8750f

File tree

4 files changed

+183
-42
lines changed

4 files changed

+183
-42
lines changed

rinja/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ percent-encoding = { version = "2.1.0", optional = true }
4242
serde = { version = "1.0", optional = true }
4343
serde_json = { version = "1.0", optional = true }
4444

45+
[target.'cfg(any(target_arch = "x86", target_arch = "x86_64"))'.dependencies]
46+
jetscii = "0.5.3"
47+
4548
[dev-dependencies]
4649
criterion = "0.5"
4750

rinja/src/html.rs

Lines changed: 174 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,203 @@
1-
use std::fmt;
2-
use std::num::NonZeroU8;
1+
use std::{fmt, str};
32

3+
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
44
#[allow(unused)]
55
pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, string: &str) -> fmt::Result {
6-
let mut escaped_buf = *b"&#__;";
6+
// Even though [`jetscii`] ships a generic implementation for unsupported platforms,
7+
// it is not well optimized for this case. This implementation should work well enough in
8+
// the meantime, until portable SIMD gets stabilized.
9+
10+
// Instead of testing the platform, we could test the CPU features. But given that the needed
11+
// instruction set SSE 4.2 was introduced in 2008, that it has an 99.61 % availability rate
12+
// in Steam's June 2024 hardware survey, and is a prerequisite to run Windows 11, I don't
13+
// think we need to care.
14+
15+
let mut escaped_buf = ESCAPED_BUF_INIT;
716
let mut last = 0;
817

918
for (index, byte) in string.bytes().enumerate() {
1019
let escaped = match byte {
1120
MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize],
12-
_ => None,
21+
_ => 0,
1322
};
14-
if let Some(escaped) = escaped {
15-
escaped_buf[2] = escaped[0].get();
16-
escaped_buf[3] = escaped[1].get();
17-
fmt.write_str(&string[last..index])?;
18-
fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?;
23+
if escaped != 0 {
24+
[escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes();
25+
write_str_if_nonempty(&mut fmt, &string[last..index])?;
26+
// SAFETY: the content of `escaped_buf` is pure ASCII
27+
fmt.write_str(unsafe {
28+
std::str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN])
29+
})?;
1930
last = index + 1;
2031
}
2132
}
22-
fmt.write_str(&string[last..])
33+
write_str_if_nonempty(&mut fmt, &string[last..])
34+
}
35+
36+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
37+
#[allow(unused)]
38+
pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, mut string: &str) -> fmt::Result {
39+
let jetscii = jetscii::bytes!(b'"', b'&', b'\'', b'<', b'>');
40+
41+
let mut escaped_buf = ESCAPED_BUF_INIT;
42+
loop {
43+
if string.is_empty() {
44+
return Ok(());
45+
}
46+
47+
let found = if string.len() >= 16 {
48+
// Only strings of at least 16 bytes can be escaped using SSE instructions.
49+
match jetscii.find(string.as_bytes()) {
50+
Some(index) => {
51+
let escaped = TABLE.lookup[(string.as_bytes()[index] - MIN_CHAR) as usize];
52+
Some((index, escaped))
53+
}
54+
None => None,
55+
}
56+
} else {
57+
// The small-string fallback of [`jetscii`] is quite slow, so we roll our own
58+
// implementation.
59+
string.as_bytes().iter().find_map(|byte: &u8| {
60+
let escaped = get_escaped(*byte)?;
61+
let index = (byte as *const u8 as usize) - (string.as_ptr() as usize);
62+
Some((index, escaped))
63+
})
64+
};
65+
let Some((index, escaped)) = found else {
66+
return fmt.write_str(string);
67+
};
68+
69+
[escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes();
70+
71+
// SAFETY: index points at an ASCII char in `string`
72+
let front;
73+
(front, string) = unsafe {
74+
(
75+
string.get_unchecked(..index),
76+
string.get_unchecked(index + 1..),
77+
)
78+
};
79+
80+
write_str_if_nonempty(&mut fmt, front)?;
81+
// SAFETY: the content of `escaped_buf` is pure ASCII
82+
fmt.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })?;
83+
}
2384
}
2485

2586
#[allow(unused)]
2687
pub(crate) fn write_escaped_char(mut fmt: impl fmt::Write, c: char) -> fmt::Result {
27-
fmt.write_str(match (c.is_ascii(), c as u8) {
28-
(true, b'"') => "&#34;",
29-
(true, b'&') => "&#38;",
30-
(true, b'\'') => "&#39;",
31-
(true, b'<') => "&#60;",
32-
(true, b'>') => "&#62;",
33-
_ => return fmt.write_char(c),
34-
})
88+
if !c.is_ascii() {
89+
fmt.write_char(c)
90+
} else if let Some(escaped) = get_escaped(c as u8) {
91+
let mut escaped_buf = ESCAPED_BUF_INIT;
92+
[escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes();
93+
// SAFETY: the content of `escaped_buf` is pure ASCII
94+
fmt.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })
95+
} else {
96+
// RATIONALE: `write_char(c)` gets optimized if it is known that `c.is_ascii()`
97+
fmt.write_char(c)
98+
}
3599
}
36100

37-
const MIN_CHAR: u8 = b'"';
38-
const MAX_CHAR: u8 = b'>';
101+
#[inline(always)]
102+
fn get_escaped(byte: u8) -> Option<u16> {
103+
let c = byte.wrapping_sub(MIN_CHAR);
104+
if (c < u32::BITS as u8) && (BITS & (1 << c as u32) != 0) {
105+
Some(TABLE.lookup[c as usize])
106+
} else {
107+
None
108+
}
109+
}
39110

40-
struct Table {
41-
_align: [usize; 0],
42-
lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize],
111+
#[inline(always)]
112+
fn write_str_if_nonempty(output: &mut impl fmt::Write, input: &str) -> fmt::Result {
113+
if !input.is_empty() {
114+
output.write_str(input)
115+
} else {
116+
Ok(())
117+
}
43118
}
44119

45-
const TABLE: Table = {
46-
const fn n(c: u8) -> Option<[NonZeroU8; 2]> {
47-
assert!(MIN_CHAR <= c && c <= MAX_CHAR);
120+
/// List of characters that need HTML escaping, not necessarily in ordinal order.
121+
/// Filling the [`TABLE`] and [`BITS`] constants will ensure that the range of lowest to hightest
122+
/// codepoint wont exceed [`u32::BITS`] (=32) items.
123+
const CHARS: &[u8] = br#""&'<>"#;
48124

49-
let n0 = match NonZeroU8::new(c / 10 + b'0') {
50-
Some(n) => n,
51-
None => panic!(),
52-
};
53-
let n1 = match NonZeroU8::new(c % 10 + b'0') {
54-
Some(n) => n,
55-
None => panic!(),
56-
};
57-
Some([n0, n1])
125+
/// The character with the smallest codepoint that needs HTML escaping.
126+
/// Both [`TABLE`] and [`BITS`] start at this value instead of `0`.
127+
const MIN_CHAR: u8 = {
128+
let mut v = u8::MAX;
129+
let mut i = 0;
130+
while i < CHARS.len() {
131+
if v > CHARS[i] {
132+
v = CHARS[i];
133+
}
134+
i += 1;
135+
}
136+
v
137+
};
138+
139+
#[allow(unused)]
140+
const MAX_CHAR: u8 = {
141+
let mut v = u8::MIN;
142+
let mut i = 0;
143+
while i < CHARS.len() {
144+
if v < CHARS[i] {
145+
v = CHARS[i];
146+
}
147+
i += 1;
58148
}
149+
v
150+
};
151+
152+
struct Table {
153+
_align: [usize; 0],
154+
lookup: [u16; u32::BITS as usize],
155+
}
59156

157+
/// For characters that need HTML escaping, the codepoint formatted as decimal digits,
158+
/// otherwise `b"\0\0"`. Starting at [`MIN_CHAR`].
159+
const TABLE: Table = {
60160
let mut table = Table {
61161
_align: [],
62-
lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize],
162+
lookup: [0; u32::BITS as usize],
63163
};
64-
65-
table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"');
66-
table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&');
67-
table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\'');
68-
table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<');
69-
table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>');
164+
let mut i = 0;
165+
while i < CHARS.len() {
166+
let c = CHARS[i];
167+
let h = c / 10 + b'0';
168+
let l = c % 10 + b'0';
169+
table.lookup[(c - MIN_CHAR) as usize] = u16::from_ne_bytes([h, l]);
170+
i += 1;
171+
}
70172
table
71173
};
174+
175+
/// A bitset of the characters that need escaping, starting at [`MIN_CHAR`]
176+
const BITS: u32 = {
177+
let mut i = 0;
178+
let mut bits = 0;
179+
while i < CHARS.len() {
180+
bits |= 1 << (CHARS[i] - MIN_CHAR) as u32;
181+
i += 1;
182+
}
183+
bits
184+
};
185+
186+
// RATIONALE: llvm generates better code if the buffer is register sized
187+
const ESCAPED_BUF_INIT: [u8; 8] = *b"&#__;\0\0\0";
188+
const ESCAPED_BUF_LEN: usize = b"&#__;".len();
189+
190+
#[test]
191+
fn simple() {
192+
let mut buf = String::new();
193+
write_escaped_str(&mut buf, "<script>").unwrap();
194+
assert_eq!(buf, "&#60;script&#62;");
195+
196+
buf.clear();
197+
write_escaped_str(&mut buf, "s<crip>t").unwrap();
198+
assert_eq!(buf, "s&#60;crip&#62;t");
199+
200+
buf.clear();
201+
write_escaped_str(&mut buf, "s<cripcripcripcripcripcripcripcripcripcrip>t").unwrap();
202+
assert_eq!(buf, "s&#60;cripcripcripcripcripcripcripcripcripcrip&#62;t");
203+
}

rinja_derive/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ quote = "1"
3636
serde = { version = "1.0", optional = true, features = ["derive"] }
3737
syn = "2.0.3"
3838

39+
[target.'cfg(any(target_arch = "x86", target_arch = "x86_64"))'.dependencies]
40+
jetscii = "0.5.3"
41+
3942
[dev-dependencies]
4043
console = "0.15.8"
4144
similar = "2.6.0"

rinja_derive_standalone/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ quote = "1"
3535
serde = { version = "1.0", optional = true, features = ["derive"] }
3636
syn = "2"
3737

38+
[target.'cfg(any(target_arch = "x86", target_arch = "x86_64"))'.dependencies]
39+
jetscii = "0.5.3"
40+
3841
[dev-dependencies]
3942
criterion = "0.5"
4043

0 commit comments

Comments
 (0)