|
1 |
| -use std::fmt; |
2 |
| -use std::num::NonZeroU8; |
| 1 | +use std::{fmt, str}; |
3 | 2 |
|
| 3 | +#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] |
4 | 4 | #[allow(unused)]
|
5 | 5 | pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, string: &str) -> fmt::Result {
|
6 |
| - let mut escaped_buf = *b"&#__;"; |
| 6 | + // Even though [`jetscii`] ships a generic implementation for unsupported platforms, |
| 7 | + // it is not well optimized for this case. This implementation should work well enough in |
| 8 | + // the meantime, until portable SIMD gets stabilized. |
| 9 | + |
| 10 | + // Instead of testing the platform, we could test the CPU features. But given that the needed |
| 11 | + // instruction set SSE 4.2 was introduced in 2008, that it has an 99.61 % availability rate |
| 12 | + // in Steam's June 2024 hardware survey, and is a prerequisite to run Windows 11, I don't |
| 13 | + // think we need to care. |
| 14 | + |
| 15 | + let mut escaped_buf = ESCAPED_BUF_INIT; |
7 | 16 | let mut last = 0;
|
8 | 17 |
|
9 | 18 | for (index, byte) in string.bytes().enumerate() {
|
10 | 19 | let escaped = match byte {
|
11 | 20 | MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize],
|
12 |
| - _ => None, |
| 21 | + _ => 0, |
13 | 22 | };
|
14 |
| - if let Some(escaped) = escaped { |
15 |
| - escaped_buf[2] = escaped[0].get(); |
16 |
| - escaped_buf[3] = escaped[1].get(); |
17 |
| - fmt.write_str(&string[last..index])?; |
18 |
| - fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?; |
| 23 | + if escaped != 0 { |
| 24 | + [escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes(); |
| 25 | + write_str_if_nonempty(&mut fmt, &string[last..index])?; |
| 26 | + // SAFETY: the content of `escaped_buf` is pure ASCII |
| 27 | + fmt.write_str(unsafe { |
| 28 | + std::str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) |
| 29 | + })?; |
19 | 30 | last = index + 1;
|
20 | 31 | }
|
21 | 32 | }
|
22 |
| - fmt.write_str(&string[last..]) |
| 33 | + write_str_if_nonempty(&mut fmt, &string[last..]) |
| 34 | +} |
| 35 | + |
| 36 | +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 37 | +#[allow(unused)] |
| 38 | +pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, mut string: &str) -> fmt::Result { |
| 39 | + let jetscii = jetscii::bytes!(b'"', b'&', b'\'', b'<', b'>'); |
| 40 | + |
| 41 | + let mut escaped_buf = ESCAPED_BUF_INIT; |
| 42 | + loop { |
| 43 | + if string.is_empty() { |
| 44 | + return Ok(()); |
| 45 | + } |
| 46 | + |
| 47 | + let found = if string.len() >= 16 { |
| 48 | + // Only strings of at least 16 bytes can be escaped using SSE instructions. |
| 49 | + match jetscii.find(string.as_bytes()) { |
| 50 | + Some(index) => { |
| 51 | + let escaped = TABLE.lookup[(string.as_bytes()[index] - MIN_CHAR) as usize]; |
| 52 | + Some((index, escaped)) |
| 53 | + } |
| 54 | + None => None, |
| 55 | + } |
| 56 | + } else { |
| 57 | + // The small-string fallback of [`jetscii`] is quite slow, so we roll our own |
| 58 | + // implementation. |
| 59 | + string.as_bytes().iter().find_map(|byte: &u8| { |
| 60 | + let escaped = get_escaped(*byte)?; |
| 61 | + let index = (byte as *const u8 as usize) - (string.as_ptr() as usize); |
| 62 | + Some((index, escaped)) |
| 63 | + }) |
| 64 | + }; |
| 65 | + let Some((index, escaped)) = found else { |
| 66 | + return fmt.write_str(string); |
| 67 | + }; |
| 68 | + |
| 69 | + [escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes(); |
| 70 | + |
| 71 | + // SAFETY: index points at an ASCII char in `string` |
| 72 | + let front; |
| 73 | + (front, string) = unsafe { |
| 74 | + ( |
| 75 | + string.get_unchecked(..index), |
| 76 | + string.get_unchecked(index + 1..), |
| 77 | + ) |
| 78 | + }; |
| 79 | + |
| 80 | + write_str_if_nonempty(&mut fmt, front)?; |
| 81 | + // SAFETY: the content of `escaped_buf` is pure ASCII |
| 82 | + fmt.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })?; |
| 83 | + } |
23 | 84 | }
|
24 | 85 |
|
25 | 86 | #[allow(unused)]
|
26 | 87 | pub(crate) fn write_escaped_char(mut fmt: impl fmt::Write, c: char) -> fmt::Result {
|
27 |
| - fmt.write_str(match (c.is_ascii(), c as u8) { |
28 |
| - (true, b'"') => """, |
29 |
| - (true, b'&') => "&", |
30 |
| - (true, b'\'') => "'", |
31 |
| - (true, b'<') => "<", |
32 |
| - (true, b'>') => ">", |
33 |
| - _ => return fmt.write_char(c), |
34 |
| - }) |
| 88 | + if !c.is_ascii() { |
| 89 | + fmt.write_char(c) |
| 90 | + } else if let Some(escaped) = get_escaped(c as u8) { |
| 91 | + let mut escaped_buf = ESCAPED_BUF_INIT; |
| 92 | + [escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes(); |
| 93 | + // SAFETY: the content of `escaped_buf` is pure ASCII |
| 94 | + fmt.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) }) |
| 95 | + } else { |
| 96 | + // RATIONALE: `write_char(c)` gets optimized if it is known that `c.is_ascii()` |
| 97 | + fmt.write_char(c) |
| 98 | + } |
35 | 99 | }
|
36 | 100 |
|
37 |
| -const MIN_CHAR: u8 = b'"'; |
38 |
| -const MAX_CHAR: u8 = b'>'; |
| 101 | +#[inline(always)] |
| 102 | +fn get_escaped(byte: u8) -> Option<u16> { |
| 103 | + let c = byte.wrapping_sub(MIN_CHAR); |
| 104 | + if (c < u32::BITS as u8) && (BITS & (1 << c as u32) != 0) { |
| 105 | + Some(TABLE.lookup[c as usize]) |
| 106 | + } else { |
| 107 | + None |
| 108 | + } |
| 109 | +} |
39 | 110 |
|
40 |
| -struct Table { |
41 |
| - _align: [usize; 0], |
42 |
| - lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize], |
| 111 | +#[inline(always)] |
| 112 | +fn write_str_if_nonempty(output: &mut impl fmt::Write, input: &str) -> fmt::Result { |
| 113 | + if !input.is_empty() { |
| 114 | + output.write_str(input) |
| 115 | + } else { |
| 116 | + Ok(()) |
| 117 | + } |
43 | 118 | }
|
44 | 119 |
|
45 |
| -const TABLE: Table = { |
46 |
| - const fn n(c: u8) -> Option<[NonZeroU8; 2]> { |
47 |
| - assert!(MIN_CHAR <= c && c <= MAX_CHAR); |
| 120 | +/// List of characters that need HTML escaping, not necessarily in ordinal order. |
| 121 | +/// Filling the [`TABLE`] and [`BITS`] constants will ensure that the range of lowest to hightest |
| 122 | +/// codepoint wont exceed [`u32::BITS`] (=32) items. |
| 123 | +const CHARS: &[u8] = br#""&'<>"#; |
48 | 124 |
|
49 |
| - let n0 = match NonZeroU8::new(c / 10 + b'0') { |
50 |
| - Some(n) => n, |
51 |
| - None => panic!(), |
52 |
| - }; |
53 |
| - let n1 = match NonZeroU8::new(c % 10 + b'0') { |
54 |
| - Some(n) => n, |
55 |
| - None => panic!(), |
56 |
| - }; |
57 |
| - Some([n0, n1]) |
| 125 | +/// The character with the smallest codepoint that needs HTML escaping. |
| 126 | +/// Both [`TABLE`] and [`BITS`] start at this value instead of `0`. |
| 127 | +const MIN_CHAR: u8 = { |
| 128 | + let mut v = u8::MAX; |
| 129 | + let mut i = 0; |
| 130 | + while i < CHARS.len() { |
| 131 | + if v > CHARS[i] { |
| 132 | + v = CHARS[i]; |
| 133 | + } |
| 134 | + i += 1; |
| 135 | + } |
| 136 | + v |
| 137 | +}; |
| 138 | + |
| 139 | +#[allow(unused)] |
| 140 | +const MAX_CHAR: u8 = { |
| 141 | + let mut v = u8::MIN; |
| 142 | + let mut i = 0; |
| 143 | + while i < CHARS.len() { |
| 144 | + if v < CHARS[i] { |
| 145 | + v = CHARS[i]; |
| 146 | + } |
| 147 | + i += 1; |
58 | 148 | }
|
| 149 | + v |
| 150 | +}; |
| 151 | + |
| 152 | +struct Table { |
| 153 | + _align: [usize; 0], |
| 154 | + lookup: [u16; u32::BITS as usize], |
| 155 | +} |
59 | 156 |
|
| 157 | +/// For characters that need HTML escaping, the codepoint formatted as decimal digits, |
| 158 | +/// otherwise `b"\0\0"`. Starting at [`MIN_CHAR`]. |
| 159 | +const TABLE: Table = { |
60 | 160 | let mut table = Table {
|
61 | 161 | _align: [],
|
62 |
| - lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize], |
| 162 | + lookup: [0; u32::BITS as usize], |
63 | 163 | };
|
64 |
| - |
65 |
| - table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"'); |
66 |
| - table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&'); |
67 |
| - table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\''); |
68 |
| - table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<'); |
69 |
| - table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>'); |
| 164 | + let mut i = 0; |
| 165 | + while i < CHARS.len() { |
| 166 | + let c = CHARS[i]; |
| 167 | + let h = c / 10 + b'0'; |
| 168 | + let l = c % 10 + b'0'; |
| 169 | + table.lookup[(c - MIN_CHAR) as usize] = u16::from_ne_bytes([h, l]); |
| 170 | + i += 1; |
| 171 | + } |
70 | 172 | table
|
71 | 173 | };
|
| 174 | + |
| 175 | +/// A bitset of the characters that need escaping, starting at [`MIN_CHAR`] |
| 176 | +const BITS: u32 = { |
| 177 | + let mut i = 0; |
| 178 | + let mut bits = 0; |
| 179 | + while i < CHARS.len() { |
| 180 | + bits |= 1 << (CHARS[i] - MIN_CHAR) as u32; |
| 181 | + i += 1; |
| 182 | + } |
| 183 | + bits |
| 184 | +}; |
| 185 | + |
| 186 | +// RATIONALE: llvm generates better code if the buffer is register sized |
| 187 | +const ESCAPED_BUF_INIT: [u8; 8] = *b"&#__;\0\0\0"; |
| 188 | +const ESCAPED_BUF_LEN: usize = b"&#__;".len(); |
| 189 | + |
| 190 | +#[test] |
| 191 | +fn simple() { |
| 192 | + let mut buf = String::new(); |
| 193 | + write_escaped_str(&mut buf, "<script>").unwrap(); |
| 194 | + assert_eq!(buf, "<script>"); |
| 195 | + |
| 196 | + buf.clear(); |
| 197 | + write_escaped_str(&mut buf, "s<crip>t").unwrap(); |
| 198 | + assert_eq!(buf, "s<crip>t"); |
| 199 | + |
| 200 | + buf.clear(); |
| 201 | + write_escaped_str(&mut buf, "s<cripcripcripcripcripcripcripcripcripcrip>t").unwrap(); |
| 202 | + assert_eq!(buf, "s<cripcripcripcripcripcripcripcripcripcrip>t"); |
| 203 | +} |
0 commit comments