Skip to content

Commit 2693a18

Browse files
authored
Feature: Add more Filter<_> implementations (#6173)
Allows us to filter by a `MaskValues` if we want to, as well as by indices or slices directly. This is just a nice to have for using the code in `vortex-compute` on top of the incoming changes in #6152 Signed-off-by: Connor Tsui <[email protected]>
1 parent af072b6 commit 2693a18

File tree

6 files changed

+482
-142
lines changed

6 files changed

+482
-142
lines changed

vortex-buffer/src/bit/mod.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,40 @@ pub use buf_mut::*;
2626
pub use view::*;
2727

2828
/// Get the bit value at `index` out of `buf`.
29+
///
30+
/// # Panics
31+
///
32+
/// Panics if `index` is not between 0 and length of `buf * 8`.
2933
#[inline(always)]
3034
pub fn get_bit(buf: &[u8], index: usize) -> bool {
3135
buf[index / 8] & (1 << (index % 8)) != 0
3236
}
3337

34-
/// Get the bit value at `index` out of `buf` without bounds checking
38+
/// Get the bit value at `index` out of `buf` without bounds checking.
3539
///
3640
/// # Safety
3741
///
38-
/// `index` must be between 0 and length of `buf`
42+
/// `index` must be between 0 and length of `buf * 8`.
3943
#[inline(always)]
4044
pub unsafe fn get_bit_unchecked(buf: *const u8, index: usize) -> bool {
4145
(unsafe { *buf.add(index / 8) } & (1 << (index % 8))) != 0
4246
}
4347

44-
/// Set the bit value at `index` in `buf` without bounds checking
48+
/// Set the bit value at `index` in `buf` without bounds checking.
4549
///
4650
/// # Safety
4751
///
48-
/// `index` must be between 0 and length of `buf`
52+
/// `index` must be between 0 and length of `buf * 8`.
4953
#[inline(always)]
5054
pub unsafe fn set_bit_unchecked(buf: *mut u8, index: usize) {
5155
unsafe { *buf.add(index / 8) |= 1 << (index % 8) };
5256
}
5357

54-
/// Unset the bit value at `index` in `buf` without bounds checking
58+
/// Unset the bit value at `index` in `buf` without bounds checking.
5559
///
5660
/// # Safety
5761
///
58-
/// `index` must be between 0 and length of `buf`
62+
/// `index` must be between 0 and length of `buf * 8`.
5963
#[inline(always)]
6064
pub unsafe fn unset_bit_unchecked(buf: *mut u8, index: usize) {
6165
unsafe { *buf.add(index / 8) &= !(1 << (index % 8)) };

vortex-compute/src/filter/bitbuffer.rs

Lines changed: 143 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,55 +9,177 @@ use vortex_buffer::get_bit_unchecked;
99
use vortex_buffer::set_bit_unchecked;
1010
use vortex_buffer::unset_bit_unchecked;
1111
use vortex_mask::Mask;
12+
use vortex_mask::MaskValues;
1213

1314
use crate::filter::Filter;
1415

1516
impl Filter<Mask> for &BitBuffer {
1617
type Output = BitBuffer;
1718

1819
fn filter(self, selection_mask: &Mask) -> BitBuffer {
20+
// We delegate checking that the mask length is equal to self to the `MaskValues`
21+
// filter implementation below.
22+
23+
match selection_mask {
24+
Mask::AllTrue(_) => self.clone(),
25+
Mask::AllFalse(_) => BitBuffer::empty(),
26+
Mask::Values(v) => self.filter(v.as_ref()),
27+
}
28+
}
29+
}
30+
31+
impl Filter<MaskValues> for &BitBuffer {
32+
type Output = BitBuffer;
33+
34+
fn filter(self, mask_values: &MaskValues) -> BitBuffer {
1935
assert_eq!(
20-
selection_mask.len(),
36+
mask_values.len(),
2137
self.len(),
2238
"Selection mask length must equal the mask length"
2339
);
2440

25-
match selection_mask {
26-
Mask::AllTrue(_) => self.clone(),
27-
Mask::AllFalse(_) => BitBuffer::empty(),
28-
Mask::Values(v) => {
29-
filter_indices(self.inner().as_ref(), self.offset(), v.indices()).freeze()
41+
self.filter(mask_values.indices())
42+
}
43+
}
44+
45+
impl Filter<[usize]> for &BitBuffer {
46+
type Output = BitBuffer;
47+
48+
/// Filters by indices.
49+
///
50+
/// The caller should ensure that the indices are strictly increasing, otherwise the resulting
51+
/// buffer might have strange values.
52+
///
53+
/// # Panics
54+
///
55+
/// Panics if any index is out of bounds. With the additional constraint that the indices are
56+
/// strictly increasing, the length of the indices must be less than or equal to the length of
57+
/// `self`.
58+
fn filter(self, indices: &[usize]) -> BitBuffer {
59+
let bools = self.inner().as_ref();
60+
let bit_offset = self.offset();
61+
62+
// FIXME(ngates): this is slower than it could be!
63+
BitBufferMut::collect_bool(indices.len(), |idx| {
64+
let idx = *unsafe { indices.get_unchecked(idx) };
65+
get_bit(bools, bit_offset + idx) // Panics if out of bounds.
66+
})
67+
.freeze()
68+
}
69+
}
70+
71+
impl Filter<[(usize, usize)]> for &BitBuffer {
72+
type Output = BitBuffer;
73+
74+
/// Filters by ranges of indices.
75+
///
76+
/// The caller should ensure that the ranges are strictly increasing, otherwise the resulting
77+
/// buffer might have strange values.
78+
///
79+
/// # Panics
80+
///
81+
/// Panics if any range is out of bounds. With the additional constraint that the ranges are
82+
/// strictly increasing, the length of the `slices` array must be less than or equal to the
83+
/// length of `self`.
84+
fn filter(self, slices: &[(usize, usize)]) -> BitBuffer {
85+
let bools = self.inner().as_ref();
86+
let bit_offset = self.offset();
87+
let output_len: usize = slices.iter().map(|(start, end)| end - start).sum();
88+
89+
let mut out = BitBufferMut::with_capacity(output_len);
90+
91+
// FIXME(ngates): this is slower than it could be!
92+
for &(start, end) in slices {
93+
for idx in start..end {
94+
out.append(get_bit(bools, bit_offset + idx)); // Panics if out of bounds.
3095
}
3196
}
97+
98+
out.freeze()
3299
}
33100
}
34101

35102
impl Filter<Mask> for &mut BitBufferMut {
36103
type Output = ();
37104

38105
fn filter(self, selection_mask: &Mask) {
39-
assert_eq!(
40-
selection_mask.len(),
41-
self.len(),
42-
"Selection mask length must equal the mask length"
43-
);
106+
// We delegate checking that the mask length is equal to self to the `MaskValues`
107+
// filter implementation below.
44108

45109
match selection_mask {
46110
Mask::AllTrue(_) => {}
47111
Mask::AllFalse(_) => self.clear(),
48-
Mask::Values(v) => {
49-
*self = filter_indices(self.inner().as_slice(), self.offset(), v.indices())
50-
}
112+
Mask::Values(v) => self.filter(v.as_ref()),
51113
}
52114
}
53115
}
54116

55-
fn filter_indices(bools: &[u8], bit_offset: usize, indices: &[usize]) -> BitBufferMut {
56-
// FIXME(ngates): this is slower than it could be!
57-
BitBufferMut::collect_bool(indices.len(), |idx| {
58-
let idx = *unsafe { indices.get_unchecked(idx) };
59-
get_bit(bools, bit_offset + idx)
60-
})
117+
impl Filter<MaskValues> for &mut BitBufferMut {
118+
type Output = ();
119+
120+
fn filter(self, mask_values: &MaskValues) {
121+
assert_eq!(
122+
mask_values.len(),
123+
self.len(),
124+
"Selection mask length must equal the mask length"
125+
);
126+
127+
// BitBufferMut filtering always uses indices for simplicity.
128+
self.filter(mask_values.indices())
129+
}
130+
}
131+
132+
impl Filter<[usize]> for &mut BitBufferMut {
133+
type Output = ();
134+
135+
/// Filters by indices.
136+
///
137+
/// The caller should ensure that the indices are strictly increasing, otherwise the resulting
138+
/// buffer might have strange values.
139+
///
140+
/// # Panics
141+
///
142+
/// Panics if any index is out of bounds. With the additional constraint that the indices are
143+
/// strictly increasing, the length of the indices must be less than or equal to the length of
144+
/// `self`.
145+
fn filter(self, indices: &[usize]) {
146+
let bools = self.inner().as_slice();
147+
let bit_offset = self.offset();
148+
149+
// FIXME(ngates): this is slower than it could be!
150+
*self = BitBufferMut::collect_bool(indices.len(), |idx| {
151+
let idx = *unsafe { indices.get_unchecked(idx) };
152+
get_bit(bools, bit_offset + idx) // Panics if out of bounds.
153+
});
154+
}
155+
}
156+
157+
impl Filter<[(usize, usize)]> for &mut BitBufferMut {
158+
type Output = ();
159+
160+
/// Filters by ranges of indices.
161+
///
162+
/// The caller should ensure that the ranges are strictly increasing, otherwise the resulting
163+
/// buffer might have strange values.
164+
///
165+
/// # Panics
166+
///
167+
/// Panics if any range is out of bounds. With the additional constraint that the ranges are
168+
/// strictly increasing, the length of the `slices` array must be less than or equal to the
169+
/// length of `self`.
170+
fn filter(self, slices: &[(usize, usize)]) {
171+
let bools = self.inner().as_slice();
172+
let bit_offset = self.offset();
173+
let output_len: usize = slices.iter().map(|(start, end)| end - start).sum();
174+
175+
let mut out = BitBufferMut::with_capacity(output_len);
176+
for &(start, end) in slices {
177+
for idx in start..end {
178+
out.append(get_bit(bools, bit_offset + idx)); // Panics if out of bounds.
179+
}
180+
}
181+
*self = out;
182+
}
61183
}
62184

63185
impl<const NB: usize> Filter<BitView<'_, NB>> for &BitBuffer {
@@ -121,7 +243,7 @@ mod test {
121243
#[test]
122244
fn filter_bool_by_index_test() {
123245
let buf = bitbuffer![1 1 0];
124-
let filtered = filter_indices(buf.inner().as_ref(), 0, &[0, 2]).freeze();
246+
let filtered = (&buf).filter([0usize, 2].as_slice());
125247
assert_eq!(2, filtered.len());
126248
assert_eq!(filtered, bitbuffer![1 0])
127249
}

0 commit comments

Comments
 (0)