Skip to content
This repository was archived by the owner on May 6, 2024. It is now read-only.

Commit 0ad0de0

Browse files
authored
[POAE7-2912] Comparison Operations Vectorization (#410)
* Enable vectorization for comparison operations. * Support avx2 byteToBit(). * Add is_column_var flag to Analyzer::Expr. * Support comparison operations vectorization. * Fix between() return type issue. * Fix format. * Fix issue.
1 parent aedf4e3 commit 0ad0de0

15 files changed

+437
-54
lines changed

cpp/src/cider/exec/nextgen/context/CodegenContext.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,13 @@ std::string AggExprsInfo::getAggName(SQLAgg agg_type,
258258
}
259259

260260
namespace codegen_utils {
261+
262+
template <jitlib::JITTypeTag Primary,
263+
jitlib::JITTypeTag Sub = jitlib::JITTypeTag::INVALID>
264+
inline bool isaTypedJITValue(const jitlib::JITValue& value) {
265+
return value.getValueTypeTag() == Primary && value.getValueSubTypeTag() == Sub;
266+
}
267+
261268
jitlib::JITValuePointer getArrowArrayLength(jitlib::JITValuePointer& arrow_array) {
262269
CHECK(arrow_array->getValueTypeTag() == JITTypeTag::POINTER);
263270
CHECK(arrow_array->getValueSubTypeTag() == JITTypeTag::INT8);
@@ -397,5 +404,17 @@ void bitBufferAnd(jitlib::JITValuePointer& output,
397404
.params_vector = {output.get(), a.get(), b.get(), bit_num.get()}});
398405
}
399406

407+
void convertByteBoolToBit(jitlib::JITValuePointer& byte,
408+
jitlib::JITValuePointer& bit,
409+
jitlib::JITValuePointer& len) {
410+
CHECK((isaTypedJITValue<JITTypeTag::POINTER, JITTypeTag::INT8>(byte)));
411+
CHECK((isaTypedJITValue<JITTypeTag::POINTER, JITTypeTag::INT8>(bit)));
412+
CHECK((isaTypedJITValue<JITTypeTag::INT64>(len)));
413+
414+
auto& func = byte->getParentJITFunction();
415+
func.emitRuntimeFunctionCall(
416+
"convert_bool_to_bit",
417+
JITFunctionEmitDescriptor{.params_vector = {byte.get(), bit.get(), len.get()}});
418+
}
400419
} // namespace codegen_utils
401420
} // namespace cider::exec::nextgen::context

cpp/src/cider/exec/nextgen/context/CodegenContext.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,10 @@ void bitBufferAnd(jitlib::JITValuePointer& output,
312312
jitlib::JITValuePointer& a,
313313
jitlib::JITValuePointer& b,
314314
jitlib::JITValuePointer& bit_num);
315+
316+
void convertByteBoolToBit(jitlib::JITValuePointer& byte,
317+
jitlib::JITValuePointer& bit,
318+
jitlib::JITValuePointer& len);
315319
} // namespace codegen_utils
316320
} // namespace cider::exec::nextgen::context
317321

cpp/src/cider/exec/nextgen/operators/OperatorRuntimeFunctions.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,4 +205,10 @@ extern "C" ALWAYS_INLINE int64_t extract_join_row_id(int8_t* buffer, int64_t ind
205205
return join_base_value[index].batch_offset;
206206
}
207207

208+
extern "C" NEVER_INLINE void convert_bool_to_bit(uint8_t* byte,
209+
uint8_t* bit,
210+
size_t len) {
211+
CiderBitUtils::byteToBit(byte, bit, len);
212+
}
213+
208214
#endif // NEXTGEN_CIDER_FUNCTION_RUNTIME_FUNCTIONS_H

cpp/src/cider/exec/nextgen/operators/RowToColumnNode.cpp

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@ class ColumnWriter {
3535
ColumnWriter(context::CodegenContext& ctx,
3636
ExprPtr& expr,
3737
JITValuePointer& index,
38-
JITValuePointer& arrow_array_len)
38+
JITValuePointer& arrow_array_len,
39+
bool bitwise_bool)
3940
: context_(ctx)
4041
, expr_(expr)
4142
, arrow_array_(getArrowArrayFromCTX())
4243
, buffers_(getArrowArrayBuffersFromCTX())
4344
, index_(index)
44-
, arrow_array_len_(arrow_array_len) {}
45+
, arrow_array_len_(arrow_array_len)
46+
, bitwise_bool_(bitwise_bool) {}
4547

