Skip to content

Commit a615f31

Browse files
committed
Rust/2024/19: improve solution
1 parent 0ffc103 commit a615f31

File tree

3 files changed

+147
-10
lines changed

3 files changed

+147
-10
lines changed

Rust/2024/19.rs

+14-10
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,34 @@
11
#![feature(test)]
22

3+
use aoc::trie::Trie;
34
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
45

56
#[derive(Debug)]
67
struct Input {
7-
patterns: Vec<String>,
8-
designs: Vec<String>,
8+
patterns: Trie<u8>,
9+
designs: Vec<Vec<u8>>,
910
}
1011

1112
fn setup(input: &str) -> Input {
1213
let mut lines = input.trim().lines();
13-
let patterns = lines.next().unwrap().split(", ").map(Into::into).collect();
14+
let patterns = lines
15+
.next()
16+
.unwrap()
17+
.split(", ")
18+
.map(|p| p.bytes())
19+
.collect();
1420
let designs = lines.skip(1).map(Into::into).collect();
1521
Input { patterns, designs }
1622
}
1723

18-
fn count(design: &str, patterns: &[String]) -> usize {
24+
fn count(design: &[u8], patterns: &Trie<u8>) -> usize {
1925
let mut dp = vec![0; design.len() + 1];
2026
dp[design.len()] = 1;
2127
for i in (0..design.len()).rev() {
22-
for p in patterns {
23-
if i + p.len() > design.len() || !design[i..].starts_with(p) {
24-
continue;
25-
}
26-
dp[i] += dp[i + p.len()];
27-
}
28+
dp[i] = patterns
29+
.prefix_matches(design[i..].iter().copied())
30+
.map(|len| dp[i + len])
31+
.sum();
2832
}
2933
dp[0]
3034
}

Rust/lib/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ pub mod iter_ext;
1010
pub mod math;
1111
pub mod parsing;
1212
pub mod range;
13+
pub mod trie;
1314
pub mod tuples;
1415

1516
extern crate test;

Rust/lib/trie.rs

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
use std::hash::Hash;
2+
3+
use rustc_hash::FxHashMap;
4+
5+
#[derive(Debug, Clone)]
6+
pub struct Trie<T>(Vec<TrieNode<T>>);
7+
8+
#[derive(Debug, Clone)]
9+
struct TrieNode<T> {
10+
flag: bool,
11+
next: FxHashMap<T, usize>,
12+
}
13+
14+
impl<T> Trie<T> {
15+
pub fn new() -> Self {
16+
Self::default()
17+
}
18+
}
19+
20+
impl<T: Eq + Hash> Trie<T> {
21+
pub fn insert(&mut self, item: impl IntoIterator<Item = T>) -> bool {
22+
let mut node = 0;
23+
for x in item {
24+
match self.0[node].next.get(&x) {
25+
Some(&next) => node = next,
26+
None => {
27+
let next = self.0.len();
28+
self.0.push(TrieNode::default());
29+
self.0[node].next.insert(x, next);
30+
node = next;
31+
}
32+
}
33+
}
34+
!std::mem::replace(&mut self.0[node].flag, true)
35+
}
36+
37+
pub fn contains(&self, item: impl IntoIterator<Item = T>) -> bool {
38+
let mut node = 0;
39+
for x in item {
40+
match self.0[node].next.get(&x) {
41+
Some(&next) => node = next,
42+
None => return false,
43+
}
44+
}
45+
self.0[node].flag
46+
}
47+
48+
pub fn prefix_matches<U: IntoIterator<Item = T>>(
49+
&self,
50+
item: U,
51+
) -> impl Iterator<Item = usize> + use<'_, T, U> {
52+
self.0[0].flag.then_some(0).into_iter().chain(
53+
item.into_iter()
54+
.scan(0, |node, x| {
55+
self.0[*node].next.get(&x).map(|&next| {
56+
*node = next;
57+
self.0[*node].flag
58+
})
59+
})
60+
.enumerate()
61+
.flat_map(|(i, flag)| flag.then_some(i + 1)),
62+
)
63+
}
64+
}
65+
66+
impl<T> Default for Trie<T> {
67+
fn default() -> Self {
68+
Self(vec![TrieNode::default()])
69+
}
70+
}
71+
72+
impl<T> Default for TrieNode<T> {
73+
fn default() -> Self {
74+
Self {
75+
flag: false,
76+
next: Default::default(),
77+
}
78+
}
79+
}
80+
81+
impl<I2> FromIterator<I2> for Trie<I2::Item>
82+
where
83+
I2: IntoIterator,
84+
I2::Item: Eq + Hash,
85+
{
86+
fn from_iter<I1: IntoIterator<Item = I2>>(iter: I1) -> Self {
87+
let mut trie = Self::new();
88+
for item in iter {
89+
trie.insert(item);
90+
}
91+
trie
92+
}
93+
}
94+
95+
#[cfg(test)]
96+
mod tests {
97+
use super::*;
98+
99+
#[test]
100+
fn trie() {
101+
let mut t = ["foo", "bar", "baz"]
102+
.into_iter()
103+
.map(|x| x.chars())
104+
.collect::<Trie<_>>();
105+
106+
assert!(t.contains("foo".chars()));
107+
assert!(t.contains("bar".chars()));
108+
assert!(t.contains("baz".chars()));
109+
assert!(!t.contains("test".chars()));
110+
assert!(!t.contains("baa".chars()));
111+
112+
assert!(t.insert("test".chars()));
113+
assert!(t.contains("test".chars()));
114+
assert!(!t.insert("test".chars()));
115+
assert!(t.contains("test".chars()));
116+
}
117+
118+
#[test]
119+
fn prefix_matches() {
120+
let t = ["", "123", "12345", "1234567", "test", "12xy"]
121+
.into_iter()
122+
.map(|x| x.chars())
123+
.collect::<Trie<_>>();
124+
125+
assert_eq!(
126+
t.prefix_matches("123456789".chars()).collect::<Vec<_>>(),
127+
[0, 3, 5, 7]
128+
);
129+
130+
assert_eq!(t.prefix_matches("".chars()).collect::<Vec<_>>(), [0]);
131+
}
132+
}

0 commit comments

Comments
 (0)