@@ -9,55 +9,177 @@ use vortex_buffer::get_bit_unchecked;
99use vortex_buffer:: set_bit_unchecked;
1010use vortex_buffer:: unset_bit_unchecked;
1111use vortex_mask:: Mask ;
12+ use vortex_mask:: MaskValues ;
1213
1314use crate :: filter:: Filter ;
1415
1516impl Filter < Mask > for & BitBuffer {
1617 type Output = BitBuffer ;
1718
1819 fn filter ( self , selection_mask : & Mask ) -> BitBuffer {
20+ // We delegate checking that the mask length is equal to self to the `MaskValues`
21+ // filter implementation below.
22+
23+ match selection_mask {
24+ Mask :: AllTrue ( _) => self . clone ( ) ,
25+ Mask :: AllFalse ( _) => BitBuffer :: empty ( ) ,
26+ Mask :: Values ( v) => self . filter ( v. as_ref ( ) ) ,
27+ }
28+ }
29+ }
30+
31+ impl Filter < MaskValues > for & BitBuffer {
32+ type Output = BitBuffer ;
33+
34+ fn filter ( self , mask_values : & MaskValues ) -> BitBuffer {
1935 assert_eq ! (
20- selection_mask . len( ) ,
36+ mask_values . len( ) ,
2137 self . len( ) ,
2238 "Selection mask length must equal the mask length"
2339 ) ;
2440
25- match selection_mask {
26- Mask :: AllTrue ( _) => self . clone ( ) ,
27- Mask :: AllFalse ( _) => BitBuffer :: empty ( ) ,
28- Mask :: Values ( v) => {
29- filter_indices ( self . inner ( ) . as_ref ( ) , self . offset ( ) , v. indices ( ) ) . freeze ( )
41+ self . filter ( mask_values. indices ( ) )
42+ }
43+ }
44+
45+ impl Filter < [ usize ] > for & BitBuffer {
46+ type Output = BitBuffer ;
47+
48+ /// Filters by indices.
49+ ///
50+ /// The caller should ensure that the indices are strictly increasing, otherwise the resulting
51+ /// buffer might have strange values.
52+ ///
53+ /// # Panics
54+ ///
55+ /// Panics if any index is out of bounds. With the additional constraint that the indices are
56+ /// strictly increasing, the length of the indices must be less than or equal to the length of
57+ /// `self`.
58+ fn filter ( self , indices : & [ usize ] ) -> BitBuffer {
59+ let bools = self . inner ( ) . as_ref ( ) ;
60+ let bit_offset = self . offset ( ) ;
61+
62+ // FIXME(ngates): this is slower than it could be!
63+ BitBufferMut :: collect_bool ( indices. len ( ) , |idx| {
64+ let idx = * unsafe { indices. get_unchecked ( idx) } ;
65+ get_bit ( bools, bit_offset + idx) // Panics if out of bounds.
66+ } )
67+ . freeze ( )
68+ }
69+ }
70+
71+ impl Filter < [ ( usize , usize ) ] > for & BitBuffer {
72+ type Output = BitBuffer ;
73+
74+ /// Filters by ranges of indices.
75+ ///
76+ /// The caller should ensure that the ranges are strictly increasing, otherwise the resulting
77+ /// buffer might have strange values.
78+ ///
79+ /// # Panics
80+ ///
81+ /// Panics if any range is out of bounds. With the additional constraint that the ranges are
82+ /// strictly increasing, the length of the `slices` array must be less than or equal to the
83+ /// length of `self`.
84+ fn filter ( self , slices : & [ ( usize , usize ) ] ) -> BitBuffer {
85+ let bools = self . inner ( ) . as_ref ( ) ;
86+ let bit_offset = self . offset ( ) ;
87+ let output_len: usize = slices. iter ( ) . map ( |( start, end) | end - start) . sum ( ) ;
88+
89+ let mut out = BitBufferMut :: with_capacity ( output_len) ;
90+
91+ // FIXME(ngates): this is slower than it could be!
92+ for & ( start, end) in slices {
93+ for idx in start..end {
94+ out. append ( get_bit ( bools, bit_offset + idx) ) ; // Panics if out of bounds.
3095 }
3196 }
97+
98+ out. freeze ( )
3299 }
33100}
34101
35102impl Filter < Mask > for & mut BitBufferMut {
36103 type Output = ( ) ;
37104
38105 fn filter ( self , selection_mask : & Mask ) {
39- assert_eq ! (
40- selection_mask. len( ) ,
41- self . len( ) ,
42- "Selection mask length must equal the mask length"
43- ) ;
106+ // We delegate checking that the mask length is equal to self to the `MaskValues`
107+ // filter implementation below.
44108
45109 match selection_mask {
46110 Mask :: AllTrue ( _) => { }
47111 Mask :: AllFalse ( _) => self . clear ( ) ,
48- Mask :: Values ( v) => {
49- * self = filter_indices ( self . inner ( ) . as_slice ( ) , self . offset ( ) , v. indices ( ) )
50- }
112+ Mask :: Values ( v) => self . filter ( v. as_ref ( ) ) ,
51113 }
52114 }
53115}
54116
55- fn filter_indices ( bools : & [ u8 ] , bit_offset : usize , indices : & [ usize ] ) -> BitBufferMut {
56- // FIXME(ngates): this is slower than it could be!
57- BitBufferMut :: collect_bool ( indices. len ( ) , |idx| {
58- let idx = * unsafe { indices. get_unchecked ( idx) } ;
59- get_bit ( bools, bit_offset + idx)
60- } )
117+ impl Filter < MaskValues > for & mut BitBufferMut {
118+ type Output = ( ) ;
119+
120+ fn filter ( self , mask_values : & MaskValues ) {
121+ assert_eq ! (
122+ mask_values. len( ) ,
123+ self . len( ) ,
124+ "Selection mask length must equal the mask length"
125+ ) ;
126+
127+ // BitBufferMut filtering always uses indices for simplicity.
128+ self . filter ( mask_values. indices ( ) )
129+ }
130+ }
131+
132+ impl Filter < [ usize ] > for & mut BitBufferMut {
133+ type Output = ( ) ;
134+
135+ /// Filters by indices.
136+ ///
137+ /// The caller should ensure that the indices are strictly increasing, otherwise the resulting
138+ /// buffer might have strange values.
139+ ///
140+ /// # Panics
141+ ///
142+ /// Panics if any index is out of bounds. With the additional constraint that the indices are
143+ /// strictly increasing, the length of the indices must be less than or equal to the length of
144+ /// `self`.
145+ fn filter ( self , indices : & [ usize ] ) {
146+ let bools = self . inner ( ) . as_slice ( ) ;
147+ let bit_offset = self . offset ( ) ;
148+
149+ // FIXME(ngates): this is slower than it could be!
150+ * self = BitBufferMut :: collect_bool ( indices. len ( ) , |idx| {
151+ let idx = * unsafe { indices. get_unchecked ( idx) } ;
152+ get_bit ( bools, bit_offset + idx) // Panics if out of bounds.
153+ } ) ;
154+ }
155+ }
156+
157+ impl Filter < [ ( usize , usize ) ] > for & mut BitBufferMut {
158+ type Output = ( ) ;
159+
160+ /// Filters by ranges of indices.
161+ ///
162+ /// The caller should ensure that the ranges are strictly increasing, otherwise the resulting
163+ /// buffer might have strange values.
164+ ///
165+ /// # Panics
166+ ///
167+ /// Panics if any range is out of bounds. With the additional constraint that the ranges are
168+ /// strictly increasing, the length of the `slices` array must be less than or equal to the
169+ /// length of `self`.
170+ fn filter ( self , slices : & [ ( usize , usize ) ] ) {
171+ let bools = self . inner ( ) . as_slice ( ) ;
172+ let bit_offset = self . offset ( ) ;
173+ let output_len: usize = slices. iter ( ) . map ( |( start, end) | end - start) . sum ( ) ;
174+
175+ let mut out = BitBufferMut :: with_capacity ( output_len) ;
176+ for & ( start, end) in slices {
177+ for idx in start..end {
178+ out. append ( get_bit ( bools, bit_offset + idx) ) ; // Panics if out of bounds.
179+ }
180+ }
181+ * self = out;
182+ }
61183}
62184
63185impl < const NB : usize > Filter < BitView < ' _ , NB > > for & BitBuffer {
@@ -121,7 +243,7 @@ mod test {
121243 #[ test]
122244 fn filter_bool_by_index_test ( ) {
123245 let buf = bitbuffer ! [ 1 1 0 ] ;
124- let filtered = filter_indices ( buf. inner ( ) . as_ref ( ) , 0 , & [ 0 , 2 ] ) . freeze ( ) ;
246+ let filtered = ( & buf) . filter ( [ 0usize , 2 ] . as_slice ( ) ) ;
125247 assert_eq ! ( 2 , filtered. len( ) ) ;
126248 assert_eq ! ( filtered, bitbuffer![ 1 0 ] )
127249 }
0 commit comments