4648
void write() {
4749
switch (expr_->get_type_info().get_type()) {
@@ -329,21 +331,38 @@ class ColumnWriter {
329331

330332
JITValuePointer setFixSizeRawData(utils::FixSizeJITExprValue& fixsize_val) {
331333
if (expr_->get_type_info().get_type() == kBOOLEAN) {
332-
auto raw_data_buffer = context_.getJITFunction()->createLocalJITValue(
333-
[this]() { return allocateBitwiseBuffer(1); });
334+
auto raw_data_buffer = context_.getJITFunction()->createLocalJITValue([this]() {
335+
if (!bitwise_bool_) {
336+
// Memory padding for AVX2
337+
JITValuePointer prev_len(arrow_array_len_);
338+
auto padded_len = (prev_len + 31) / 32 * 32;
339+
arrow_array_len_.replace(padded_len);
340+
auto mem = allocateBitwiseBuffer(1);
341+
arrow_array_len_.replace(prev_len);
342+
return mem;
343+
} else {
344+
return allocateBitwiseBuffer(1);
345+
}
346+
});
334347
// leverage existing set_null_vector but need opposite value as input
335348
// TODO: (yma11) need check in UT
336-
std::string fname = "set_null_vector_bit";
337-
if (context_.getCodegenOptions().set_null_bit_vector_opt) {
338-
fname = "set_null_vector_bit_opt";
349+
if (bitwise_bool_) {
350+
std::string fname = "set_null_vector_bit";
351+
if (context_.getCodegenOptions().set_null_bit_vector_opt) {
352+
fname = "set_null_vector_bit_opt";
353+
}
354+
context_.getJITFunction()->emitRuntimeFunctionCall(
355+
fname,
356+
JITFunctionEmitDescriptor{
357+
.ret_type = JITTypeTag::VOID,
358+
.params_vector = {{raw_data_buffer.get(),
359+
index_.get(),
360+
(!fixsize_val.getValue()).get()}}});
361+
} else {
362+
auto bool_buffer = raw_data_buffer->castPointerSubType(JITTypeTag::BOOL);
363+
bool_buffer[index_] = *fixsize_val.getValue();
339364
}
340-
context_.getJITFunction()->emitRuntimeFunctionCall(
341-
fname,
342-
JITFunctionEmitDescriptor{
343-
.ret_type = JITTypeTag::VOID,
344-
.params_vector = {{raw_data_buffer.get(),
345-
index_.get(),
346-
(!fixsize_val.getValue()).get()}}});
365+
347366
return raw_data_buffer;
348367
} else {
349368
auto raw_data_buffer = context_.getJITFunction()->createLocalJITValue([this]() {
@@ -415,6 +434,7 @@ class ColumnWriter {
415434
utils::JITExprValue& buffers_;
416435
JITValuePointer& index_;
417436
JITValuePointer& arrow_array_len_;
437+
bool bitwise_bool_;
418438
};
419439

420440
TranslatorPtr RowToColumnNode::toTranslator(const TranslatorPtr& succ) {
@@ -441,10 +461,11 @@ void RowToColumnTranslator::codegenImpl(SuccessorEmitter successor_wrapper,
441461
// Get input ArrowArray length from previous C2RNode
442462
auto prev_c2r_node = static_cast<RowToColumnNode*>(node_.get())->getColumnToRowNode();
443463
auto input_array_len = prev_c2r_node->getColumnRowNum();
464+
bool bitwise_bool = static_cast<RowToColumnNode*>(node_.get())->writeBitwiseBool();
444465

445466
for (int64_t i = 0; i < exprs.size(); ++i) {
446467
ExprPtr& expr = exprs[i];
447-
ColumnWriter writer(context, expr, output_index, input_array_len);
468+
ColumnWriter writer(context, expr, output_index, input_array_len, bitwise_bool);
448469
writer.write();
449470
}
450471
// Update index

cpp/src/cider/exec/nextgen/operators/RowToColumnNode.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,29 @@
2727
namespace cider::exec::nextgen::operators {
2828
class RowToColumnNode : public OpNode {
2929
public:
30-
RowToColumnNode(ExprPtrVector&& output_exprs, ColumnToRowNode* prev_c2r)
30+
RowToColumnNode(ExprPtrVector&& output_exprs,
31+
ColumnToRowNode* prev_c2r,
32+
bool bitwise_bool = true)
3133
: OpNode("RowToColumnNode", std::move(output_exprs), JITExprValueType::BATCH)
32-
, prev_c2r_node_(prev_c2r) {}
34+
, prev_c2r_node_(prev_c2r)
35+
, bitwise_bool_(bitwise_bool) {}
3336

34-
RowToColumnNode(const ExprPtrVector& output_exprs, ColumnToRowNode* prev_c2r)
37+
RowToColumnNode(const ExprPtrVector& output_exprs,
38+
ColumnToRowNode* prev_c2r,
39+
bool bitwise_bool = true)
3540
: OpNode("RowToColumnNode", output_exprs, JITExprValueType::BATCH)
36-
, prev_c2r_node_(prev_c2r) {}
41+
, prev_c2r_node_(prev_c2r)
42+
, bitwise_bool_(bitwise_bool) {}
3743

3844
TranslatorPtr toTranslator(const TranslatorPtr& successor = nullptr) override;
3945

4046
ColumnToRowNode* getColumnToRowNode() { return prev_c2r_node_; }
4147

48+
bool writeBitwiseBool() const { return bitwise_bool_; }
49+
4250
private:
4351
ColumnToRowNode* prev_c2r_node_;
52+
bool bitwise_bool_;
4453
};
4554

4655
class RowToColumnTranslator : public Translator {

0 commit comments

Comments
 (0)