diff --git a/UserConfig.json b/UserConfig.json index 4754c4fd9..94ced9437 100644 --- a/UserConfig.json +++ b/UserConfig.json @@ -6,6 +6,7 @@ "matmul_unroll_factor": 1, "matmul_unroll_jam_factor": 4, "matmul_num_vec_registers": 16, + "use_columnar": false, "use_cuda": false, "use_vectorized_exec": false, "use_obj_ref_mgnt": true, @@ -13,6 +14,7 @@ "use_mlir_codegen": false, "vectorized_single_queue": false, "debug_llvm": false, + "explain_columnar": false, "explain_kernels": false, "explain_llvm": false, "explain_parsing": false, diff --git a/src/api/cli/DaphneUserConfig.h b/src/api/cli/DaphneUserConfig.h index bccbe0b92..41765b121 100644 --- a/src/api/cli/DaphneUserConfig.h +++ b/src/api/cli/DaphneUserConfig.h @@ -36,6 +36,8 @@ class DaphneLogger; */ struct DaphneUserConfig { // Remember to update UserConfig.json accordingly! + + bool use_columnar = false; bool use_cuda = false; bool use_vectorized_exec = false; bool use_distributed = false; @@ -62,6 +64,7 @@ struct DaphneUserConfig { bool enable_profiling = false; bool debug_llvm = false; + bool explain_columnar = false; bool explain_kernels = false; bool explain_llvm = false; bool explain_parsing = false; diff --git a/src/api/internal/daphne_internal.cpp b/src/api/internal/daphne_internal.cpp index 4a8cac63a..8791b843d 100644 --- a/src/api/internal/daphne_internal.cpp +++ b/src/api/internal/daphne_internal.cpp @@ -230,6 +230,9 @@ int startDAPHNE(int argc, const char **argv, DaphneLibResult *daphneLibRes, int "(e.g., dense/sparse)")); static alias selectMatrixReprAlias( // to still support the longer old form "select-matrix-representations", aliasopt(selectMatrixRepr), desc("Alias for --select-matrix-repr")); + static opt useColumnar("columnar", cat(daphneOptions), + desc("Use columnar operations instead of frame/matrix operations for relational query " + "processing where possible")); static opt cuda("cuda", cat(daphneOptions), desc("Use CUDA")); static opt fpgaopencl("fpgaopencl", cat(daphneOptions), desc("Use FPGAOPENCL")); static opt libDir("libdir", cat(daphneOptions), @@ -276,6 +279,7 @@ int startDAPHNE(int argc, const char **argv, DaphneLibResult *daphneLibRes, int "(path to a kernel catalog JSON file).")); enum ExplainArgs { + columnar, kernels, llvm, parsing, @@ -301,6 +305,7 @@ int startDAPHNE(int argc, const char **argv, DaphneLibResult *daphneLibRes, int clEnumVal(parsing, "Show DaphneIR after parsing"), clEnumVal(parsing_simplified, "Show DaphneIR after parsing and some simplifications"), clEnumVal(sql, "Show DaphneIR after SQL parsing"), + clEnumVal(columnar, "Show DaphneIR after lowering to columnar operations"), clEnumVal(property_inference, "Show DaphneIR after property inference"), clEnumVal(select_matrix_repr, "Show DaphneIR after selecting " "physical matrix representations"), @@ -420,6 +425,7 @@ int startDAPHNE(int argc, const char **argv, DaphneLibResult *daphneLibRes, int user_config.debugMultiThreading = debugMultiThreading; user_config.prePartitionRows = prePartitionRows; user_config.distributedBackEndSetup = distributedBackEndSetup; + user_config.use_columnar = useColumnar; if (user_config.use_distributed) { if (user_config.distributedBackEndSetup != ALLOCATION_TYPE::DIST_MPI && user_config.distributedBackEndSetup != ALLOCATION_TYPE::DIST_GRPC_SYNC && @@ -451,6 +457,9 @@ int startDAPHNE(int argc, const char **argv, DaphneLibResult *daphneLibRes, int for (auto explain : explainArgList) { switch (explain) { + case columnar: + user_config.explain_columnar = true; + break; case kernels: user_config.explain_kernels = true; break; diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index 0e5c0c4fd..1287700aa 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -117,11 +117,26 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) { pm.addNestedPass(mlir::daphne::createInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); + if (userConfig_.use_columnar) { + // Rewrite certain matrix/frame ops from linear/relational algebra to columnar ops from column algebra. + pm.addPass(mlir::daphne::createRewriteToColumnarOpsPass()); + // Infer the result types of the newly created columnar ops. + pm.addNestedPass(mlir::daphne::createInferencePass()); + // Simplify the IR. + pm.addPass(mlir::createCanonicalizerPass()); + // Remove unused ops after simplifications. + // TODO The CSE pass seems to eliminate only "one row" of dead code at a time, so we need it as many times as + // the longest chain of ops we reduce; how to apply CSE until a fixpoint? + for (size_t i = 0; i < 5; i++) + pm.addPass(mlir::createCSEPass()); + } + if (userConfig_.explain_columnar) + pm.addPass(mlir::daphne::createPrintIRPass("IR after lowering to columnar ops:")); + if (selectMatrixRepresentations_) { pm.addNestedPass(mlir::daphne::createSelectMatrixRepresentationsPass(userConfig_)); pm.addNestedPass(mlir::createCanonicalizerPass()); } - if (userConfig_.explain_select_matrix_repr) pm.addPass(mlir::daphne::createPrintIRPass("IR after selecting matrix representations:")); diff --git a/src/compiler/inference/InferencePass.cpp b/src/compiler/inference/InferencePass.cpp index 2fe63de74..68ce1a05f 100644 --- a/src/compiler/inference/InferencePass.cpp +++ b/src/compiler/inference/InferencePass.cpp @@ -134,6 +134,8 @@ class InferencePass : public PassWrapper()) t = ft.withSameColumnTypes(); + else if (auto ct = t.dyn_cast()) + t = ct.withSameValueType(); op->getResult(i).setType(t); } return WalkResult::advance(); @@ -482,10 +484,14 @@ class InferencePass : public PassWrapper()) return llvm::isa(mt.getElementType()); - if (auto ft = resType.dyn_cast()) + if (auto ft = resType.dyn_cast()) { for (Type ct : ft.getColumnTypes()) if (llvm::isa(ct)) return true; + return false; + } + if (auto ct = resType.dyn_cast()) + return ct.getValueType().isa(); return false; }); } diff --git a/src/compiler/inference/TypeInferenceUtils.cpp b/src/compiler/inference/TypeInferenceUtils.cpp index 2b0f4da65..77cc9f7b3 100644 --- a/src/compiler/inference/TypeInferenceUtils.cpp +++ b/src/compiler/inference/TypeInferenceUtils.cpp @@ -100,6 +100,8 @@ DataTypeCode getDataTypeCode(mlir::Type t) { return DataTypeCode::FRAME; if (llvm::isa(t)) return DataTypeCode::MATRIX; + if (llvm::isa(t)) + return DataTypeCode::COLUMN; if (auto lt = t.dyn_cast()) return getDataTypeCode(lt.getElementType()); if (CompilerUtils::isScaType(t)) diff --git a/src/compiler/inference/TypeInferenceUtils.h b/src/compiler/inference/TypeInferenceUtils.h index 4ea8ebb45..35ab9e239 100644 --- a/src/compiler/inference/TypeInferenceUtils.h +++ b/src/compiler/inference/TypeInferenceUtils.h @@ -46,6 +46,7 @@ enum class DataTypeCode : uint8_t { // The greater the number, the more general the type. SCALAR, // least general MATRIX, + COLUMN, FRAME, UNKNOWN // most general @@ -136,6 +137,8 @@ template mlir::Type inferTypeByTraits(O *op) { resDtc = DataTypeCode::SCALAR; else if (op->template hasTrait()) resDtc = DataTypeCode::MATRIX; + else if (op->template hasTrait()) + resDtc = DataTypeCode::COLUMN; else if (op->template hasTrait()) resDtc = DataTypeCode::FRAME; @@ -251,8 +254,9 @@ template mlir::Type inferTypeByTraits(O *op) { resVts.push_back(argVts[i][0]); break; } + case DataTypeCode::COLUMN: // fall-through intended case DataTypeCode::SCALAR: - // Append the value type of this input scalar to + // Append the value type of this input scalar/column to // the result column types. resVts.push_back(argVts[i][0]); break; @@ -268,6 +272,7 @@ template mlir::Type inferTypeByTraits(O *op) { } break; case DataTypeCode::MATRIX: // fall-through intended + case DataTypeCode::COLUMN: // fall-through intended case DataTypeCode::SCALAR: resVts = {mostGeneralVt(argVts, numArgsConsider)}; break; @@ -286,7 +291,7 @@ template mlir::Type inferTypeByTraits(O *op) { // Create the result type // -------------------------------------------------------------------- - // It is important to recreate matrix and frame types (not reuse those from + // It is important to recreate matrix, frame, and column types (not reuse those from // the inputs) to get rid of any additional properties (shape, etc.). switch (resDtc) { case DataTypeCode::UNKNOWN: @@ -301,6 +306,9 @@ template mlir::Type inferTypeByTraits(O *op) { case DataTypeCode::FRAME: resTy = daphne::FrameType::get(ctx, resVts); break; + case DataTypeCode::COLUMN: + resTy = daphne::ColumnType::get(ctx, mostGeneralVt(resVts)); + break; } if (resIsList) diff --git a/src/compiler/lowering/CMakeLists.txt b/src/compiler/lowering/CMakeLists.txt index fcbc1c995..d9cb1ed98 100644 --- a/src/compiler/lowering/CMakeLists.txt +++ b/src/compiler/lowering/CMakeLists.txt @@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRDaphneTransforms LowerToLLVMPass.cpp PhyOperatorSelectionPass.cpp RewriteToCallKernelOpPass.cpp + RewriteToColumnarOpsPass.cpp SpecializeGenericFunctionsPass.cpp VectorizeComputationsPass.cpp DaphneOptPass.cpp diff --git a/src/compiler/lowering/LowerToLLVMPass.cpp b/src/compiler/lowering/LowerToLLVMPass.cpp index d6800ff4b..a30dc4199 100644 --- a/src/compiler/lowering/LowerToLLVMPass.cpp +++ b/src/compiler/lowering/LowerToLLVMPass.cpp @@ -913,6 +913,8 @@ void DaphneLowerToLLVMPass::runOnOperation() { [&](daphne::FrameType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 1)); }); typeConverter.addConversion( [&](daphne::ListType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 1)); }); + typeConverter.addConversion( + [&](daphne::ColumnType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 1)); }); typeConverter.addConversion( [&](daphne::StringType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 8)); }); typeConverter.addConversion([&](daphne::VariadicPackType t) { diff --git a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp index 8fea5a0a2..6a61d2442 100644 --- a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp +++ b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp @@ -127,12 +127,15 @@ class KernelReplacement : public RewritePattern { mlir::Type adaptType(mlir::Type t, bool generalizeToStructure) const { MLIRContext *mctx = t.getContext(); - if (generalizeToStructure && t.isa()) + if (generalizeToStructure && t.isa()) return mlir::daphne::StructureType::get(mctx); if (auto mt = t.dyn_cast()) return mt.withSameElementTypeAndRepr(); if (t.isa()) return mlir::daphne::FrameType::get(mctx, {mlir::daphne::UnknownType::get(mctx)}); + if (auto ct = t.dyn_cast()) + return ct.withSameValueType(); if (auto lt = t.dyn_cast()) return mlir::daphne::ListType::get(mctx, adaptType(lt.getElementType(), generalizeToStructure)); if (auto mrt = t.dyn_cast()) { diff --git a/src/compiler/lowering/RewriteToColumnarOpsPass.cpp b/src/compiler/lowering/RewriteToColumnarOpsPass.cpp new file mode 100644 index 000000000..e6edb5bd3 --- /dev/null +++ b/src/compiler/lowering/RewriteToColumnarOpsPass.cpp @@ -0,0 +1,549 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using namespace mlir; + +namespace { + +// ******************************************************************************** +// Helper Functions +// ******************************************************************************** + +/** + * @brief Applies a columnar projection (`ColProjectOp`) to all columns of the given frame. + * + * @param rewriter a `PatternRewriter` + * @param loc the location to use for newly created ops + * @param argFrm the frame whose columns shall be projected + * @param selPosLstCol the position list column to use for the columnar projection + * @param colLabelsStr the labels of the columns + * @param cols the projected columns (used for the result) + * @param colLabelsVal the labels of the columns (newly created `mlir::daphne::ConstantOp`s, used for the result) + */ +void projectAllColumns(PatternRewriter &rewriter, Location loc, Value argFrm, Value selPosLstCol, + std::vector colLabelsStr, std::vector &cols, + std::vector &colLabelsVal) { + // Some types. + Type u = daphne::UnknownType::get(rewriter.getContext()); + Type cu = daphne::ColumnType::get(rewriter.getContext(), u); + Type mu = daphne::MatrixType::get(rewriter.getContext(), u); + + for (std::string colLabelStr : colLabelsStr) { + // Create a ConstantOp with the label of the column, keep it for later use in CreateFrameOp. + Value colLabelVal = rewriter.create(loc, colLabelStr); + colLabelsVal.push_back(colLabelVal); + // Extract the column by its label; the outcome is a single-column frame. + Value argColFrm = rewriter.create(loc, u, argFrm, colLabelVal); + // Cast the single-column frame to a column. + Value argColCol = rewriter.create(loc, cu, argColFrm); + // Apply the columnar projection; the outcome is a column. + Value resColCol = rewriter.create(loc, u, argColCol, selPosLstCol); + // Cast the column to a matrix, keep it for later use in CreateFrameOp. + Value resColMat = rewriter.create(loc, mu, resColCol); + cols.push_back(resColMat); + } +} + +template LogicalResult replaceCompareOp(PatternRewriter &rewriter, MatCmpOp op) { + // Get the location of the op to replace; we will use it for all newly created ops. + Location loc = op->getLoc(); + + // Some types. + Type u = daphne::UnknownType::get(rewriter.getContext()); + Type cu = daphne::ColumnType::get(rewriter.getContext(), u); + Type mu = daphne::MatrixType::get(rewriter.getContext(), u); + + // Get the left input, which is a matrix (otherwise, op would be legal). + Value lhsMat = op.getLhs(); + // Cast the left input to a column. + Value lhsCol = rewriter.create(loc, cu, lhsMat); + + // Get the right input, which is a scalar (otherwise, op would be legal). + Value rhsSca = op.getRhs(); + + // Create the columnar comparison op. The outcome is a position list represented as a column. + Value resPosLstCol = rewriter.create(loc, u, lhsCol, rhsSca); + + // Cast the position list to a single-column matrix. + Value resPosLstMat = rewriter.create(loc, mu, resPosLstCol); + + // Convert the position list to a bit vector and replace the original comparison op by it. + Value numRows = rewriter.create(loc, rewriter.getIndexType(), lhsMat); + rewriter.replaceOpWithNewOp(op, op.getRes().getType(), resPosLstMat, numRows); + + return success(); +} + +template LogicalResult replaceLogicalOp(PatternRewriter &rewriter, MatLogOp op) { + // Get the location of the op to replace; we will use it for all newly created ops. + Location loc = op->getLoc(); + + // Some types. + Type u = daphne::UnknownType::get(rewriter.getContext()); + Type cu = daphne::ColumnType::get(rewriter.getContext(), u); + Type mu = daphne::MatrixType::get(rewriter.getContext(), u); + + // Get the left and right inputs, which are matrices (otherwise, op would be legal) containing bitmaps. + Value lhsBitmapMat = op.getLhs(); + Value rhsBitmapMat = op.getRhs(); + // Convert the left and right inputs to position lists. + Value lhsPosLstMat = rewriter.create(loc, u, lhsBitmapMat); + Value rhsPosLstMat = rewriter.create(loc, u, rhsBitmapMat); + // Cast the input position lists to columns. + Value lhsPosLstCol = rewriter.create(loc, cu, lhsPosLstMat); + Value rhsPosLstCol = rewriter.create(loc, cu, rhsPosLstMat); + + // Create the columnar set op. The outcome is a position list represented as a column. + Value resPosLstCol = rewriter.create(loc, u, lhsPosLstCol, rhsPosLstCol); + + // Cast the position list to a single-column matrix. + Value resPosLstMat = rewriter.create(loc, mu, resPosLstCol); + + // Convert the position list to a bit vector and replace the original logical op by it. + Value numRows = rewriter.create(loc, rewriter.getIndexType(), lhsBitmapMat); + rewriter.replaceOpWithNewOp(op, op.getRes().getType(), resPosLstMat, numRows); + + return success(); +} + +template LogicalResult replaceExtractRowOp(PatternRewriter &rewriter, FrmOp op, Value selPosLstMat) { + // Get the location of the op to replace; we will use it for all newly created ops. + Location loc = op->getLoc(); + + // Some types. + Type u = daphne::UnknownType::get(rewriter.getContext()); + Type cu = daphne::ColumnType::get(rewriter.getContext(), u); + + // We need the information which rows to extract as a position list contained in a column. + Value selPosLstCol = rewriter.create(loc, cu, selPosLstMat); + + Value src = op.getSource(); + + auto srcFrmTy = src.getType().dyn_cast(); + // For now, we only replace the op, if it is applied to a frame. + // TODO We could also support matrices as source. + if (!srcFrmTy) + return failure(); + + std::vector *srcColLabels = srcFrmTy.getLabels(); + // The column labels of the source must be known (current requirement by projectAllColumns()). + // TODO We could relax this requirement. + if (!srcColLabels) + return failure(); + + // Project all columns of the source separately using columnar operations. + std::vector cols; + std::vector colLabels; + projectAllColumns(rewriter, loc, src, selPosLstCol, *srcColLabels, cols, colLabels); + + // Replace the op by a new CreateFrameOp consisting of the individually processed columns. + // The result type is the same as before the rewrite. This includes the column order and column labels. + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), cols, colLabels); + + return success(); +} + +template LogicalResult replaceBinaryOp(PatternRewriter &rewriter, MatBinOp op) { + // Get the location of the op to replace; we will use it for all newly created ops. + Location loc = op->getLoc(); + + // Some types. + Type u = daphne::UnknownType::get(rewriter.getContext()); + Type cu = daphne::ColumnType::get(rewriter.getContext(), u); + + // Get the left and right inputs, which are matrices (otherwise, op would be legal). + Value lhsMat = op.getLhs(); + Value rhsMat = op.getRhs(); + // Cast the inputs to columns. + Value lhsCol = rewriter.create(loc, cu, lhsMat); + Value rhsCol = rewriter.create(loc, cu, rhsMat); + + // Create the columnar binary op, whose result is a column. + Value resCol = rewriter.create(loc, u, lhsCol, rhsCol); + + // Cast the result column to a single-column matrix and replace the original comparison op by it. + rewriter.replaceOpWithNewOp(op, op.getRes().getType(), resCol); + + return success(); +} + +template +LogicalResult replaceAllAggOp(PatternRewriter &rewriter, MatAllAggOp op) { + // Get the location of the op to replace; we will use it for all newly created ops. + Location loc = op->getLoc(); + + // Some types. + Type u = daphne::UnknownType::get(rewriter.getContext()); + Type cu = daphne::ColumnType::get(rewriter.getContext(), u); + + // Get the input, which is a matrix (otherwise, op would be legal). + Value argMat = op.getArg(); + // Cast the input to a column. + Value argCol = rewriter.create(loc, cu, argMat); + + // Create the columnar aggregation op, whose result is a single-element column. + Value resCol = rewriter.create(loc, u, argCol); + + // Cast the result column to a scalar (same type of before the rewrite) and replace the original aggregation op by + // it. + rewriter.replaceOpWithNewOp(op, op.getRes().getType(), resCol); + + return success(); +} + +// ******************************************************************************** +// Rewrite Patterns +// ******************************************************************************** + +struct ColumnarOpReplacement : public RewritePattern { + + ColumnarOpReplacement(MLIRContext *context, PatternBenefit benefit = 1) + : RewritePattern(Pattern::MatchAnyOpTypeTag(), benefit, context) {} + + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + // Note that all unknown data/value types we introduce when creating new ops are infered in subsequent compiler + // passes. We only specify the necessary type information, i.e., the result data type of casts. + Type u = daphne::UnknownType::get(rewriter.getContext()); + Type cu = daphne::ColumnType::get(rewriter.getContext(), u); + Type mu = daphne::MatrixType::get(rewriter.getContext(), u); + + if (auto cmpOp = llvm::dyn_cast(op)) { + return replaceCompareOp(rewriter, cmpOp); + } else if (auto cmpOp = llvm::dyn_cast(op)) { + return replaceCompareOp(rewriter, cmpOp); + } else if (auto cmpOp = llvm::dyn_cast(op)) { + return replaceCompareOp(rewriter, cmpOp); + } else if (auto cmpOp = llvm::dyn_cast(op)) { + return replaceCompareOp(rewriter, cmpOp); + } else if (auto cmpOp = llvm::dyn_cast(op)) { + return replaceCompareOp(rewriter, cmpOp); + } else if (auto cmpOp = llvm::dyn_cast(op)) { + return replaceCompareOp(rewriter, cmpOp); + } else if (auto logOp = llvm::dyn_cast(op)) { + return replaceLogicalOp(rewriter, logOp); + } else if (auto logOp = llvm::dyn_cast(op)) { + return replaceLogicalOp(rewriter, logOp); + } else if (auto frOp = llvm::dyn_cast(op)) { + // Get the location of the op to replace; we will use it for all newly created ops. + Location loc = frOp->getLoc(); + + // For FilterRowOp, the rows to extract are given as a bit vector (0/1 values) contained in a matrix. + // replaceExtractRowOp() needs them as a position list contained in a matrix. + Value selBitVecMat = frOp.getSelectedRows(); + Value selPosLstMat = rewriter.create(loc, u, selBitVecMat); + + return replaceExtractRowOp(rewriter, frOp, selPosLstMat); + } else if (auto erOp = llvm::dyn_cast(op)) { + // For ExtractRowOp, the rows to extract are given as a positions list contained in a matrix, which is what + // replaceExtractRowOp() needs. + Value selPosLstMat = erOp.getSelectedRows(); + + return replaceExtractRowOp(rewriter, erOp, selPosLstMat); + } else if (auto binOp = llvm::dyn_cast(op)) { + return replaceBinaryOp(rewriter, binOp); + } else if (auto binOp = llvm::dyn_cast(op)) { + return replaceBinaryOp(rewriter, binOp); + } else if (auto aggOp = llvm::dyn_cast(op)) { + return replaceAllAggOp(rewriter, aggOp); + } else if (auto ijOp = llvm::dyn_cast(op)) { + // Get the location of the op to replace; we will use it for all newly created ops. + Location loc = ijOp->getLoc(); + + Value lhs = ijOp.getLhs(); // left input frame + Value rhs = ijOp.getRhs(); // right input frame + Value lhsOn = ijOp.getLhsOn(); // key column label in the left input + Value rhsOn = ijOp.getRhsOn(); // key column label in the rifht input + + auto lhsFrmTy = lhs.getType().dyn_cast(); + auto rhsFrmTy = rhs.getType().dyn_cast(); + // Both inputs must be frames. + if (!lhsFrmTy || !rhsFrmTy) + return failure(); + + // Extract the key columns as single-column frames and cast them to columns. + Value lhsKeysMat = rewriter.create(loc, u, lhs, lhsOn); + Value rhsKeysMat = rewriter.create(loc, u, rhs, rhsOn); + Value lhsKeysCol = rewriter.create(loc, cu, lhsKeysMat); + Value rhsKeysCol = rewriter.create(loc, cu, rhsKeysMat); + + // Perform the columnar inner join. + auto cjOp = rewriter.create(loc, u, u, lhsKeysCol, rhsKeysCol, ijOp.getNumRowRes()); + Value lhsPosLstCol = cjOp.getResLhsPos(); + Value rhsPosLstCol = cjOp.getResRhsPos(); + + std::vector *lhsColLabels = lhsFrmTy.getLabels(); + std::vector *rhsColLabels = rhsFrmTy.getLabels(); + // The column labels of both inputs must be known (current requirement by projectAllColumns()). + // TODO We could relax this requirement. + if (!lhsColLabels || !rhsColLabels) + return failure(); + + // Project all columns of the left and right inputs on the matching rows separately using columnar + // operations. + std::vector resCols; + std::vector resColLabels; + projectAllColumns(rewriter, loc, lhs, lhsPosLstCol, *lhsColLabels, resCols, resColLabels); + projectAllColumns(rewriter, loc, rhs, rhsPosLstCol, *rhsColLabels, resCols, resColLabels); + + // Replace the InnerJoinOp by a new CreateFrameOp consisting of the individually processed columns. + // The result type is the same as before the rewrite. This includes the column order and column labels. + rewriter.replaceOpWithNewOp(op, ijOp.getRes().getType(), resCols, resColLabels); + + return success(); + } else if (auto sjOp = llvm::dyn_cast(op)) { + // Get the location of the op to replace; we will use it for all newly created ops. + Location loc = sjOp->getLoc(); + + Value lhs = sjOp.getLhs(); // left input frame + Value rhs = sjOp.getRhs(); // right input frame + Value lhsOn = sjOp.getLhsOn(); // key column label in the left input + Value rhsOn = sjOp.getRhsOn(); // key column label in the rifht input + + auto lhsFrmTy = lhs.getType().dyn_cast(); + auto rhsFrmTy = rhs.getType().dyn_cast(); + // Both inputs must be frames. + if (!lhsFrmTy || !rhsFrmTy) + return failure(); + + // Extract the key columns as single-column frames and cast them to columns. + Value lhsKeysMat = rewriter.create(loc, u, lhs, lhsOn); + Value rhsKeysMat = rewriter.create(loc, u, rhs, rhsOn); + Value lhsKeysCol = rewriter.create(loc, cu, lhsKeysMat); + Value rhsKeysCol = rewriter.create(loc, cu, rhsKeysMat); + + // Perform the columnar semi join. + Value lhsPosLstCol = + rewriter.create(loc, u, lhsKeysCol, rhsKeysCol, sjOp.getNumRowRes()); + + std::vector *lhsColLabels = lhsFrmTy.getLabels(); + // The column labels of the left inputs must be known (current requirement by projectAllColumns()). + // TODO We could relax this requirement. + if (!lhsColLabels) + return failure(); + + // Project all columns of the left input on the matching rows separately using columnar operations. + std::vector resCols; + std::vector resColLabels; + projectAllColumns(rewriter, loc, lhs, lhsPosLstCol, *lhsColLabels, resCols, resColLabels); + + // Replace the SemiJoinOp by two new results: a new CreateFrameOp consisting of the individually processed + // columns and the positions of the rows in the left input that have a join partner in the right input. The + // result types are the same as before the rewrite. This includes the column order and column labels. + Value res = rewriter.create(loc, sjOp.getRes().getType(), resCols, resColLabels); + Value lhsPosLstMat = rewriter.create(loc, sjOp.getLhsTids().getType(), lhsPosLstCol); + rewriter.replaceOp(op, {res, lhsPosLstMat}); + + return success(); + } else if (auto gOp = llvm::dyn_cast(op)) { + // Get the location of the op to replace; we will use it for all newly created ops. + Location loc = gOp->getLoc(); + + Value arg = gOp.getFrame(); // input frame + ValueRange keyLabels = gOp.getKeyCol(); // labels of the columns to group on + ValueRange aggLabels = gOp.getAggCol(); // labels of the columns to aggregate + + auto argFrmTy = arg.getType().dyn_cast(); + // The input must be a frame. + if (!argFrmTy) + return failure(); + + std::vector resCols; + std::vector resColLabels; + + // Find out the group ids and representative positions. + std::vector keyCols; + // Process the first key column. + Value keyMat = rewriter.create(loc, u, arg, keyLabels[0]); + Value keyCol = rewriter.create(loc, cu, keyMat); + keyCols.push_back(keyCol); + auto cgfOp = rewriter.create(loc, u, u, keyCol); + Value grpIds = cgfOp.getResGrpIds(); + Value reprPos = cgfOp.getResReprPos(); + // Process the remaining key columns. + for (size_t i = 1; i < keyLabels.size(); i++) { + keyMat = rewriter.create(loc, u, arg, keyLabels[i]); + keyCol = rewriter.create(loc, cu, keyMat); + keyCols.push_back(keyCol); + // Use the group ids from the previous ColGroupFirstOp/ColGroupNextOp. + auto cgnOp = rewriter.create(loc, u, u, keyCol, grpIds); + grpIds = cgnOp.getResGrpIds(); + reprPos = cgnOp.getResReprPos(); + } + + // Extract the representatives from all key columns. + for (size_t i = 0; i < keyCols.size(); i++) { + Value keyColReprsCol = rewriter.create(loc, u, keyCols[i], reprPos); + Value keyColReprsMat = rewriter.create(loc, mu, keyColReprsCol); + resCols.push_back(keyColReprsMat); + resColLabels.push_back(keyLabels[i]); + } + + // Grouped aggregation on all aggregation columns. + Value numDistinct = rewriter.create(loc, u, reprPos); + for (size_t i = 0; i < aggLabels.size(); i++) { + Value aggMat = rewriter.create(loc, u, arg, aggLabels[i]); + Value aggCol = rewriter.create(loc, cu, aggMat); + Value aggedCol = rewriter.create(loc, u, aggCol, grpIds, numDistinct); + Value aggedMat = rewriter.create(loc, mu, aggedCol); + resCols.push_back(aggedMat); + // TODO Don't hardcode "SUM(", do it like the group kernel on frames does it. + Value newLabel = rewriter.create( + loc, u, + rewriter.create( + loc, u, rewriter.create(loc, std::string("SUM(")), aggLabels[i]), + rewriter.create(loc, std::string(")"))); + resColLabels.push_back(newLabel); + } + + // Replace the GroupOp by a new CreateFrameOp consisting of the individually processed columns. + // The result type is the same as before the rewrite. This includes the column order and column labels. + rewriter.replaceOpWithNewOp(op, gOp.getRes().getType(), resCols, resColLabels); + + return success(); + } + + // This should never happen (all ops to be replaced should be handled above). + return failure(); + } +}; + +// ******************************************************************************** +// Compiler Pass +// ******************************************************************************** + +/** + * @brief Rewrites certain matrix/frame ops from linear/relational algebra to columnar ops from column algebra. + * + * The general idea is to identify and replace individual matrix/frame ops that (depending on the op and the types, + * shapes, etc. of its arguments) could be expressed by columnar ops. Then, each of these ops is replaced in isolation + * by creating casts/conversions of its arguments as needed, creating the columnar op(s), and creating casts/conversions + * of the results as needed. In the end, the results of the rewritten DAG of operations are the same as of the replaced + * op. After these replacements of individual ops, the IR may contain lots of redundant operations or operations + * elimiating each other's effects. Such issues are not addressed by this pass, but are subject to simplifications in + * subsequent passes. + */ +struct RewriteToColumnarOpsPass : public PassWrapper> { + + void runOnOperation() final; +}; +} // namespace + +void RewriteToColumnarOpsPass::runOnOperation() { + auto module = getOperation(); + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addLegalOp(); + + // Rewrite elementwise comparisons, but only if the left-hand-side operand is a matrix with exactly one column and + // the right-hand-side operand is a scalar. + target.addDynamicallyLegalOp([](Operation *op) { + Type lhsTy = op->getOperand(0).getType(); + Type rhsTy = op->getOperand(1).getType(); + return !(CompilerUtils::isMatTypeWithSingleCol(lhsTy) && CompilerUtils::isScaType(rhsTy)); + }); + // Rewrite elementwise logical operations, but only if their arguments are the results of columnar select ops or set + // ops (these are the ops that result from relational algebra selection). However, depending on the order these + // rewrite patterns are applied, the arguments of EwAndOp/EwOrOp may not have been rewritten yet. In that case, we + // check if the arguments are defined by operations that would normally get rewritten by this pass. + target.addDynamicallyLegalOp([](Operation *op) { + for (size_t i = 0; i < op->getNumOperands(); i++) { + // Check each argument individually. + Value arg = op->getOperand(i); + if (auto plbmcOp = arg.getDefiningOp()) { + // In case the inputs have already been rewritten. + if (auto cOp = plbmcOp.getArg().getDefiningOp()) { + bool isArgColTy = cOp.getArg().getType().isa(); + bool isResMatTy = cOp.getRes().getType().isa(); + if (isArgColTy && isResMatTy) { + if (auto defOp = cOp.getArg().getDefiningOp()) { + // TODO We could define a trait for these ops. + if (!llvm::isa(defOp)) + continue; + } + } + } + } else if (auto defOp = arg.getDefiningOp()) { + // In case the inputs have not been rewritten yet. + // TODO Double-check if these ops are really illegal ones. + if (llvm::isa(defOp)) + continue; + } + return true; + } + return false; + }); + // Rewrite elementwise binary ops (other than comparisons), but only if both operands are a matrix with exactly one + // column. + target.addDynamicallyLegalOp([](Operation *op) { + Type lhsTy = op->getOperand(0).getType(); + Type rhsTy = op->getOperand(1).getType(); + return !(CompilerUtils::isMatTypeWithSingleCol(lhsTy) && CompilerUtils::isMatTypeWithSingleCol(rhsTy)); + }); + // Rewrite full aggregation ops, but only if the argument is a matrix with a single column. + target.addDynamicallyLegalOp([](Operation *op) { + Type argTy = op->getOperand(0).getType(); + return !CompilerUtils::isMatTypeWithSingleCol(argTy); + }); + // Rewrite FilterRowOp and ExtractRowOp, but only if the source is a frame. + // TODO We could also support matrix inputs. + // TODO Check if the frame labels are known (current requirement of the rewrite code, which could be relaxed). + target.addDynamicallyLegalOp([](Operation *op) { + Type argTy = op->getOperand(0).getType(); + return !llvm::isa(argTy); + }); + // Always rewrite InnerJoinOp and SemiJoinOp. + target.addIllegalOp(); + // Rewrite GroupOp, but only if all aggregation functions are SUM. + // TODO We could also support other aggregation functions. + target.addDynamicallyLegalOp([](Operation *op) { + auto gOp = llvm::dyn_cast(op); + ArrayAttr aggFuncs = gOp.getAggFuncs(); + return !llvm::all_of(aggFuncs.getValue(), [](Attribute af) { + return af.dyn_cast().getValue() == daphne::GroupEnum::SUM; + }); + ; + }); + + patterns.add(&getContext()); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +std::unique_ptr daphne::createRewriteToColumnarOpsPass() { return std::make_unique(); } diff --git a/src/compiler/utils/CompilerUtils.h b/src/compiler/utils/CompilerUtils.h index 54c758851..c135bf85a 100644 --- a/src/compiler/utils/CompilerUtils.h +++ b/src/compiler/utils/CompilerUtils.h @@ -182,6 +182,18 @@ struct CompilerUtils { if (generalizeToStructure) return "Structure"; return "Frame"; + } else if (auto colTy = t.dyn_cast()) { + if (generalizeToStructure) + return "Structure"; + // For columns of strings we use `std::string` as the value type, while for string scalars we use + // `const char *` as the value type. Thus, we need this special case here. Maybe we can do it without a + // special case in the future. + std::string vtName; + if (colTy.getValueType().isa()) + vtName = "std::string"; + else + vtName = mlirTypeToCppTypeName(colTy.getValueType(), angleBrackets, false); + return angleBrackets ? ("Column<" + vtName + ">") : ("Column_" + vtName); } else if (auto lstTy = t.dyn_cast()) { if (generalizeToStructure) return "Structure"; @@ -248,7 +260,8 @@ struct CompilerUtils { * @brief Returns `true` if the given type is a DAPHNE data object type; or `false`, otherwise. */ [[maybe_unused]] static bool isObjType(mlir::Type t) { - return llvm::isa(t); + return llvm::isa(t); } /** @@ -256,6 +269,14 @@ struct CompilerUtils { */ [[maybe_unused]] static bool hasObjType(mlir::Value v) { return isObjType(v.getType()); } + /** + * @brief Returns `true` if the given type is a DAPHNE matrix that has exactly one column; or `false`, otherwise. + */ + static bool isMatTypeWithSingleCol(mlir::Type t) { + auto mt = t.dyn_cast(); + return mt && mt.getNumCols() == 1; + } + /** * @brief Returns `true` if the given type is a DAPHNE scalar type (or: value type); or `false`, otherwise. */ @@ -286,6 +307,7 @@ struct CompilerUtils { * - For the unknown type, the unknown type is returned. * - For matrices, the value type is extracted. * - For frames, an error is thrown. + * - For columns, the value type is extracted. * - For lists, this function is called recursively on the element type. * - For scalars, the type itself is returned. * - For anything else, an error is thrown. @@ -301,6 +323,8 @@ struct CompilerUtils { if (auto ft = t.dyn_cast()) throw std::runtime_error( "getValueType() doesn't support frames yet"); // TODO maybe use the most general value type + if (auto ct = t.dyn_cast()) + return ct.getValueType(); if (auto lt = t.dyn_cast()) return getValueType(lt.getElementType()); if (isScaType(t)) @@ -318,6 +342,7 @@ struct CompilerUtils { * - For the unknown type, the unknown type is returned (single-element sequence). * - For matrices, the value type is extracted (single-element sequence). * - For frames, the sequence of column types is returned. + * - For columns, the value type is extracted (single-element sequence). * - For lists, this function is called recursively on the element type. * - For scalars, the type itself is returned (single-element sequence). * - For anything else, an error is thrown. @@ -332,6 +357,8 @@ struct CompilerUtils { return {mt.getElementType()}; if (auto ft = t.dyn_cast()) return ft.getColumnTypes(); + if (auto ct = t.dyn_cast()) + return {ct.getValueType()}; if (auto lt = t.dyn_cast()) return getValueTypes(lt.getElementType()); if (isScaType(t)) @@ -350,6 +377,7 @@ struct CompilerUtils { * - For the unknown type, an error is thrown. * - For matrices, the value type is set to the given value type. * - For frames, an error is thrown. + * - For columns, the value type is set to the given value type. * - For lists, this function is called recursively on the element type. * - For scalars, the given value type is returned. * - For anything else, an error is thrown. @@ -365,6 +393,8 @@ struct CompilerUtils { return mt.withElementType(vt); if (auto ft = t.dyn_cast()) throw std::runtime_error("setValueType() doesn't support frames yet"); // TODO + if (auto ct = t.dyn_cast()) + return ct.withValueType(vt); if (auto lt = t.dyn_cast()) return setValueType(lt.getElementType(), vt); if (isScaType(t)) diff --git a/src/ir/daphneir/Canonicalize.cpp b/src/ir/daphneir/Canonicalize.cpp index a70067d45..e769bdc3b 100644 --- a/src/ir/daphneir/Canonicalize.cpp +++ b/src/ir/daphneir/Canonicalize.cpp @@ -595,4 +595,197 @@ mlir::LogicalResult mlir::daphne::EwMinusOp::canonicalize(mlir::daphne::EwMinusO return mlir::success(); } return mlir::failure(); +} + +/** + * @brief Eliminates redundant conversions of a position list to a bitmap and back to a position list. + * + * This pattern frequently occurs during lowering to columnar operations. This simplification rewrite avoids the + * unnecessary creation of the intermediate bitmap. + */ +mlir::LogicalResult mlir::daphne::ConvertBitmapToPosListOp::canonicalize(mlir::daphne::ConvertBitmapToPosListOp bmplcOp, + PatternRewriter &rewriter) { + if (auto plbmcOp = bmplcOp.getArg().getDefiningOp()) { + // The ConvertPosListToBitmapOp has a second argument for the number of rows (bitmap size). No matter what this + // size is, we always get back the original position list. The only exception would be if the bitmap size is + // less than the greatest position in the original position list (information loss or error during conversion + // from position list to bitmap). However, this case is not relevant to us at the moment. + rewriter.replaceOp(bmplcOp, plbmcOp.getArg()); + return mlir::success(); + } + return mlir::failure(); +} + +/** + * @brief Simplifies the extraction of a single column (by its label) from a frame in various situations. + */ +mlir::LogicalResult mlir::daphne::ExtractColOp::canonicalize(mlir::daphne::ExtractColOp ecOp, + PatternRewriter &rewriter) { + Value src = ecOp.getSource(); + Value sel = ecOp.getSelectedCols(); + if (auto srcFrmTy = src.getType().dyn_cast()) { + if (sel.getType().isa()) { + if (srcFrmTy.getNumCols() == 1) { + // Eliminate the extraction of a single column (by its label) from a frame with a single column which + // has exactly that label. + + if (std::vector *labels = srcFrmTy.getLabels()) { + std::pair selConst = CompilerUtils::isConstant(sel); + if (selConst.first && selConst.second == (*labels)[0]) { + rewriter.replaceOp(ecOp, src); + return mlir::success(); + } + } + } else if (auto cfOp = src.getDefiningOp()) { + // Replace the extraction of a single column (by its label) from the result of a CreateFrameOp, by the + // respective input column of the CreateFrameOp. This simplification rewrite can help us avoid the + // unnecessary creation of frames when we are interested only in certain columns later on. It can lead + // to the elimination of the operations creating later unused input columns of the CreateFrameOp. + + // CreateFrameOp always has an even number of arguments. The first half are the columns and the second + // half are the labels. + const size_t cfNumCols = cfOp->getNumOperands() / 2; + // Search for the column with the specified label. + for (size_t i = 0; i < cfNumCols; i++) { + Value label = cfOp->getOperand(cfNumCols + i); + if (label == sel) { + // We need to insert an additional cast of the CreateFrameOp's input column to the result type + // of the ExtractColOp, because usually, the arguments to CreateFrameOp are single-column + // *matrices*, while the result of ExtractColOp on a frame is a single-column *frame*. Such + // additional casts will often be eliminated through the canonicalization of CastOp later. + Type resTy = ecOp.getResult().getType(); + if (auto resFrmTy = resTy.dyn_cast()) { + // Reset the labels to unknown, because TODO + Value casted = rewriter + .create( + ecOp.getLoc(), resFrmTy.withLabels(nullptr), cfOp->getOperand(i)) + .getResult(); + // If the result of the ExtractColOp is a (single-column) frame (typically the case), we + // must make sure that it gets the right column label after the rewrite. The label is an + // argument to the CreateFrameOp and might not be known as a compile-time constant at the + // point in time when this rewrite happens (the label could be the result of a complex + // string expression, which is resolved later during compile-time through constant folding + // or only at run-time). Thus, we additionally insert a SetColLabelsOp which reuses exactly + // the same input as the CreateFrameOp for the label. + Value labeled = + rewriter.create(ecOp.getLoc(), resTy, casted, label) + .getResult(); + rewriter.replaceOp(ecOp, labeled); + return success(); + } + } + } + } else if (auto cbOp = src.getDefiningOp()) { + // Push down the extraction of a single column (by its label) from the result of a ColBindOp to the + // argument of the ColBindOp that has the column with the desired label. + + std::pair selConst = CompilerUtils::isConstant(sel); + if (selConst.first) { + auto tryOneArg = [&](Value arg) { + if (auto argFrmTy = arg.getType().dyn_cast()) + if (argFrmTy.getLabels() != nullptr) + for (std::string label : *(argFrmTy.getLabels())) + if (label == selConst.second) { + rewriter.replaceOpWithNewOp(ecOp, ecOp.getType(), + arg, sel); + return true; + } + return false; + }; + if (tryOneArg(cbOp.getLhs()) || tryOneArg(cbOp.getRhs())) + return mlir::success(); + } + } else if (auto ecOp2 = src.getDefiningOp()) { + // Eliminate two subsequent extractions of a single column (by its label) with the same label. + + if (ecOp2.getSelectedCols() == sel) { + rewriter.replaceOpWithNewOp(ecOp, ecOp.getResult().getType(), + ecOp2.getResult()); + return mlir::success(); + } + } + } + } + return mlir::failure(); +} + +/** + * @brief Simplifies various patterns of ops that end with a CastOp. + * + * Simple examples include the elimination of trivial casts (casting from A to A) and the simplification of chains of + * casts (e.g., casting from A to B to C becomes casting from A to C, in case there is no information loss). Besides + * that, there are patterns that by-pass ops that create the input of a CastOp and patterns that by-pass the CastOp + * itself. + */ +mlir::LogicalResult mlir::daphne::CastOp::canonicalize(mlir::daphne::CastOp cOp, PatternRewriter &rewriter) { + // TODO Maybe skip property casts, or separate them, combine multiple property casts. + + // Replace cast "a -> a". + if (cOp.isTrivialCast()) { + rewriter.replaceOp(cOp, cOp.getArg()); + return mlir::success(); + } + + // Replace cast "a -> b -> c" to cast "a -> c", if the cast "a -> b" does not lose information, because if "b" + // contains the same information as "a", we could directly cast from "a" to "c". + // It does not matter if "b -> c" may lose information, because "b" does not have more information than "a". + if (auto cOp0 = cOp.getArg().getDefiningOp()) { + if (!cOp0.mightLoseInformation()) { + rewriter.replaceOpWithNewOp(cOp, cOp.getRes().getType(), cOp0.getArg()); + return mlir::success(); + } + } + + // Bypass operations manipulating frame column labels in case their result is casted to a data type that does not + // support column labels (matrix/column). This is sound, since the column label information from the frame would be + // gone after the cast anyway. + if (cOp.getArg().getType().isa() && + cOp.getRes().getType().isa()) { + if (auto sclOp = cOp.getArg().getDefiningOp()) { + rewriter.replaceOpWithNewOp(cOp, cOp.getRes().getType(), sclOp.getArg()); + return mlir::success(); + } + if (auto sclpOp = cOp.getArg().getDefiningOp()) { + rewriter.replaceOpWithNewOp(cOp, cOp.getRes().getType(), sclpOp.getArg()); + return mlir::success(); + } + } + + // Bypass CreateFrameOp (with a single input matrix) followed by CastOp to the same matrix type as the + // CreateFrameOp's input matrix. In such cases, the result of the CastOp is simply the argument of the + // CreateFrameOp. + // TODO check if the matrix has a single column, otherwise, we would bypass the check in createframe + if (cOp.getArg().getType().isa() && cOp.getRes().getType().isa()) + if (auto cfOp = cOp.getArg().getDefiningOp()) + if (cfOp.getCols().size() == 1 && cfOp.getCols()[0].getType() == cOp.getRes().getType()) { + rewriter.replaceOp(cOp, cfOp.getCols()[0]); + return mlir::success(); + } + + return mlir::failure(); +} + +/** + * @brief Eliminates a SetColLabelsOp if the input frame already has the new column labels. + */ +mlir::LogicalResult mlir::daphne::SetColLabelsOp::canonicalize(mlir::daphne::SetColLabelsOp sclOp, + PatternRewriter &rewriter) { + if (auto argFrmTy = sclOp.getArg().getType().dyn_cast()) { // if the arg is a frame + if (std::vector *argLabels = argFrmTy.getLabels()) { // if the arg's labels are known + // Compare the arg's labels with the new labels. + mlir::ValueRange newLabels = sclOp.getLabels(); + if (argLabels->size() != newLabels.size()) + return mlir::failure(); + for (size_t i = 0; i < newLabels.size(); i++) { + std::pair labelConst = CompilerUtils::isConstant(newLabels[i]); + if (!labelConst.first || + labelConst.second != (*argLabels)[i]) // the new label is not known or differs from the arg label + return mlir::failure(); + } + // The arg frame already has the new labels, so we can replace this SetColLabelsOp by its arg. + rewriter.replaceOp(sclOp, sclOp.getArg()); + return mlir::success(); + } + } + return mlir::failure(); } \ No newline at end of file diff --git a/src/ir/daphneir/DaphneDialect.cpp b/src/ir/daphneir/DaphneDialect.cpp index 6c543d314..5f480c44a 100644 --- a/src/ir/daphneir/DaphneDialect.cpp +++ b/src/ir/daphneir/DaphneDialect.cpp @@ -206,6 +206,22 @@ mlir::Type mlir::daphne::DaphneDialect::parseType(mlir::DialectAsmParser &parser return mlir::daphne::HandleType::get(parser.getBuilder().getContext(), dataType); } else if (keyword == "String") { return StringType::get(parser.getBuilder().getContext()); + } else if (keyword == "Column") { + if (parser.parseLess()) + return nullptr; + ssize_t numRows = -1; + if (parser.parseOptionalQuestion()) + // Parse #rows if there was no '?'. + if (parser.parseInteger(numRows)) + return nullptr; + if (parser.parseXInDimensionList()) + return nullptr; + mlir::Type vt; + if (parser.parseType(vt)) + return nullptr; + if (parser.parseGreater()) + return nullptr; + return ColumnType::get(parser.getBuilder().getContext(), vt, numRows); } else if (keyword == "DaphneContext") { return mlir::daphne::DaphneContextType::get(parser.getBuilder().getContext()); } else { @@ -257,6 +273,8 @@ void mlir::daphne::DaphneDialect::printType(mlir::Type type, mlir::DialectAsmPri } else os << '?'; os << '>'; + } else if (auto t = type.dyn_cast()) { + os << "Column<" << unknownStrIf(t.getNumRows()) << "x" << t.getValueType() << '>'; } else if (auto t = type.dyn_cast()) { os << "List<" << t.getElementType() << '>'; } else if (auto handle = type.dyn_cast()) { @@ -405,3 +423,12 @@ ::mlir::LogicalResult mlir::daphne::HandleType::verify(::llvm::function_ref<::ml } else return emitError() << "only matrix type is supported for handle atm, got: " << dataType; } + +::mlir::LogicalResult mlir::daphne::ColumnType::verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + Type valueType, ssize_t numRows) { + if (!CompilerUtils::isScaType(valueType) && !llvm::isa(valueType)) + return mlir::failure(); + if (numRows < -1) + return mlir::failure(); + return mlir::success(); +} \ No newline at end of file diff --git a/src/ir/daphneir/DaphneInferShapeOpInterface.cpp b/src/ir/daphneir/DaphneInferShapeOpInterface.cpp index 3a6aabb8a..c12aac089 100644 --- a/src/ir/daphneir/DaphneInferShapeOpInterface.cpp +++ b/src/ir/daphneir/DaphneInferShapeOpInterface.cpp @@ -43,6 +43,8 @@ std::pair getShape(Value v) { return std::make_pair(mt.getNumRows(), mt.getNumCols()); if (auto ft = t.dyn_cast()) return std::make_pair(ft.getNumRows(), ft.getNumCols()); + if (auto ct = t.dyn_cast()) + return std::make_pair(ct.getNumRows(), 1); if (CompilerUtils::isScaType(t)) return std::make_pair(1, 1); diff --git a/src/ir/daphneir/DaphneInferTypesOpInterface.cpp b/src/ir/daphneir/DaphneInferTypesOpInterface.cpp index f760303a4..036a076d1 100644 --- a/src/ir/daphneir/DaphneInferTypesOpInterface.cpp +++ b/src/ir/daphneir/DaphneInferTypesOpInterface.cpp @@ -68,41 +68,126 @@ std::vector daphne::CastOp::inferTypes() { Type resultType = getRes().getType(); auto matrixArgument = argumentType.dyn_cast(); auto frameArgument = argumentType.dyn_cast(); + auto columnArgument = argumentType.dyn_cast(); auto matrixResult = resultType.dyn_cast(); + auto frameResult = resultType.dyn_cast(); + auto columnResult = resultType.dyn_cast(); + + if (matrixResult) { + if (!llvm::isa(matrixResult.getElementType())) + // The result type is a matrix with a known value type. We we leave the result type as it is. We do not + // overwrite the value type, since this could drop information that was explicitly encoded in the CastOp. + return {resultType}; + else { + // The result is a matrix with an unknown value type. We infer the value type from the argument. + + // The argument is a matrix; we use its value type for the result. + if (matrixArgument) + return {matrixResult.withElementType(matrixArgument.getElementType())}; + + // The argument is a column; we use its value type for the result. + if (columnArgument) + return {matrixResult.withElementType(columnArgument.getValueType())}; + + // The argument is a frame; we use the value type of its only column for the results; if the argument has + // more than one column, we throw an exception. + if (frameArgument) { + auto argumentColumnTypes = frameArgument.getColumnTypes(); + /*if (argumentColumnTypes.size() == 0) + return {resultType}; + else*/ + if (argumentColumnTypes.size() != 1) { + // TODO We could use the most general of the column types. + throw ErrorHandler::compilerError(getLoc(), "InferTypesOpInterface (daphne::CastOp::inferTypes)", + "currently CastOp cannot infer the value type of its " + "output matrix, if the input is a multi-column frame"); + } + return {matrixResult.withElementType(argumentColumnTypes[0])}; + } - // If the result type is not a matrix or a matrix with so far unknown value type, then we - // we leave the result type as it is. We do not reset it to - // unknown, since this could drop information that was explicitly - // encoded in the CastOp. - if (!matrixResult || !llvm::isa(matrixResult.getElementType())) - return {resultType}; + // The argument is a scalar; we use its type for the value type of the result. + if (CompilerUtils::isScaType(argumentType)) + return {daphne::MatrixType::get(getContext(), argumentType)}; - // The argument is a matrix, result is a matrix; we use its value type for the result. - if (matrixArgument) - return {matrixResult.withElementType(matrixArgument.getElementType())}; - - // The argument is a frame, result is a matrix; we use the value type of its only - // column for the results; if the argument has more than one - // column, we throw an exception. - if (frameArgument) { - auto argumentColumnTypes = frameArgument.getColumnTypes(); - if (argumentColumnTypes.size() != 1) { - // TODO We could use the most general of the column types. - throw ErrorHandler::compilerError(getLoc(), "InferTypesOpInterface (daphne::CastOp::inferTypes)", - "currently CastOp cannot infer the value type of its " - "output matrix, if the input is a multi-column frame"); + // The argument is some unsupported type; this is an error. + throw std::runtime_error( + "CastOp::inferTypes(): the argument is neither a supported data type nor a supported value type"); } - return {matrixResult.withElementType(argumentColumnTypes[0])}; - } + } else if (frameResult) { + std::vector resultColumnTypesBefore = frameResult.getColumnTypes(); + std::vector resultColumnTypesAfter; + for (size_t i = 0; i < resultColumnTypesBefore.size(); i++) { + if (!llvm::isa(resultColumnTypesBefore[i])) + // The value type of this frame column is known. We leave it as it is. We do not overwrite this column's + // value type, since this could drop information that was explicitly encoded in the CastOp. + resultColumnTypesAfter.push_back(resultColumnTypesBefore[i]); + else { + // The value type of this frame column is unknown. We infer the value type of this frame column from the + // argument. + + if (matrixArgument) + // The argument is a matrix; we use its value type for this frame column. + resultColumnTypesAfter.push_back(matrixArgument.getElementType()); + else if (columnArgument) + // The argument is a column; we use its value type for this frame column. + resultColumnTypesAfter.push_back(columnArgument.getValueType()); + else if (frameArgument) + // The argument is a frame; we use the value type of its corresponding column for this frame column. + // TODO double-check if there the #cols matches + resultColumnTypesAfter.push_back(frameArgument.getColumnTypes()[i]); + else if (CompilerUtils::isScaType(argumentType)) + // The argument is a scalar; we use its type for the value type of the result. + resultColumnTypesAfter.push_back(argumentType); + else { + // The argument is some unsupported type; this is an error. + throw std::runtime_error("CastOp::inferTypes(): the argument is neither a supported data type nor " + "a supported value type"); + } + } + } + return {frameResult.withColumnTypes(resultColumnTypesAfter)}; + } else if (columnResult) { + if (!llvm::isa(columnResult.getValueType())) + // The result type is a column with a known value type. We we leave the result type as it is. We do not + // overwrite the value type, since this could drop information that was explicitly encoded in the CastOp. + return {resultType}; + else { + // The result is a column with an unknown value type. We infer the value type from the argument. + + // The argument is a matrix; we use its value type for the result. + if (matrixArgument) + return {columnResult.withValueType(matrixArgument.getElementType())}; + + // The argument is a column; we use its value type for the result. + if (columnArgument) + return {columnResult.withValueType(columnArgument.getValueType())}; + + // The argument is a frame; we use the value type of its only column for the results; if the argument has + // more than one column, we throw an exception. + if (frameArgument) { + auto argumentColumnTypes = frameArgument.getColumnTypes(); + /*if (argumentColumnTypes.size() == 0) + return {resultType}; + else*/ + if (argumentColumnTypes.size() != 1) { + // TODO We could use the most general of the column types. + throw ErrorHandler::compilerError(getLoc(), "InferTypesOpInterface (daphne::CastOp::inferTypes)", + "currently CastOp cannot infer the value type of its " + "output column, if the input is a multi-column frame"); + } + return {columnResult.withValueType(argumentColumnTypes[0])}; + } - // The argument is a scalar, result is a matrix; we use its type for the value type - // of the result. - if (CompilerUtils::isScaType(argumentType)) - return {daphne::MatrixType::get(getContext(), argumentType)}; + // The argument is a scalar; we use its type for the value type of the result. + if (CompilerUtils::isScaType(argumentType)) + return {daphne::ColumnType::get(getContext(), argumentType)}; - // The argument is some unsupported type; this is an error. - throw std::runtime_error( - "CastOp::inferTypes(): the argument is neither a supported data type nor a supported value type"); + // The argument is some unsupported type; this is an error. + throw std::runtime_error( + "CastOp::inferTypes(): the argument is neither a supported data type nor a supported value type"); + } + } else + return {resultType}; } std::vector daphne::ExtractColOp::inferTypes() { @@ -289,6 +374,24 @@ std::vector daphne::GroupOp::inferTypes() { return {daphne::FrameType::get(ctx, newColumnTypes)}; } +std::vector daphne::ColJoinOp::inferTypes() { + MLIRContext *ctx = getContext(); + Builder builder(ctx); + return {daphne::ColumnType::get(ctx, builder.getIndexType()), daphne::ColumnType::get(ctx, builder.getIndexType())}; +} + +std::vector daphne::ColGroupFirstOp::inferTypes() { + MLIRContext *ctx = getContext(); + Builder builder(ctx); + return {daphne::ColumnType::get(ctx, builder.getIndexType()), daphne::ColumnType::get(ctx, builder.getIndexType())}; +} + +std::vector daphne::ColGroupNextOp::inferTypes() { + MLIRContext *ctx = getContext(); + Builder builder(ctx); + return {daphne::ColumnType::get(ctx, builder.getIndexType()), daphne::ColumnType::get(ctx, builder.getIndexType())}; +} + std::vector daphne::ExtractOp::inferTypes() { throw ErrorHandler::compilerError(getLoc(), "InferTypesOpInterface", "type inference not implemented for ExtractOp"); // TODO diff --git a/src/ir/daphneir/DaphneInferTypesOpInterface.h b/src/ir/daphneir/DaphneInferTypesOpInterface.h index 05f90d6d7..727399870 100644 --- a/src/ir/daphneir/DaphneInferTypesOpInterface.h +++ b/src/ir/daphneir/DaphneInferTypesOpInterface.h @@ -78,6 +78,13 @@ template class DataTypeSca : public TraitBase class DataTypeMat : public TraitBase {}; +/** + * @brief The data type (of the single result) is always `Column`. + * + * Assumes that the operation has always exactly one result. + */ +template class DataTypeCol : public TraitBase {}; + /** * @brief The data type (of the single result) is always `Frame`. * diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index 04438a263..c8c4958d2 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -136,7 +136,8 @@ def Daphne_CreateFrameOp : Daphne_Op<"createFrame", [ SameVariadicOperandSize, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + Pure ]> { let arguments = (ins Variadic:$cols, Variadic:$labels); let results = (outs FrameOrU:$res); @@ -663,7 +664,7 @@ def Daphne_ExtractColOp : Daphne_Op<"extractCol", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, CUDASupport + DeclareOpInterfaceMethods, CUDASupport, Pure ]> { let summary = "Copies the specified columns from the argument to the result."; @@ -684,6 +685,8 @@ def Daphne_ExtractColOp : Daphne_Op<"extractCol", [ let arguments = (ins MatrixOrFrame:$source, AnyTypeOf<[Selection, StrScalar, Unknown]>:$selectedCols); let results = (outs MatrixOrFrame:$res); + + let hasCanonicalizeMethod = 1; } def Daphne_SliceColOp : Daphne_Op<"sliceCol", [ @@ -760,7 +763,8 @@ def Daphne_ColBindOp : Daphne_BindOp<"colBind", [ ValueTypesConcat, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - NumRowsFromAllArgs, NumColsFromSumOfAllArgs, CUDASupport + NumRowsFromAllArgs, NumColsFromSumOfAllArgs, CUDASupport, + Pure ]>; def Daphne_RowBindOp : Daphne_BindOp<"rowBind", [ @@ -1206,7 +1210,7 @@ def Daphne_QueryOp : Daphne_Op<"query", [ def Daphne_FilterRowOp : Daphne_Op<"filterRow", [ TypeFromFirstArg, DeclareOpInterfaceMethods, - NumColsFromArg + NumColsFromArg, Pure ]> { let summary = "Filters the rows of a data object according to a bit vector"; @@ -1284,6 +1288,186 @@ def Daphne_GroupOp : Daphne_Op<"group", [ let results = (outs FrameOrU:$res); } +// **************************************************************************** +// Columnar algebra +// **************************************************************************** +// TODO These columnar operations (plus the Column type) could become a separate dialect. + +class Daphne_ColSelectCmpOp traits = []> +: Daphne_Op { + let summary = "Selection on columnar data; input: data; output: sorted positions."; + let description = [{ + Compares each value of the left-hand-side input column to the right-hand-side input scalar using the concrete + comparison operation. The output column contains the positions (counting starts at zero) of the matching + values in the left-hand-side input column. The output column is sorted. + + Note that this operation can only compare a column to a scalar. To compare two columns, consider using + `ColCalcBinaryOp`. + }]; + + let arguments = (ins ColumnOrU:$lhsData, AnyTypeOf<[AnyScalar, Unknown]>:$rhsData); + let results = (outs ColumnOrU:$resPos); +} +def Daphne_ColSelectEqOp : Daphne_ColSelectCmpOp<"colSelectEq">; +def Daphne_ColSelectNeqOp : Daphne_ColSelectCmpOp<"colSelectNeq">; +def Daphne_ColSelectGtOp : Daphne_ColSelectCmpOp<"colSelectGt">; +def Daphne_ColSelectGeOp : Daphne_ColSelectCmpOp<"colSelectGe">; +def Daphne_ColSelectLtOp : Daphne_ColSelectCmpOp<"colSelectLt">; +def Daphne_ColSelectLeOp : Daphne_ColSelectCmpOp<"colSelectLe">; + +def Daphne_ColProjectOp : Daphne_Op<"colProject", [ + DataTypeCol, ValueTypeFromFirstArg +]> { + let summary = "Projection on columnar data; input: data, positions; output: data."; + let description = [{ + Extracts the data elements addressed by the positions in the right-hand-side input column from the + left-hand-side input column. The input positions must be valid positions in the input data column (not out of + bounds); other than that, there are no restrictions, i.e., the positions do not need to be sorted and may + contain duplicates. The elements in the output data column correspond to the input positions, i.e., are in the + same order. + }]; + + let arguments = (ins ColumnOrU:$lhsData, ColumnOrU:$rhsPos); + let results = (outs ColumnOrU:$resData); +} + +class Daphne_ColSetOp traits = []> +: Daphne_Op { + let arguments = (ins ColumnOrU:$lhsPos, ColumnOrU:$rhsPos); + let results = (outs ColumnOrU:$resPos); +} +def Daphne_ColIntersectOp : Daphne_ColSetOp<"colIntersect"> { + let summary = "Set intersection of columnar positions lists; input: sorted and unique positions; output: sorted and unique positions."; + let description = [{ + Intersects the two given position list columns, i.e., returns a column containing the positions present in both + inputs. Both inputs must be sorted and unique (i.e., must not contain duplicates). The output positions are + sorted and unique, too. + }]; +} +def Daphne_ColMergeOp : Daphne_ColSetOp<"colMerge"> { + let summary = "Set union of columnar positions lists; input: sorted and unique positions; output: sorted and unique positions."; + let description = [{ + Merges the two given position list columns, i.e., returns a column containing the positions present in any + input. Both inputs must be sorted and unique (i.e., must not contain duplicates). The output positions are + sorted and unique, too. + }]; +} + +def Daphne_ColJoinOp : Daphne_Op<"colJoin", [ + DeclareOpInterfaceMethods +]> { + let summary = "N:1 equi-join on columnar data; input: data, unique data; output: sorted positions, positions."; + let description = [{ + Compares each element in the left-hand-side input column to each element in the right-hand-side input column. + For each match, outputs the pair of positions of the matching elements in the left-hand-side input and + right-hand-side input. The values the in right-hand-side input must be unique (no duplicates). The two output + positions lists have the same length. When used for a primary-key foreign-key join, the left-hand-side input + is the foreign key column and the right-hand-side input is the primary key column. + }]; + + let arguments = (ins ColumnOrU:$lhsData, ColumnOrU:$rhsData, Size:$numRes); + let results = (outs ColumnOrU:$resLhsPos, ColumnOrU:$resRhsPos); +} + +def Daphne_ColSemiJoinOp : Daphne_Op<"colSemiJoin", [ + DataTypeCol, ValueTypeSize +]> { + let summary = "Semi-join on columnar data; input: data, data, output: sorted positions."; + let description = [{ + Outputs the position of each element in the left-hand-side input column for which there is a matching element + in the right-hand-side input column. The output positions are sorted. + }]; + + let arguments = (ins ColumnOrU:$lhsData, ColumnOrU:$rhsData, Size:$numRes); + let results = (outs ColumnOrU:$resLhsPos); +} + +def Daphne_ColGroupFirstOp : Daphne_Op<"colGroupFirst", [ + DeclareOpInterfaceMethods +]> { + let summary = "Initial grouping step on columnar data; input: data; output: positions (group ids), positions (of representatives)."; + let description = [{ + Maps each distinct value in the input data column to a unique group id. Group ids are consecutive integers + starting at zero. Outputs (1) a column of group ids for each element in the input data column (same length as + the input data column), and (2) a column of positions of one representative data element per group (the i-th + element is the position of any input data element belonging to group id i). This operation is the first step + when grouping by one or multiple columns. + }]; + + let arguments = (ins ColumnOrU:$argData); + let results = (outs ColumnOrU:$resGrpIds, ColumnOrU:$resReprPos); +} + +def Daphne_ColGroupNextOp : Daphne_Op<"colGroupNext", [ + DeclareOpInterfaceMethods +]> { + let summary = "Subsequent grouping step on columnar data; input: data, positions (group ids); output: positions (group ids), positions (of representatives)."; + let description = [{ + Maps each distinct combination of a value in the left-hand-side input data column and the right-hand-side group + id column to a unique group id. The two input columns must have the same length. Group ids are consecutive + integers starting at zero. Outputs (1) a column of group ids for each element in the input data column (same + length as the input data column), and (2) a column of positions of one representative data element per group + (the i-th element is the position of any input data element belonging to group id i). When grouping by multiple + columns, this operation is applied after `colGroupFirst` to add one more key column to the grouping. + }]; + + let arguments = (ins ColumnOrU:$argData, ColumnOrU:$argGrpIds); + let results = (outs ColumnOrU:$resGrpIds, ColumnOrU:$resReprPos); +} + +class Daphne_ColCalcBinaryOp traits = []> +: Daphne_Op { + let summary = "Elementwise binary operation on columnar data; input: data; output: data."; + let description = [{ + Applies the given binary operation to all corresponding pairs in the two input data columns. The two input data + columns must have the same length. The output data column has the same length as the inputs. + }]; + + // TODO Support scalars as arguments. + let arguments = (ins ColumnOrU:$lhsData, ColumnOrU:$rhsData); + let results = (outs ColumnOrU:$resData); +} +def Daphne_ColCalcSubOp : Daphne_ColCalcBinaryOp<"colCalcSub">; +def Daphne_ColCalcMulOp : Daphne_ColCalcBinaryOp<"colCalcMul">; + +class Daphne_ColAllAggOp traits = []> +: Daphne_Op { + let summary = "Full aggregation of columnar data; input: data; output: data."; + let description = [{ + Aggregates all elements in the input data column using the given aggregation function. The result is a column + containing a single data element. + }]; + + let arguments = (ins ColumnOrU:$arg); + let results = (outs ColumnOrU:$res); +} +def Daphne_ColAllAggSumOp : Daphne_ColAllAggOp<"colSumAll">; + +class Daphne_ColGrpAggOp traits = []> +: Daphne_Op { + let summary = "Grouped aggregation of columnar data; input: data, positions (group ids); output: data."; + let description = [{ + Aggregates all elements in the input data column according to the input group ids using the specified + aggregation function. The two input columns must have the same length. The input group ids are typically the + output of the `colGroupFirst` or `colGroupNext` operation. The output column has one element per input group + (the i-th element is the aggregate for group id i). + }]; + + let arguments = (ins ColumnOrU:$data, ColumnOrU:$groupIds, Size:$numDistinct); + let results = (outs ColumnOrU:$res); +} +def Daphne_ColGrpAggSumOp : Daphne_ColGrpAggOp<"colSumGrp">; + // **************************************************************************** // Frame label manipulation // **************************************************************************** @@ -1291,10 +1475,13 @@ def Daphne_GroupOp : Daphne_Op<"group", [ def Daphne_SetColLabelsOp : Daphne_Op<"setColLabels", [ TypeFromFirstArg, DeclareOpInterfaceMethods, - ShapeFromArg + ShapeFromArg, + Pure ]> { let arguments = (ins FrameOrU:$arg, Variadic:$labels); let results = (outs FrameOrU:$res); + + let hasCanonicalizeMethod = 1; } def Daphne_SetColLabelsPrefixOp : Daphne_Op<"setColLabelsPrefix", [ @@ -1319,18 +1506,41 @@ def Daphne_ToStringOp : Daphne_Op<"toString", [DataTypeSca, ValueTypeStr]> { def Daphne_CastOp : Daphne_Op<"cast", [ DeclareOpInterfaceMethods, - ShapeFromArg + ShapeFromArg, + Pure ]> { // Note that the requested result type is not an argument, but should be // specified as the output type when creating a CastOp. - let arguments = (ins AnyTypeOf<[MatrixOrFrame, AnyScalar, Unknown]>:$arg); - let results = (outs AnyTypeOf<[MatrixOrFrame, AnyScalar, Unknown]>:$res); + let arguments = (ins AnyTypeOf<[MatrixOrFrame, ColumnOrU, AnyScalar, Unknown]>:$arg); + let results = (outs AnyTypeOf<[MatrixOrFrame, ColumnOrU, AnyScalar, Unknown]>:$res); let hasFolder = 1; let extraClassDeclaration = [{ bool isTrivialCast() { - return getArg().getType() == getRes().getType(); + Type argTy = getArg().getType(); + Type resTy = getRes().getType(); + auto argFrmTy = argTy.dyn_cast(); + auto resFrmTy = resTy.dyn_cast(); + if(argFrmTy && resFrmTy) { + // TODO avoid such a special case for frames, represent the optionality of labels in a different + // way than a pointer (maybe llvm::Option/std::optional), sth that supports op== to be true on different instances + if(argFrmTy.getColumnTypes() != resFrmTy.getColumnTypes()) + return false; + if(argFrmTy.getNumRows() != resFrmTy.getNumRows()) + return false; + if(argFrmTy.getNumCols() != resFrmTy.getNumCols()) + return false; + std::vector *argLabels = argFrmTy.getLabels(); + std::vector *resLabels = resFrmTy.getLabels(); + if((argLabels == nullptr) != (resLabels == nullptr)) + return false; + if(argLabels && *argLabels != *resLabels) + return false; + return true; + } else + return argTy == resTy; } + /** * @brief Checks if this cast just removes detailed properties from a matrix/frame type, * to make them unknown instead. @@ -1348,7 +1558,60 @@ def Daphne_CastOp : Daphne_Op<"cast", [ return (argMatTy && resMatTy && argMatTy.isSpecializationOf(resMatTy)) || (argFrmTy && resFrmTy && argFrmTy.isSpecializationOf(resFrmTy)); } + + bool mightLoseInformation() { + Type argTy = getArg().getType(); + Type resTy = getRes().getType(); + + auto argMatTy = argTy.dyn_cast(); + auto argFrmTy = argTy.dyn_cast(); + auto argColTy = argTy.dyn_cast(); + + auto resMatTy = resTy.dyn_cast(); + auto resFrmTy = resTy.dyn_cast(); + auto resColTy = resTy.dyn_cast(); + + // Note: We only consider data/value types, but not the properties. + // TODO What if "b" has more property info that would be important for further inference (is that possible)? + + // If we cast from a matrix to a frame and all columns of the frame have the value type of the matrix, then we certainly do not lose information. + if(argMatTy && resFrmTy) { + Type argMatValTy = argMatTy.getElementType(); + std::vector resFrmColTys = resFrmTy.getColumnTypes(); + for (Type resFrmColTy : resFrmColTys) + // TODO It would be okay if resFrmColTy can represent all values of argMatValTy. + if (argMatValTy != resFrmColTy) + return true; + return false; + } + // If we cast from a frame to a matrix and all columns of the frame have the value type of the matrix, then we certainly do not lose information. + if(argFrmTy && resMatTy) { + std::vector argFrmColTys = argFrmTy.getColumnTypes(); + Type resMatValTy = resMatTy.getElementType(); + for (Type argFrmColTy : argFrmColTys) + // TODO It would be okay if resMatValTy can represent all values of argFrmColTy. + if (argFrmColTy != resMatValTy) + return true; + return false; + } + // If we cast from a matrix to a column and both have the same value type, then we certainly do not lose information. + if (argMatTy && resColTy && argMatTy.getElementType() == resColTy.getValueType()) + return false; + // If we cast from a column to a matrix and both have the same value type, then we certainly do not lose information. + if (argColTy && resMatTy && argColTy.getValueType() == resMatTy.getElementType()) + return false; + // If we cast from a column to a single-column frame and both have the same value type, then we certainly do not lose information. + if (argColTy && resFrmTy && argColTy.getValueType() == resFrmTy.getColumnTypes()[0]) + return false; + // If we cast from a single-column frame to a column and both have the same value type, then we certainly do not lose information. + if (argFrmTy && resColTy && argFrmTy.getColumnTypes()[0] == resColTy.getValueType()) + return false; + + return true; + } }]; + + let hasCanonicalizeMethod = 1; } def Daphne_RenameOp : Daphne_Op<"rename", [ @@ -1370,6 +1633,31 @@ def Daphne_GetColIdxOp : Daphne_Op<"getColIdx", [DataTypeSca, ValueTypeSize]>{ let arguments = (ins Frame:$frame, StrScalar:$columnName); let results = (outs Size:$res); } + +def Daphne_ConvertPosListToBitmapOp : Daphne_Op<"convertPosListToBitmap", [ + TypeFromFirstArg, + ShapeFromArg, + Pure +]> { + let summary = "Creates a single-column matrix of zero/one entries (bit vector), where the entries addressed by the numbers in the single-column input matrix (position list) are set to one."; + + let arguments = (ins MatrixOrU:$arg, Size:$numRowsRes); + let results = (outs MatrixOrU:$res); +} + +def Daphne_ConvertBitmapToPosListOp : Daphne_Op<"convertBitmapToPosList", [ + DataTypeMat, ValueTypeSize, + ShapeFromArg, + Pure +]> { + let summary = "Given a single-column matrix with only zero/one entries (bit vector), creates a sorted single-column matrix of the positions of ones in the input (position list)."; + + let arguments = (ins MatrixOrU:$arg); + let results = (outs MatrixOrU:$res); + + let hasCanonicalizeMethod = 1; +} + // **************************************************************************** // Distributed Operations // **************************************************************************** diff --git a/src/ir/daphneir/DaphneTypeInferenceTraits.td b/src/ir/daphneir/DaphneTypeInferenceTraits.td index ea882fdf0..2787e266f 100644 --- a/src/ir/daphneir/DaphneTypeInferenceTraits.td +++ b/src/ir/daphneir/DaphneTypeInferenceTraits.td @@ -32,6 +32,7 @@ def DataTypeFromArgs: NativeOpTrait<"DataTypeFromArgs">; def DataTypeSca: NativeOpTrait<"DataTypeSca">; def DataTypeMat: NativeOpTrait<"DataTypeMat">; def DataTypeFrm: NativeOpTrait<"DataTypeFrm">; +def DataTypeCol: NativeOpTrait<"DataTypeCol">; // ============================================================================ // Value type diff --git a/src/ir/daphneir/DaphneTypes.td b/src/ir/daphneir/DaphneTypes.td index c0d307fe7..d2687119b 100644 --- a/src/ir/daphneir/DaphneTypes.td +++ b/src/ir/daphneir/DaphneTypes.td @@ -226,6 +226,74 @@ def MatrixOrFrame : AnyTypeOf<[Matrix, Frame, Unknown]>; def MatrixOrU : AnyTypeOf<[Matrix, Unknown]>; def FrameOrU : AnyTypeOf<[Frame, Unknown]>; +def Column : Daphne_Type<"Column"> { + let summary = "column"; + + let parameters = (ins "::mlir::Type":$valueType, "ssize_t":$numRows); + + let genVerifyDecl = 1; + + let builders = [ + // Creates a ColumnType from mere value type information, with all other parameters reset. + TypeBuilder<(ins "::mlir::Type":$valueType), [{ + return Base::get($_ctxt, valueType, -1); + }]>, + ]; + + let extraClassDeclaration = [{ + // The following methods return a ColumnType with the same parameters + // as this ColumnType, except for one parameter, which is replaced by a + // new value. + + ::mlir::daphne::ColumnType withValueType(::mlir::Type valueType) { + return get(getContext(), valueType, getNumRows()); + } + + ::mlir::daphne::ColumnType withShape(ssize_t numRows) { + return get(getContext(), getValueType(), numRows); + } + + // The following method returns a ColumnType with the same value type + // as this ColumnType, but all other parameters reset. + + ::mlir::daphne::ColumnType withSameValueType() { + return get(getContext(), getValueType()); + } + + /** + * @brief Check if `other` has strictly less information than this column type. + * @param other the other column type to compare to + * @return true if this type has more information, false otherwise + */ + bool isSpecializationOf(::mlir::daphne::ColumnType other) { + auto valueType = getValueType(); + auto otherValueType = other.getValueType(); + if(!otherValueType.isa() && valueType != otherValueType) + return false; + + if (other.getNumRows() != -1 && getNumRows() != other.getNumRows()) + return false; + + return true; + } + }]; +} +def ColumnOrU : AnyTypeOf<[Column, Unknown]>; + +// TODO We could omit the "Of" and have it default to ScalarType. +// A type constraint checking if a type is a column whose value type is one of +// the given types (whereby Unknown is always allowed automatically). +// Reuses MLIR's ContainerType constraint in a way inspired by MLIR's VectorOf. +class ColumnOf allowedTypes> : AnyTypeOf<[ + ContainerType< + AnyTypeOf, + Column.predicate, + "$_self.dyn_cast<::mlir::daphne::ColumnType>().getValueType()", + "column" + >, + Unknown +]>; + def List : Daphne_Type<"List"> { let summary = "list"; diff --git a/src/ir/daphneir/Passes.h b/src/ir/daphneir/Passes.h index 464aeb577..4aa400654 100644 --- a/src/ir/daphneir/Passes.h +++ b/src/ir/daphneir/Passes.h @@ -65,6 +65,7 @@ std::unique_ptr createPhyOperatorSelectionPass(); std::unique_ptr createPrintIRPass(std::string message = ""); std::unique_ptr createProfilingPass(); std::unique_ptr createRewriteSqlOpPass(); +std::unique_ptr createRewriteToColumnarOpsPass(); std::unique_ptr createRewriteToCallKernelOpPass(const DaphneUserConfig &cfg, std::unordered_map &usedLibPaths); std::unique_ptr createSelectMatrixRepresentationsPass(const DaphneUserConfig &cfg); diff --git a/src/parser/catalog/KernelCatalogParser.cpp b/src/parser/catalog/KernelCatalogParser.cpp index 3fabec824..44ec25d49 100644 --- a/src/parser/catalog/KernelCatalogParser.cpp +++ b/src/parser/catalog/KernelCatalogParser.cpp @@ -69,6 +69,10 @@ KernelCatalogParser::KernelCatalogParser(mlir::MLIRContext *mctx) { mlir::Type ltCSR = mlir::daphne::ListType::get(mctx, mtCSR); typeMap.emplace(CompilerUtils::mlirTypeToCppTypeName(ltCSR), ltCSR); + // Column type. + mlir::Type ct = mlir::daphne::ColumnType::get(mctx, st); + typeMap.emplace(CompilerUtils::mlirTypeToCppTypeName(ct), ct); + // MemRef type. if (!st.isa()) { // DAPHNE's StringType is not supported as the element type of a diff --git a/src/parser/config/ConfigParser.cpp b/src/parser/config/ConfigParser.cpp index 47d3def14..ed8875d57 100644 --- a/src/parser/config/ConfigParser.cpp +++ b/src/parser/config/ConfigParser.cpp @@ -57,6 +57,8 @@ void ConfigParser::readUserConfig(const std::string &filename, DaphneUserConfig config.use_phy_op_selection = jf.at(DaphneConfigJsonParams::USE_PHY_OP_SELECTION).get(); if (keyExists(jf, DaphneConfigJsonParams::USE_MLIR_CODEGEN)) config.use_mlir_codegen = jf.at(DaphneConfigJsonParams::USE_MLIR_CODEGEN).get(); + if (keyExists(jf, DaphneConfigJsonParams::USE_COLUMNAR)) + config.use_columnar = jf.at(DaphneConfigJsonParams::USE_COLUMNAR).get(); if (keyExists(jf, DaphneConfigJsonParams::MATMUL_VEC_SIZE_BITS)) config.matmul_vec_size_bits = jf.at(DaphneConfigJsonParams::MATMUL_VEC_SIZE_BITS).get(); if (keyExists(jf, DaphneConfigJsonParams::MATMUL_TILE)) @@ -80,6 +82,8 @@ void ConfigParser::readUserConfig(const std::string &filename, DaphneUserConfig config.vectorized_single_queue = jf.at(DaphneConfigJsonParams::VECTORIZED_SINGLE_QUEUE).get(); if (keyExists(jf, DaphneConfigJsonParams::DEBUG_LLVM)) config.debug_llvm = jf.at(DaphneConfigJsonParams::DEBUG_LLVM).get(); + if (keyExists(jf, DaphneConfigJsonParams::EXPLAIN_COLUMNAR)) + config.explain_columnar = jf.at(DaphneConfigJsonParams::EXPLAIN_COLUMNAR).get(); if (keyExists(jf, DaphneConfigJsonParams::EXPLAIN_KERNELS)) config.explain_kernels = jf.at(DaphneConfigJsonParams::EXPLAIN_KERNELS).get(); if (keyExists(jf, DaphneConfigJsonParams::EXPLAIN_LLVM)) diff --git a/src/parser/config/JsonParams.h b/src/parser/config/JsonParams.h index 3a6717497..b05dd2826 100644 --- a/src/parser/config/JsonParams.h +++ b/src/parser/config/JsonParams.h @@ -30,6 +30,7 @@ struct DaphneConfigJsonParams { inline static const std::string USE_IPA_CONST_PROPA = "use_ipa_const_propa"; inline static const std::string USE_PHY_OP_SELECTION = "use_phy_op_selection"; inline static const std::string USE_MLIR_CODEGEN = "use_mlir_codegen"; + inline static const std::string USE_COLUMNAR = "use_columnar"; inline static const std::string MATMUL_VEC_SIZE_BITS = "matmul_vec_size_bits"; inline static const std::string MATMUL_TILE = "matmul_tile"; inline static const std::string MATMUL_FIXED_TILE_SIZES = "matmul_fixed_tile_sizes"; @@ -42,6 +43,7 @@ struct DaphneConfigJsonParams { inline static const std::string VECTORIZED_SINGLE_QUEUE = "vectorized_single_queue"; inline static const std::string DEBUG_LLVM = "debug_llvm"; + inline static const std::string EXPLAIN_COLUMNAR = "explain_columnar"; inline static const std::string EXPLAIN_KERNELS = "explain_kernels"; inline static const std::string EXPLAIN_LLVM = "explain_llvm"; inline static const std::string EXPLAIN_PARSING = "explain_parsing"; @@ -85,9 +87,11 @@ struct DaphneConfigJsonParams { USE_IPA_CONST_PROPA, USE_PHY_OP_SELECTION, USE_MLIR_CODEGEN, + USE_COLUMNAR, CUDA_FUSE_ANY, VECTORIZED_SINGLE_QUEUE, DEBUG_LLVM, + EXPLAIN_COLUMNAR, EXPLAIN_KERNELS, EXPLAIN_LLVM, EXPLAIN_PARSING, diff --git a/src/runtime/local/datagen/GenGivenVals.h b/src/runtime/local/datagen/GenGivenVals.h index 0d62d1889..2b50219b1 100644 --- a/src/runtime/local/datagen/GenGivenVals.h +++ b/src/runtime/local/datagen/GenGivenVals.h @@ -18,6 +18,7 @@ #define SRC_RUNTIME_LOCAL_DATAGEN_GENGIVENVALS_H #include +#include #include #include @@ -42,10 +43,10 @@ template struct GenGivenVals { // **************************************************************************** /** - * @brief A very simple data generator which populates a matrix with the + * @brief A very simple data generator which populates a nxm data object (e.g., a matrix) with the * elements of the given `std::vector`. * - * Meant only for small matrices, mainly as a utility for testing and + * Meant only for small data objects, mainly as a utility for testing and * debugging. Note that it can easily be used with an initializer list as * follows: * @@ -56,11 +57,11 @@ template struct GenGivenVals { * ``` * * @param numRows The number of rows. - * @param elements The data elements to populate the matrix with. Their number + * @param elements The data elements to populate the data object with. Their number * must be divisible by `numRows`. * @param minNumNonZeros The minimum number of non-zeros to reserve space for * in a sparse matrix. - * @return A matrix of the specified data type `DT` containing the provided + * @return A data object of the specified data type `DT` containing the provided * data elements. */ template @@ -68,6 +69,31 @@ DT *genGivenVals(size_t numRows, const std::vector &elements, s return GenGivenVals
::generate(numRows, elements, minNumNonZeros); } +/** + * @brief A very simple data generator which populates a nx1 data object (e.g., a matrix) with the + * elements of the given `std::vector`. + * + * Meant only for small data objects, mainly as a utility for testing and + * debugging. Note that it can easily be used with an initializer list as + * follows: + * + * ```c++ + * // Generates the matrix 3 + * // 1 + * // 4 + * auto m = genGivenVals>({3, 1, 4}); + * ``` + * + * @param elements The data elements to populate the data object with. + * @param minNumNonZeros The minimum number of non-zeros to reserve space for + * in a sparse matrix. + * @return A data object of the specified data type `DT` with a single column containing the provided + * data elements. + */ +template DT *genGivenVals(const std::vector &elements, size_t minNumNonZeros = 0) { + return genGivenVals
(elements.size(), elements, minNumNonZeros); +} + // **************************************************************************** // (Partial) template specializations for different data/value types // **************************************************************************** @@ -172,4 +198,19 @@ template struct GenGivenVals> { } }; +// ---------------------------------------------------------------------------- +// Column +// ---------------------------------------------------------------------------- + +template struct GenGivenVals> { + static Column *generate(size_t numRows, const std::vector &elements, size_t minNumNonZeros = 0) { + if (numRows != elements.size()) + throw std::runtime_error("GenGivenVals>: the given number of rows must match the number of " + "elements in the given vector"); + auto res = DataObjectFactory::create>(numRows, false); + std::copy(elements.begin(), elements.end(), res->getValues()); + return res; + } +}; + #endif // SRC_RUNTIME_LOCAL_DATAGEN_GENGIVENVALS_H \ No newline at end of file diff --git a/src/runtime/local/datastructures/Column.h b/src/runtime/local/datastructures/Column.h new file mode 100644 index 000000000..ca8375b6a --- /dev/null +++ b/src/runtime/local/datastructures/Column.h @@ -0,0 +1,143 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include + +/** + * @brief A data structure that represents a single column in columnar query processing, all elements have the same + * value type and are stored contiguously in memory. + */ +template class Column : public Structure { + // `using`, so that we do not need to prefix each occurrence of these fields from the super-classes. + using Structure::numCols; + using Structure::numRows; + + std::shared_ptr values{}; + + // Grant DataObjectFactory access to the private constructors and destructors. + template friend DataType *DataObjectFactory::create(ArgTypes...); + template friend void DataObjectFactory::destroy(const DataType *obj); + + Column(size_t numRows, bool zero) + : Structure(numRows, 1), values(std::shared_ptr(new ValueType[numRows])) { + if (zero) + std::fill(values.get(), values.get() + numRows, ValueTypeUtils::defaultValue); + } + + Column(size_t numRows, std::shared_ptr &values) : Structure(numRows, 1), values(values) {} + + ~Column() override = default; + + void printValue(std::ostream &os, ValueType val) const { + if constexpr (std::is_same::value || std::is_same::value) + os << static_cast(val); + else + os << val; + } + + public: + template using WithValueType = Column; + + /** + * @brief The common type of all values in this column. + */ + using VT = ValueType; + + static std::string getName() { return "Column"; } + + size_t getNumDims() const override { return 1; } + + size_t getNumItems() const override { return this->numRows; } + + void print(std::ostream &os) const override { + os << "Column(" << numRows << ", " << ValueTypeUtils::cppNameFor << ')' << std::endl; + + for (size_t r = 0; r < numRows; r++) { + printValue(os, values.get()[r]); + os << std::endl; + } + } + + Column *sliceRow(size_t rl, size_t ru) const override { + throw std::runtime_error("slicing has not been implemented for Column yet"); + } + + Column *sliceCol(size_t cl, size_t cu) const override { + throw std::runtime_error("slicing has not been implemented for Column yet"); + } + + Column *slice(size_t rl, size_t ru, size_t cl, size_t cu) const override { + throw std::runtime_error("slicing has not been implemented for Column yet"); + } + + size_t serialize(std::vector &buf) const override { + throw std::runtime_error("serialization has not been implemented for Column yet"); + } + + void shrinkNumRows(size_t numRows) { + if (numRows > this->numRows) + throw std::runtime_error("Column (shrinkNumRows): number of rows can only be shrunk"); + // TODO Here we could reduce the allocated size of the values array. + this->numRows = numRows; + } + + const ValueType *getValues() const { return values.get(); } + + ValueType *getValues() { return values.get(); } + + std::shared_ptr getValuesSharedPtr() const { return values; } + + bool operator==(const Column &rhs) const { + // Note that we do not use the generic `get` interface here since this operator is meant to be used for writing + // tests for, besides others, those generic interfaces. + + if (this == &rhs) + return true; + + const size_t numRows = this->getNumRows(); + + if (numRows != rhs.getNumRows()) + return false; + + const ValueType *valuesLhs = this->getValues(); + const ValueType *valuesRhs = rhs.getValues(); + + if (valuesLhs == valuesRhs) + return true; + + for (size_t r = 0; r < numRows; r++) + if (valuesLhs[r] != valuesRhs[r]) + return false; + + return true; + } +}; + +template std::ostream &operator<<(std::ostream &os, const Column &obj) { + obj.print(os); + return os; +} \ No newline at end of file diff --git a/src/runtime/local/kernels/CastObj.h b/src/runtime/local/kernels/CastObj.h index 3984276d4..2b6d6c460 100644 --- a/src/runtime/local/kernels/CastObj.h +++ b/src/runtime/local/kernels/CastObj.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -317,3 +318,91 @@ template class CastObj, MatrixfinishAppend(); } }; + +// ---------------------------------------------------------------------------- +// Column <- DenseMatrix +// ---------------------------------------------------------------------------- + +template class CastObj, DenseMatrix> { + + public: + static void apply(Column *&res, const DenseMatrix *arg, DCTX(ctx)) { + const size_t numRows = arg->getNumRows(); + const size_t numCols = arg->getNumCols(); + if (numCols == 1) { + // The input matrix has a single column. + const size_t rowSkipArg = arg->getRowSkip(); + if (rowSkipArg == 1) { + // The input's single column is stored contiguously. + // Reuse the input's memory for the result (zero-copy). + res = DataObjectFactory::create>(numRows, arg->getValuesSharedPtr()); + } else { + // The input's single column is not stored contiguosly. + // Copy the input data to the result. + res = DataObjectFactory::create>(numRows, false); + const VT *valuesArg = arg->getValues(); + VT *valuesRes = res->getValues(); + for (size_t r = 0; r < numRows; r++) { + valuesRes[r] = *valuesArg; + valuesArg += rowSkipArg; + } + } + } else { + // The input matrix has zero or multiple columns. + throw std::runtime_error("CastObj::apply: cannot cast a matrix with zero or mutliple columns to Column"); + } + } +}; + +// ---------------------------------------------------------------------------- +// DenseMatrix <- Column +// ---------------------------------------------------------------------------- + +template class CastObj, Column> { + + public: + static void apply(DenseMatrix *&res, const Column *arg, DCTX(ctx)) { + res = DataObjectFactory::create>(arg->getNumRows(), 1, arg->getValuesSharedPtr()); + } +}; + +// ---------------------------------------------------------------------------- +// Column <- Frame +// ---------------------------------------------------------------------------- + +template class CastObj, Frame> { + + public: + static void apply(Column *&res, const Frame *arg, DCTX(ctx)) { + const size_t numRows = arg->getNumRows(); + const size_t numCols = arg->getNumCols(); + if (numCols == 1 && arg->getColumnType(0) == ValueTypeUtils::codeFor) { + // The input frame has a single column of the result's value type. + // Zero-cost cast from frame to Column. + // TODO This case could even be used for (un)signed integers of the + // same width, involving a reinterpret cast of the pointers. + // TODO Can we avoid this const_cast? + res = DataObjectFactory::create>(numRows, arg->getColumn(0)->getValuesSharedPtr()); + } else { + // The input frame has multiple columns and/or other value types + // than the result. + throw std::runtime_error("CastObj::apply: cannot cast Frame with mutliple columns to Column"); + } + } +}; + +// ---------------------------------------------------------------------------- +// Frame <- Column +// ---------------------------------------------------------------------------- + +template class CastObj> { + + public: + static void apply(Frame *&res, const Column *arg, DCTX(ctx)) { + std::vector colMats; + DenseMatrix *argMat = nullptr; + castObj>(argMat, arg, ctx); + colMats.push_back(argMat); + res = DataObjectFactory::create(colMats, nullptr); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/CastObjSca.h b/src/runtime/local/kernels/CastObjSca.h index 9252bdf60..ef33a93b1 100644 --- a/src/runtime/local/kernels/CastObjSca.h +++ b/src/runtime/local/kernels/CastObjSca.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -85,3 +86,16 @@ template struct CastObjSca { return res; } }; + +// ---------------------------------------------------------------------------- +// Scalar <- Column +// ---------------------------------------------------------------------------- + +template struct CastObjSca> { + static VTRes apply(const Column *arg, DCTX(ctx)) { + const size_t numRows = arg->getNumRows(); + if (numRows != 1) + throw std::runtime_error("cast column to scalar: column must have exactly one element"); + return static_cast(*arg->getValues()); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/CheckEq.h b/src/runtime/local/kernels/CheckEq.h index a1a31fea2..2b1776bb5 100644 --- a/src/runtime/local/kernels/CheckEq.h +++ b/src/runtime/local/kernels/CheckEq.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -100,4 +101,12 @@ template struct CheckEq> { template struct CheckEq> { static bool apply(const Matrix *lhs, const Matrix *rhs, DCTX(ctx)) { return *lhs == *rhs; } +}; + +// ---------------------------------------------------------------------------- +// Column +// ---------------------------------------------------------------------------- + +template struct CheckEq> { + static bool apply(const Column *lhs, const Column *rhs, DCTX(ctx)) { return *lhs == *rhs; } }; \ No newline at end of file diff --git a/src/runtime/local/kernels/CmpOpCode.h b/src/runtime/local/kernels/CmpOpCode.h new file mode 100644 index 000000000..e2f592321 --- /dev/null +++ b/src/runtime/local/kernels/CmpOpCode.h @@ -0,0 +1,82 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +// **************************************************************************** +// Enum for comparison op codes and their names +// **************************************************************************** + +enum class CmpOpCode { EQ, NEQ, GT, GE, LT, LE }; + +/** + * @brief Array of the "names" of the `CmpOpCode`s. + * + * Must contain the same elements as `CmpOpCode` in the same order, + * such that we can obtain the name corresponding to a `CmpOpCode` `opCode` + * by `comparison_op_codes[static_cast(opCode)]`. + */ +static std::string_view comparison_op_codes[] = {"EQ", "NEQ", "GT", "GE", "LT", "LE"}; + +// **************************************************************************** +// Specification which comparison ops should be supported on which value types +// **************************************************************************** + +/** + * @brief Template constant specifying if the given comparison operation + * should be supported on arguments of the given value types. + * + * @tparam VTRes The result value type. + * @tparam VTLhs The left-hand-side argument value type. + * @tparam VTRhs The right-hand-side argument value type. + * @tparam op The binary operation. + */ +template static constexpr bool supportsCmpOp = false; + +// Macros for concisely specifying which comparison operations should be +// supported on which value types. + +// Generates code specifying that the comparison operation `Op` should be supported +// on the value type `VT` (for the result and the two arguments, for +// simplicity). +#define SUPPORT(Op, VT) template <> constexpr bool supportsCmpOp = true; + +#define SUPPORT_ALL_VTS(Op) \ + SUPPORT(Op, double) \ + SUPPORT(Op, float) \ + SUPPORT(Op, int64_t) \ + SUPPORT(Op, int32_t) \ + SUPPORT(Op, int8_t) \ + SUPPORT(Op, uint64_t) \ + SUPPORT(Op, uint32_t) \ + SUPPORT(Op, uint8_t) \ + SUPPORT(Op, bool) \ + SUPPORT(Op, std::string) + +SUPPORT_ALL_VTS(EQ); +SUPPORT_ALL_VTS(NEQ); +SUPPORT_ALL_VTS(GT); +SUPPORT_ALL_VTS(GE); +SUPPORT_ALL_VTS(LT); +SUPPORT_ALL_VTS(LE); + +// Undefine helper macros. +#undef SUPPORT +#undef SUPPORT_ALL_VTS \ No newline at end of file diff --git a/src/runtime/local/kernels/ColAggAll.h b/src/runtime/local/kernels/ColAggAll.h new file mode 100644 index 000000000..9866eb4f5 --- /dev/null +++ b/src/runtime/local/kernels/ColAggAll.h @@ -0,0 +1,59 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ColAggAll { + static void apply(AggOpCode opCode, DTRes *&res, const DTArg *arg, DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template void colAggAll(AggOpCode opCode, DTRes *&res, const DTArg *arg, DCTX(ctx)) { + ColAggAll::apply(opCode, res, arg, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Column <- Column +// ---------------------------------------------------------------------------- + +template struct ColAggAll, Column> { + static void apply(AggOpCode opCode, Column *&res, const Column *arg, DCTX(ctx)) { + DenseMatrix *argMat = nullptr; + castObj>(argMat, arg, ctx); + VTRes resSca = aggAll(opCode, argMat, ctx); + if (res == nullptr) + res = DataObjectFactory::create>(1, false); + res->getValues()[0] = resSca; + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ColAggGrp.h b/src/runtime/local/kernels/ColAggGrp.h new file mode 100644 index 000000000..617d9b9e4 --- /dev/null +++ b/src/runtime/local/kernels/ColAggGrp.h @@ -0,0 +1,88 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ColAggGrp { + static void apply(AggOpCode opCode, DTRes *&res, const DTData *data, const DTGrpIds *grpIds, size_t numDistinct, + DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +void colAggGrp(AggOpCode opCode, DTRes *&res, const DTData *data, const DTGrpIds *grpIds, size_t numDistinct, + DCTX(ctx)) { + ColAggGrp::apply(opCode, res, data, grpIds, numDistinct, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Column <- Column, Column +// ---------------------------------------------------------------------------- + +template struct ColAggGrp, Column, Column> { + static void apply(AggOpCode opCode, Column *&res, const Column *data, const Column *grpIds, + size_t numDistinct, DCTX(ctx)) { + const size_t numData = data->getNumRows(); + + if (numData != grpIds->getNumRows()) + throw std::runtime_error("input data and input group ids must have the same number of elements"); + + const VTData *valuesData = data->getValues(); + const VTPos *valuesGrpIds = grpIds->getValues(); + + if (res == nullptr) + res = DataObjectFactory::create>(numDistinct, false); + + VTData *valuesRes = res->getValues(); + + // Initialize the accumulator of each group with the neutral element of the aggregation function. + std::fill(valuesRes, valuesRes + numDistinct, AggOpCodeUtils::getNeutral(opCode)); + + // Perform the grouped aggregation. + EwBinaryScaFuncPtr func; + if (AggOpCodeUtils::isPureBinaryReduction(opCode)) { + func = getEwBinaryScaFuncPtr(AggOpCodeUtils::getBinaryOpCode(opCode)); + + for (size_t r = 0; r < numData; r++) { + VTPos grpId = valuesGrpIds[r]; + if (grpId < 0 || grpId > numDistinct) + throw std::runtime_error("out-of-bounds access"); + valuesRes[grpId] = func(valuesRes[grpId], valuesData[r], ctx); + } + } else + throw std::runtime_error("unsupported op code"); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ColCalcBinary.h b/src/runtime/local/kernels/ColCalcBinary.h new file mode 100644 index 000000000..8d57acb59 --- /dev/null +++ b/src/runtime/local/kernels/ColCalcBinary.h @@ -0,0 +1,66 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ColCalcBinary { + static void apply(BinaryOpCode opCode, DTResData *&resData, const DTLhsData *lhsData, const DTRhsData *rhsData, + DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +void colCalcBinary(BinaryOpCode opCode, DTResData *&resData, const DTLhsData *lhsData, const DTRhsData *rhsData, + DCTX(ctx)) { + ColCalcBinary::apply(opCode, resData, lhsData, rhsData, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Column <- Column, Column +// ---------------------------------------------------------------------------- + +template +struct ColCalcBinary, Column, Column> { + static void apply(BinaryOpCode opCode, Column *&resData, const Column *lhsData, + const Column *rhsData, DCTX(ctx)) { + DenseMatrix *resDataMat = nullptr; + DenseMatrix *lhsDataMat = nullptr; + DenseMatrix *rhsDataMat = nullptr; + castObj>(lhsDataMat, lhsData, ctx); + castObj>(rhsDataMat, rhsData, ctx); + ewBinaryMat(opCode, resDataMat, lhsDataMat, rhsDataMat, ctx); + castObj>(resData, resDataMat, ctx); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ColGroupFirst.h b/src/runtime/local/kernels/ColGroupFirst.h new file mode 100644 index 000000000..b21cccf82 --- /dev/null +++ b/src/runtime/local/kernels/ColGroupFirst.h @@ -0,0 +1,79 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ColGroupFirst { + static void apply(DTResGrpIds *&resGrpIds, DTResReprPos *&resReprPos, const DTArgData *argData, DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +void colGroupFirst(DTResGrpIds *&resGrpIds, DTResReprPos *&resReprPos, const DTArgData *argData, DCTX(ctx)) { + ColGroupFirst::apply(resGrpIds, resReprPos, argData, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Column, Column <- Column +// ---------------------------------------------------------------------------- + +template struct ColGroupFirst, Column, Column> { + static void apply(Column *&resGrpIds, Column *&resReprPos, const Column *argData, DCTX(ctx)) { + const size_t numArgData = argData->getNumRows(); + + if (resGrpIds == nullptr) + resGrpIds = DataObjectFactory::create>(numArgData, false); + if (resReprPos == nullptr) + resReprPos = DataObjectFactory::create>(numArgData, false); + VTPos *valuesResGrpIds = resGrpIds->getValues(); + VTPos *valuesResReprPos = resReprPos->getValues(); + VTPos *valuesResReprPosBeg = valuesResReprPos; + + const VTData *valuesArgData = argData->getValues(); + std::unordered_map grpIds; + for (size_t r = 0; r < numArgData; r++) { + VTPos &grpId = grpIds[valuesArgData[r]]; + if (!grpId) { // the value was not found + grpId = grpIds.size(); + *valuesResReprPos = r; + valuesResReprPos++; + } + *valuesResGrpIds = grpId - 1; // -1 because we use a zero entry in ht to indicate a newly created entry + valuesResGrpIds++; + } + + resReprPos->shrinkNumRows(valuesResReprPos - valuesResReprPosBeg); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ColGroupNext.h b/src/runtime/local/kernels/ColGroupNext.h new file mode 100644 index 000000000..7a14213e0 --- /dev/null +++ b/src/runtime/local/kernels/ColGroupNext.h @@ -0,0 +1,90 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ColGroupNext { + static void apply(DTResGrpIds *&resGrpIds, DTResReprPos *&resReprPos, const DTArgData *argData, + const DTArgGrpIds *argGrpIds, DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +void colGroupNext(DTResGrpIds *&resGrpIds, DTResReprPos *&resReprPos, const DTArgData *argData, + const DTArgGrpIds *argGrpIds, DCTX(ctx)) { + ColGroupNext::apply(resGrpIds, resReprPos, argData, argGrpIds, + ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Column, Column <- Column, Column +// ---------------------------------------------------------------------------- + +template +struct ColGroupNext, Column, Column, Column> { + static void apply(Column *&resGrpIds, Column *&resReprPos, const Column *argData, + const Column *argGrpIds, DCTX(ctx)) { + const size_t numArgData = argData->getNumRows(); + + if (numArgData != argGrpIds->getNumRows()) + throw std::runtime_error("input data and input group ids must have the same number of elements"); + + if (resGrpIds == nullptr) + resGrpIds = DataObjectFactory::create>(numArgData, false); + if (resReprPos == nullptr) + resReprPos = DataObjectFactory::create>(numArgData, false); + VTPos *valuesResGrpIds = resGrpIds->getValues(); + VTPos *valuesResReprPos = resReprPos->getValues(); + VTPos *valuesResReprPosBeg = valuesResReprPos; + + const VTData *valuesArgData = argData->getValues(); + const VTPos *valuesArgGrpIds = argGrpIds->getValues(); + // We have to use std::map, since std::pair is not hashable. + std::map, VTPos> grpIds; + for (size_t r = 0; r < numArgData; r++) { + VTPos &grpId = grpIds[std::make_pair(valuesArgData[r], valuesArgGrpIds[r])]; + if (!grpId) { // The value was not found. + grpId = grpIds.size(); + *valuesResReprPos = r; + valuesResReprPos++; + } + *valuesResGrpIds = grpId - 1; // -1 because we use a zero entry in ht to indicate a newly created entry + valuesResGrpIds++; + } + + resReprPos->shrinkNumRows(valuesResReprPos - valuesResReprPosBeg); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ColIntersect.h b/src/runtime/local/kernels/ColIntersect.h new file mode 100644 index 000000000..d4091f2f4 --- /dev/null +++ b/src/runtime/local/kernels/ColIntersect.h @@ -0,0 +1,80 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ColIntersect { + static void apply(DTResPos *&resPos, const DTLhsPos *lhsPos, const DTRhsPos *rhsPos, DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +void colIntersect(DTResPos *&resPos, const DTLhsPos *lhsPos, const DTRhsPos *rhsPos, DCTX(ctx)) { + ColIntersect::apply(resPos, lhsPos, rhsPos, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Column <- Column, Column +// ---------------------------------------------------------------------------- + +template struct ColIntersect, Column, Column> { + static void apply(Column *&resPos, const Column *lhsPos, const Column *rhsPos, DCTX(ctx)) { + const size_t numLhsPos = lhsPos->getNumRows(); + const size_t numRhsPos = rhsPos->getNumRows(); + + if (resPos == nullptr) + resPos = DataObjectFactory::create>(std::min(numLhsPos, numRhsPos), false); + + const VTPos *valuesLhsPos = lhsPos->getValues(); + const VTPos *valuesRhsPos = rhsPos->getValues(); + const VTPos *valuesLhsPosEnd = valuesLhsPos + numLhsPos; + const VTPos *valuesRhsPosEnd = valuesRhsPos + numRhsPos; + VTPos *valuesResPos = resPos->getValues(); + VTPos *valuesResPosBeg = valuesResPos; + + while (valuesLhsPos < valuesLhsPosEnd && valuesRhsPos < valuesRhsPosEnd) { + if (*valuesLhsPos < *valuesRhsPos) + valuesLhsPos++; + else if (*valuesRhsPos < *valuesLhsPos) + valuesRhsPos++; + else { // *valuesLhsPos == *valuesRhsPos + *valuesResPos = *valuesLhsPos; + valuesLhsPos++; + valuesRhsPos++; + valuesResPos++; + } + } + + resPos->shrinkNumRows(valuesResPos - valuesResPosBeg); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ColJoin.h b/src/runtime/local/kernels/ColJoin.h new file mode 100644 index 000000000..b6fb498da --- /dev/null +++ b/src/runtime/local/kernels/ColJoin.h @@ -0,0 +1,93 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ColJoin { + static void apply(DTResLhsPos *&resLhsPos, DTResRhsPos *&resRhsPos, const DTLhsData *lhsData, + const DTRhsData *rhsData, int64_t numRes, DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +void colJoin(DTResLhsPos *&resLhsPos, DTResRhsPos *&resRhsPos, const DTLhsData *lhsData, const DTRhsData *rhsData, + int64_t numRes, DCTX(ctx)) { + ColJoin::apply(resLhsPos, resRhsPos, lhsData, rhsData, numRes, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Column, Column <- Column, Column +// ---------------------------------------------------------------------------- + +template +struct ColJoin, Column, Column, Column> { + static void apply(Column *&resLhsPos, Column *&resRhsPos, const Column *lhsData, + const Column *rhsData, int64_t numRes, DCTX(ctx)) { + const size_t numLhsData = lhsData->getNumRows(); + const size_t numRhsData = rhsData->getNumRows(); + + if (numRes == -1) + // Assuming FK-PK join. + numRes = numLhsData; + + if (resLhsPos == nullptr) + resLhsPos = DataObjectFactory::create>(numRes, false); + if (resRhsPos == nullptr) + resRhsPos = DataObjectFactory::create>(numRes, false); + VTPos *valuesResLhsPos = resLhsPos->getValues(); + VTPos *valuesResRhsPos = resRhsPos->getValues(); + + // Build phase. + absl::flat_hash_map ht; + const VTData *valuesRhsData = rhsData->getValues(); + for (size_t r = 0; r < numRhsData; r++) + ht[valuesRhsData[r]] = r; + + // Probe phase. + const VTData *valuesLhsData = lhsData->getValues(); + size_t posRes = 0; + for (size_t r = 0; r < numLhsData; r++) { + auto it = ht.find(valuesLhsData[r]); + if (it != ht.end()) { + valuesResLhsPos[posRes] = r; + valuesResRhsPos[posRes] = it->second; + posRes++; + } + } + + resLhsPos->shrinkNumRows(posRes); + resRhsPos->shrinkNumRows(posRes); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ColMerge.h b/src/runtime/local/kernels/ColMerge.h new file mode 100644 index 000000000..fdb7a65b4 --- /dev/null +++ b/src/runtime/local/kernels/ColMerge.h @@ -0,0 +1,93 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ColMerge { + static void apply(DTResPos *&resPos, const DTLhsPos *lhsPos, const DTRhsPos *rhsPos, DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +void colMerge(DTResPos *&resPos, const DTLhsPos *lhsPos, const DTRhsPos *rhsPos, DCTX(ctx)) { + ColMerge::apply(resPos, lhsPos, rhsPos, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Column <- Column, Column +// ---------------------------------------------------------------------------- + +template struct ColMerge, Column, Column> { + static void apply(Column *&resPos, const Column *lhsPos, const Column *rhsPos, DCTX(ctx)) { + const size_t numLhsPos = lhsPos->getNumRows(); + const size_t numRhsPos = rhsPos->getNumRows(); + + if (resPos == nullptr) + resPos = DataObjectFactory::create>(numLhsPos + numRhsPos, false); + + const VTPos *valuesLhsPos = lhsPos->getValues(); + const VTPos *valuesRhsPos = rhsPos->getValues(); + const VTPos *valuesLhsPosEnd = valuesLhsPos + numLhsPos; + const VTPos *valuesRhsPosEnd = valuesRhsPos + numRhsPos; + VTPos *valuesResPos = resPos->getValues(); + VTPos *valuesResPosBeg = valuesResPos; + + while (valuesLhsPos < valuesLhsPosEnd && valuesRhsPos < valuesRhsPosEnd) { + if (*valuesLhsPos < *valuesRhsPos) { + *valuesResPos = *valuesLhsPos; + valuesLhsPos++; + } else if (*valuesRhsPos < *valuesLhsPos) { + *valuesResPos = *valuesRhsPos; + valuesRhsPos++; + } else { // *valuesLhsPos == *valuesRhsPos + *valuesResPos = *valuesLhsPos; + valuesLhsPos++; + valuesRhsPos++; + } + valuesResPos++; + } + // One or both operands have been consumed, but the other one might still contain positions. + while (valuesLhsPos < valuesLhsPosEnd) { + *valuesResPos = *valuesLhsPos; + valuesResPos++; + valuesLhsPos++; + } + while (valuesRhsPos < valuesRhsPosEnd) { + *valuesResPos = *valuesRhsPos; + valuesResPos++; + valuesRhsPos++; + } + + resPos->shrinkNumRows(valuesResPos - valuesResPosBeg); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ColProject.h b/src/runtime/local/kernels/ColProject.h new file mode 100644 index 000000000..e8c1fd287 --- /dev/null +++ b/src/runtime/local/kernels/ColProject.h @@ -0,0 +1,71 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ColProject { + static void apply(DTResData *&resData, const DTLhsData *lhsData, const DTRhsPos *rhsPos, DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +void colProject(DTResData *&resData, const DTLhsData *lhsData, const DTRhsPos *rhsPos, DCTX(ctx)) { + ColProject::apply(resData, lhsData, rhsPos, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Column <- Column, Column +// ---------------------------------------------------------------------------- + +template struct ColProject, Column, Column> { + static void apply(Column *&resData, const Column *lhsData, const Column *rhsPos, DCTX(ctx)) { + const size_t numLhsData = lhsData->getNumRows(); + const size_t numRhsPos = rhsPos->getNumRows(); + const size_t numResData = numRhsPos; + + if (resData == nullptr) + resData = DataObjectFactory::create>(numResData, false); + + const VTData *valuesLhsData = lhsData->getValues(); + const VTPos *valuesRhsPos = rhsPos->getValues(); + VTData *valuesResData = resData->getValues(); + for (size_t r = 0; r < numRhsPos; r++) { + const VTPos pos = valuesRhsPos[r]; + if (pos < 0 || static_cast(pos) >= numLhsData) + throw std::runtime_error("out-of-bounds access"); + valuesResData[r] = valuesLhsData[pos]; + } + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ColSelectCmp.h b/src/runtime/local/kernels/ColSelectCmp.h new file mode 100644 index 000000000..070f49e0b --- /dev/null +++ b/src/runtime/local/kernels/ColSelectCmp.h @@ -0,0 +1,95 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ColSelectCmp { + static void apply(CmpOpCode opCode, DTResPos *&resPos, const DTLhsData *lhsData, VTRhsData rhsData, + DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +void colSelectCmp(CmpOpCode opCode, DTResPos *&resPos, const DTLhsData *lhsData, VTRhsData rhsData, DCTX(ctx)) { + ColSelectCmp::apply(opCode, resPos, lhsData, rhsData, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Column <- Column, scalar +// ---------------------------------------------------------------------------- + +template +struct ColSelectCmp, Column, VTRhsData> { + static void apply(CmpOpCode opCode, Column *&resPos, const Column *lhsData, VTRhsData rhsData, + DCTX(ctx)) { + const size_t numLhsData = lhsData->getNumRows(); + + if (resPos == nullptr) + resPos = DataObjectFactory::create>(numLhsData, false); + + const VTLhsData *valuesLhsData = lhsData->getValues(); + VTPos *valuesResPos = resPos->getValues(); + + bool (*func)(VTLhsData, VTRhsData) = nullptr; + switch (opCode) { + case CmpOpCode::EQ: + func = [](VTLhsData lhs, VTRhsData rhs) { return lhs == rhs; }; + break; + case CmpOpCode::NEQ: + func = [](VTLhsData lhs, VTRhsData rhs) { return lhs != rhs; }; + break; + case CmpOpCode::GT: + func = [](VTLhsData lhs, VTRhsData rhs) { return lhs > rhs; }; + break; + case CmpOpCode::GE: + func = [](VTLhsData lhs, VTRhsData rhs) { return lhs >= rhs; }; + break; + case CmpOpCode::LT: + func = [](VTLhsData lhs, VTRhsData rhs) { return lhs < rhs; }; + break; + case CmpOpCode::LE: + func = [](VTLhsData lhs, VTRhsData rhs) { return lhs <= rhs; }; + break; + } + + size_t numResPos = 0; + for (size_t r = 0; r < numLhsData; r++) + if (func(valuesLhsData[r], rhsData)) { + valuesResPos[numResPos] = r; + numResPos++; + } + + resPos->shrinkNumRows(numResPos); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ColSemiJoin.h b/src/runtime/local/kernels/ColSemiJoin.h new file mode 100644 index 000000000..c8755bde4 --- /dev/null +++ b/src/runtime/local/kernels/ColSemiJoin.h @@ -0,0 +1,83 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ColSemiJoin { + static void apply(DTResLhsPos *&resLhsPos, const DTLhsData *lhsData, const DTRhsData *rhsData, int64_t numRes, + DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +void colSemiJoin(DTResLhsPos *&resLhsPos, const DTLhsData *lhsData, const DTRhsData *rhsData, int64_t numRes, + DCTX(ctx)) { + ColSemiJoin::apply(resLhsPos, lhsData, rhsData, numRes, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Column <- Column, Column +// ---------------------------------------------------------------------------- + +template struct ColSemiJoin, Column, Column> { + static void apply(Column *&resLhsPos, const Column *lhsData, const Column *rhsData, + int64_t numRes, DCTX(ctx)) { + const size_t numLhsData = lhsData->getNumRows(); + const size_t numRhsData = rhsData->getNumRows(); + + if (numRes == -1) + // Assuming FK-PK join. + numRes = numLhsData; + + if (resLhsPos == nullptr) + resLhsPos = DataObjectFactory::create>(numRes, false); + VTPos *valuesResLhsPos = resLhsPos->getValues(); + + // Build phase. + absl::flat_hash_set ht; + const VTData *valuesRhsData = rhsData->getValues(); + for (size_t r = 0; r < numRhsData; r++) + ht.insert(valuesRhsData[r]); + + // Probe phase. + const VTData *valuesLhsData = lhsData->getValues(); + size_t posRes = 0; + for (size_t r = 0; r < numLhsData; r++) + if (ht.count(valuesLhsData[r])) + valuesResLhsPos[posRes++] = r; + + resLhsPos->shrinkNumRows(posRes); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ConvertBitmapToPosList.h b/src/runtime/local/kernels/ConvertBitmapToPosList.h new file mode 100644 index 000000000..d10f38a81 --- /dev/null +++ b/src/runtime/local/kernels/ConvertBitmapToPosList.h @@ -0,0 +1,74 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ConvertBitmapToPosList { + static void apply(DTRes *&res, const DTArg *arg, DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template void convertBitmapToPosList(DTRes *&res, const DTArg *arg, DCTX(ctx)) { + ConvertBitmapToPosList::apply(res, arg, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// DenseMatrix <- DenseMatrix +// ---------------------------------------------------------------------------- + +template struct ConvertBitmapToPosList, DenseMatrix> { + static void apply(DenseMatrix *&res, const DenseMatrix *arg, DCTX(ctx)) { + const size_t numColsArg = arg->getNumCols(); + if (numColsArg != 1) + throw std::runtime_error("the argument must have exactly one column but has " + std::to_string(numColsArg) + + " columns"); + + const size_t numRowsArg = arg->getNumRows(); + + if (res == nullptr) + res = DataObjectFactory::create>(numRowsArg, 1, false); + + const VTArg *valuesArg = arg->getValues(); + VTRes *valuesRes = res->getValues(); + size_t numRowsRes = 0; + + for (size_t r = 0; r < numRowsArg; r++) { + if (*valuesArg == 1) { + *valuesRes = r; + valuesRes += res->getRowSkip(); + numRowsRes++; + } + valuesArg += arg->getRowSkip(); + } + + res->shrinkNumRows(numRowsRes); + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/ConvertPosListToBitmap.h b/src/runtime/local/kernels/ConvertPosListToBitmap.h new file mode 100644 index 000000000..0621caf60 --- /dev/null +++ b/src/runtime/local/kernels/ConvertPosListToBitmap.h @@ -0,0 +1,70 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +// **************************************************************************** +// Struct for partial template specialization +// **************************************************************************** + +template struct ConvertPosListToBitmap { + static void apply(DTRes *&res, const DTArg *arg, const size_t numRowsRes, DCTX(ctx)) = delete; +}; + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +void convertPosListToBitmap(DTRes *&res, const DTArg *arg, size_t numRowsRes, DCTX(ctx)) { + ConvertPosListToBitmap::apply(res, arg, numRowsRes, ctx); +} + +// **************************************************************************** +// (Partial) template specializations for different data/value types +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// DenseMatrix <- DenseMatrix +// ---------------------------------------------------------------------------- + +template struct ConvertPosListToBitmap, DenseMatrix> { + static void apply(DenseMatrix *&res, const DenseMatrix *arg, size_t numRowsRes, DCTX(ctx)) { + const size_t numColsArg = arg->getNumCols(); + if (numColsArg != 1) + throw std::runtime_error("the argument must have exactly one column but has " + std::to_string(numColsArg) + + " columns"); + + if (res == nullptr) + res = DataObjectFactory::create>(numRowsRes, 1, true); + + const VTArg *valuesArg = arg->getValues(); + VTRes *valuesRes = res->getValues(); + + for (size_t r = 0; r < arg->getNumRows(); r++) { + const size_t pos = *valuesArg; + if (pos > numRowsRes) + throw std::runtime_error("out-of-bounds access: trying to set position " + std::to_string(pos) + + " in a column matrix with " + std::to_string(numRowsRes) + " rows"); + valuesRes[pos * res->getRowSkip()] = 1; + valuesArg += arg->getRowSkip(); + } + } +}; \ No newline at end of file diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 2d795df07..c663cc499 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -423,7 +423,26 @@ ["Frame", ["DenseMatrix", "double"]], ["Frame", ["DenseMatrix", "int64_t"]], ["Frame", ["DenseMatrix", "uint64_t"]], - + [["Column", "double"], ["DenseMatrix", "double"]], + [["Column", "int64_t"], ["DenseMatrix", "int64_t"]], + [["Column", "uint64_t"], ["DenseMatrix", "uint64_t"]], + [["Column", "size_t"], ["DenseMatrix", "size_t"]], + [["Column", "std::string"], ["DenseMatrix", "std::string"]], + [["DenseMatrix", "double"], ["Column", "double"]], + [["DenseMatrix", "int64_t"], ["Column", "int64_t"]], + [["DenseMatrix", "uint64_t"], ["Column", "uint64_t"]], + [["DenseMatrix", "size_t"], ["Column", "size_t"]], + [["DenseMatrix", "std::string"], ["Column", "std::string"]], + [["Column", "double"], "Frame"], + [["Column", "int64_t"], "Frame"], + [["Column", "uint64_t"], "Frame"], + [["Column", "size_t"], "Frame"], + [["Column", "std::string"], "Frame"], + ["Frame", ["Column", "double"]], + ["Frame", ["Column", "int64_t"]], + ["Frame", ["Column", "uint64_t"]], + ["Frame", ["Column", "size_t"]], + ["Frame", ["Column", "std::string"]], [ ["DenseMatrix", "double"], ["DenseMatrix", "double"] @@ -980,7 +999,11 @@ ["size_t", ["DenseMatrix", "size_t"]], ["double", "Frame"], - ["int64_t", "Frame"] + ["int64_t", "Frame"], + + ["double", ["Column", "double"]], + ["int64_t", ["Column", "int64_t"]], + ["uint64_t", ["Column", "uint64_t"]] ] }, { @@ -6981,6 +7004,553 @@ [["CSRMatrix", "double"]], [["CSRMatrix", "int64_t"]] ] + }, + { + "kernelTemplate": { + "header": "ColSelectCmp.h", + "opName": "colSelectCmp", + "returnType": "void", + "templateParams": [ + { + "name": "DTResPos", + "isDataType": true + }, + { + "name": "DTLhsData", + "isDataType": true + }, + { + "name": "VTRhsData", + "isDataType": false + } + ], + "runtimeParams": [ + { + "type": "CmpOpCode", + "name": "opCode" + }, + { + "type": "DTResPos *&", + "name": "resPos" + }, + { + "type": "const DTLhsData *", + "name": "lhsData" + }, + { + "type": "VTRhsData", + "name": "rhsData" + } + ] + }, + "instantiations": [ + [["Column", "size_t"], ["Column", "double"], "double"], + [["Column", "size_t"], ["Column", "int64_t"], "int64_t"], + [["Column", "size_t"], ["Column", "uint64_t"], "uint64_t"], + [["Column", "size_t"], ["Column", "std::string"], "const char *"] + ], + "opCodes": ["EQ", "NEQ", "GT", "GE", "LT", "LE"] + }, + { + "kernelTemplate": { + "header": "ColProject.h", + "opName": "colProject", + "returnType": "void", + "templateParams": [ + { + "name": "DTResData", + "isDataType": true + }, + { + "name": "DTLhsData", + "isDataType": true + }, + { + "name": "DTRhsPos", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "DTResData *&", + "name": "resData" + }, + { + "type": "const DTLhsData *", + "name": "lhsData" + }, + { + "type": "const DTRhsPos *", + "name": "rhsPos" + } + ] + }, + "instantiations": [ + [["Column", "double"], ["Column", "double"], ["Column", "size_t"]], + [["Column", "int64_t"], ["Column", "int64_t"], ["Column", "size_t"]], + [["Column", "uint64_t"], ["Column", "uint64_t"], ["Column", "size_t"]], + [["Column", "std::string"], ["Column", "std::string"], ["Column", "size_t"]] + ] + }, + { + "kernelTemplate": { + "header": "ColIntersect.h", + "opName": "colIntersect", + "returnType": "void", + "templateParams": [ + { + "name": "DTResPos", + "isDataType": true + }, + { + "name": "DTLhsPos", + "isDataType": true + }, + { + "name": "DTRhsPos", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "DTResPos *&", + "name": "resPos" + }, + { + "type": "const DTLhsPos *", + "name": "lhsPos" + }, + { + "type": "const DTRhsPos *", + "name": "rhsPos" + } + ] + }, + "instantiations": [ + [["Column", "size_t"], ["Column", "size_t"], ["Column", "size_t"]] + ] + }, + { + "kernelTemplate": { + "header": "ColMerge.h", + "opName": "colMerge", + "returnType": "void", + "templateParams": [ + { + "name": "DTResPos", + "isDataType": true + }, + { + "name": "DTLhsPos", + "isDataType": true + }, + { + "name": "DTRhsPos", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "DTResPos *&", + "name": "resPos" + }, + { + "type": "const DTLhsPos *", + "name": "lhsPos" + }, + { + "type": "const DTRhsPos *", + "name": "rhsPos" + } + ] + }, + "instantiations": [ + [["Column", "size_t"], ["Column", "size_t"], ["Column", "size_t"]] + ] + }, + { + "kernelTemplate": { + "header": "ColJoin.h", + "opName": "colJoin", + "returnType": "void", + "templateParams": [ + { + "name": "DTResLhsPos", + "isDataType": true + }, + { + "name": "DTResRhsPos", + "isDataType": true + }, + { + "name": "DTLhsData", + "isDataType": true + }, + { + "name": "DTRhsData", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "DTResLhsPos *&", + "name": "resLhsPos" + }, + { + "type": "DTResRhsPos *&", + "name": "resRhsPos" + }, + { + "type": "const DTLhsData *", + "name": "lhsData" + }, + { + "type": "const DTRhsData *", + "name": "rhsData" + }, + { + "type": "int64_t", + "name": "numRes" + } + ] + }, + "instantiations": [ + [["Column", "size_t"], ["Column", "size_t"], ["Column", "double"], ["Column", "double"]], + [["Column", "size_t"], ["Column", "size_t"], ["Column", "int64_t"], ["Column", "int64_t"]], + [["Column", "size_t"], ["Column", "size_t"], ["Column", "uint64_t"], ["Column", "uint64_t"]], + [["Column", "size_t"], ["Column", "size_t"], ["Column", "std::string"], ["Column", "std::string"]] + ] + }, + { + "kernelTemplate": { + "header": "ColSemiJoin.h", + "opName": "colSemiJoin", + "returnType": "void", + "templateParams": [ + { + "name": "DTResLhsPos", + "isDataType": true + }, + { + "name": "DTLhsData", + "isDataType": true + }, + { + "name": "DTRhsData", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "DTResLhsPos *&", + "name": "resLhsPos" + }, + { + "type": "const DTLhsData *", + "name": "lhsData" + }, + { + "type": "const DTRhsData *", + "name": "rhsData" + }, + { + "type": "int64_t", + "name": "numRes" + } + ] + }, + "instantiations": [ + [["Column", "size_t"], ["Column", "double"], ["Column", "double"]], + [["Column", "size_t"], ["Column", "int64_t"], ["Column", "int64_t"]], + [["Column", "size_t"], ["Column", "uint64_t"], ["Column", "uint64_t"]], + [["Column", "size_t"], ["Column", "std::string"], ["Column", "std::string"]] + ] + }, + { + "kernelTemplate": { + "header": "ColGroupFirst.h", + "opName": "colGroupFirst", + "returnType": "void", + "templateParams": [ + { + "name": "DTResGrpIds", + "isDataType": true + }, + { + "name": "DTResReprPos", + "isDataType": true + }, + { + "name": "DTArgData", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "DTResGrpIds *&", + "name": "resGrpIds" + }, + { + "type": "DTResReprPos *&", + "name": "resReprPos" + }, + { + "type": "const DTArgData *", + "name": "argData" + } + ] + }, + "instantiations": [ + [["Column", "size_t"], ["Column", "size_t"], ["Column", "double"]], + [["Column", "size_t"], ["Column", "size_t"], ["Column", "int64_t"]], + [["Column", "size_t"], ["Column", "size_t"], ["Column", "uint64_t"]], + [["Column", "size_t"], ["Column", "size_t"], ["Column", "std::string"]] + ] + }, + { + "kernelTemplate": { + "header": "ColGroupNext.h", + "opName": "colGroupNext", + "returnType": "void", + "templateParams": [ + { + "name": "DTResGrpIds", + "isDataType": true + }, + { + "name": "DTResReprPos", + "isDataType": true + }, + { + "name": "DTArgData", + "isDataType": true + }, + { + "name": "DTArgGrpIds", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "DTResGrpIds *&", + "name": "resGrpIds" + }, + { + "type": "DTResReprPos *&", + "name": "resReprPos" + }, + { + "type": "const DTArgData *", + "name": "argData" + }, + { + "type": "const DTArgGrpIds *", + "name": "argGrpIds" + } + ] + }, + "instantiations": [ + [["Column", "size_t"], ["Column", "size_t"], ["Column", "double"], ["Column", "size_t"]], + [["Column", "size_t"], ["Column", "size_t"], ["Column", "int64_t"], ["Column", "size_t"]], + [["Column", "size_t"], ["Column", "size_t"], ["Column", "uint64_t"], ["Column", "size_t"]], + [["Column", "size_t"], ["Column", "size_t"], ["Column", "std::string"], ["Column", "size_t"]] + ] + }, + { + "kernelTemplate": { + "header": "ColCalcBinary.h", + "opName": "colCalcBinary", + "returnType": "void", + "templateParams": [ + { + "name": "DTResData", + "isDataType": true + }, + { + "name": "DTLhsData", + "isDataType": true + }, + { + "name": "DTRhsData", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "BinaryOpCode", + "name": "opCode" + }, + { + "type": "DTResData *&", + "name": "resData" + }, + { + "type": "const DTLhsData *", + "name": "lhsData" + }, + { + "type": "const DTRhsData *", + "name": "rhsData" + } + ] + }, + "instantiations": [ + [["Column", "double"], ["Column", "double"], ["Column", "double"]], + [["Column", "int64_t"], ["Column", "int64_t"], ["Column", "int64_t"]], + [["Column", "uint64_t"], ["Column", "uint64_t"], ["Column", "uint64_t"]] + ], + "opCodes": ["SUB", "MUL"] + }, + { + "kernelTemplate": { + "header": "ColAggAll.h", + "opName": "colAggAll", + "returnType": "void", + "templateParams": [ + { + "name": "DTRes", + "isDataType": true + }, + { + "name": "DTArg", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "AggOpCode", + "name": "opCode" + }, + { + "type": "DTRes *&", + "name": "res" + }, + { + "type": "const DTArg *", + "name": "arg" + } + ] + }, + "instantiations": [ + [["Column", "double"], ["Column", "double"]], + [["Column", "int64_t"], ["Column", "int64_t"]], + [["Column", "uint64_t"], ["Column", "uint64_t"]] + ], + "opCodes": ["SUM"] + }, + { + "kernelTemplate": { + "header": "ColAggGrp.h", + "opName": "colAggGrp", + "returnType": "void", + "templateParams": [ + { + "name": "DTRes", + "isDataType": true + }, + { + "name": "DTData", + "isDataType": true + }, + { + "name": "DTGrpIds", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "AggOpCode", + "name": "opCode" + }, + { + "type": "DTRes *&", + "name": "res" + }, + { + "type": "const DTData *", + "name": "data" + }, + { + "type": "const DTGrpIds *", + "name": "grpIds" + }, + { + "type": "size_t", + "name": "numDistinct" + } + ] + }, + "instantiations": [ + [["Column", "double"], ["Column", "double"], ["Column", "size_t"]], + [["Column", "int64_t"], ["Column", "int64_t"], ["Column", "size_t"]], + [["Column", "uint64_t"], ["Column", "uint64_t"], ["Column", "size_t"]] + ], + "opCodes": ["SUM"] + }, + { + "kernelTemplate": { + "header": "ConvertPosListToBitmap.h", + "opName": "convertPosListToBitmap", + "returnType": "void", + "templateParams": [ + { + "name": "DTRes", + "isDataType": true + }, + { + "name": "DTArg", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "DTRes *&", + "name": "res" + }, + { + "type": "const DTArg *", + "name": "arg" + }, + { + "type": "size_t", + "name": "numRowsRes" + } + ] + }, + "instantiations": [ + [["DenseMatrix", "size_t"], ["DenseMatrix", "size_t"]], + [["DenseMatrix", "int64_t"], ["DenseMatrix", "int64_t"]] + ] + }, + { + "kernelTemplate": { + "header": "ConvertBitmapToPosList.h", + "opName": "convertBitmapToPosList", + "returnType": "void", + "templateParams": [ + { + "name": "DTRes", + "isDataType": true + }, + { + "name": "DTArg", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "DTRes *&", + "name": "res" + }, + { + "type": "const DTArg *", + "name": "arg" + } + ] + }, + "instantiations": [ + [["DenseMatrix", "size_t"], ["DenseMatrix", "size_t"]], + [["DenseMatrix", "int64_t"], ["DenseMatrix", "int64_t"]] + ] } - ] diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ea548b82c..f05a7b2ed 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -46,6 +46,7 @@ set(TEST_SOURCES api/cli/scoping/ScopingTest.cpp api/cli/scriptargs/ScriptArgsTest.cpp api/cli/secondorder/SecondOrderTest.cpp + api/cli/sql/ColumnarTest.cpp api/cli/sql/SQLTest.cpp api/cli/sql/SQLResultTest.cpp api/cli/syntax/SyntaxTest.cpp @@ -97,11 +98,24 @@ set(TEST_SOURCES runtime/local/kernels/CastScaObjTest.cpp runtime/local/kernels/CheckEqTest.cpp runtime/local/kernels/CheckEqApproxTest.cpp + runtime/local/kernels/ColAggAllTest.cpp + runtime/local/kernels/ColAggGrpTest.cpp runtime/local/kernels/ColBindTest.cpp + runtime/local/kernels/ColCalcBinaryTest.cpp + runtime/local/kernels/ColGroupFirstTest.cpp + runtime/local/kernels/ColGroupNextTest.cpp + runtime/local/kernels/ColIntersectTest.cpp + runtime/local/kernels/ColJoinTest.cpp + runtime/local/kernels/ColMergeTest.cpp + runtime/local/kernels/ColProjectTest.cpp + runtime/local/kernels/ColSelectCmpTest.cpp + runtime/local/kernels/ColSemiJoinTest.cpp runtime/local/kernels/CondMatMatMatTest.cpp runtime/local/kernels/CondMatMatScaTest.cpp runtime/local/kernels/CondMatScaMatTest.cpp runtime/local/kernels/CondMatScaScaTest.cpp + runtime/local/kernels/ConvertBitmapToPosListTest.cpp + runtime/local/kernels/ConvertPosListToBitmapTest.cpp runtime/local/kernels/CreateFrameTest.cpp runtime/local/kernels/CTableTest.cpp runtime/local/kernels/DiagMatrixTest.cpp @@ -160,8 +174,6 @@ set(TEST_SOURCES runtime/local/kernels/TriTest.cpp runtime/local/vectorized/MultiThreadedKernelTest.cpp - -# runtime/local/kernels/Morphstore/ProjectTest.cpp ) if(USE_CUDA AND CMAKE_CUDA_COMPILER) diff --git a/test/api/cli/sql/ColumnarTest.cpp b/test/api/cli/sql/ColumnarTest.cpp new file mode 100644 index 000000000..21329cf78 --- /dev/null +++ b/test/api/cli/sql/ColumnarTest.cpp @@ -0,0 +1,142 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include + +#include + +#include +#include +#include + +const std::string dirPath = "test/api/cli/sql/"; + +/** + * @brief Checks if the specified `requiredOps` are contained and if the specified `disallowedOps` are not contained in + * the IR after lowering to columnar ops. + * + * Runs the specified script in DAPHNE (using --columnar if specified), checks for successful completion and correct + * result, and checks for the specified required and disallowed ops in the IR. Besides the explicitly specified + * disallowed ops, ops for the conversion between position lists and bit vectors are always implicitly disallowed, + * because they shall not remain in the IR. + * + * @param exp The expected output on `cout`. + * @param scriptFilePath The path to the DaphneDSL script to execute. + * @param useColumnar Whether to invoke DAPHNE with the `--columnar` flag. + * @param requiredOps List of DaphneIR op mnemonics that must be contained in the IR. + * @param disallowedOps List of DaphneIR op mnemonics that must not be contained in the IR. + */ +void checkOps(const std::string &exp, const std::string &scriptFilePath, bool useColumnar, + const std::vector &requiredOps, const std::vector &disallowedOps) { + // Ops converting between position lists and bit vectors. These should never remain after simplificatin rewrites. + std::vector cnvOps = {mlir::daphne::ConvertPosListToBitmapOp::getOperationName().str(), + mlir::daphne::ConvertBitmapToPosListOp::getOperationName().str()}; + + std::stringstream out; + std::stringstream err; + int status; + // TODO If we had arguments like `--no-columnar`, we could avoid the if-then-else here and simply use sth like + // `useColumnar ? "--columnar" : "--no-columnar"`. + if (useColumnar) + status = runDaphne(out, err, "--explain", "columnar", "--columnar", scriptFilePath.c_str()); + else + status = runDaphne(out, err, "--explain", "columnar", scriptFilePath.c_str()); + + CHECK(status == StatusCode::SUCCESS); + CHECK(out.str() == exp); + for (const std::string &opName : requiredOps) + CHECK_THAT(err.str(), Catch::Contains(opName)); + for (const std::string &opName : disallowedOps) + CHECK_THAT(err.str(), !Catch::Contains(opName)); + for (std::string &opName : cnvOps) + CHECK_THAT(err.str(), !Catch::Contains(opName)); +} + +TEST_CASE("columnar", TAG_SQL) { + // The operations to check for in the IR. The elements of this array correspond to the test cases columnar_*. Each + // element is a pair of (1) the decisive frame operations that should be used when running the script without + // --columnar, and (2) the decisive column operations that should be used when running the script with --columnar. + std::pair /*frmOps*/, std::vector /*colOps*/> expectedOps[] = { + // columnar_1 + {{}, {}}, + // columnar_2 + {{mlir::daphne::EwGeOp::getOperationName().str(), mlir::daphne::EwEqOp::getOperationName().str(), + mlir::daphne::EwLeOp::getOperationName().str(), mlir::daphne::EwAndOp::getOperationName().str(), + mlir::daphne::EwOrOp::getOperationName().str(), mlir::daphne::FilterRowOp::getOperationName().str()}, + {mlir::daphne::ColSelectGeOp::getOperationName().str(), mlir::daphne::ColSelectEqOp::getOperationName().str(), + mlir::daphne::ColSelectLeOp::getOperationName().str(), mlir::daphne::ColIntersectOp::getOperationName().str(), + mlir::daphne::ColMergeOp::getOperationName().str(), mlir::daphne::ColProjectOp::getOperationName().str()}}, + // columnar_3 + {{mlir::daphne::AllAggSumOp::getOperationName().str()}, + {mlir::daphne::ColAllAggSumOp::getOperationName().str()}}, + // columnar_4 + {{mlir::daphne::GroupOp::getOperationName().str()}, + {mlir::daphne::ColGroupFirstOp::getOperationName().str(), + mlir::daphne::ColGrpAggSumOp::getOperationName().str()}}, + // columnar_5 + {{mlir::daphne::GroupOp::getOperationName().str()}, + {mlir::daphne::ColGroupFirstOp::getOperationName().str(), + mlir::daphne::ColGroupNextOp::getOperationName().str(), + mlir::daphne::ColGrpAggSumOp::getOperationName().str()}}, + // columnar_6 + {{mlir::daphne::EwEqOp::getOperationName().str(), mlir::daphne::EwLtOp::getOperationName().str(), + mlir::daphne::FilterRowOp::getOperationName().str(), mlir::daphne::InnerJoinOp::getOperationName().str(), + mlir::daphne::AllAggSumOp::getOperationName().str()}, + {mlir::daphne::ColSelectEqOp::getOperationName().str(), mlir::daphne::ColSelectLtOp::getOperationName().str(), + mlir::daphne::ColJoinOp::getOperationName().str(), mlir::daphne::ColAllAggSumOp::getOperationName().str()}}, + // columnar_7 + // semi join is not generated by the SQL parser yet (it uses inner join instead) + {{mlir::daphne::EwEqOp::getOperationName().str(), mlir::daphne::EwLtOp::getOperationName().str(), + mlir::daphne::EwOrOp::getOperationName().str(), + mlir::daphne::InnerJoinOp::getOperationName().str(), /*mlir::daphne::SemiJoinOp::getOperationName().str(),*/ + mlir::daphne::GroupOp::getOperationName().str()}, + {mlir::daphne::ColSelectEqOp::getOperationName().str(), mlir::daphne::ColSelectLtOp::getOperationName().str(), + mlir::daphne::ColMergeOp::getOperationName().str(), + mlir::daphne::ColJoinOp::getOperationName().str(), /*mlir::daphne::ColSemiJoinOp::getOperationName().str(),*/ + mlir::daphne::ColGroupFirstOp::getOperationName().str(), + mlir::daphne::ColGrpAggSumOp::getOperationName().str()}} + // end + }; + + // We have multiple test queries, each of which is expressed in both DaphneDSL and SQL. + for (size_t i = 1; i <= 7; i++) { + for (std::string lang : {"daphnedsl", "sql"}) { + DYNAMIC_SECTION("columnar_" << i << "_" << lang << ".daphne") { + const std::string scriptFilePath = dirPath + "columnar_" + std::to_string(i) + "_" + lang + ".daphne"; + + // Read the expected result. + const std::string refFilePath = dirPath + "columnar_" + std::to_string(i) + ".txt"; + const std::string exp = readTextFile(refFilePath); + + // Check if DAPHNE runs successfully and produces the correct result both without and with --columnar. + // Here, we don't use --explain and expect cerr to be empty. + compareDaphneToStr(exp, scriptFilePath); + compareDaphneToStr(exp, scriptFilePath, "--columnar"); + + // Check if DAPHNE uses the expected operations and doesn't use the unexpected operations to run the + // query/script (frame ops without --columnar, column ops with --columnar). + std::vector frmOps = expectedOps[i - 1].first; + std::vector colOps = expectedOps[i - 1].second; + checkOps(exp, scriptFilePath, false, frmOps, colOps); + checkOps(exp, scriptFilePath, true, colOps, frmOps); + } + } + } +} \ No newline at end of file diff --git a/test/api/cli/sql/columnar_1.txt b/test/api/cli/sql/columnar_1.txt new file mode 100644 index 000000000..37150afe5 --- /dev/null +++ b/test/api/cli/sql/columnar_1.txt @@ -0,0 +1,4 @@ +Frame(3x2, [foo:int64_t, bar:std::string]) +1 x +2 y +3 z diff --git a/test/api/cli/sql/columnar_1_daphnedsl.daphne b/test/api/cli/sql/columnar_1_daphnedsl.daphne new file mode 100644 index 000000000..0e61c8149 --- /dev/null +++ b/test/api/cli/sql/columnar_1_daphnedsl.daphne @@ -0,0 +1,10 @@ +// Super simple query. + +r = { + "a": [1, 2, 3], + "b": ["x", "y", "z"] +}; + +res = createFrame(as.matrix(r[, "a"]), as.matrix(r[, "b"]), "foo", "bar"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_1_sql.daphne b/test/api/cli/sql/columnar_1_sql.daphne new file mode 100644 index 000000000..c1d6475f1 --- /dev/null +++ b/test/api/cli/sql/columnar_1_sql.daphne @@ -0,0 +1,12 @@ +// Super simple query. + +r = { + "a": [1, 2, 3], + "b": ["x", "y", "z"] +}; + +registerView("r", r); + +res = sql("SELECT r.a AS foo, r.b AS bar FROM r;"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_2.txt b/test/api/cli/sql/columnar_2.txt new file mode 100644 index 000000000..d79527936 --- /dev/null +++ b/test/api/cli/sql/columnar_2.txt @@ -0,0 +1,6 @@ +Frame(5x1, [foo:double]) +1.1 +3.3 +4.4 +6.6 +7.7 diff --git a/test/api/cli/sql/columnar_2_daphnedsl.daphne b/test/api/cli/sql/columnar_2_daphnedsl.daphne new file mode 100644 index 000000000..81ee494c4 --- /dev/null +++ b/test/api/cli/sql/columnar_2_daphnedsl.daphne @@ -0,0 +1,15 @@ +// One table, multiple filters (AND, OR, BETWEEN), including a filter on a string column. + +r = { + "a": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8], + "b": [ 10, 5, 20, 15, 5, 25, 20, 5], + "c": ["x", "x", "y", "x", "z", "z", "x", "y"] +}; + +rb = as.matrix(r[, "b"]); +rc = as.matrix(r[, "c"]); +selR = (rb >= 10 && rc == "x") || (rb >= 20 && rb <= 30); +r = createFrame(as.matrix(r[, "a"]), "foo"); +res = r[[selR, ]]; + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_2_sql.daphne b/test/api/cli/sql/columnar_2_sql.daphne new file mode 100644 index 000000000..45f9e427b --- /dev/null +++ b/test/api/cli/sql/columnar_2_sql.daphne @@ -0,0 +1,13 @@ +// One table, multiple filters (AND, OR, BETTWEEN), including a filter on a string column. + +r = { + "a": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8], + "b": [ 10, 5, 20, 15, 5, 25, 20, 5], + "c": ["x", "x", "y", "x", "z", "z", "x", "y"] +}; + +registerView("r", r); + +res = sql("SELECT r.a AS foo FROM r WHERE (r.b >= 10 AND r.c = 'x') OR (r.b BETWEEN 20 AND 30)"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_3.txt b/test/api/cli/sql/columnar_3.txt new file mode 100644 index 000000000..acbddf419 --- /dev/null +++ b/test/api/cli/sql/columnar_3.txt @@ -0,0 +1,2 @@ +Frame(1x1, [foo:int64_t]) +35 diff --git a/test/api/cli/sql/columnar_3_daphnedsl.daphne b/test/api/cli/sql/columnar_3_daphnedsl.daphne new file mode 100644 index 000000000..ad04a6014 --- /dev/null +++ b/test/api/cli/sql/columnar_3_daphnedsl.daphne @@ -0,0 +1,11 @@ +// One table, full aggregation. + +r = { + "a": [1.1, 2.2, 3.3], + "b": [ 10, 5, 20] +}; + +s = sum(as.matrix(r[, "b"])); +res = createFrame(as.matrix(s), "foo"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_3_sql.daphne b/test/api/cli/sql/columnar_3_sql.daphne new file mode 100644 index 000000000..477cdfd05 --- /dev/null +++ b/test/api/cli/sql/columnar_3_sql.daphne @@ -0,0 +1,12 @@ +// One table, full aggregation. + +r = { + "a": [1.1, 2.2, 3.3], + "b": [ 10, 5, 20] +}; + +registerView("r", r); + +res = sql("SELECT sum(r.b) AS foo FROM r;"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_4.txt b/test/api/cli/sql/columnar_4.txt new file mode 100644 index 000000000..6d97491a1 --- /dev/null +++ b/test/api/cli/sql/columnar_4.txt @@ -0,0 +1,4 @@ +Frame(3x2, [key:std::string, agg:int64_t]) +x 50 +y 25 +z 30 diff --git a/test/api/cli/sql/columnar_4_daphnedsl.daphne b/test/api/cli/sql/columnar_4_daphnedsl.daphne new file mode 100644 index 000000000..4c5c8ae41 --- /dev/null +++ b/test/api/cli/sql/columnar_4_daphnedsl.daphne @@ -0,0 +1,12 @@ +// One table, group-by on a single columns with a single aggregate. + +r = { + "a": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8], + "b": [ 10, 5, 20, 15, 5, 25, 20, 5], + "c": ["x", "x", "y", "x", "z", "z", "x", "y"] +}; + +res = groupSum(r, "c", "b"); +res = setColLabels(res, "key", "agg"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_4_sql.daphne b/test/api/cli/sql/columnar_4_sql.daphne new file mode 100644 index 000000000..7de5dad9d --- /dev/null +++ b/test/api/cli/sql/columnar_4_sql.daphne @@ -0,0 +1,13 @@ +// One table, group-by on a single columns with a single aggregate. + +r = { + "a": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8], + "b": [ 10, 5, 20, 15, 5, 25, 20, 5], + "c": ["x", "x", "y", "x", "z", "z", "x", "y"] +}; + +registerView("r", r); + +res = sql("SELECT r.c AS key, sum(r.b) AS agg FROM r GROUP by r.c;"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_5.txt b/test/api/cli/sql/columnar_5.txt new file mode 100644 index 000000000..0de5efc4d --- /dev/null +++ b/test/api/cli/sql/columnar_5.txt @@ -0,0 +1,5 @@ +Frame(4x4, [r.c:std::string, r.d:int64_t, sum(r.a):double, sum(r.b):int64_t]) +x 0 8.8 30 +x 1 6.6 20 +y 0 12.1 25 +z 1 12.1 30 diff --git a/test/api/cli/sql/columnar_5_daphnedsl.daphne b/test/api/cli/sql/columnar_5_daphnedsl.daphne new file mode 100644 index 000000000..af5fd8567 --- /dev/null +++ b/test/api/cli/sql/columnar_5_daphnedsl.daphne @@ -0,0 +1,20 @@ +// One table, group-by on multiple columns with multiple aggregates. + +r = { + "a": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8], + "b": [ 10, 5, 20, 15, 5, 25, 20, 5], + "c": ["x", "x", "y", "x", "z", "z", "x", "y"], + "d": [ 0, 1, 0, 1, 1, 1, 0, 0] +}; + +resA = groupSum(r, "c", "d", "a"); +resB = groupSum(r, "c", "d", "b"); +resAc = as.matrix(resA[, "c"]); +resAd = as.matrix(resA[, "d"]); +resBc = as.matrix(resB[, "c"]); +resBd = as.matrix(resB[, "d"]); +if (sum(resAc != resBc) || sum(resAd != resBd)) + stop(); +res = createFrame(resAc, resAd, as.matrix(resA[, "SUM(a)"]), as.matrix(resB[, "SUM(b)"]), "r.c", "r.d", "sum(r.a)", "sum(r.b)"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_5_sql.daphne b/test/api/cli/sql/columnar_5_sql.daphne new file mode 100644 index 000000000..c9a4268a5 --- /dev/null +++ b/test/api/cli/sql/columnar_5_sql.daphne @@ -0,0 +1,14 @@ +// One table, group-by on multiple columns with multiple aggregates. + +r = { + "a": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8], + "b": [ 10, 5, 20, 15, 5, 25, 20, 5], + "c": ["x", "x", "y", "x", "z", "z", "x", "y"], + "d": [ 0, 1, 0, 1, 1, 1, 0, 0] +}; + +registerView("r", r); + +res = sql("SELECT r.c, r.d, sum(r.a), sum(r.b) FROM r GROUP BY r.c, r.d;"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_6.txt b/test/api/cli/sql/columnar_6.txt new file mode 100644 index 000000000..c8f6ddf8d --- /dev/null +++ b/test/api/cli/sql/columnar_6.txt @@ -0,0 +1,2 @@ +Frame(1x1, [foo:int64_t]) +120 diff --git a/test/api/cli/sql/columnar_6_daphnedsl.daphne b/test/api/cli/sql/columnar_6_daphnedsl.daphne new file mode 100644 index 000000000..57d4d87a8 --- /dev/null +++ b/test/api/cli/sql/columnar_6_daphnedsl.daphne @@ -0,0 +1,25 @@ +// Join of two tables (both with filters), full aggregation. + +f = { + "a" : [ 1, 1, 4, 3, 2, 1, 1, 0, 1, 2, 0, 4], + "b" : [10, 30, 20, 10, 20, 30, 30, 20, 10, 20, 30, 10], + "did": [ 3, 0, 2, 1, 4, 3, 2, 2, 0, 1, 2, 1] +}; +d = { + "id": [ 0, 1, 2, 3, 4], + "x" : [10, 20, 10, 20, 30] +}; + +selD = as.matrix(d[, "x"]) == 10; +d = d[, "id"]; +d = d[[selD, ]]; + +selF = as.matrix(f[, "a"]) < 3; +f = cbind(f[, "b"], f[, "did"]); +f = f[[selF, ]]; + +f = innerJoin(f, d, "did", "id"); + +res = createFrame(as.matrix(sum(as.matrix(f[, "b"]))), "foo"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_6_sql.daphne b/test/api/cli/sql/columnar_6_sql.daphne new file mode 100644 index 000000000..3a699d2da --- /dev/null +++ b/test/api/cli/sql/columnar_6_sql.daphne @@ -0,0 +1,18 @@ +// Join of two tables (both with filters), full aggregation. + +f = { + "a" : [ 1, 1, 4, 3, 2, 1, 1, 0, 1, 2, 0, 4], + "b" : [10, 30, 20, 10, 20, 30, 30, 20, 10, 20, 30, 10], + "did": [ 3, 0, 2, 1, 4, 3, 2, 2, 0, 1, 2, 1] +}; +d = { + "id": [ 0, 1, 2, 3, 4], + "x" : [10, 20, 10, 20, 30] +}; + +registerView("f", f); +registerView("d", d); + +res = sql("SELECT sum(f.b) AS foo FROM f INNER JOIN d ON f.did = d.id WHERE d.x = 10 AND f.a < 3;"); + +print(res); \ No newline at end of file diff --git a/test/api/cli/sql/columnar_7.txt b/test/api/cli/sql/columnar_7.txt new file mode 100644 index 000000000..c3361be80 --- /dev/null +++ b/test/api/cli/sql/columnar_7.txt @@ -0,0 +1,3 @@ +Frame(2x2, [key:int64_t, agg:int64_t]) +0 30 +1 100 diff --git a/test/api/cli/sql/columnar_7_daphnedsl.daphne b/test/api/cli/sql/columnar_7_daphnedsl.daphne new file mode 100644 index 000000000..16c2c783e --- /dev/null +++ b/test/api/cli/sql/columnar_7_daphnedsl.daphne @@ -0,0 +1,43 @@ +// Join of three tables (each with filters, one semi-join), group-by on multiple columns with multiple aggregates, +// order-by on multiple columns, calculation in the projection (inside an aggregatin function). + +f = { + "a" : [ 1, 1, 4, 3, 2, 1, 1, 0, 1, 2, 0, 4], + "b" : [10, 30, 20, 10, 20, 30, 30, 20, 10, 20, 30, 10], + "d1id": [ 3, 0, 2, 1, 4, 3, 2, 2, 0, 1, 2, 1], + "d2id": [30, 20, 50, 40, 0, 60, 10, 0, 10, 40, 60, 50] +}; +d1 = { + "id": [ 0, 1, 2, 3, 4], + "x" : [10, 20, 10, 20, 30] +}; +d2 = { + "id": [ 0, 10, 20, 30, 40, 50, 60], + "y" : ["a", "b", "a", "a", "c", "b", "a"], + "z" : [1.1, 2.2, 1.1, 3.3, 2.2, 1.1, 2.2] +}; + +selD1 = as.matrix(d1[, "x"]) == 10; +d1 = d1[[selD1, ]]; + +selD2 = as.matrix(d2[, "y"]) == "b" || as.matrix(d2[, "z"]) == 1.1; +d2 = d2[, "id"]; +d2 = d2[[selD2, ]]; + +selF = as.matrix(f[, "a"]) < 3; +f = f[[selF, ]]; + +keys, tids = semiJoin(f, d2, "d2id", "id"); +f = cbind(f[, "a"], cbind(f[, "b"], f[, "d1id"])); +f = f[tids, ]; + +f = innerJoin(f, d1, "d1id", "id"); + +tmp = createFrame(as.matrix(f[, "b"]) + as.matrix(f[, "x"]), "tmp"); +f = cbind(f, tmp); + +res = groupSum(f, "a", "tmp"); +res = order(res, 0, true, false); +res = setColLabels(res, "key", "agg"); + +print(res); diff --git a/test/api/cli/sql/columnar_7_sql.daphne b/test/api/cli/sql/columnar_7_sql.daphne new file mode 100644 index 000000000..3bede4f6b --- /dev/null +++ b/test/api/cli/sql/columnar_7_sql.daphne @@ -0,0 +1,32 @@ +// Join of three tables (each with filters, one semi-join), group-by on multiple columns with multiple aggregates, +// order-by on multiple columns, calculation in the projection (inside an aggregatin function). + +f = { + "a" : [ 1, 1, 4, 3, 2, 1, 1, 0, 1, 2, 0, 4], + "b" : [10, 30, 20, 10, 20, 30, 30, 20, 10, 20, 30, 10], + "d1id": [ 3, 0, 2, 1, 4, 3, 2, 2, 0, 1, 2, 1], + "d2id": [30, 20, 50, 40, 0, 60, 10, 0, 10, 40, 60, 50] +}; +d1 = { + "id": [ 0, 1, 2, 3, 4], + "x" : [10, 20, 10, 20, 30] +}; +d2 = { + "id": [ 0, 10, 20, 30, 40, 50, 60], + "y" : ["a", "b", "a", "a", "c", "b", "a"], + "z" : [1.1, 2.2, 1.1, 3.3, 2.2, 1.1, 2.2] +}; + +registerView("f", f); +registerView("d1", d1); +registerView("d2", d2); + +res = sql(" + SELECT f.a AS key, sum(f.b + d1.x) AS agg + FROM f INNER JOIN d1 ON f.d1id = d1.id INNER JOIN d2 ON f.d2id = d2.id + WHERE d1.x = 10 AND (d2.y = 'b' OR d2.z = 1.1) AND f.a < 3 + GROUP BY f.a + ORDER BY f.a; +"); + +print(res); \ No newline at end of file diff --git a/test/runtime/local/kernels/CastObjScaTest.cpp b/test/runtime/local/kernels/CastObjScaTest.cpp index aa3e8975d..e1a0fcba5 100644 --- a/test/runtime/local/kernels/CastObjScaTest.cpp +++ b/test/runtime/local/kernels/CastObjScaTest.cpp @@ -115,3 +115,36 @@ TEMPLATE_TEST_CASE("castObjSca, frame to scalar, non-single-element", TAG_KERNEL CHECK_THROWS(res = castObjSca(arg, nullptr)); DataObjectFactory::destroy(argC0, arg); } + +TEMPLATE_TEST_CASE("castObjSca, column to scalar, single-element", TAG_KERNELS, double, float, int64_t, uint64_t, + int32_t, uint32_t) { + using VTRes = TestType; + + SECTION("Column to VTRes") { + auto arg = genGivenVals>(1, {static_cast(2)}); + VTRes exp = VTRes(2); + VTRes res = castObjSca>(arg, nullptr); + CHECK(res == exp); + DataObjectFactory::destroy(arg); + } + SECTION("Column to VTRes") { + auto arg = genGivenVals>(1, {static_cast(2.2)}); + VTRes exp = VTRes(2.2); + VTRes res = castObjSca>(arg, nullptr); + CHECK(res == exp); + DataObjectFactory::destroy(arg); + } +} + +TEMPLATE_TEST_CASE("castObjSca, column to scalar, non-single-element", TAG_KERNELS, double, int64_t, uint32_t) { + using VT = TestType; + + Column *arg = nullptr; + SECTION("zero-element") { arg = DataObjectFactory::create>(0, false); } + SECTION("multi-element (nx1)") { arg = genGivenVals>(2, {VT(1), VT(2)}); } + // 1xm column is not possible + // nxm column is not possible + VT res; + CHECK_THROWS(res = castObjSca>(arg, nullptr)); + DataObjectFactory::destroy(arg); +} diff --git a/test/runtime/local/kernels/CastObjTest.cpp b/test/runtime/local/kernels/CastObjTest.cpp index d66d748e0..c7eb934e0 100644 --- a/test/runtime/local/kernels/CastObjTest.cpp +++ b/test/runtime/local/kernels/CastObjTest.cpp @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -451,4 +452,129 @@ TEMPLATE_TEST_CASE("CastObj DenseMatrix to CSRMatrix", TAG_KERNELS, double, floa DataObjectFactory::destroy(m0, d0, res0); DataObjectFactory::destroy(m1, d1, res1); DataObjectFactory::destroy(m2, d2, res2); +} + +TEMPLATE_PRODUCT_TEST_CASE("castObj, column to matrix", TAG_KERNELS, (DenseMatrix), (double, int64_t, uint32_t)) { + using DTRes = TestType; + using VT = typename DTRes::VT; + + std::vector vals = {VT(0.0), VT(1.1), VT(2.2), VT(3.3)}; + + auto arg = genGivenVals>(vals.size(), vals); + auto exp = genGivenVals>(vals.size(), vals); + + DTRes *res = nullptr; + castObj>(res, arg, nullptr); + + CHECK(*res == *exp); + + DataObjectFactory::destroy(arg, exp, res); +} + +TEMPLATE_PRODUCT_TEST_CASE("castObj, matrix to column, single-column", TAG_KERNELS, (DenseMatrix), + (double, int64_t, uint32_t)) { + using DTArg = TestType; + using VT = typename DTArg::VT; + + std::vector vals = {VT(0.0), VT(1.1), VT(2.2), VT(3.3)}; + + auto arg = genGivenVals(vals.size(), vals); + auto exp = genGivenVals>(vals.size(), vals); + + Column *res = nullptr; + castObj, DTArg>(res, arg, nullptr); + + CHECK(*res == *exp); + + DataObjectFactory::destroy(arg, exp, res); +} + +TEMPLATE_PRODUCT_TEST_CASE("castObj, matrix to column, single-column, view", TAG_KERNELS, (DenseMatrix), + (double, int64_t, uint32_t)) { + using DTArg = TestType; + using VT = typename DTArg::VT; + + std::vector valsArgOrig = {VT(0.0), VT(1.1), VT(2.2), VT(3.3), VT(4.4), VT(5.5)}; + std::vector valsExp = {VT(3.3), VT(5.5)}; + + auto argOrig = genGivenVals(valsArgOrig.size() / 2, valsArgOrig); + auto arg = DataObjectFactory::create(argOrig, 1, 3, 1, 2); // view into argOrig + auto exp = genGivenVals>(valsExp.size(), valsExp); + + Column *res = nullptr; + castObj, DTArg>(res, arg, nullptr); + + CHECK(*res == *exp); + + DataObjectFactory::destroy(arg, exp, res); +} + +TEMPLATE_PRODUCT_TEST_CASE("castObj, matrix to column, multi-column", TAG_KERNELS, (DenseMatrix), + (double, int64_t, uint32_t)) { + using DTArg = TestType; + using VT = typename DTArg::VT; + + std::vector vals = {VT(0.0), VT(1.1), VT(2.2), VT(3.3)}; + + auto arg = genGivenVals(vals.size() / 2, vals); + + Column *res = nullptr; + CHECK_THROWS(castObj, DTArg>(res, arg, nullptr)); + + DataObjectFactory::destroy(arg); + if (res) + DataObjectFactory::destroy(res); +} + +TEMPLATE_TEST_CASE("castObj, column to frame", TAG_KERNELS, double, int64_t, uint32_t) { + using VT = TestType; + + std::vector vals = {VT(0.0), VT(1.1), VT(2.2), VT(3.3)}; + + auto arg = genGivenVals>(vals.size(), vals); + auto expC0 = genGivenVals>(vals.size(), vals); + std::vector expCs = {expC0}; + auto exp = DataObjectFactory::create(expCs, nullptr); + + Frame *res = nullptr; + castObj>(res, arg, nullptr); + + CHECK(*res == *exp); + + DataObjectFactory::destroy(arg, expC0, exp, res); +} + +TEMPLATE_TEST_CASE("castObj, frame to column, single-column", TAG_KERNELS, double, int64_t, uint32_t) { + using VT = TestType; + + std::vector vals = {VT(0.0), VT(1.1), VT(2.2), VT(3.3)}; + + auto argC0 = genGivenVals>(vals.size(), vals); + std::vector expCs = {argC0}; + auto arg = DataObjectFactory::create(expCs, nullptr); + auto exp = genGivenVals>(vals.size(), vals); + + Column *res = nullptr; + castObj, Frame>(res, arg, nullptr); + + CHECK(*res == *exp); + + DataObjectFactory::destroy(argC0, arg, exp, res); +} + +TEMPLATE_TEST_CASE("castObj, frame to column, multi-column", TAG_KERNELS, double, int64_t, uint32_t) { + using VT = TestType; + + std::vector vals = {VT(0.0), VT(1.1), VT(2.2), VT(3.3)}; + + auto argC0 = genGivenVals>(vals.size(), vals); + std::vector expCs = {argC0, argC0}; + auto arg = DataObjectFactory::create(expCs, nullptr); + + Column *res = nullptr; + CHECK_THROWS(castObj, Frame>(res, arg, nullptr)); + + DataObjectFactory::destroy(argC0, arg); + if (res) + DataObjectFactory::destroy(res); } \ No newline at end of file diff --git a/test/runtime/local/kernels/ColAggAllTest.cpp b/test/runtime/local/kernels/ColAggAllTest.cpp new file mode 100644 index 000000000..c12ffe9c4 --- /dev/null +++ b/test/runtime/local/kernels/ColAggAllTest.cpp @@ -0,0 +1,151 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include + +#define TEST_NAME(opName) "ColAggAll (" opName ")" +#define DATA_TYPES Column +#define NUM_VALUE_TYPES double, uint32_t, int8_t + +template +void checkColAggAllAndDestroy(AggOpCode opCode, const DTArg *arg, const DTRes *exp) { + DTRes *res = nullptr; + colAggAll(opCode, res, arg, nullptr); + CHECK(*res == *exp); + DataObjectFactory::destroy(arg, exp, res); +} + +// **************************************************************************** +// Valid arguments +// **************************************************************************** + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("sum"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + AggOpCode opCode = AggOpCode::SUM; + DT *arg = nullptr; + DT *exp = nullptr; + + SECTION("empty input") { + arg = DataObjectFactory::create
(0, false); + exp = genGivenVals
({AggOpCodeUtils::getNeutral(opCode)}); + } + SECTION("non-empty input") { + arg = genGivenVals
({VT(2), VT(1), VT(3)}); + exp = genGivenVals
({VT(6)}); + } + + checkColAggAllAndDestroy(opCode, arg, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("prod"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + AggOpCode opCode = AggOpCode::PROD; + DT *arg = nullptr; + DT *exp = nullptr; + + SECTION("empty input") { + arg = DataObjectFactory::create
(0, false); + exp = genGivenVals
({AggOpCodeUtils::getNeutral(opCode)}); + } + SECTION("non-empty input, without zero") { + arg = genGivenVals
({VT(2), VT(1), VT(0)}); + exp = genGivenVals
({VT(0)}); + } + SECTION("non-empty input, with zero") { + arg = genGivenVals
({VT(2), VT(1), VT(3)}); + exp = genGivenVals
({VT(6)}); + } + + checkColAggAllAndDestroy(opCode, arg, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("min"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + AggOpCode opCode = AggOpCode::MIN; + DT *arg = nullptr; + DT *exp = nullptr; + + SECTION("empty input") { + arg = DataObjectFactory::create
(0, false); + exp = genGivenVals
({AggOpCodeUtils::getNeutral(opCode)}); + } + SECTION("non-empty input") { + arg = genGivenVals
({VT(2), VT(1), VT(3), VT(2)}); + exp = genGivenVals
({VT(1)}); + } + + checkColAggAllAndDestroy(opCode, arg, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("max"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + AggOpCode opCode = AggOpCode::MAX; + DT *arg = nullptr; + DT *exp = nullptr; + + SECTION("empty input") { + arg = DataObjectFactory::create
(0, false); + exp = genGivenVals
({AggOpCodeUtils::getNeutral(opCode)}); + } + SECTION("non-empty input") { + arg = genGivenVals
({VT(2), VT(1), VT(3), VT(2)}); + exp = genGivenVals
({VT(3)}); + } + + checkColAggAllAndDestroy(opCode, arg, exp); +} + +// TODO IDXMIX +// TODO IDXMAX +// TODO MEAN +// TODO STDDEV +// TODO VAR + +// **************************************************************************** +// Invalid arguments +// **************************************************************************** + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("some invalid op-code"), TAG_KERNELS, (DATA_TYPES), (double)) { + using DT = TestType; + + auto arg = genGivenVals
({1}); + + DT *res = nullptr; + CHECK_THROWS(colAggAll(static_cast(999), res, arg, nullptr)); + + DataObjectFactory::destroy(arg); + if (res) + DataObjectFactory::destroy(res); +} \ No newline at end of file diff --git a/test/runtime/local/kernels/ColAggGrpTest.cpp b/test/runtime/local/kernels/ColAggGrpTest.cpp new file mode 100644 index 000000000..42f1232ff --- /dev/null +++ b/test/runtime/local/kernels/ColAggGrpTest.cpp @@ -0,0 +1,256 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include + +#define TEST_NAME(opName) "ColAggGrp (" opName ")" +#define DATA_TYPES Column +#define NUM_VALUE_TYPES double, uint32_t, int8_t + +template +void checkColAggGrpAndDestroy(AggOpCode opCode, const DTData *data, const DTGrpIds *grpIds, size_t numDistinct, + const DTRes *exp) { + DTRes *res = nullptr; + colAggGrp(opCode, res, data, grpIds, numDistinct, nullptr); + CHECK(*res == *exp); + DataObjectFactory::destroy(data, grpIds, exp, res); +} + +// **************************************************************************** +// Valid arguments +// **************************************************************************** + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("sum"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + using VTPos = size_t; + using DTPos = typename DTData::template WithValueType; + + DTData *data = nullptr; + DTPos *grpIds = nullptr; + size_t numDistinct = 0; + DTData *exp = nullptr; + + // Empty input. + SECTION("empty input") { + data = DataObjectFactory::create(0, false); + grpIds = DataObjectFactory::create(0, false); + numDistinct = 0; + exp = DataObjectFactory::create(0, false); + } + // Non-empty input. + // - Given n values, there could be: 1 group, k groups (1 < k < n), or n groups. + SECTION("non-empty input, 1 group") { + data = genGivenVals({VTData(2), VTData(1), VTData(3)}); + grpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0)}); + numDistinct = 1; + exp = genGivenVals({VTData(6)}); + } + SECTION("non-empty input, k groups") { + data = genGivenVals({VTData(2), VTData(1), VTData(4), VTData(1), VTData(3), VTData(2)}); + grpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(2), VTPos(1), VTPos(0), VTPos(1)}); + numDistinct = 3; + exp = genGivenVals({VTData(6), VTData(3), VTData(4)}); + } + SECTION("non-empty input, n groups") { + data = genGivenVals({VTData(2), VTData(1), VTData(3)}); + grpIds = genGivenVals({VTPos(1), VTPos(2), VTPos(0)}); + numDistinct = 3; + exp = genGivenVals({VTData(3), VTData(2), VTData(1)}); + } + + checkColAggGrpAndDestroy(AggOpCode::SUM, data, grpIds, numDistinct, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("prod"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + using VTPos = size_t; + using DTPos = typename DTData::template WithValueType; + + DTData *data = nullptr; + DTPos *grpIds = nullptr; + size_t numDistinct = 0; + DTData *exp = nullptr; + + // Empty input. + SECTION("empty input") { + data = DataObjectFactory::create(0, false); + grpIds = DataObjectFactory::create(0, false); + numDistinct = 0; + exp = DataObjectFactory::create(0, false); + } + // Non-empty input. + // - Given n values, there could be: 1 group, k groups (1 < k < n), or n groups. + SECTION("non-empty input, 1 group") { + data = genGivenVals({VTData(2), VTData(1), VTData(3)}); + grpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0)}); + numDistinct = 1; + exp = genGivenVals({VTData(6)}); + } + SECTION("non-empty input, k groups") { + data = genGivenVals({VTData(2), VTData(1), VTData(4), VTData(1), VTData(3), VTData(2)}); + grpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(2), VTPos(1), VTPos(0), VTPos(1)}); + numDistinct = 3; + exp = genGivenVals({VTData(6), VTData(2), VTData(4)}); + } + SECTION("non-empty input, n groups") { + data = genGivenVals({VTData(2), VTData(1), VTData(3)}); + grpIds = genGivenVals({VTPos(1), VTPos(2), VTPos(0)}); + numDistinct = 3; + exp = genGivenVals({VTData(3), VTData(2), VTData(1)}); + } + + checkColAggGrpAndDestroy(AggOpCode::PROD, data, grpIds, numDistinct, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("min"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + using VTPos = size_t; + using DTPos = typename DTData::template WithValueType; + + DTData *data = nullptr; + DTPos *grpIds = nullptr; + size_t numDistinct = 0; + DTData *exp = nullptr; + + // Empty input. + SECTION("empty input") { + data = DataObjectFactory::create(0, false); + grpIds = DataObjectFactory::create(0, false); + numDistinct = 0; + exp = DataObjectFactory::create(0, false); + } + // Non-empty input. + // - Given n values, there could be: 1 group, k groups (1 < k < n), or n groups. + SECTION("non-empty input, 1 group") { + data = genGivenVals({VTData(2), VTData(1), VTData(3)}); + grpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0)}); + numDistinct = 1; + exp = genGivenVals({VTData(1)}); + } + SECTION("non-empty input, k groups") { + data = genGivenVals({VTData(2), VTData(1), VTData(4), VTData(1), VTData(3), VTData(2)}); + grpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(2), VTPos(1), VTPos(0), VTPos(1)}); + numDistinct = 3; + exp = genGivenVals({VTData(1), VTData(1), VTData(4)}); + } + SECTION("non-empty input, n groups") { + data = genGivenVals({VTData(2), VTData(1), VTData(3)}); + grpIds = genGivenVals({VTPos(1), VTPos(2), VTPos(0)}); + numDistinct = 3; + exp = genGivenVals({VTData(3), VTData(2), VTData(1)}); + } + + checkColAggGrpAndDestroy(AggOpCode::MIN, data, grpIds, numDistinct, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("max"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + using VTPos = size_t; + using DTPos = typename DTData::template WithValueType; + + DTData *data = nullptr; + DTPos *grpIds = nullptr; + size_t numDistinct = 0; + DTData *exp = nullptr; + + // Empty input. + SECTION("empty input") { + data = DataObjectFactory::create(0, false); + grpIds = DataObjectFactory::create(0, false); + numDistinct = 0; + exp = DataObjectFactory::create(0, false); + } + // Non-empty input. + // - Given n values, there could be: 1 group, k groups (1 < k < n), or n groups. + SECTION("non-empty input, 1 group") { + data = genGivenVals({VTData(2), VTData(1), VTData(3)}); + grpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0)}); + numDistinct = 1; + exp = genGivenVals({VTData(3)}); + } + SECTION("non-empty input, k groups") { + data = genGivenVals({VTData(2), VTData(1), VTData(4), VTData(1), VTData(3), VTData(2)}); + grpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(2), VTPos(1), VTPos(0), VTPos(1)}); + numDistinct = 3; + exp = genGivenVals({VTData(3), VTData(2), VTData(4)}); + } + SECTION("non-empty input, n groups") { + data = genGivenVals({VTData(2), VTData(1), VTData(3)}); + grpIds = genGivenVals({VTPos(1), VTPos(2), VTPos(0)}); + numDistinct = 3; + exp = genGivenVals({VTData(3), VTData(2), VTData(1)}); + } + + checkColAggGrpAndDestroy(AggOpCode::MAX, data, grpIds, numDistinct, exp); +} + +// TODO IDXMIX +// TODO IDXMAX +// TODO MEAN +// TODO STDDEV +// TODO VAR + +// **************************************************************************** +// Invalid arguments +// **************************************************************************** + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("some invalid op-code"), TAG_KERNELS, (DATA_TYPES), (double)) { + using DTData = TestType; + using DTPos = typename DTData::template WithValueType; + + auto data = genGivenVals({1}); + auto grpIds = genGivenVals({0}); + size_t numDistinct = 1; + + DTData *res = nullptr; + CHECK_THROWS(colAggGrp(static_cast(999), res, data, grpIds, numDistinct, nullptr)); + + DataObjectFactory::destroy(data); + DataObjectFactory::destroy(grpIds); + if (res) + DataObjectFactory::destroy(res); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("any") ": size mismatch", TAG_KERNELS, (DATA_TYPES), (double)) { + using DTData = TestType; + using DTPos = typename DTData::template WithValueType; + + auto data = genGivenVals({1, 2, 3}); + auto grpIds = genGivenVals({0, 1}); + size_t numDistinct = 2; + + DTData *res = nullptr; + CHECK_THROWS(colAggGrp(AggOpCode::SUM, res, data, grpIds, numDistinct, nullptr)); + + DataObjectFactory::destroy(data, grpIds); + if (res) + DataObjectFactory::destroy(res); +} \ No newline at end of file diff --git a/test/runtime/local/kernels/ColCalcBinaryTest.cpp b/test/runtime/local/kernels/ColCalcBinaryTest.cpp new file mode 100644 index 000000000..cf545c467 --- /dev/null +++ b/test/runtime/local/kernels/ColCalcBinaryTest.cpp @@ -0,0 +1,676 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include + +#define TEST_NAME(opName) "ColCalcBinary (" opName ")" +#define DATA_TYPES Column +#define NUM_VALUE_TYPES double, uint32_t, int8_t +#define STR_VALUE_TYPES std::string + +template +void checkColCalcBinaryAndDestroy(BinaryOpCode opCode, const DTArg *lhs, const DTArg *rhs, const DTRes *exp) { + DTRes *res = nullptr; + colCalcBinary(opCode, res, lhs, rhs, nullptr); + CHECK(*res == *exp); + DataObjectFactory::destroy(lhs, rhs, exp, res); +} + +// **************************************************************************** +// Valid arguments +// **************************************************************************** + +// ---------------------------------------------------------------------------- +// Arithmetic +// ---------------------------------------------------------------------------- + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("add"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(1), VT(2), VT(3)}); + rhs = genGivenVals
({VT(0), VT(4), VT(6)}); + exp = genGivenVals
({VT(1), VT(6), VT(9)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::ADD, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("sub"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(1), VT(6), VT(9)}); + rhs = genGivenVals
({VT(1), VT(2), VT(3)}); + exp = genGivenVals
({VT(0), VT(4), VT(6)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::SUB, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("mul"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(1), VT(2), VT(3)}); + rhs = genGivenVals
({VT(0), VT(4), VT(6)}); + exp = genGivenVals
({VT(0), VT(8), VT(18)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::MUL, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("div"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(0), VT(8), VT(18)}); + rhs = genGivenVals
({VT(1), VT(2), VT(3)}); + exp = genGivenVals
({VT(0), VT(4), VT(6)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::DIV, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("pow"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(2), VT(2), VT(2)}); + rhs = genGivenVals
({VT(0), VT(1), VT(3)}); + exp = genGivenVals
({VT(1), VT(2), VT(8)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::POW, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("mod"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(8), VT(0), VT(2)}); + rhs = genGivenVals
({VT(3), VT(5), VT(3)}); + exp = genGivenVals
({VT(2), VT(0), VT(2)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::MOD, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("log"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(1), VT(2), VT(8)}); + rhs = genGivenVals
({VT(2), VT(2), VT(2)}); + exp = genGivenVals
({VT(0), VT(1), VT(3)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::LOG, lhs, rhs, exp); +} + +// ---------------------------------------------------------------------------- +// Comparisons +// ---------------------------------------------------------------------------- + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("eq"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(0), VT(1), VT(2), VT(3)}); + rhs = genGivenVals
({VT(0), VT(2), VT(2), VT(1)}); + exp = genGivenVals
({VT(1), VT(0), VT(1), VT(0)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::EQ, lhs, rhs, exp); +} +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("eq"), TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + using VTRes = int64_t; + using DTRes = typename DT::template WithValueType; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DTRes *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT("str0"), VT("str1"), VT("str2"), VT("str3")}); + rhs = genGivenVals
({VT("str0"), VT("str2"), VT("str2"), VT("str1")}); + exp = genGivenVals({VTRes(1), VTRes(0), VTRes(1), VTRes(0)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::EQ, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("neq"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(0), VT(1), VT(2), VT(3)}); + rhs = genGivenVals
({VT(0), VT(2), VT(2), VT(1)}); + exp = genGivenVals
({VT(0), VT(1), VT(0), VT(1)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::NEQ, lhs, rhs, exp); +} +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("neq"), TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + using VTRes = int64_t; + using DTRes = typename DT::template WithValueType; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DTRes *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT("str0"), VT("str1"), VT("str2"), VT("str3")}); + rhs = genGivenVals
({VT("str0"), VT("str2"), VT("str2"), VT("str1")}); + exp = genGivenVals({VTRes(0), VTRes(1), VTRes(0), VTRes(1)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::NEQ, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("lt"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(0), VT(1), VT(2), VT(3)}); + rhs = genGivenVals
({VT(0), VT(2), VT(2), VT(1)}); + exp = genGivenVals
({VT(0), VT(1), VT(0), VT(0)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::LT, lhs, rhs, exp); +} +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("lt"), TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + using VTRes = int64_t; + using DTRes = typename DT::template WithValueType; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DTRes *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT("str0"), VT("str1"), VT("str2"), VT("str3")}); + rhs = genGivenVals
({VT("str0"), VT("str2"), VT("str2"), VT("str1")}); + exp = genGivenVals({VTRes(0), VTRes(1), VTRes(0), VTRes(0)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::LT, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("le"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(0), VT(1), VT(2), VT(3)}); + rhs = genGivenVals
({VT(0), VT(2), VT(2), VT(1)}); + exp = genGivenVals
({VT(1), VT(1), VT(1), VT(0)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::LE, lhs, rhs, exp); +} +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("le"), TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + using VTRes = int64_t; + using DTRes = typename DT::template WithValueType; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DTRes *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT("str0"), VT("str1"), VT("str2"), VT("str3")}); + rhs = genGivenVals
({VT("str0"), VT("str2"), VT("str2"), VT("str1")}); + exp = genGivenVals({VTRes(1), VTRes(1), VTRes(1), VTRes(0)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::LE, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("gt"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(0), VT(1), VT(2), VT(3)}); + rhs = genGivenVals
({VT(0), VT(2), VT(2), VT(1)}); + exp = genGivenVals
({VT(0), VT(0), VT(0), VT(1)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::GT, lhs, rhs, exp); +} +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("gt"), TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + using VTRes = int64_t; + using DTRes = typename DT::template WithValueType; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DTRes *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT("str0"), VT("str1"), VT("str2"), VT("str3")}); + rhs = genGivenVals
({VT("str0"), VT("str2"), VT("str2"), VT("str1")}); + exp = genGivenVals({VTRes(0), VTRes(0), VTRes(0), VTRes(1)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::GT, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("ge"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(0), VT(1), VT(2), VT(3)}); + rhs = genGivenVals
({VT(0), VT(2), VT(2), VT(1)}); + exp = genGivenVals
({VT(1), VT(0), VT(1), VT(1)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::GE, lhs, rhs, exp); +} +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("ge"), TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + using VTRes = int64_t; + using DTRes = typename DT::template WithValueType; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DTRes *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT("str0"), VT("str1"), VT("str2"), VT("str3")}); + rhs = genGivenVals
({VT("str0"), VT("str2"), VT("str2"), VT("str1")}); + exp = genGivenVals({VTRes(1), VTRes(0), VTRes(1), VTRes(1)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::GE, lhs, rhs, exp); +} + +// ---------------------------------------------------------------------------- +// Min/max +// ---------------------------------------------------------------------------- + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("min"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(0), VT(1), VT(2), VT(3)}); + rhs = genGivenVals
({VT(0), VT(2), VT(2), VT(1)}); + exp = genGivenVals
({VT(0), VT(1), VT(2), VT(1)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::MIN, lhs, rhs, exp); +} +#if 0 // not supported yet +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("min"), TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT("str0"), VT("str1"), VT("str2"), VT("str3")}); + rhs = genGivenVals
({VT("str0"), VT("str2"), VT("str2"), VT("str1")}); + exp = genGivenVals
({VT("str0"), VT("str1"), VT("str2"), VT("str1")}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::MIN, lhs, rhs, exp); +} +#endif + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("max"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(0), VT(1), VT(2), VT(3)}); + rhs = genGivenVals
({VT(0), VT(2), VT(2), VT(1)}); + exp = genGivenVals
({VT(0), VT(2), VT(2), VT(3)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::MAX, lhs, rhs, exp); +} +#if 0 // not supported yet +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("max"), TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT("str0"), VT("str1"), VT("str2"), VT("str3")}); + rhs = genGivenVals
({VT("str0"), VT("str2"), VT("str2"), VT("str1")}); + exp = genGivenVals
({VT("str0"), VT("str2"), VT("str2"), VT("str3")}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::MAX, lhs, rhs, exp); +} +#endif + +// ---------------------------------------------------------------------------- +// Logical +// ---------------------------------------------------------------------------- + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("and"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(0), VT(1), VT(0), VT(1), VT(3), VT(1)}); + rhs = genGivenVals
({VT(0), VT(0), VT(1), VT(1), VT(0), VT(3)}); + exp = genGivenVals
({VT(0), VT(0), VT(0), VT(1), VT(0), VT(1)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::AND, lhs, rhs, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("or"), TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(0), VT(1), VT(0), VT(1), VT(3), VT(1)}); + rhs = genGivenVals
({VT(0), VT(0), VT(1), VT(1), VT(0), VT(3)}); + exp = genGivenVals
({VT(0), VT(1), VT(1), VT(1), VT(1), VT(1)}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::OR, lhs, rhs, exp); +} + +// ---------------------------------------------------------------------------- +// Strings +// ---------------------------------------------------------------------------- + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("concat"), TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + DT *lhs = nullptr; + DT *rhs = nullptr; + DT *exp = nullptr; + + SECTION("empty inputs") { + lhs = DataObjectFactory::create
(0, false); + rhs = DataObjectFactory::create
(0, false); + exp = DataObjectFactory::create
(0, false); + } + SECTION("non-empty inputs") { + lhs = genGivenVals
({VT(""), VT("abc"), VT(""), VT("abc")}); + rhs = genGivenVals
({VT(""), VT(""), VT("de"), VT("de")}); + exp = genGivenVals
({VT(""), VT("abc"), VT("de"), VT("abcde")}); + } + + checkColCalcBinaryAndDestroy(BinaryOpCode::CONCAT, lhs, rhs, exp); +} + +// **************************************************************************** +// Invalid arguments +// **************************************************************************** + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("some invalid op-code"), TAG_KERNELS, (DATA_TYPES), (double)) { + using DT = TestType; + + auto arg = genGivenVals
({1}); + + DT *res = nullptr; + CHECK_THROWS(colCalcBinary(static_cast(999), res, arg, arg, nullptr)); + + DataObjectFactory::destroy(arg); + if (res) + DataObjectFactory::destroy(res); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("any") ": size mismatch", TAG_KERNELS, (DATA_TYPES), (double)) { + using DT = TestType; + + auto lhs = genGivenVals
({1, 2, 3}); + auto rhs = genGivenVals
({4, 5}); + + DT *res = nullptr; + CHECK_THROWS(colCalcBinary(BinaryOpCode::ADD, res, lhs, rhs, nullptr)); + + DataObjectFactory::destroy(lhs, rhs); + if (res) + DataObjectFactory::destroy(res); +} \ No newline at end of file diff --git a/test/runtime/local/kernels/ColGroupFirstTest.cpp b/test/runtime/local/kernels/ColGroupFirstTest.cpp new file mode 100644 index 000000000..259f8f4c4 --- /dev/null +++ b/test/runtime/local/kernels/ColGroupFirstTest.cpp @@ -0,0 +1,175 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +#include + +#define TEST_NAME "ColGroupFirst" +#define DATA_TYPES Column +#define NUM_VALUE_TYPES double, uint32_t, int8_t +#define STR_VALUE_TYPES std::string + +// This is the same as "valid args, string data", just with numeric input data (keep consistent). +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *argData = nullptr; + DTPos *resGrpIdsExp = nullptr; + DTPos *resReprPosExp = nullptr; + + // Empty input data. + SECTION("empty argData") { + argData = DataObjectFactory::create(0, false); + resGrpIdsExp = DataObjectFactory::create(0, false); + resReprPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data, one distinct value. + SECTION("non-empty argData (one distinct value)") { + argData = genGivenVals({VTData(1.1), VTData(1.1), VTData(1.1)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(0), VTPos(0)}); + resReprPosExp = genGivenVals({VTPos(0)}); + } + // Non-empty input data, unique values. + // - The input values could be sorted or unsorted. + SECTION("non-empty argData (unique values, sorted)") { + argData = genGivenVals({VTData(1.1), VTData(2.2), VTData(3.3), VTData(4.4), VTData(5.5)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (unique values, unsorted)") { + argData = genGivenVals({VTData(3.3), VTData(1.1), VTData(2.2), VTData(5.5), VTData(4.4)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + // Non-empty input data, a few distinct values. + // - The input values could be + // - sorted + // - unsorted and clustered (all occurrences of each distinct value in a contiguous subsequence) + // - unsorted and unclustered (occurrences of a distict value can be separated from each other) + SECTION("non-empty argData (a few distinct values, sorted)") { + argData = genGivenVals( + {VTData(1.1), VTData(1.1), VTData(2.2), VTData(3.3), VTData(3.3), VTData(3.3), VTData(4.4)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(0), VTPos(1), VTPos(2), VTPos(2), VTPos(2), VTPos(3)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(3), VTPos(6)}); + } + SECTION("non-empty argData (a few distinct values, unsorted+clustered)") { + argData = genGivenVals( + {VTData(2.2), VTData(1.1), VTData(1.1), VTData(4.4), VTData(3.3), VTData(3.3), VTData(3.3)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(1), VTPos(2), VTPos(3), VTPos(3), VTPos(3)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (a few distinct values, unsorted+unclustered)") { + argData = genGivenVals( + {VTData(2.2), VTData(3.3), VTData(1.1), VTData(4.4), VTData(3.3), VTData(3.3), VTData(1.1)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(1), VTPos(1), VTPos(2)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3)}); + } + + DTPos *resGrpIdsFnd = nullptr; + DTPos *resReprPosFnd = nullptr; + colGroupFirst(resGrpIdsFnd, resReprPosFnd, argData, nullptr); + CHECK(*resGrpIdsFnd == *resGrpIdsExp); + CHECK(*resReprPosFnd == *resReprPosExp); + + DataObjectFactory::destroy(argData, resGrpIdsFnd, resReprPosFnd, resGrpIdsExp, resReprPosExp); +} + +// This is the same as "valid args, numeric data", just with string-valued input data (keep consistent). +// - Used the following regex replace: `VTData\((\d+\.\d+)\)` -> `VTData("str\1")` +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args, string data", TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *argData = nullptr; + DTPos *resGrpIdsExp = nullptr; + DTPos *resReprPosExp = nullptr; + + // Empty input data. + SECTION("empty argData") { + argData = DataObjectFactory::create(0, false); + resGrpIdsExp = DataObjectFactory::create(0, false); + resReprPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data, one distinct value. + SECTION("non-empty argData (one distinct value)") { + argData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str1.1")}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(0), VTPos(0)}); + resReprPosExp = genGivenVals({VTPos(0)}); + } + // Non-empty input data, unique values. + // - The input values could be sorted or unsorted. + SECTION("non-empty argData (unique values, sorted)") { + argData = genGivenVals( + {VTData("str1.1"), VTData("str2.2"), VTData("str3.3"), VTData("str4.4"), VTData("str5.5")}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (unique values, unsorted)") { + argData = genGivenVals( + {VTData("str3.3"), VTData("str1.1"), VTData("str2.2"), VTData("str5.5"), VTData("str4.4")}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + // Non-empty input data, a few distinct values. + // - The input values could be + // - sorted + // - unsorted and clustered (all occurrences of each distinct value in a contiguous subsequence) + // - unsorted and unclustered (occurrences of a distict value can be separated from each other) + SECTION("non-empty argData (a few distinct values, sorted)") { + argData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str2.2"), VTData("str3.3"), + VTData("str3.3"), VTData("str3.3"), VTData("str4.4")}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(0), VTPos(1), VTPos(2), VTPos(2), VTPos(2), VTPos(3)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(3), VTPos(6)}); + } + SECTION("non-empty argData (a few distinct values, unsorted+clustered)") { + argData = genGivenVals({VTData("str2.2"), VTData("str1.1"), VTData("str1.1"), VTData("str4.4"), + VTData("str3.3"), VTData("str3.3"), VTData("str3.3")}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(1), VTPos(2), VTPos(3), VTPos(3), VTPos(3)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (a few distinct values, unsorted+unclustered)") { + argData = genGivenVals({VTData("str2.2"), VTData("str3.3"), VTData("str1.1"), VTData("str4.4"), + VTData("str3.3"), VTData("str3.3"), VTData("str1.1")}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(1), VTPos(1), VTPos(2)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3)}); + } + + DTPos *resGrpIdsFnd = nullptr; + DTPos *resReprPosFnd = nullptr; + colGroupFirst(resGrpIdsFnd, resReprPosFnd, argData, nullptr); + CHECK(*resGrpIdsFnd == *resGrpIdsExp); + CHECK(*resReprPosFnd == *resReprPosExp); + + DataObjectFactory::destroy(argData, resGrpIdsFnd, resReprPosFnd, resGrpIdsExp, resReprPosExp); +} + +// There are no invalid input data for the colGroupFirst-kernel, so no tests with invalid data here. \ No newline at end of file diff --git a/test/runtime/local/kernels/ColGroupNextTest.cpp b/test/runtime/local/kernels/ColGroupNextTest.cpp new file mode 100644 index 000000000..24bad3ac5 --- /dev/null +++ b/test/runtime/local/kernels/ColGroupNextTest.cpp @@ -0,0 +1,384 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +#include + +#define TEST_NAME "ColGroupNext" +#define DATA_TYPES Column +#define NUM_VALUE_TYPES double, uint32_t, int8_t +#define STR_VALUE_TYPES std::string + +// This is the same as "valid args, string data", just with numeric input data (keep consistent). +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *argData = nullptr; + DTPos *argGrpIds = nullptr; + DTPos *resGrpIdsExp = nullptr; + DTPos *resReprPosExp = nullptr; + + // Empty input data. + + SECTION("empty argData, empty argGrpIds") { + argData = DataObjectFactory::create(0, false); + argGrpIds = DataObjectFactory::create(0, false); + resGrpIdsExp = DataObjectFactory::create(0, false); + resReprPosExp = DataObjectFactory::create(0, false); + } + + // Non-empty input data. + // - argData and argGrpIds must always have the same number of rows. + // - argGrpIds (the group ids of a previous grouping step on another column) could be + // - one distinct value + // - unqiue values + // - multiple distinct values + // for each of those, we test various cases of argData. + + // argGrpIds: one distinct value. + // - The previous grouping (argGrpIds) does not have an impact on the results, i.e., argData alone determines the + // results. + // - We reuse the test cases of the colGroupFirst-kernel. + + // Non-empty input data, argData (one distinct value), argGrpIds (one distinct value). + SECTION("non-empty argData (one distinct value), non-empty argGrpIds (one distinct value)") { + argData = genGivenVals({VTData(1.1), VTData(1.1), VTData(1.1)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(0), VTPos(0)}); + resReprPosExp = genGivenVals({VTPos(0)}); + } + // Non-empty input data, argData (unique values), argGrpIds (one distinct value). + // - The input values could be sorted or unsorted. + SECTION("non-empty argData (unique values, sorted), non-empty argGrpIds (one distinct value)") { + argData = genGivenVals({VTData(1.1), VTData(2.2), VTData(3.3), VTData(4.4), VTData(5.5)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (unique values, unsorted), non-empty argGrpIds (one distinct value)") { + argData = genGivenVals({VTData(3.3), VTData(1.1), VTData(2.2), VTData(5.5), VTData(4.4)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + // Non-empty input data, argData (multiple distinct values), argGrpIds (one distinct value). + // - The input values could be + // - sorted + // - unsorted and clustered (all occurrences of each distinct value in a contiguous subsequence) + // - unsorted and unclustered (occurrences of a distict value can be separated from each other) + SECTION("non-empty argData (multiple distinct values, sorted), non-empty argGrpIds (one distinct value)") { + argData = genGivenVals( + {VTData(1.1), VTData(1.1), VTData(2.2), VTData(3.3), VTData(3.3), VTData(3.3), VTData(4.4)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(0), VTPos(1), VTPos(2), VTPos(2), VTPos(2), VTPos(3)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(3), VTPos(6)}); + } + SECTION( + "non-empty argData (multiple distinct values, unsorted+clustered), non-empty argGrpIds (one distinct value)") { + argData = genGivenVals( + {VTData(2.2), VTData(1.1), VTData(1.1), VTData(4.4), VTData(3.3), VTData(3.3), VTData(3.3)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(1), VTPos(2), VTPos(3), VTPos(3), VTPos(3)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (multiple distinct values, unsorted+unclustered), non-empty argGrpIds (one distinct " + "value)") { + argData = genGivenVals( + {VTData(2.2), VTData(3.3), VTData(1.1), VTData(4.4), VTData(3.3), VTData(3.3), VTData(1.1)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(1), VTPos(1), VTPos(2)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3)}); + } + + // argGrpIds: unique values. + // - The resulting grouping is unique, irrespective of the values in argData. + // - We test just a few cases here: argData have one distinct value, multiple distinct values, or unique values. + + SECTION("non-empty argData (one distinct value), non-empty argGrpIds (unique values)") { + argData = genGivenVals({VTData(1.1), VTData(1.1), VTData(1.1), VTData(1.1), VTData(1.1)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (multiple distinct values), non-empty argGrpIds (unique values)") { + argData = genGivenVals({VTData(1.1), VTData(2.2), VTData(1.1), VTData(2.2), VTData(2.2)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (unique values), non-empty argGrpIds (unique values)") { + argData = genGivenVals({VTData(1.1), VTData(3.3), VTData(2.2), VTData(4.4), VTData(5.5)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + + // argGrpIds: multiple distinct values. + // - The input grouping is refined depending on the values in argData. + // - argData could have + // - one distinct value (no impact on the results) + // - one distinct value per input group (no impact on the results) + // - multiple distinct values per input group, not shared across input groups (refinement of the input groups) + // - multiple distinct values per input group, shared across input groups (refinement of the input groups) + SECTION("non-empty argData (one distinct value), non-empty argGrpIds (multiple distinct values)") { + argData = genGivenVals({VTData(1.1), VTData(1.1), VTData(1.1), VTData(1.1), VTData(1.1), VTData(1.1)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(1)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(1)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(4)}); + } + SECTION("non-empty argData (one distinct value per input group), non-empty argGrpIds (multiple distinct values)") { + argData = genGivenVals({VTData(1.1), VTData(2.2), VTData(1.1), VTData(1.1), VTData(3.3), VTData(2.2)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(1)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(1)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(4)}); + } + SECTION("non-empty argData (multiple distinct values per input group, not shared across input groups), non-empty " + "argGrpIds (multiple distinct values)") { + argData = genGivenVals({VTData(1.1), VTData(3.3), VTData(2.2), VTData(1.1), VTData(4.4), VTData(3.3)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(1)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(0), VTPos(3), VTPos(1)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(4)}); + } + SECTION("non-empty argData (multiple distinct values per input group, shared across input groups), non-empty " + "argGrpIds (multiple distinct values)") { + argData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2), VTData(1.1), VTData(2.2), VTData(3.3)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(3)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(0), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(4), VTPos(5)}); + } + + DTPos *resGrpIdsFnd = nullptr; + DTPos *resReprPosFnd = nullptr; + colGroupNext(resGrpIdsFnd, resReprPosFnd, argData, argGrpIds, nullptr); + CHECK(*resGrpIdsFnd == *resGrpIdsExp); + CHECK(*resReprPosFnd == *resReprPosExp); + + DataObjectFactory::destroy(argData, argGrpIds, resGrpIdsFnd, resReprPosFnd, resGrpIdsExp, resReprPosExp); +} + +// This is the same as "valid args, numric data", just with string-valued input data (keep consistent). +// - Used the following regex replace: `VTData\((\d+\.\d+)\)` -> `VTData("str\1")` +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *argData = nullptr; + DTPos *argGrpIds = nullptr; + DTPos *resGrpIdsExp = nullptr; + DTPos *resReprPosExp = nullptr; + + // Empty input data. + + SECTION("empty argData, empty argGrpIds") { + argData = DataObjectFactory::create(0, false); + argGrpIds = DataObjectFactory::create(0, false); + resGrpIdsExp = DataObjectFactory::create(0, false); + resReprPosExp = DataObjectFactory::create(0, false); + } + + // Non-empty input data. + // - argData and argGrpIds must always have the same number of rows. + // - argGrpIds (the group ids of a previous grouping step on another column) could be + // - one distinct value + // - unqiue values + // - multiple distinct values + // for each of those, we test various cases of argData. + + // argGrpIds: one distinct value. + // - The previous grouping (argGrpIds) does not have an impact on the results, i.e., argData alone determines the + // results. + // - We reuse the test cases of the colGroupFirst-kernel. + + // Non-empty input data, argData (one distinct value), argGrpIds (one distinct value). + SECTION("non-empty argData (one distinct value), non-empty argGrpIds (one distinct value)") { + argData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str1.1")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(0), VTPos(0)}); + resReprPosExp = genGivenVals({VTPos(0)}); + } + // Non-empty input data, argData (unique values), argGrpIds (one distinct value). + // - The input values could be sorted or unsorted. + SECTION("non-empty argData (unique values, sorted), non-empty argGrpIds (one distinct value)") { + argData = genGivenVals( + {VTData("str1.1"), VTData("str2.2"), VTData("str3.3"), VTData("str4.4"), VTData("str5.5")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (unique values, unsorted), non-empty argGrpIds (one distinct value)") { + argData = genGivenVals( + {VTData("str3.3"), VTData("str1.1"), VTData("str2.2"), VTData("str5.5"), VTData("str4.4")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + // Non-empty input data, argData (multiple distinct values), argGrpIds (one distinct value). + // - The input values could be + // - sorted + // - unsorted and clustered (all occurrences of each distinct value in a contiguous subsequence) + // - unsorted and unclustered (occurrences of a distict value can be separated from each other) + SECTION("non-empty argData (multiple distinct values, sorted), non-empty argGrpIds (one distinct value)") { + argData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str2.2"), VTData("str3.3"), + VTData("str3.3"), VTData("str3.3"), VTData("str4.4")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(0), VTPos(1), VTPos(2), VTPos(2), VTPos(2), VTPos(3)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(3), VTPos(6)}); + } + SECTION( + "non-empty argData (multiple distinct values, unsorted+clustered), non-empty argGrpIds (one distinct value)") { + argData = genGivenVals({VTData("str2.2"), VTData("str1.1"), VTData("str1.1"), VTData("str4.4"), + VTData("str3.3"), VTData("str3.3"), VTData("str3.3")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(1), VTPos(2), VTPos(3), VTPos(3), VTPos(3)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (multiple distinct values, unsorted+unclustered), non-empty argGrpIds (one distinct " + "value)") { + argData = genGivenVals({VTData("str2.2"), VTData("str3.3"), VTData("str1.1"), VTData("str4.4"), + VTData("str3.3"), VTData("str3.3"), VTData("str1.1")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0), VTPos(0)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(1), VTPos(1), VTPos(2)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3)}); + } + + // argGrpIds: unique values. + // - The resulting grouping is unique, irrespective of the values in argData. + // - We test just a few cases here: argData have one distinct value, multiple distinct values, or unique values. + + SECTION("non-empty argData (one distinct value), non-empty argGrpIds (unique values)") { + argData = genGivenVals( + {VTData("str1.1"), VTData("str1.1"), VTData("str1.1"), VTData("str1.1"), VTData("str1.1")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (multiple distinct values), non-empty argGrpIds (unique values)") { + argData = genGivenVals( + {VTData("str1.1"), VTData("str2.2"), VTData("str1.1"), VTData("str2.2"), VTData("str2.2")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty argData (unique values), non-empty argGrpIds (unique values)") { + argData = genGivenVals( + {VTData("str1.1"), VTData("str3.3"), VTData("str2.2"), VTData("str4.4"), VTData("str5.5")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + + // argGrpIds: multiple distinct values. + // - The input grouping is refined depending on the values in argData. + // - argData could have + // - one distinct value (no impact on the results) + // - one distinct value per input group (no impact on the results) + // - multiple distinct values per input group, not shared across input groups (refinement of the input groups) + // - multiple distinct values per input group, shared across input groups (refinement of the input groups) + SECTION("non-empty argData (one distinct value), non-empty argGrpIds (multiple distinct values)") { + argData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str1.1"), VTData("str1.1"), + VTData("str1.1"), VTData("str1.1")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(1)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(1)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(4)}); + } + SECTION("non-empty argData (one distinct value per input group), non-empty argGrpIds (multiple distinct values)") { + argData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str1.1"), VTData("str1.1"), + VTData("str3.3"), VTData("str2.2")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(1)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(1)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(4)}); + } + SECTION("non-empty argData (multiple distinct values per input group, not shared across input groups), non-empty " + "argGrpIds (multiple distinct values)") { + argData = genGivenVals({VTData("str1.1"), VTData("str3.3"), VTData("str2.2"), VTData("str1.1"), + VTData("str4.4"), VTData("str3.3")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(1)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(0), VTPos(3), VTPos(1)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(4)}); + } + SECTION("non-empty argData (multiple distinct values per input group, shared across input groups), non-empty " + "argGrpIds (multiple distinct values)") { + argData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str2.2"), VTData("str1.1"), + VTData("str2.2"), VTData("str3.3")}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1), VTPos(0), VTPos(0), VTPos(2), VTPos(3)}); + resGrpIdsExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(0), VTPos(3), VTPos(4)}); + resReprPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(4), VTPos(5)}); + } + + DTPos *resGrpIdsFnd = nullptr; + DTPos *resReprPosFnd = nullptr; + colGroupNext(resGrpIdsFnd, resReprPosFnd, argData, argGrpIds, nullptr); + CHECK(*resGrpIdsFnd == *resGrpIdsExp); + CHECK(*resReprPosFnd == *resReprPosExp); + + DataObjectFactory::destroy(argData, argGrpIds, resGrpIdsFnd, resReprPosFnd, resGrpIdsExp, resReprPosExp); +} + +// We only use numeric value types for the input data here, since this test case is mainly about the input sizes and +// the basic functionality for string value types has been tested above. +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": invalid args", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = ssize_t; + using DTPos = typename TestType::template WithValueType; + + DTData *argData = nullptr; + DTPos *argGrpIds = nullptr; + + // argData and argGrpIds must have the same number of elements. + + // One input empty, the other input non-empty. + SECTION("empty argData, non-empty argGrpIds") { + argData = DataObjectFactory::create(0, false); + argGrpIds = genGivenVals({0}); + } + SECTION("non-empty argData, empty argGrpIds") { + argData = genGivenVals({VTData(1.1)}); + argGrpIds = DataObjectFactory::create(0, false); + } + // Both inputs non-empty, but mismatching sizes. + SECTION("non-empty argData, non-empty argGrpPos (mismatching sizes)") { + argData = genGivenVals({VTData(1.1), VTData(2.2), VTData(3.3)}); + argGrpIds = genGivenVals({VTPos(0), VTPos(1)}); + } + + DTPos *resGrpIdsFnd = nullptr; + DTPos *resReprPosFnd = nullptr; + CHECK_THROWS(colGroupNext(resGrpIdsFnd, resReprPosFnd, argData, argGrpIds, nullptr)); + + DataObjectFactory::destroy(argData, argGrpIds); + if (resGrpIdsFnd) + DataObjectFactory::destroy(resGrpIdsFnd); + if (resReprPosFnd) + DataObjectFactory::destroy(resReprPosFnd); +} \ No newline at end of file diff --git a/test/runtime/local/kernels/ColIntersectTest.cpp b/test/runtime/local/kernels/ColIntersectTest.cpp new file mode 100644 index 000000000..b06be8e28 --- /dev/null +++ b/test/runtime/local/kernels/ColIntersectTest.cpp @@ -0,0 +1,109 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +#define TEST_NAME "ColIntersect" +#define DATA_TYPES Column +#define VALUE_TYPES int64_t, uint32_t, int8_t, size_t + +// The left-hand-side (lhs) and right-hand-side (rhs) input positions for the colIntersect-kernel must both be sorted +// and unqiue. For performance reasons, the kernel does not check if this requirement is fulfilled. Thus, we do not test +// if unsorted or non-unqiue inputs are detected. Other than unsorted or non-unique inputs, there are no invalid inputs. +// Thus, we don't test any invalid inputs here. + +// The colIntersect-kernel is meant to work on positions (not on data). Thus, we only test with integral value types. + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args", TAG_KERNELS, (DATA_TYPES), (VALUE_TYPES)) { + using DTPos = TestType; + + DTPos *lhsPos = nullptr; + DTPos *rhsPos = nullptr; + DTPos *resPosExp = nullptr; + + // At least one empty input, hence empty result. + SECTION("empty lhsPos, empty rhsPos") { + lhsPos = DataObjectFactory::create(0, false); + rhsPos = DataObjectFactory::create(0, false); + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("empty lhsPos, non-empty rhsPos") { + lhsPos = DataObjectFactory::create(0, false); + rhsPos = genGivenVals({0, 1, 3}); + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsPos, empty rhsPos") { + lhsPos = genGivenVals({0, 1, 3}); + rhsPos = DataObjectFactory::create(0, false); + resPosExp = DataObjectFactory::create(0, false); + } + // Two non-empty inputs, with empty result (i.e., lhsPos and rhsPos are disjoint). + // - All lhs positions could be before all rhs positions ("lhs-before-rhs") or vice versa ("rhs-before-rhs"). + // - The ranges of the lhs and rhs positions could be "overlapping". + SECTION("non-empty lhsPos, non-empty rhsPos, empty resPos (lhs-before-rhs)") { + lhsPos = genGivenVals({0, 1, 3}); + rhsPos = genGivenVals({7, 8, 10, 12}); + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsPos, non-empty rhsPos, empty resPos (rhs-before-lhs)") { + lhsPos = genGivenVals({7, 8, 10, 12}); + rhsPos = genGivenVals({0, 1, 3}); + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsPos, non-empty rhsPos, empty resPos (overlapping)") { + lhsPos = genGivenVals({0, 3, 8, 10}); + rhsPos = genGivenVals({1, 7, 12}); + resPosExp = DataObjectFactory::create(0, false); + } + // Two non-empty inputs, with non-empty result (i.e., lhsPos and rhsPos are not disjoint). + // - The lhs and rhs positions could be the "same". + // - The lhs positions could be a subset of the rhs positions ("lhs-in-rhs") or vice versa ("rhs-in-lhs"). + // - None of the inputs could a subset of the other one but they could still be "overlapping". + SECTION("non-empty lhsPos, non-empty rhsPos, non-empty resPos (same)") { + lhsPos = genGivenVals({0, 1, 3}); + rhsPos = genGivenVals({0, 1, 3}); + resPosExp = genGivenVals({0, 1, 3}); + } + SECTION("non-empty lhsPos, non-empty rhsPos, non-empty resPos (lhs-in-rhs)") { + lhsPos = genGivenVals({0, 3}); + rhsPos = genGivenVals({0, 1, 3}); + resPosExp = genGivenVals({0, 3}); + } + SECTION("non-empty lhsPos, non-empty rhsPos, non-empty resPos (rhs-in-lhs)") { + lhsPos = genGivenVals({0, 1, 3}); + rhsPos = genGivenVals({0, 3}); + resPosExp = genGivenVals({0, 3}); + } + SECTION("non-empty lhsPos, non-empty rhsPos, non-empty resPos (overlapping)") { + lhsPos = genGivenVals({0, 1}); + rhsPos = genGivenVals({1, 3}); + resPosExp = genGivenVals({1}); + } + + DTPos *resPosFnd = nullptr; + colIntersect(resPosFnd, lhsPos, rhsPos, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsPos, rhsPos, resPosExp, resPosFnd); +} \ No newline at end of file diff --git a/test/runtime/local/kernels/ColJoinTest.cpp b/test/runtime/local/kernels/ColJoinTest.cpp new file mode 100644 index 000000000..a85778c13 --- /dev/null +++ b/test/runtime/local/kernels/ColJoinTest.cpp @@ -0,0 +1,554 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +#include + +// These test cases are essentially the same as in ColSemiJoinTest.cpp (keep consistent). +// - The only difference is that the colJoin-kernel has the additional result resRhsPos. +// - Note that we assume the rhsData to be unique (primary key), so we don't consider expanding joins here. + +#define TEST_NAME "ColJoin" +#define DATA_TYPES Column +#define NUM_VALUE_TYPES double, uint32_t, int8_t +#define STR_VALUE_TYPES std::string + +// This is the same as "valid args, string data", just with numeric input data (keep consistent). +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + DTData *rhsData = nullptr; + DTPos *resLhsPosExp = nullptr; + DTPos *resRhsPosExp = nullptr; + + // Empty input data. + SECTION("empty lhsData, empty rhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = DataObjectFactory::create(0, false); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("empty lhsData, non-empty rhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData, empty rhsData") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + rhsData = DataObjectFactory::create(0, false); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + + // Non-empty input data. + // - The lhsData is assumed to be a foreign key. + // - The lhsData could be unique or non-unique. + // - The lhsData could be sorted or unsorted. + // - The rhsData is assumed to be a primary key. + // - The rhsData is assumed to be unique. + // - The rhsData could be sorted or unsorted. + // - Matches: the lhsData could contain + // - the values in rhsData (all) + // - a subset of the values in rhsData (subset) + // - a superset of the values in rhsData (superset) + // - no values in rhsData (none) + + // (all) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2), VTData(4.4), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(1), VTPos(2), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2), VTData(4.4), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(2), VTPos(1), VTPos(1)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2), VTData(1.1), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1), VTPos(0), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2), VTData(1.1), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(0), VTPos(1)}); + } + // (subset) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(1.1)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(0)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(1.1)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(0)}); + } + // (superset) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4), VTData(5.5)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4), VTData(5.5)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData(5.5), VTData(1.1), VTData(4.4), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData(5.5), VTData(1.1), VTData(4.4), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2), VTData(4.4), VTData(4.4), VTData(5.5)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(1), VTPos(2), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2), VTData(4.4), VTData(4.4), VTData(5.5)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(2), VTPos(1), VTPos(1)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData(5.5), VTData(1.1), VTData(4.4), VTData(2.2), VTData(1.1), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3), VTPos(4), VTPos(5)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1), VTPos(0), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData(5.5), VTData(1.1), VTData(4.4), VTData(2.2), VTData(1.1), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3), VTPos(4), VTPos(5)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(0), VTPos(1)}); + } + // (none) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(22.2), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(22.2), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(44.4), VTData(22.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(44.4), VTData(22.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(11.1), VTData(22.2), VTData(44.4), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(11.1), VTData(22.2), VTData(44.4), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(44.4), VTData(222.2), VTData(11.1), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(44.4), VTData(22.2), VTData(11.1), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + + DTPos *resLhsPosFnd = nullptr; + DTPos *resRhsPosFnd = nullptr; + colJoin(resLhsPosFnd, resRhsPosFnd, lhsData, rhsData, -1, nullptr); + CHECK(*resLhsPosFnd == *resLhsPosExp); + CHECK(*resRhsPosFnd == *resRhsPosExp); + + DataObjectFactory::destroy(lhsData, rhsData, resLhsPosExp, resRhsPosExp, resLhsPosFnd, resRhsPosFnd); +} + +// This is the same as "valid args, numeric data", just with string-valued input data (keep consistent). +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args, string data", TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + DTData *rhsData = nullptr; + DTPos *resLhsPosExp = nullptr; + DTPos *resRhsPosExp = nullptr; + + // Empty input data. + SECTION("empty lhsData, empty rhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = DataObjectFactory::create(0, false); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("empty lhsData, non-empty rhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData, empty rhsData") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + rhsData = DataObjectFactory::create(0, false); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + + // Non-empty input data. + // - The lhsData is assumed to be a foreign key. + // - The lhsData could be unique or non-unique. + // - The lhsData could be sorted or unsorted. + // - The rhsData is assumed to be a primary key. + // - The rhsData is assumed to be unique. + // - The rhsData could be sorted or unsorted. + // - Matches: the lhsData could contain + // - the values in rhsData (all) + // - a subset of the values in rhsData (subset) + // - a superset of the values in rhsData (superset) + // - no values in rhsData (none) + + // (all) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals( + {VTData("str1.1"), VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(1), VTPos(2), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals( + {VTData("str1.1"), VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(2), VTPos(1), VTPos(1)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals( + {VTData("str1.1"), VTData("str4.4"), VTData("str2.2"), VTData("str1.1"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1), VTPos(0), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals( + {VTData("str1.1"), VTData("str4.4"), VTData("str2.2"), VTData("str1.1"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(0), VTPos(1)}); + } + // (subset) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str1.1")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(0)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str1.1")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(0)}); + } + // (superset) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), VTData("str5.5")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), VTData("str5.5")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData("str5.5"), VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData("str5.5"), VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), + VTData("str4.4"), VTData("str5.5")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(1), VTPos(2), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), + VTData("str4.4"), VTData("str5.5")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(0), VTPos(2), VTPos(1), VTPos(1)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData("str5.5"), VTData("str1.1"), VTData("str4.4"), VTData("str2.2"), + VTData("str1.1"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3), VTPos(4), VTPos(5)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(2), VTPos(1), VTPos(0), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData("str5.5"), VTData("str1.1"), VTData("str4.4"), VTData("str2.2"), + VTData("str1.1"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3), VTPos(4), VTPos(5)}); + resRhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(0), VTPos(1)}); + } + // (none) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData("str11.1"), VTData("str22.2"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData("str11.1"), VTData("str22.2"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData("str11.1"), VTData("str44.4"), VTData("str22.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData("str11.1"), VTData("str44.4"), VTData("str22.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals( + {VTData("str11.1"), VTData("str11.1"), VTData("str22.2"), VTData("str44.4"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals( + {VTData("str11.1"), VTData("str11.1"), VTData("str22.2"), VTData("str44.4"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals( + {VTData("str11.1"), VTData("str44.4"), VTData("str222.2"), VTData("str11.1"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals( + {VTData("str11.1"), VTData("str44.4"), VTData("str22.2"), VTData("str11.1"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = DataObjectFactory::create(0, false); + resRhsPosExp = DataObjectFactory::create(0, false); + } + + DTPos *resLhsPosFnd = nullptr; + DTPos *resRhsPosFnd = nullptr; + colJoin(resLhsPosFnd, resRhsPosFnd, lhsData, rhsData, -1, nullptr); + CHECK(*resLhsPosFnd == *resLhsPosExp); + CHECK(*resRhsPosFnd == *resRhsPosExp); + + DataObjectFactory::destroy(lhsData, rhsData, resLhsPosExp, resRhsPosExp, resLhsPosFnd, resRhsPosFnd); +} + +// The only possible invalid arg would be a too low numRes. We don't test this case here. \ No newline at end of file diff --git a/test/runtime/local/kernels/ColMergeTest.cpp b/test/runtime/local/kernels/ColMergeTest.cpp new file mode 100644 index 000000000..656bae895 --- /dev/null +++ b/test/runtime/local/kernels/ColMergeTest.cpp @@ -0,0 +1,102 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +#define TEST_NAME "ColMerge" +#define DATA_TYPES Column +#define VALUE_TYPES int64_t, uint32_t, int8_t, size_t + +// The left-hand-side (lhs) and right-hand-side (rhs) input positions for the colMerge-kernel must both be sorted +// and unqiue. For performance reasons, the kernel does not check if this requirement is fulfilled. Thus, we do not test +// if unsorted or non-unqiue inputs are detected. Other than unsorted or non-unique inputs, there are no invalid inputs. +// Thus, we don't test any invalid inputs here. + +// The colMerge-kernel is meant to work on positions (not on data). Thus, we only test with integral value types. + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args", TAG_KERNELS, (DATA_TYPES), (VALUE_TYPES)) { + using DTPos = TestType; + + DTPos *lhsPos = nullptr; + DTPos *rhsPos = nullptr; + DTPos *resPosExp = nullptr; + + // At least one empty input. + SECTION("empty lhsPos, empty rhsPos") { + lhsPos = DataObjectFactory::create(0, false); + rhsPos = DataObjectFactory::create(0, false); + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("empty lhsPos, non-empty rhsPos") { + lhsPos = DataObjectFactory::create(0, false); + rhsPos = genGivenVals({0, 1, 3}); + resPosExp = genGivenVals({0, 1, 3}); + } + SECTION("non-empty lhsPos, empty rhsPos") { + lhsPos = genGivenVals({0, 1, 3}); + rhsPos = DataObjectFactory::create(0, false); + resPosExp = genGivenVals({0, 1, 3}); + } + // Two non-empty inputs, hence non-empty result. + // - The lhs and rhs positions could be the "same". + // - The lhs positions could be a subset of the rhs positions ("lhs-in-rhs") or vice versa ("rhs-in-lhs"). + // - All lhs positions could be before all rhs positions ("lhs-before-rhs") or vice versa ("rhs-before-rhs"). + // - None of the inputs could a subset of the other one but they could still be "overlapping". + SECTION("non-empty lhsPos, non-empty rhsPos (same)") { + lhsPos = genGivenVals({0, 1, 3}); + rhsPos = genGivenVals({0, 1, 3}); + resPosExp = genGivenVals({0, 1, 3}); + } + SECTION("non-empty lhsPos, non-empty rhsPos (lhs-in-rhs)") { + lhsPos = genGivenVals({0, 3}); + rhsPos = genGivenVals({0, 1, 3}); + resPosExp = genGivenVals({0, 1, 3}); + } + SECTION("non-empty lhsPos, non-empty rhsPos (rhs-in-lhs)") { + lhsPos = genGivenVals({0, 1, 3}); + rhsPos = genGivenVals({0, 3}); + resPosExp = genGivenVals({0, 1, 3}); + } + SECTION("non-empty lhsPos, non-empty rhsPos (lhs-before-rhs)") { + lhsPos = genGivenVals({0, 1, 3}); + rhsPos = genGivenVals({7, 8, 10, 12}); + resPosExp = genGivenVals({0, 1, 3, 7, 8, 10, 12}); + } + SECTION("non-empty lhsPos, non-empty rhsPos (rhs-before-lhs)") { + lhsPos = genGivenVals({7, 8, 10, 12}); + rhsPos = genGivenVals({0, 1, 3}); + resPosExp = genGivenVals({0, 1, 3, 7, 8, 10, 12}); + } + SECTION("non-empty lhsPos, non-empty rhsPos (overlapping)") { + lhsPos = genGivenVals({0, 3, 8, 10}); + rhsPos = genGivenVals({1, 7, 12}); + resPosExp = genGivenVals({0, 1, 3, 7, 8, 10, 12}); + } + + DTPos *resPosFnd = nullptr; + colMerge(resPosFnd, lhsPos, rhsPos, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsPos, rhsPos, resPosExp, resPosFnd); +} \ No newline at end of file diff --git a/test/runtime/local/kernels/ColProjectTest.cpp b/test/runtime/local/kernels/ColProjectTest.cpp new file mode 100644 index 000000000..7caa76ab1 --- /dev/null +++ b/test/runtime/local/kernels/ColProjectTest.cpp @@ -0,0 +1,309 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +#include + +#define TEST_NAME "ColProject" +#define DATA_TYPES Column +#define NUM_VALUE_TYPES double, uint32_t, int8_t +#define STR_VALUE_TYPES std::string + +// TODO make #cols in genGivenVals optional (but still consistent with the other data types, maybe take the #elems as +// #rows, if #rows not specified) + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + DTPos *rhsPos = nullptr; + DTData *resDataExp = nullptr; + + // Empty input positions. + // - The input data could be empty or non-empty, but the concrete input data doesn't matter. + SECTION("empty data, empty positions") { + lhsData = DataObjectFactory::create(0, false); + rhsPos = DataObjectFactory::create(0, false); + resDataExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty data, empty positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = DataObjectFactory::create(0, false); + resDataExp = DataObjectFactory::create(0, false); + } + // Non-empty input positions. + // - The input data must be non-empty, but the concrete input data doesn't matter. + // - The input positions are characterized by the following properties, which can be freely combined: + // - They can contain "all" or just a "subset" of the valid positions. + // - They can be "unique" (no repetitions) or "non-unique" (repetitions). + // - They can be "sorted" or "unsorted". + SECTION("non-empty data, non-empty (all, unique, sorted) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({0, 1, 2, 3}); + resDataExp = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + } + SECTION("non-empty data, non-empty (all, unique, unsorted) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({0, 3, 1, 2}); + resDataExp = genGivenVals({VTData(100.0), VTData(103.3), VTData(101.1), VTData(102.2)}); + } + SECTION("non-empty data, non-empty (all, non-unique, sorted) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({0, 1, 1, 2, 3, 3, 3}); + resDataExp = genGivenVals( + {VTData(100.0), VTData(101.1), VTData(101.1), VTData(102.2), VTData(103.3), VTData(103.3), VTData(103.3)}); + } + SECTION("non-empty data, non-empty (all, non-unique, unsorted) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({0, 3, 3, 1, 2, 1, 3}); + resDataExp = genGivenVals( + {VTData(100.0), VTData(103.3), VTData(103.3), VTData(101.1), VTData(102.2), VTData(101.1), VTData(103.3)}); + } + SECTION("non-empty data, non-empty (subset, unique, sorted) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({2, 3}); + resDataExp = genGivenVals({VTData(102.2), VTData(103.3)}); + } + SECTION("non-empty data, non-empty (subset, unique, unsorted) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({3, 2}); + resDataExp = genGivenVals({VTData(103.3), VTData(102.2)}); + } + SECTION("non-empty data, non-empty (subset, non-unique, sorted) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({2, 3, 3, 3}); + resDataExp = genGivenVals({VTData(102.2), VTData(103.3), VTData(103.3), VTData(103.3)}); + } + SECTION("non-empty data, non-empty (subset, non-unique, unsorted) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({3, 3, 2, 3}); + resDataExp = genGivenVals({VTData(103.3), VTData(103.3), VTData(102.2), VTData(103.3)}); + } + + DTData *resDataFnd = nullptr; + colProject(resDataFnd, lhsData, rhsPos, nullptr); + // TODO check that resDataFnd is not nullptr anymore (maybe do that in all kernel test cases) + CHECK(*resDataFnd == *resDataExp); + + DataObjectFactory::destroy(lhsData, rhsPos, resDataExp, resDataFnd); +} + +// This is the same as "valid args, numeric data", just with string-valued input data. +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args, string data", TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DTData = TestType; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + DTPos *rhsPos = nullptr; + DTData *resDataExp = nullptr; + + // Empty input positions. + // - The input data could be empty or non-empty, but the concrete input data doesn't matter. + SECTION("empty data, empty positions") { + lhsData = DataObjectFactory::create(0, false); + rhsPos = DataObjectFactory::create(0, false); + resDataExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty data, empty positions") { + lhsData = genGivenVals({"str100.0", "str101.1", "str102.2", "str103.3"}); + rhsPos = DataObjectFactory::create(0, false); + resDataExp = DataObjectFactory::create(0, false); + } + // Non-empty input positions. + // - The input data must be non-empty, but the concrete input data doesn't matter. + // - The input positions are characterized by the following properties, which can be freely combined: + // - They can contain "all" or just a "subset" of the valid positions. + // - They can be "unique" (no repetitions) or "non-unique" (repetitions). + // - They can be "sorted" or "unsorted". + SECTION("non-empty data, non-empty (all, unique, sorted) positions") { + lhsData = genGivenVals({"str100.0", "str101.1", "str102.2", "str103.3"}); + rhsPos = genGivenVals({0, 1, 2, 3}); + resDataExp = genGivenVals({"str100.0", "str101.1", "str102.2", "str103.3"}); + } + SECTION("non-empty data, non-empty (all, unique, unsorted) positions") { + lhsData = genGivenVals({"str100.0", "str101.1", "str102.2", "str103.3"}); + rhsPos = genGivenVals({0, 3, 1, 2}); + resDataExp = genGivenVals({"str100.0", "str103.3", "str101.1", "str102.2"}); + } + SECTION("non-empty data, non-empty (all, non-unique, sorted) positions") { + lhsData = genGivenVals({"str100.0", "str101.1", "str102.2", "str103.3"}); + rhsPos = genGivenVals({0, 1, 1, 2, 3, 3, 3}); + resDataExp = + genGivenVals({"str100.0", "str101.1", "str101.1", "str102.2", "str103.3", "str103.3", "str103.3"}); + } + SECTION("non-empty data, non-empty (all, non-unique, unsorted) positions") { + lhsData = genGivenVals({"str100.0", "str101.1", "str102.2", "str103.3"}); + rhsPos = genGivenVals({0, 3, 3, 1, 2, 1, 3}); + resDataExp = + genGivenVals({"str100.0", "str103.3", "str103.3", "str101.1", "str102.2", "str101.1", "str103.3"}); + } + SECTION("non-empty data, non-empty (subset, unique, sorted) positions") { + lhsData = genGivenVals({"str100.0", "str101.1", "str102.2", "str103.3"}); + rhsPos = genGivenVals({2, 3}); + resDataExp = genGivenVals({"str102.2", "str103.3"}); + } + SECTION("non-empty data, non-empty (subset, unique, unsorted) positions") { + lhsData = genGivenVals({"str100.0", "str101.1", "str102.2", "str103.3"}); + rhsPos = genGivenVals({3, 2}); + resDataExp = genGivenVals({"str103.3", "str102.2"}); + } + SECTION("non-empty data, non-empty (subset, non-unique, sorted) positions") { + lhsData = genGivenVals({"str100.0", "str101.1", "str102.2", "str103.3"}); + rhsPos = genGivenVals({2, 3, 3, 3}); + resDataExp = genGivenVals({"str102.2", "str103.3", "str103.3", "str103.3"}); + } + SECTION("non-empty data, non-empty (subset, non-unique, unsorted) positions") { + lhsData = genGivenVals({"str100.0", "str101.1", "str102.2", "str103.3"}); + rhsPos = genGivenVals({3, 3, 2, 3}); + resDataExp = genGivenVals({"str103.3", "str103.3", "str102.2", "str103.3"}); + } + + DTData *resDataFnd = nullptr; + colProject(resDataFnd, lhsData, rhsPos, nullptr); + // TODO check that resDataFnd is not nullptr anymore (maybe do that in all kernel test cases) + CHECK(*resDataFnd == *resDataExp); + + DataObjectFactory::destroy(lhsData, rhsPos, resDataExp, resDataFnd); +} + +// We only use numeric value types for the input data here, since this test case is mainly about the input positions and +// the basic functionality for string value types has been tested above. +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": invalid args, unsigned positions", TAG_KERNELS, (DATA_TYPES), + (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + DTPos *rhsPos = nullptr; + + // Empty input data, any non-empty input positions are invalid. + SECTION("empty data, non-empty positions") { + lhsData = DataObjectFactory::create(0, false); + rhsPos = genGivenVals({0}); + } + // Non-empty input data, invalid non-empty positions. + // - The input data are non-empty, but the concrete input data doesn't matter. + // - The invalid input positions could be characterized by the following properties, which can be freely combined: + // - The positions could be "one-too-high" or "far-too-high". + // - There could be "1-of-1 invalid" position or "1-of-many" invalid positions (other options are possible, too, + // but we don't check them here). + SECTION("non-empty data, invalid non-empty (one-too-high, 1-of-1) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({4}); + } + SECTION("non-empty data, invalid non-empty (one-too-high, 1-of-many) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({0, 4, 2}); + } + SECTION("non-empty data, invalid non-empty (far-too-high, 1-of-1) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({1000000}); + } + SECTION("non-empty data, invalid non-empty (far-too-high, 1-of-many) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({0, 1000000, 2}); + } + + DTData *resDataFnd = nullptr; + CHECK_THROWS(colProject(resDataFnd, lhsData, rhsPos, nullptr)); + + DataObjectFactory::destroy(lhsData, rhsPos); + if (resDataFnd) + DataObjectFactory::destroy(resDataFnd); +} + +// We only use numeric value types for the input data here, since this test case is mainly about the input positions and +// the basic functionality for string value types has been tested above. +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": invalid args, signed positions", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = ssize_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + DTPos *rhsPos = nullptr; + + // Empty input data, any non-empty input positions are invalid. + SECTION("empty data, non-empty positions") { + lhsData = DataObjectFactory::create(0, false); + rhsPos = genGivenVals({0}); + } + // Non-empty input data, invalid non-empty positions. + // - The input data are non-empty, but the concrete input data doesn't matter. + // - The invalid input positions could be characterized by the following properties, which can be freely combined: + // - The positions could be "one-too-high", "far-too-high" or "one-too-low" (negative, -1), or "far-too-low" + // (negative). + // - There could be "1-of-1 invalid" position or "1-of-many" invalid positions (other options are possible, too, + // but we don't check them here). + SECTION("non-empty data, invalid non-empty (one-too-high, 1-of-1) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({4}); + } + SECTION("non-empty data, invalid non-empty (one-too-high, 1-of-many) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({0, 4, 2}); + } + SECTION("non-empty data, invalid non-empty (far-too-high, 1-of-1) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({1000000}); + } + SECTION("non-empty data, invalid non-empty (far-too-high, 1-of-many) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({0, 1000000, 2}); + } + SECTION("non-empty data, invalid non-empty (one-too-low, 1-of-1) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({-1}); + } + SECTION("non-empty data, invalid non-empty (one-too-low, 1-of-many) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({0, -1, 2}); + } + SECTION("non-empty data, invalid non-empty (far-too-low, 1-of-1) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({-1000000}); + } + SECTION("non-empty data, invalid non-empty (far-too-low, 1-of-many) positions") { + lhsData = genGivenVals({VTData(100.0), VTData(101.1), VTData(102.2), VTData(103.3)}); + rhsPos = genGivenVals({0, -1000000, 2}); + } + + DTData *resDataFnd = nullptr; + CHECK_THROWS(colProject(resDataFnd, lhsData, rhsPos, nullptr)); + + DataObjectFactory::destroy(lhsData, rhsPos); + if (resDataFnd) + DataObjectFactory::destroy(resDataFnd); +} \ No newline at end of file diff --git a/test/runtime/local/kernels/ColSelectCmpTest.cpp b/test/runtime/local/kernels/ColSelectCmpTest.cpp new file mode 100644 index 000000000..45fd1297a --- /dev/null +++ b/test/runtime/local/kernels/ColSelectCmpTest.cpp @@ -0,0 +1,543 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +#include + +#define TEST_NAME(opName) "ColSelectCmp (" opName ")" +#define DATA_TYPES Column +#define NUM_VALUE_TYPES double, uint32_t, int8_t +#define STR_VALUE_TYPES std::string + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("eq") ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = VTData(123); + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(2.2); + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(1.1); + resPosExp = genGivenVals({0, 3}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(1.1)}); + rhsData = VTData(1.1); + resPosExp = genGivenVals({0, 1, 2}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::EQ, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} + +// This is the same as "valid args, numeric data", just with string-valued input data. +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("eq") ": valid args, string data", TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = "str123"; + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str2.2"; + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str1.1"; + resPosExp = genGivenVals({0, 3}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({"str1.1", "str1.1", "str1.1"}); + rhsData = "str1.1"; + resPosExp = genGivenVals({0, 1, 2}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::EQ, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("neq") ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), + (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = VTData(123); + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(1.1)}); + rhsData = VTData(1.1); + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(1.1); + resPosExp = genGivenVals({1, 2}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(2.2); + resPosExp = genGivenVals({0, 1, 2, 3}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::NEQ, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} + +// This is the same as "valid args, numeric data", just with string-valued input data. +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("neq") ": valid args, string data", TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = "str123"; + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({"str1.1", "str1.1", "str1.1"}); + rhsData = "str1.1"; + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str1.1"; + resPosExp = genGivenVals({1, 2}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str2.2"; + resPosExp = genGivenVals({0, 1, 2, 3}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::NEQ, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("gt") ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = VTData(123); + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(6.6); + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(2.2); + resPosExp = genGivenVals({1, 2}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(0.0); + resPosExp = genGivenVals({0, 1, 2, 3}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::GT, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} + +// This is the same as "valid args, numeric data", just with string-valued input data. +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("gt") ": valid args, string data", TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = "str123"; + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str6.6"; + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str2.2"; + resPosExp = genGivenVals({1, 2}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str0.0"; + resPosExp = genGivenVals({0, 1, 2, 3}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::GT, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("ge") ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = VTData(123); + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(6.6); + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(3.3); + resPosExp = genGivenVals({1, 2}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(1.1); + resPosExp = genGivenVals({0, 1, 2, 3}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::GE, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} + +// This is the same as "valid args, numeric data", just with string-valued input data. +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("ge") ": valid args, string data", TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = "str123"; + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str6.6"; + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str3.3"; + resPosExp = genGivenVals({1, 2}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str1.1"; + resPosExp = genGivenVals({0, 1, 2, 3}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::GE, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("lt") ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = VTData(123); + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(1.1); + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(3.3); + resPosExp = genGivenVals({0, 3}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(6.6); + resPosExp = genGivenVals({0, 1, 2, 3}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::LT, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} + +// This is the same as "valid args, numeric data", just with string-valued input data. +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("lt") ": valid args, string data", TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = "str123"; + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str1.1"; + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str3.3"; + resPosExp = genGivenVals({0, 3}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str6.6"; + resPosExp = genGivenVals({0, 1, 2, 3}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::LT, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("le") ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = VTData(123); + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(0.0); + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(3.3); + resPosExp = genGivenVals({0, 2, 3}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({VTData(1.1), VTData(5.5), VTData(3.3), VTData(1.1)}); + rhsData = VTData(5.5); + resPosExp = genGivenVals({0, 1, 2, 3}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::LE, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} + +// This is the same as "valid args, numeric data", just with string-valued input data. +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("le") ": valid args, string data", TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + VTData rhsData; + DTPos *resPosExp = nullptr; + + // Empty input data -> empty result positions. + SECTION("empty lhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = "str123"; + resPosExp = DataObjectFactory::create(0, false); + } + // Non-empty input data. + // - "no", "some", or "all" of the lhs input values could match the rhs input value. + SECTION("non-empty lhsData (no matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str0.0"; + resPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (some matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str3.3"; + resPosExp = genGivenVals({0, 2, 3}); + } + SECTION("non-empty lhsData (all matches)") { + lhsData = genGivenVals({"str1.1", "str5.5", "str3.3", "str1.1"}); + rhsData = "str5.5"; + resPosExp = genGivenVals({0, 1, 2, 3}); + } + + DTPos *resPosFnd = nullptr; + colSelectCmp(CmpOpCode::LE, resPosFnd, lhsData, rhsData, nullptr); + CHECK(*resPosFnd == *resPosExp); + + DataObjectFactory::destroy(lhsData, resPosExp, resPosFnd); +} \ No newline at end of file diff --git a/test/runtime/local/kernels/ColSemiJoinTest.cpp b/test/runtime/local/kernels/ColSemiJoinTest.cpp new file mode 100644 index 000000000..d4cee9564 --- /dev/null +++ b/test/runtime/local/kernels/ColSemiJoinTest.cpp @@ -0,0 +1,478 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +#include + +// These test cases are essentially the same as in ColJoinTest.cpp (keep consistent). +// - The only difference is that the colJoin-kernel has the additional result resRhsPos. +// - Note that we assume the rhsData to be unique (primary key), so we don't consider expanding joins here. + +#define TEST_NAME "ColSemiJoin" +#define DATA_TYPES Column +#define NUM_VALUE_TYPES double, uint32_t, int8_t +#define STR_VALUE_TYPES std::string + +// This is the same as "valid args, string data", just with numeric input data (keep consistent). +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args, numeric data", TAG_KERNELS, (DATA_TYPES), (NUM_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + DTData *rhsData = nullptr; + DTPos *resLhsPosExp = nullptr; + + // Empty input data. + SECTION("empty lhsData, empty rhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = DataObjectFactory::create(0, false); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("empty lhsData, non-empty rhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData, empty rhsData") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + rhsData = DataObjectFactory::create(0, false); + resLhsPosExp = DataObjectFactory::create(0, false); + } + + // Non-empty input data. + // - The lhsData is assumed to be a foreign key. + // - The lhsData could be unique or non-unique. + // - The lhsData could be sorted or unsorted. + // - The rhsData is assumed to be a primary key. + // - The rhsData is assumed to be unique. + // - The rhsData could be sorted or unsorted. + // - Matches: the lhsData could contain + // - the values in rhsData (all) + // - a subset of the values in rhsData (subset) + // - a superset of the values in rhsData (superset) + // - no values in rhsData (none) + + // (all) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2), VTData(4.4), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2), VTData(4.4), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2), VTData(1.1), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2), VTData(1.1), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + // (subset) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(1.1)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(1.1)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + // (superset) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4), VTData(5.5)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4), VTData(5.5)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData(5.5), VTData(1.1), VTData(4.4), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData(5.5), VTData(1.1), VTData(4.4), VTData(2.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2), VTData(4.4), VTData(4.4), VTData(5.5)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData(1.1), VTData(1.1), VTData(2.2), VTData(4.4), VTData(4.4), VTData(5.5)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData(5.5), VTData(1.1), VTData(4.4), VTData(2.2), VTData(1.1), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3), VTPos(4), VTPos(5)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData(5.5), VTData(1.1), VTData(4.4), VTData(2.2), VTData(1.1), VTData(4.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3), VTPos(4), VTPos(5)}); + } + // (none) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(22.2), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(22.2), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(44.4), VTData(22.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(44.4), VTData(22.2)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(11.1), VTData(22.2), VTData(44.4), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(11.1), VTData(22.2), VTData(44.4), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(44.4), VTData(222.2), VTData(11.1), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(2.2), VTData(4.4)}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData(11.1), VTData(44.4), VTData(22.2), VTData(11.1), VTData(44.4)}); + rhsData = genGivenVals({VTData(1.1), VTData(4.4), VTData(2.2)}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + + DTPos *resLhsPosFnd = nullptr; + colSemiJoin(resLhsPosFnd, lhsData, rhsData, -1, nullptr); + CHECK(*resLhsPosFnd == *resLhsPosExp); + + DataObjectFactory::destroy(lhsData, rhsData, resLhsPosExp, resLhsPosFnd); +} + +// This is the same as "valid args, numeric data", just with string-valued input data (keep consistent). +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME ": valid args, string data", TAG_KERNELS, (DATA_TYPES), (STR_VALUE_TYPES)) { + using DTData = TestType; + using VTData = typename DTData::VT; + + using VTPos = size_t; + using DTPos = typename TestType::template WithValueType; + + DTData *lhsData = nullptr; + DTData *rhsData = nullptr; + DTPos *resLhsPosExp = nullptr; + + // Empty input data. + SECTION("empty lhsData, empty rhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = DataObjectFactory::create(0, false); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("empty lhsData, non-empty rhsData") { + lhsData = DataObjectFactory::create(0, false); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData, empty rhsData") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + rhsData = DataObjectFactory::create(0, false); + resLhsPosExp = DataObjectFactory::create(0, false); + } + + // Non-empty input data. + // - The lhsData is assumed to be a foreign key. + // - The lhsData could be unique or non-unique. + // - The lhsData could be sorted or unsorted. + // - The rhsData is assumed to be a primary key. + // - The rhsData is assumed to be unique. + // - The rhsData could be sorted or unsorted. + // - Matches: the lhsData could contain + // - the values in rhsData (all) + // - a subset of the values in rhsData (subset) + // - a superset of the values in rhsData (superset) + // - no values in rhsData (none) + + // (all) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals( + {VTData("str1.1"), VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals( + {VTData("str1.1"), VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), all") { + lhsData = genGivenVals( + {VTData("str1.1"), VTData("str4.4"), VTData("str2.2"), VTData("str1.1"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), all") { + lhsData = genGivenVals( + {VTData("str1.1"), VTData("str4.4"), VTData("str2.2"), VTData("str1.1"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + // (subset) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str1.1")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), subset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str1.1")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + // (superset) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), VTData("str5.5")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), VTData("str5.5")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData("str5.5"), VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3)}); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData("str5.5"), VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), + VTData("str4.4"), VTData("str5.5")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData("str1.1"), VTData("str1.1"), VTData("str2.2"), VTData("str4.4"), + VTData("str4.4"), VTData("str5.5")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(0), VTPos(1), VTPos(2), VTPos(3), VTPos(4)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), superset") { + lhsData = genGivenVals({VTData("str5.5"), VTData("str1.1"), VTData("str4.4"), VTData("str2.2"), + VTData("str1.1"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3), VTPos(4), VTPos(5)}); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), superset") { + lhsData = genGivenVals({VTData("str5.5"), VTData("str1.1"), VTData("str4.4"), VTData("str2.2"), + VTData("str1.1"), VTData("str4.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = genGivenVals({VTPos(1), VTPos(2), VTPos(3), VTPos(4), VTPos(5)}); + } + // (none) + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData("str11.1"), VTData("str22.2"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, sorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData("str11.1"), VTData("str22.2"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals({VTData("str11.1"), VTData("str44.4"), VTData("str22.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (unique, unsorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals({VTData("str11.1"), VTData("str44.4"), VTData("str22.2")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals( + {VTData("str11.1"), VTData("str11.1"), VTData("str22.2"), VTData("str44.4"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, sorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals( + {VTData("str11.1"), VTData("str11.1"), VTData("str22.2"), VTData("str44.4"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, sorted), none") { + lhsData = genGivenVals( + {VTData("str11.1"), VTData("str44.4"), VTData("str222.2"), VTData("str11.1"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str2.2"), VTData("str4.4")}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + SECTION("non-empty lhsData (non-unique, unsorted), non-empty rhsData (unique, unsorted), none") { + lhsData = genGivenVals( + {VTData("str11.1"), VTData("str44.4"), VTData("str22.2"), VTData("str11.1"), VTData("str44.4")}); + rhsData = genGivenVals({VTData("str1.1"), VTData("str4.4"), VTData("str2.2")}); + resLhsPosExp = DataObjectFactory::create(0, false); + } + + DTPos *resLhsPosFnd = nullptr; + colSemiJoin(resLhsPosFnd, lhsData, rhsData, -1, nullptr); + CHECK(*resLhsPosFnd == *resLhsPosExp); + + DataObjectFactory::destroy(lhsData, rhsData, resLhsPosExp, resLhsPosFnd); +} + +// The only possible invalid arg would be a too low numRes. We don't test this case here. \ No newline at end of file diff --git a/test/runtime/local/kernels/ConvertBitmapToPosListTest.cpp b/test/runtime/local/kernels/ConvertBitmapToPosListTest.cpp new file mode 100644 index 000000000..6f219e413 --- /dev/null +++ b/test/runtime/local/kernels/ConvertBitmapToPosListTest.cpp @@ -0,0 +1,59 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +TEMPLATE_TEST_CASE("convertBitMapToPosList", TAG_KERNELS, double, int64_t, size_t) { + using VT = TestType; + + DenseMatrix *arg = nullptr; + DenseMatrix *exp = nullptr; + + SECTION("zero-sized bitmap") { + arg = DataObjectFactory::create>(0, 1, false); + exp = DataObjectFactory::create>(0, 1, false); + } + SECTION("all-zero bitmap") { + arg = genGivenVals>(5, {0, 0, 0, 0, 0}); + exp = DataObjectFactory::create>(0, 1, false); + } + SECTION("all-ones bitmap") { + arg = genGivenVals>(5, {1, 1, 1, 1, 1}); + exp = genGivenVals>(5, {0, 1, 2, 3, 4}); + } + SECTION("mixed zeros/ones bitmap, case 1") { + arg = genGivenVals>(10, {0, 0, 0, 0, 1, 0, 1, 1, 0, 0}); + exp = genGivenVals>(3, {4, 6, 7}); + } + SECTION("mixed zeros/ones, case 2") { + arg = genGivenVals>(10, {1, 1, 1, 1, 0, 0, 0, 1, 1, 1}); + exp = genGivenVals>(7, {0, 1, 2, 3, 7, 8, 9}); + } + + DenseMatrix *res = nullptr; + convertBitmapToPosList(res, arg, nullptr); + CHECK(*res == *exp); + + DataObjectFactory::destroy(arg, exp, res); +} \ No newline at end of file diff --git a/test/runtime/local/kernels/ConvertPosListToBitmapTest.cpp b/test/runtime/local/kernels/ConvertPosListToBitmapTest.cpp new file mode 100644 index 000000000..b761cc637 --- /dev/null +++ b/test/runtime/local/kernels/ConvertPosListToBitmapTest.cpp @@ -0,0 +1,52 @@ +/* + * Copyright 2025 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +TEMPLATE_TEST_CASE("convertPosListToBitmap", TAG_KERNELS, double, int64_t, size_t) { + using VT = TestType; + + DenseMatrix *arg = nullptr; + const size_t numRowsRes = 10; + DenseMatrix *exp = nullptr; + + SECTION("empty poslist") { + arg = DataObjectFactory::create>(0, 1, false); + exp = genGivenVals>(10, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + } + SECTION("non-empty poslist, case 1") { + arg = genGivenVals>(5, {0, 3, 7, 8, 9}); + exp = genGivenVals>(10, {1, 0, 0, 1, 0, 0, 0, 1, 1, 1}); + } + SECTION("non-empty poslist, case 2") { + arg = genGivenVals>(5, {2, 3, 4, 5, 6}); + exp = genGivenVals>(10, {0, 0, 1, 1, 1, 1, 1, 0, 0, 0}); + } + + DenseMatrix *res = nullptr; + convertPosListToBitmap(res, arg, numRowsRes, nullptr); + CHECK(*res == *exp); + + DataObjectFactory::destroy(arg, exp, res); +} \ No newline at end of file