Skip to content

Commit cbe6c2b

Browse files
committed
feat[cuda]: sequence
Signed-off-by: Andrew Duffy <[email protected]>
1 parent 469e31f commit cbe6c2b

File tree

9 files changed

+235
-18
lines changed

9 files changed

+235
-18
lines changed

Cargo.lock

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

encodings/sequence/src/array.rs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,22 @@ pub struct SequenceMetadata {
5959
multiplier: Option<vortex_proto::scalar::ScalarValue>,
6060
}
6161

62+
/// Components of [`SequenceArray`].
63+
pub struct SequenceArrayParts {
64+
pub base: PValue,
65+
pub multiplier: PValue,
66+
pub len: usize,
67+
pub ptype: PType,
68+
pub nullability: Nullability,
69+
}
70+
6271
#[derive(Clone, Debug)]
6372
/// An array representing the equation `A[i] = base + i * multiplier`.
6473
pub struct SequenceArray {
6574
base: PValue,
6675
multiplier: PValue,
6776
dtype: DType,
68-
pub(crate) length: usize,
77+
pub(crate) len: usize,
6978
stats_set: ArrayStats,
7079
}
7180

@@ -124,7 +133,7 @@ impl SequenceArray {
124133
base,
125134
multiplier,
126135
dtype,
127-
length,
136+
len: length,
128137
// TODO(joe): add stats, on construct or on use?
129138
stats_set: Default::default(),
130139
}
@@ -164,7 +173,7 @@ impl SequenceArray {
164173
}
165174

