diff --git a/components/collections/src/codepointtrie/cptrie.rs b/components/collections/src/codepointtrie/cptrie.rs index dbcbb8e72e4..60d03520b2b 100644 --- a/components/collections/src/codepointtrie/cptrie.rs +++ b/components/collections/src/codepointtrie/cptrie.rs @@ -382,6 +382,12 @@ impl<'trie, T: TrieValue> CodePointTrie<'trie, T> { return Err(Error::DataTooShortForFastAccess); } + // The builder is supposed to support direct indexing to the data array + // by ASCII. + if data.len() < 128 { + return Err(Error::DataTooShortForAsciiAccess); + } + // Invariant upheld for `data`: If we got this far, the length of `data` // satisfies `data`'s length invariant on the assumption that the contents // of `fast_index` subslice of `index` and `header.trie_type` will not @@ -576,6 +582,12 @@ impl<'trie, T: TrieValue> CodePointTrie<'trie, T> { ); let bit_prefix = (code_point as usize) >> FAST_TYPE_SHIFT; + let bit_suffix = (code_point & FAST_TYPE_DATA_MASK) as usize; + self.get_bit_prefix_suffix_assuming_fast_index(bit_prefix, bit_suffix) + } + + #[inline(always)] + unsafe fn get_bit_prefix_suffix_assuming_fast_index(&self, bit_prefix: usize, bit_suffix: usize) -> T { debug_assert!(bit_prefix < self.index.len()); // SAFETY: Relying on the length invariant of `self.index` having // been checked and on the unchangedness invariant of `self.index` @@ -583,7 +595,6 @@ impl<'trie, T: TrieValue> CodePointTrie<'trie, T> { let base_offset_to_data: usize = usize::from(u16::from_unaligned(*unsafe { self.index.as_ule_slice().get_unchecked(bit_prefix) })); - let bit_suffix = (code_point & FAST_TYPE_DATA_MASK) as usize; // SAFETY: Cannot overflow with supported (32-bit and 64-bit) `usize` // sizes, since `base_offset_to_data` was extended from `u16` and // `bit_suffix` is at most `FAST_TYPE_DATA_MASK`, which is well @@ -694,6 +705,96 @@ impl<'trie, T: TrieValue> CodePointTrie<'trie, T> { } } + /// Returns the value that is associated with `latin1` in this [`CodePointTrie`]. + #[inline(always)] + pub fn get8(&self, latin1: u8) -> T { + let code_point = u32::from(latin1); + debug_assert!(code_point <= SMALL_TYPE_FAST_INDEXING_MAX); + // SAFETY: `u8` is always below `SMALL_TYPE_FAST_INDEXING_MAX` and, + // therefore, belowe `FAST_TYPE_FAST_INDEXING_MAX`. + unsafe { self.get32_assuming_fast_index(code_point) } + } + + /// Returns the value that is associated with `ascii` in this [`CodePointTrie`]. + /// + /// # Safety + /// + /// `ascii` must be less than 128. + #[inline(always)] + pub unsafe fn get7(&self, ascii: u8) -> T { + debug_assert!(ascii < 128); + debug_assert!((ascii as usize) < self.data.len()); + // SAFETY: Length of `self.data` checked in the constructor. + T::from_unaligned(*unsafe { self.data.as_ule_slice().get_unchecked(ascii as usize) }) + } + + /// Returns the value that is associated with a two-byte UTF-8 sequence in this [`CodePointTrie`]. + /// + /// `high_five` is the low five bits of the lead byte of a two-byte UTF-8 sequence. + /// `low_six` is the low six bits of the trail byte of a two-byte UTF-8 sequence. + /// + /// # Safety + /// + /// `high_five` must not have bit positions other than the lowest 5 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + /// + /// # Panics + /// + /// With debug assertions enabled, panics if the above safety invariants are + /// violated or `high_five` represents non-shortest form. + #[inline(always)] + pub unsafe fn get_utf8_two_byte(&self, high_five: u32, low_six: u32) -> T { + debug_assert!(low_six <= 0b111_111); // Safety invariant. + debug_assert!(high_five <= 0b11_111); // Safety invariant. + debug_assert!(high_five > 0b1); // Non-shortest form; not safety invariant. + // SAFETY: The highest character representable as a two-byte + // UTF-8 sequence is U+07FF, eleven binary ones, which is below + // both `SMALL_TYPE_FAST_INDEXING_MAX` and `FAST_TYPE_FAST_INDEXING_MAX`. + self.get_bit_prefix_suffix_assuming_fast_index(high_five as usize, low_six as usize) + } + + /// Returns the value that is associated with a three-byte UTF-8 or WTF-8 sequence in this [`CodePointTrie`]. + /// + /// `high_ten` is the low four bits of the lead byte of three-byte UTF-8 or WTF-8 sequence shifted left by 6 followed by the low six bits of the first trail byte. + /// `low_six` is the low six bits of the last trail byte of a three-byte UTF-8 or WTF-8 sequence. + /// + /// Sequences representing surrogates (WTF-8) are allowed. + /// + /// # Safety + /// + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + /// + /// # Intended Invariant + /// + /// `high_ten` must not have bit positions other than the lowest 10 set to 1. + /// + /// # Panics + /// + /// With debug assertions enabled, panics if the above safety invariant is + /// violated or `high_ten` is out of range for three-byte WTF-8 (or UTF-8) + /// sequence. + #[inline(always)] + pub unsafe fn get_utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + debug_assert!(low_six <= 0b111_111); // Safety invariant. + debug_assert!(high_ten <= 0b1111_111_111); // Not actually a _safety_ invariant for this impl. + debug_assert!(high_ten > 0b11_111); // Non-shortest form; not safety invariant. + + let fast_max = match self.header.trie_type { + TrieType::Fast => FAST_TYPE_FAST_INDEXING_MAX, + TrieType::Small => SMALL_TYPE_FAST_INDEXING_MAX, + }; + // Keep only the prefix bits: + let max_bit_prefix = fast_max >> FAST_TYPE_SHIFT; + if high_ten <= max_bit_prefix { + // SAFETY: The caller is responsible for upholding the safety + // invariant for `low_six` and we just checked the safety + // invariant of `high_ten`. + self.get_bit_prefix_suffix_assuming_fast_index(high_ten as usize, low_six as usize) + } else { + self.get32_by_small_index_cold((high_ten << 6) | low_six) + } + } + /// Lookup trie value by non-Basic Multilingual Plane Scalar Value. /// /// The return value may be bogus (not necessarily `error_value`) is the argument is actually in @@ -1429,6 +1530,8 @@ impl Iterator for CodePointMapRangeIterator<'_, T> { /// All implementations of `TypedCodePointTrie` are reviewable in this module. trait Seal {} +impl<'trie, T: TrieValue> Seal for CodePointTrie<'trie, T> {} + /// Trait for writing trait bounds for monomorphizing over either /// `FastCodePointTrie` or `SmallCodePointTrie`. #[allow(private_bounds)] // Permit sealing @@ -1460,6 +1563,22 @@ pub trait TypedCodePointTrie<'trie, T: TrieValue>: Seal { } } + /// Lookup trie value by Latin1 Code Point without branching on trie type. + #[inline(always)] + fn get8(&self, latin1: u8) -> T { + self.as_untyped_ref().get8(latin1) + } + + /// Lookup trie value by ASCII Code Point without branching on trie type. + /// + /// # Safety + /// + /// `ascii` must be less than 128. + #[inline(always)] + unsafe fn get7(&self, ascii: u8) -> T { + self.as_untyped_ref().get7(ascii) + } + /// Lookup trie value by non-Basic Multilingual Plane Scalar Value without branching on trie type. #[inline(always)] fn get32_supplementary(&self, supplementary: u32) -> T { @@ -1521,6 +1640,66 @@ pub trait TypedCodePointTrie<'trie, T: TrieValue>: Seal { } } + /// Returns the value that is associated with a two-byte UTF-8 sequence. + /// + /// `high_five` is the low five bits of the lead byte of a two-byte UTF-8 sequence. + /// `low_six` is the low six bits of the trail byte of a two-byte UTF-8 sequence. + /// + /// # Safety + /// + /// `high_five` must not have bit positions other than the lowest 5 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + /// + /// # Panics + /// + /// With debug assertions enabled, panics if the above safety invariants are + /// violated or `high_five` represents non-shortest form. + #[inline(always)] + unsafe fn get_utf8_two_byte(&self, high_five: u32, low_six: u32) -> T { + self.as_untyped_ref().get_utf8_two_byte(high_five, low_six) + } + + /// Returns the value that is associated with a three-byte UTF-8 or WTF-8 sequence. + /// + /// `high_ten` is the low four bits of the lead byte of three-byte UTF-8 or WTF-8 sequence shifted left by 6 followed by the low six bits of the first trail byte. + /// `low_six` is the low six bits of the last trail byte of a three-byte UTF-8 or WTF-8 sequence. + /// + /// Sequences representing surrogates (WTF-8) are allowed. + /// + /// # Safety + /// + /// `high_ten` must not have bit positions other than the lowest 10 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + /// + /// # Panics + /// + /// With debug assertions enabled, panics if the above safety invariants are + /// violated or `high_ten` is out of range for three-byte WTF-8 (or UTF-8) + /// sequence. + #[inline(always)] + unsafe fn get_utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + debug_assert!(low_six <= 0b111_111); // Safety invariant. + debug_assert!(high_ten <= 0b1111_111_111); // Not actually a _safety_ invariant for this impl. + debug_assert!(high_ten > 0b11_111); // Non-shortest form; not safety invariant. + + debug_assert_eq!(Self::TRIE_TYPE, self.as_untyped_ref().header.trie_type); + let fast_max = match Self::TRIE_TYPE { + TrieType::Fast => FAST_TYPE_FAST_INDEXING_MAX, + TrieType::Small => SMALL_TYPE_FAST_INDEXING_MAX, + }; + + // Keep only the prefix bits: + let max_bit_prefix = fast_max >> FAST_TYPE_SHIFT; + if high_ten <= max_bit_prefix { + // SAFETY: The caller is responsible for upholding the safety + // invariant for `low_six` and we just checked the safety + // invariant of `high_ten`. + self.as_untyped_ref().get_bit_prefix_suffix_assuming_fast_index(high_ten as usize, low_six as usize) + } else { + self.as_untyped_ref().get32_by_small_index_cold((high_ten << 6) | low_six) + } + } + /// Returns a reference to the wrapped `CodePointTrie`. fn as_untyped_ref(&self) -> &CodePointTrie<'trie, T>; @@ -1570,6 +1749,36 @@ impl<'trie, T: TrieValue> TypedCodePointTrie<'trie, T> for FastCodePointTrie<'tr // being correct and the exclusive ways of obtaining `Self`. unsafe { self.as_untyped_ref().get32_assuming_fast_index(code_point) } } + + /// Returns the value that is associated with a three-byte UTF-8 or WTF-8 sequence. + /// + /// `high_ten` is the low four bits of the lead byte of three-byte UTF-8 or WTF-8 sequence shifted left by 6 followed by the low six bits of the first trail byte. + /// `low_six` is the low six bits of the last trail byte of a three-byte UTF-8 or WTF-8 sequence. + /// + /// Sequences representing surrogates (WTF-8) are allowed. + /// + /// # Safety + /// + /// `high_ten` must not have bit positions other than the lowest 10 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + /// + /// # Panics + /// + /// With debug assertions enabled, panics if the above safety invariants are + /// violated or `high_ten` is out of range for three-byte WTF-8 (or UTF-8) + /// sequence. + #[inline(always)] + unsafe fn get_utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + debug_assert!(low_six <= 0b111_111); // Safety invariant. + debug_assert!(high_ten <= 0b1111_111_111); // Safety invariant. + debug_assert!(high_ten > 0b11_111); // Non-shortest form; not safety invariant. + debug_assert_eq!(Self::TRIE_TYPE, TrieType::Fast); + debug_assert_eq!(self.as_untyped_ref().header.trie_type, TrieType::Fast); + // SAFETY: The highest character representable as a three-byte + // UTF-8 sequence is U+FFFF, which is `FAST_TYPE_FAST_INDEXING_MAX`. + self.inner.get_bit_prefix_suffix_assuming_fast_index(high_ten as usize, low_six as usize) + } + } impl<'trie, T: TrieValue> Seal for FastCodePointTrie<'trie, T> {} @@ -1671,6 +1880,194 @@ pub enum Typed { Small(S), } +/// Trait for writing trait bounds for monomorphizing over either +/// `CodePointTrie`, `FastCodePointTrie`, or `SmallCodePointTrie`. +/// +/// Method naming intentionally differs from the method naming on +/// those types in order to disambiguate. +#[allow(private_bounds)] // Permit sealing +pub trait AbstractCodePointTrie<'trie, T: TrieValue>: Seal { + /// Look up trie value by an ASCII character. + /// + /// # Safety + /// + /// `ascii` must be less than 128. + unsafe fn ascii(&self, ascii: u8) -> T; + + /// Look up trie value by a two-byte UTF-8 sequence. + /// + /// `high_five` is the low five bits of the lead byte of a two-byte UTF-8 sequence. + /// `low_six` is the low six bits of the trail byte of a two-byte UTF-8 sequence. + /// + /// # Safety + /// + /// `high_five` must not have bit positions other than the lowest 5 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + unsafe fn utf8_two_byte(&self, high_five: u32, low_six: u32) -> T; + + /// Look up trie value by a three-byte UTF-8 or WTF-8 sequence. + /// + /// `high_ten` is the low four bits of the lead byte of three-byte UTF-8 or WTF-8 sequence shifted left by 6 followed by the low six bits of the first trail byte. + /// `low_six` is the low six bits of the last trail byte of a three-byte UTF-8 or WTF-8 sequence. + /// + /// Sequences representing surrogates (WTF-8) are allowed. + /// + /// # Safety + /// + /// `high_ten` must not have bit positions other than the lowest 10 set to 1. + /// `low_six` must not have bit positions other than the lowest 6 set to 1. + unsafe fn utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T; + + /// Look up trie value by a Latin1 character. + fn latin1(&self, latin1: u8) -> T; + + /// Look up trie value by a Basic Multilingual Plane character. + /// + /// Surrogate values are allowed. + fn bmp(&self, bmp: u16) -> T; + + /// Look up trie value by a non-Basic Multilingual Plane character. + /// + /// The behavior is memory-safe nonsense if the argument is not + /// actually a non-Basic Multilingual Plane character. + fn supplementary(&self, supplementary: u32) -> T; + + /// Look up trie value by a Unicode Scalar Value. + fn scalar(&self, scalar: char) -> T; + + /// Look up trie value by Unicode Code Point. + /// + /// Surrogate values are allowed. Out of range input + /// results in the error value. + fn code_point(&self, code_point: u32) -> T; +} + +impl<'trie, T: TrieValue> AbstractCodePointTrie<'trie, T> for FastCodePointTrie<'trie, T> { + #[inline(always)] + unsafe fn ascii(&self, ascii: u8) -> T { + self.get7(ascii) + } + + #[inline(always)] + unsafe fn utf8_two_byte(&self, high_five: u32, low_six: u32) -> T { + self.get_utf8_two_byte(high_five, low_six) + } + + #[inline(always)] + unsafe fn utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + self.get_utf8_three_byte(high_ten, low_six) + } + + #[inline(always)] + fn latin1(&self, latin1: u8) -> T { + self.get8(latin1) + } + + #[inline(always)] + fn bmp(&self, bmp: u16) -> T { + self.get16(bmp) + } + + #[inline(always)] + fn supplementary(&self, supplementary: u32) -> T { + self.get32_supplementary(supplementary) + } + + #[inline(always)] + fn scalar(&self, scalar: char) -> T { + self.get(scalar) + } + + #[inline(always)] + fn code_point(&self, code_point: u32) -> T { + self.get32(code_point) + } +} + +impl<'trie, T: TrieValue> AbstractCodePointTrie<'trie, T> for SmallCodePointTrie<'trie, T> { + #[inline(always)] + unsafe fn ascii(&self, ascii: u8) -> T { + self.get7(ascii) + } + + #[inline(always)] + unsafe fn utf8_two_byte(&self, high_five: u32, low_six: u32) -> T { + self.get_utf8_two_byte(high_five, low_six) + } + + #[inline(always)] + unsafe fn utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + self.get_utf8_three_byte(high_ten, low_six) + } + + #[inline(always)] + fn latin1(&self, latin1: u8) -> T { + self.get8(latin1) + } + + #[inline(always)] + fn bmp(&self, bmp: u16) -> T { + self.get16(bmp) + } + + #[inline(always)] + fn supplementary(&self, supplementary: u32) -> T { + self.get32_supplementary(supplementary) + } + + #[inline(always)] + fn scalar(&self, scalar: char) -> T { + self.get(scalar) + } + + #[inline(always)] + fn code_point(&self, code_point: u32) -> T { + self.get32(code_point) + } +} + +impl<'trie, T: TrieValue> AbstractCodePointTrie<'trie, T> for CodePointTrie<'trie, T> { + #[inline(always)] + unsafe fn ascii(&self, ascii: u8) -> T { + self.get7(ascii) + } + + #[inline(always)] + unsafe fn utf8_two_byte(&self, high_five: u32, low_six: u32) -> T { + self.get_utf8_two_byte(high_five, low_six) + } + + #[inline(always)] + unsafe fn utf8_three_byte(&self, high_ten: u32, low_six: u32) -> T { + self.get_utf8_three_byte(high_ten, low_six) + } + + #[inline(always)] + fn latin1(&self, latin1: u8) -> T { + self.get8(latin1) + } + + #[inline(always)] + fn bmp(&self, bmp: u16) -> T { + self.get16(bmp) + } + + #[inline(always)] + fn supplementary(&self, supplementary: u32) -> T { + self.get32_supplementary(supplementary) + } + + #[inline(always)] + fn scalar(&self, scalar: char) -> T { + self.get(scalar) + } + + #[inline(always)] + fn code_point(&self, code_point: u32) -> T { + self.get32(code_point) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/components/collections/src/codepointtrie/error.rs b/components/collections/src/codepointtrie/error.rs index 383949d9561..4168c3f1de8 100644 --- a/components/collections/src/codepointtrie/error.rs +++ b/components/collections/src/codepointtrie/error.rs @@ -25,6 +25,9 @@ pub enum Error { /// [`CodePointTrie`](super::CodePointTrie) must be constructed from data vector long enough to accommodate fast-path access #[displaydoc("CodePointTrie must be constructed from data vector long enough to accommodate fast-path access")] DataTooShortForFastAccess, + /// [`CodePointTrie`](super::CodePointTrie) must be constructed from data vector long enough to accommodate direct ASCII access + #[displaydoc("CodePointTrie must be constructed from data vector long enough to accommodate direct ASCII access")] + DataTooShortForAsciiAccess, } impl core::error::Error for Error {} diff --git a/components/collections/src/codepointtrie/mod.rs b/components/collections/src/codepointtrie/mod.rs index dfc5a29ffc2..b6a7fe4b271 100644 --- a/components/collections/src/codepointtrie/mod.rs +++ b/components/collections/src/codepointtrie/mod.rs @@ -40,6 +40,7 @@ pub mod toml; #[cfg(feature = "serde")] mod serde; +pub use cptrie::AbstractCodePointTrie; pub use cptrie::CodePointMapRange; pub use cptrie::CodePointMapRangeIterator; pub use cptrie::CodePointTrie; diff --git a/components/collections/src/codepointtrie/serde.rs b/components/collections/src/codepointtrie/serde.rs index adfb2d8aa08..483b3f6d6de 100644 --- a/components/collections/src/codepointtrie/serde.rs +++ b/components/collections/src/codepointtrie/serde.rs @@ -60,6 +60,9 @@ where super::CodePointTrieError::DataTooShortForFastAccess => { return Err(D::Error::custom("CodePointTrie must be constructed from data vector long enough to accommodate fast-path access")); } + super::CodePointTrieError::DataTooShortForAsciiAccess => { + return Err(D::Error::custom("CodePointTrie must be constructed from data vector long enough to accommodate direct ASCII access")); + } } } };