166175
pub(crate) fn index_value(&self, idx: usize) -> PValue {
167-
assert!(idx < self.length, "index_value({idx}): index out of bounds");
176+
assert!(idx < self.len, "index_value({idx}): index out of bounds");
168177

169178
match_each_native_ptype!(self.ptype(), |P| {
170179
let base = self.base.cast::<P>();
@@ -177,9 +186,19 @@ impl SequenceArray {
177186

178187
/// Returns the validated final value of a sequence array
179188
pub fn last(&self) -> PValue {
180-
Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
189+
Self::try_last(self.base, self.multiplier, self.ptype(), self.len)
181190
.vortex_expect("validated array")
182191
}
192+
193+
pub fn into_parts(self) -> SequenceArrayParts {
194+
SequenceArrayParts {
195+
base: self.base,
196+
multiplier: self.multiplier,
197+
len: self.len,
198+
ptype: self.dtype.as_ptype(),
199+
nullability: self.dtype.nullability(),
200+
}
201+
}
183202
}
184203

185204
impl VTable for SequenceVTable {
@@ -355,7 +374,7 @@ fn execute_iter<P: NativePType, I: Iterator<Item = usize>>(
355374

356375
impl BaseArrayVTable<SequenceVTable> for SequenceVTable {
357376
fn len(array: &SequenceArray) -> usize {
358-
array.length
377+
array.len
359378
}
360379

361380
fn dtype(array: &SequenceArray) -> &DType {
@@ -374,14 +393,14 @@ impl BaseArrayVTable<SequenceVTable> for SequenceVTable {
374393
array.base.hash(state);
375394
array.multiplier.hash(state);
376395
array.dtype.hash(state);
377-
array.length.hash(state);
396+
array.len.hash(state);
378397
}
379398

380399
fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
381400
array.base == other.base
382401
&& array.multiplier == other.multiplier
383402
&& array.dtype == other.dtype
384-
&& array.length == other.length
403+
&& array.len == other.len
385404
}
386405
}
387406

encodings/sequence/src/kernel.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ impl ExecuteParentKernel<SequenceVTable> for SequenceCompareKernel {
8989
// Constant is null - result is all null for comparisons
9090
let nullability = array.dtype().nullability() | constant.dtype().nullability();
9191
let result_array =
92-
ConstantArray::new(Scalar::null(DType::Bool(nullability)), array.length).to_array();
92+
ConstantArray::new(Scalar::null(DType::Bool(nullability)), array.len).to_array();
9393
return Ok(Some(result_array.execute(ctx)?));
9494
};
9595

@@ -125,22 +125,22 @@ fn compare_eq_neq(
125125

126126
// Check if there exists an integer solution to const = base + idx * multiplier
127127
let Some(set_idx) =
128-
find_intersection_scalar(array.base(), array.multiplier(), array.length, constant)
128+
find_intersection_scalar(array.base(), array.multiplier(), array.len, constant)
129129
else {
130130
let result_array = ConstantArray::new(
131131
Scalar::new(DType::Bool(nullability), not_match_val.into()),
132-
array.length,
132+
array.len,
133133
)
134134
.to_array();
135135
return Ok(Some(result_array.execute(ctx)?));
136136
};
137137
let idx = set_idx as u64;
138-
let len = array.length as u64;
138+
let len = array.len as u64;
139139

140140
if len == 1 && set_idx == 0 {
141141
let result_array = ConstantArray::new(
142142
Scalar::new(DType::Bool(nullability), match_val.into()),
143-
array.length,
143+
array.len,
144144
)
145145
.to_array();
146146
return Ok(Some(result_array.execute(ctx)?));
@@ -179,31 +179,31 @@ fn compare_ordering(
179179
let transition = find_transition_point(
180180
array.base(),
181181
array.multiplier(),
182-
array.length,
182+
array.len,
183183
constant,
184184
operator,
185185
);
186186

187187
let result_array = match transition {
188188
Transition::AllTrue => ConstantArray::new(
189189
Scalar::new(DType::Bool(nullability), true.into()),
190-
array.length,
190+
array.len,
191191
)
192192
.to_array(),
193193
Transition::AllFalse => ConstantArray::new(
194194
Scalar::new(DType::Bool(nullability), false.into()),
195-
array.length,
195+
array.len,
196196
)
197197
.to_array(),
198198
Transition::FalseToTrue(idx) => {
199199
// [0..idx) is false, [idx..len) is true
200-
let ends = buffer![idx as u64, array.length as u64].into_array();
200+
let ends = buffer![idx as u64, array.len as u64].into_array();
201201
let values = BoolArray::new(bitbuffer![false, true], nullability.into()).into_array();
202202
RunEndArray::try_new(ends, values)?.into_array()
203203
}
204204
Transition::TrueToFalse(idx) => {
205205
// [0..idx) is true, [idx..len) is false
206-
let ends = buffer![idx as u64, array.length as u64].into_array();
206+
let ends = buffer![idx as u64, array.len as u64].into_array();
207207
let values = BoolArray::new(bitbuffer![true, false], nullability.into()).into_array();
208208
RunEndArray::try_new(ends, values)?.into_array()
209209
}

encodings/sequence/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ mod kernel;
99
/// Represents the equation A\[i\] = a * i + b.
1010
/// This can be used for compression, fast comparisons and also for row ids.
1111
pub use array::SequenceArray;
12+
pub use array::SequenceArrayParts;
1213
/// Represents the equation A\[i\] = a * i + b.
1314
/// This can be used for compression, fast comparisons and also for row ids.
1415
pub use array::SequenceVTable;

vortex-cuda/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ _test-harness = []
2323
[dependencies]
2424
arc-swap = { workspace = true }
2525
async-trait = { workspace = true }
26-
cudarc = { workspace = true }
26+
cudarc = { workspace = true, features = ["f16"] }
2727
fastlanes = { workspace = true }
2828
futures = { workspace = true, features = ["executor"] }
2929
kanal = { workspace = true }
@@ -41,6 +41,7 @@ vortex-fastlanes = { workspace = true }
4141
vortex-io = { workspace = true }
4242
vortex-mask = { workspace = true }
4343
vortex-nvcomp = { path = "nvcomp" }
44+
vortex-sequence = { workspace = true }
4445
vortex-session = { workspace = true }
4546
vortex-utils = { workspace = true }
4647
vortex-zigzag = { workspace = true }
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#include <stdint.h>
5+
6+
template<typename ValueT>
7+
__device__ void sequence(
8+
ValueT *const output,
9+
ValueT base,
10+
ValueT multiplier,
11+
uint64_t len
12+
) {
13+
const uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
14+
if (idx >= len) {
15+
return;
16+
}
17+
18+
output[idx] = static_cast<ValueT>(idx) * multiplier + base;
19+
}
20+
21+
#define GENERATE_KERNEL(ValueT, suffix) \
22+
extern "C" __global__ void sequence_##suffix( \
23+
ValueT *const output, \
24+
ValueT base, \
25+
ValueT multiplier, \
26+
uint64_t len \
27+
) { \
28+
sequence(output, base, multiplier, len); \
29+
}
30+
31+
GENERATE_KERNEL(uint8_t, u8);
32+
GENERATE_KERNEL(uint16_t, u16);
33+
GENERATE_KERNEL(uint32_t, u32);
34+
GENERATE_KERNEL(uint64_t, u64);
35+
GENERATE_KERNEL(int8_t, i8);
36+
GENERATE_KERNEL(int16_t, i16);
37+
GENERATE_KERNEL(int32_t, i32);
38+
GENERATE_KERNEL(int64_t, i64);
39+
GENERATE_KERNEL(float, f32);
40+
GENERATE_KERNEL(double, f64);

vortex-cuda/src/kernel/encodings/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@ mod alp;
55
mod bitpacked;
66
mod decimal_byte_parts;
77
mod for_;
8+
mod sequence;
89
mod zigzag;
910
mod zstd;
1011

1112
pub use alp::ALPExecutor;
1213
pub use bitpacked::BitPackedExecutor;
1314
pub use decimal_byte_parts::DecimalBytePartsExecutor;
1415
pub use for_::FoRExecutor;
16+
pub use sequence::SequenceExecutor;
1517
pub use zigzag::ZigZagExecutor;
1618
pub use zstd::ZstdExecutor;
1719
pub use zstd::ZstdKernelPrep;

0 commit comments

Comments
 (0)