diff --git a/src/include/typesr.hpp b/src/include/typesr.hpp index bc5c17361..9cc5839ff 100644 --- a/src/include/typesr.hpp +++ b/src/include/typesr.hpp @@ -21,6 +21,7 @@ enum class RTypeId { LOGICAL, INTEGER, NUMERIC, + COMPLEX, STRING, FACTOR, DATE, @@ -43,12 +44,14 @@ enum class RTypeId { // No RType equivalent BYTE, LIST, + MATRIX, STRUCT, }; struct RType { RType(); RType(RTypeId id); // NOLINT: Allow implicit conversion from `RTypeId` + RType(RTypeId id, R_len_t size); // NOLINT: Allow implicit conversion from `RTypeId` RType(const RType &other); RType(RType &&other) noexcept; @@ -57,12 +60,14 @@ struct RType { // copy assignment inline RType &operator=(const RType &other) { id_ = other.id_; + size_ = other.size_; aux_ = other.aux_; return *this; } // move assignment inline RType &operator=(RType &&other) noexcept { id_ = other.id_; + size_ = other.size_; std::swap(aux_, other.aux_); return *this; } @@ -76,6 +81,7 @@ struct RType { static constexpr const RTypeId LOGICAL = RTypeId::LOGICAL; static constexpr const RTypeId INTEGER = RTypeId::INTEGER; static constexpr const RTypeId NUMERIC = RTypeId::NUMERIC; + static constexpr const RTypeId COMPLEX = RTypeId::COMPLEX; static constexpr const RTypeId STRING = RTypeId::STRING; static constexpr const RTypeId DATE = RTypeId::DATE; static constexpr const RTypeId DATE_INTEGER = RTypeId::DATE_INTEGER; @@ -105,8 +111,13 @@ struct RType { static RType STRUCT(child_list_t &&children); child_list_t GetStructChildTypes() const; + static RType MATRIX(const RType &child, R_len_t ncols); + RType GetMatrixElementType() const; + R_len_t GetMatrixNcols() const; + private: RTypeId id_; + R_len_t size_; child_list_t aux_; }; diff --git a/src/scan.cpp b/src/scan.cpp index d1c68517b..c2583562e 100644 --- a/src/scan.cpp +++ b/src/scan.cpp @@ -7,6 +7,7 @@ using namespace duckdb; using namespace cpp11; +static data_ptr_t GetColDataPtr(const RType &rtype, SEXP coldata) { switch (rtype.id()) { case RType::LOGICAL: @@ -51,6 +52,7 @@ data_ptr_t GetColDataPtr(const RType &rtype, SEXP coldata) { return (data_ptr_t)DATAPTR_RO(coldata); case RTypeId::LIST: return (data_ptr_t)DATAPTR_RO(coldata); + case RTypeId::MATRIX: case RTypeId::STRUCT: // Will bind child columns dynamically. Could also optimize by descending early and recording. return (data_ptr_t)coldata; @@ -83,6 +85,7 @@ static void AppendColumnSegment(SRC *source_data, idx_t sexp_offset, Vector &res } } +static void AppendListColumnSegment(const RType &rtype, SEXP *source_data, idx_t sexp_offset, Vector &result, idx_t count) { source_data += sexp_offset; auto &result_mask = FlatVector::Validity(result); @@ -104,9 +107,84 @@ void AppendListColumnSegment(const RType &rtype, SEXP *source_data, idx_t sexp_o } } +template +static inline +void AppendMatrixSegmentAtomic(SRC *src_ptr, int nrows, int ncols, idx_t sexp_offset, + Vector &child_vector, idx_t count) { + auto child_data = FlatVector::GetData(child_vector); + auto &child_mask = FlatVector::Validity(child_vector); + idx_t vector_idx = 0; + for (idx_t i = 0; i < count; i++) { + auto matrix_elt_idx = sexp_offset + i; + for (idx_t k = 0; k < ncols; k++) { + auto val = src_ptr[matrix_elt_idx]; + if (RTYPE::IsNull(val)) { + child_mask.SetInvalid(vector_idx++); + } else { + child_data[vector_idx++] = RTYPE::Convert(val); + } + matrix_elt_idx += nrows; + } + } +} + +static +void AppendMatrixColumnSegment(const RType &rtype, bool experimental, SEXP source_data, idx_t sexp_offset, Vector &result, idx_t count) { + auto element_rtype = rtype.GetMatrixElementType(); + auto nrows = Rf_nrows(source_data); + auto ncols = Rf_ncols(source_data); + auto &child_vector = ArrayVector::GetEntry(result); + + switch (element_rtype.id()) { + case RType::LOGICAL: //LGLSXP + AppendMatrixSegmentAtomic(LOGICAL_POINTER(source_data), + nrows, ncols, sexp_offset, child_vector, count); + break; + + case RType::INTEGER: //INTSXP + AppendMatrixSegmentAtomic(INTEGER_POINTER(source_data), + nrows, ncols, sexp_offset, child_vector, count); + break; + + case RType::INTEGER64: //REALSXP + AppendMatrixSegmentAtomic((int64_t *)NUMERIC_POINTER(source_data), + nrows, ncols, sexp_offset, child_vector, count); + break; + + case RType::NUMERIC: //REALSXP + AppendMatrixSegmentAtomic(NUMERIC_POINTER(source_data), + nrows, ncols, sexp_offset, child_vector, count); + break; + + case RType::COMPLEX: //CPLXSXP + cpp11::stop("Matrix with complex numbers are not supported."); + break; + + case RTypeId::BYTE: // RAWSXP + cpp11::stop("Matrix of type raw is not supported."); + break; + + case RType::STRING: //STRSXP + if (experimental) { + D_ASSERT(result.GetType().id() == LogicalTypeId::POINTER); + AppendMatrixSegmentAtomic((SEXP *)DATAPTR_RO(source_data), + nrows, ncols, sexp_offset, child_vector, count); + } else { + AppendMatrixSegmentAtomic((SEXP *)DATAPTR_RO(source_data), + nrows, ncols, sexp_offset, child_vector, count); + } + break; + + default: + cpp11::stop("AppendMatrixColumnSegment: Unsupported matrix type for scan"); + } +} + +static void AppendAnyColumnSegment(const RType &rtype, bool experimental, data_ptr_t coldata_ptr, idx_t sexp_offset, Vector &v, idx_t this_count); +static void AppendStructColumnSegment(const RType &rtype, bool experimental, SEXP source_data, idx_t sexp_offset, Vector &result, idx_t count) { // No NULL values for STRUCTs. @@ -120,6 +198,7 @@ void AppendStructColumnSegment(const RType &rtype, bool experimental, SEXP sourc } } +static void AppendAnyColumnSegment(const RType &rtype, bool experimental, data_ptr_t coldata_ptr, idx_t sexp_offset, Vector &v, idx_t this_count) { switch (rtype.id()) { @@ -253,6 +332,11 @@ void AppendAnyColumnSegment(const RType &rtype, bool experimental, data_ptr_t co AppendListColumnSegment(rtype, data_ptr, sexp_offset, v, this_count); break; } + case RTypeId::MATRIX: { + auto data_ptr = (SEXP)coldata_ptr; + AppendMatrixColumnSegment(rtype, experimental, data_ptr, sexp_offset, v, this_count); + break; + } case RTypeId::STRUCT: { auto data_ptr = (SEXP)coldata_ptr; AppendStructColumnSegment(rtype, experimental, data_ptr, sexp_offset, v, this_count); diff --git a/src/types.cpp b/src/types.cpp index 52d17af5f..b5c3a1e05 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -12,13 +12,16 @@ using namespace duckdb; RType::RType() : id_(RTypeId::UNKNOWN) { } -RType::RType(RTypeId id) : id_(id) { +RType::RType(RTypeId id) : id_(id), size_(0) { } -RType::RType(const RType &other) : id_(other.id_), aux_(other.aux_) { +RType::RType(RTypeId id, R_len_t size) : id_(id), size_(size) { } -RType::RType(RType &&other) noexcept : id_(other.id_), aux_(std::move(other.aux_)) { +RType::RType(const RType &other) : id_(other.id_), size_(other.size_), aux_(other.aux_) { +} + +RType::RType(RType &&other) noexcept : id_(other.id_), size_(other.size_), aux_(std::move(other.aux_)) { } RTypeId RType::id() const { @@ -26,7 +29,7 @@ RTypeId RType::id() const { } bool RType::operator==(const RType &rhs) const { - return id_ == rhs.id_ && aux_ == rhs.aux_; + return id_ == rhs.id_ && size_ == rhs.size_ && aux_ == rhs.aux_; } RType RType::FACTOR(cpp11::strings levels) { @@ -74,6 +77,22 @@ RType RType::GetListChildType() const { return aux_.front().second; } +RType RType::MATRIX(const RType &child, R_len_t ncols) { + RType out = RType(RTypeId::MATRIX, ncols); + out.aux_.push_back(std::make_pair("", child)); + return out; +} + +RType RType::GetMatrixElementType() const { + D_ASSERT(id_ == RTypeId::MATRIX); + return aux_.front().second; +} + +R_len_t RType::GetMatrixNcols() const { + D_ASSERT(id_ == RTypeId::MATRIX); + return size_; +} + RType RType::STRUCT(child_list_t &&children) { RType out = RType(RTypeId::STRUCT); std::swap(out.aux_, children); @@ -132,6 +151,23 @@ RType RApiTypes::DetectRType(SEXP v, bool integer64) { } } else if (Rf_isFactor(v) && TYPEOF(v) == INTSXP) { return RType::FACTOR(GET_LEVELS(v)); + } else if (Rf_isMatrix(v)) { + if (TYPEOF(v) == LGLSXP) { + return RType::MATRIX(RType::LOGICAL, Rf_ncols(v)); + } else if (TYPEOF(v) == INTSXP) { + return RType::MATRIX(RType::INTEGER, Rf_ncols(v)); + } else if (TYPEOF(v) == REALSXP) { + if (integer64 && Rf_inherits(v, "integer64")) { + return RType::MATRIX(RType::INTEGER64, Rf_ncols(v)); + } + return RType::MATRIX(RType::NUMERIC, Rf_ncols(v)); + } else if (TYPEOF(v) == CPLXSXP) { + return RType::MATRIX(RType::COMPLEX, Rf_ncols(v)); + } else if (TYPEOF(v) == STRSXP) { + return RType::MATRIX(RType::STRING, Rf_ncols(v)); + } else { + return RType::UNKNOWN; + } } else if (TYPEOF(v) == LGLSXP) { return RType::LOGICAL; } else if (TYPEOF(v) == INTSXP) { @@ -145,6 +181,8 @@ RType RApiTypes::DetectRType(SEXP v, bool integer64) { return RType::NUMERIC; } else if (TYPEOF(v) == STRSXP) { return RType::STRING; + } else if (TYPEOF(v) == CPLXSXP) { + return RType::COMPLEX; } else if (TYPEOF(v) == VECSXP) { if (Rf_inherits(v, "blob")) { return RType::BLOB; @@ -211,6 +249,8 @@ LogicalType RApiTypes::LogicalTypeFromRType(const RType &rtype, bool experimenta return LogicalType::DOUBLE; case RType::INTEGER64: return LogicalType::BIGINT; + case RType::COMPLEX: + return LogicalType::ARRAY(LogicalType::DOUBLE, 2); case RTypeId::FACTOR: { auto duckdb_levels = rtype.GetFactorLevels(); return LogicalType::ENUM(duckdb_levels, rtype.GetFactorLevelsCount()); @@ -244,6 +284,8 @@ LogicalType RApiTypes::LogicalTypeFromRType(const RType &rtype, bool experimenta return LogicalType::BLOB; case RTypeId::LIST: return LogicalType::LIST(RApiTypes::LogicalTypeFromRType(rtype.GetListChildType(), experimental)); + case RTypeId::MATRIX: + return LogicalType::ARRAY(RApiTypes::LogicalTypeFromRType(rtype.GetMatrixElementType(), experimental), rtype.GetMatrixNcols()); case RTypeId::STRUCT: { child_list_t children; for (const auto &child : rtype.GetStructChildTypes()) { diff --git a/src/utils.cpp b/src/utils.cpp index 493a65de0..608fcf718 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -116,6 +116,9 @@ R_len_t RApiTypes::GetVecSize(RType rtype, SEXP coldata) { D_ASSERT(TYPEOF(coldata) == VECSXP); coldata = VECTOR_ELT(coldata, 0); } + if (rtype.id() == RTypeId::MATRIX) { + return Rf_nrows(coldata); + } // This still isn't quite accurate, but good enough for the types we support. return Rf_length(coldata); } diff --git a/tests/testthat/_snaps/array.md b/tests/testthat/_snaps/array.md index 0ab36db40..1aaa94dfb 100644 --- a/tests/testthat/_snaps/array.md +++ b/tests/testthat/_snaps/array.md @@ -22,3 +22,14 @@ Error in `duckdb_result()`: ! Use `dbConnect(array = "matrix")` to enable arrays to be returned to R. +# array errors when writing matrix of complex numbers + + Code + dbWriteTable(con, "tbl", df) + Condition + Error in `duckdb_result()`: + ! Matrix with complex numbers are not supported. + Error in `duckdb_result()`: + ! rapi_execute: Failed to run query + Error: Invalid Error: std::exception + diff --git a/tests/testthat/test-array.R b/tests/testthat/test-array.R index 2ccddf1ce..2a873cee9 100644 --- a/tests/testthat/test-array.R +++ b/tests/testthat/test-array.R @@ -12,8 +12,8 @@ test_that("arrays of INTEGER can be read", { df <- dbGetQuery(con, "FROM tbl") - a = c(10, 11, 12, 13) - b = matrix(1:12, nrow = 4, ncol = 3) + a <- c(10, 11, 12, 13) + b <- matrix(1:12, nrow = 4, ncol = 3) expect_equal(df$a, a) expect_equal(df$b, b) @@ -32,8 +32,8 @@ test_that("arrays of INTEGER with NULL can be read", { df <- dbGetQuery(con, "FROM tbl") - a = c(10, 11, 12, 13) - b = matrix(c(1, NA, 3, 4, 5, 6, 7, 8, 9, 10, NA, 12), nrow = 4, ncol = 3) + a <- c(10, 11, 12, 13) + b <- matrix(c(1, NA, 3, 4, 5, 6, 7, 8, 9, 10, NA, 12), nrow = 4, ncol = 3) expect_equal(df$a, a) expect_equal(df$b, b) @@ -52,8 +52,8 @@ test_that("arrays of DOUBLE can be read", { df <- dbGetQuery(con, "FROM tbl") - a = c(10, 11, 12, 13) - b = matrix(as.double(1:12), nrow = 4, ncol = 3) + a <- c(10, 11, 12, 13) + b <- matrix(as.double(1:12), nrow = 4, ncol = 3) expect_equal(df$a, a) expect_equal(df$b, b) @@ -72,8 +72,8 @@ test_that("arrays of DOUBLE with NULL can be read", { df <- dbGetQuery(con, "FROM tbl") - a = c(10, 11, 12, 13) - b = matrix(as.double(c(1, 2, 3, 4, 5, 6, 7, NA, 9, NA, 11, 12)), nrow = 4, ncol = 3) + a <- c(10, 11, 12, 13) + b <- matrix(as.double(c(1, 2, 3, 4, 5, 6, 7, NA, 9, NA, 11, 12)), nrow = 4, ncol = 3) expect_equal(df$a, a) expect_equal(df$b, b) @@ -92,8 +92,8 @@ test_that("arrays of BOOELAN can be read", { df <- dbGetQuery(con, "FROM tbl") - a = c(10, 11, 12, 13) - b = matrix( c(T, F, T, F, F, T, T, F, T, T, F, F) , nrow = 4, ncol = 3 ) + a <- c(10, 11, 12, 13) + b <- matrix(c(T, F, T, F, F, T, T, F, T, T, F, F) , nrow = 4, ncol = 3) expect_equal(df$a, a) expect_equal(df$b, b) @@ -112,8 +112,8 @@ test_that("arrays of BOOELAN with NULL can be read", { df <- dbGetQuery(con, "FROM tbl") - a = c(10, 11, 12, 13) - b = matrix( c(T, F, NA, F, NA, T, T, F, T, T, F, F) , nrow = 4, ncol = 3 ) + a <- c(10, 11, 12, 13) + b <- matrix(c(T, F, NA, F, NA, T, T, F, T, T, F, F) , nrow = 4, ncol = 3) expect_equal(df$a, a) expect_equal(df$b, b) @@ -132,8 +132,8 @@ test_that("arrays of INTEGER in struct column can be read", { df <- dbGetQuery(con, "FROM tbl") - a = c(10, 11, 12, 13) - b = matrix(1:12, nrow = 4, ncol = 3) + a <- c(10, 11, 12, 13) + b <- matrix(1:12, nrow = 4, ncol = 3) expect_equal(df$s$a, a) expect_equal(df$s$b, b) @@ -152,8 +152,8 @@ test_that("arrays of DOUBLE in struct column can be read", { df <- dbGetQuery(con, "FROM tbl") - a = c(10, 11, 12, 13) - b = matrix(as.double(1:12), nrow = 4, ncol = 3) + a <- c(10, 11, 12, 13) + b <- matrix(as.double(1:12), nrow = 4, ncol = 3) expect_equal(df$s$a, a) expect_equal(df$s$b, b) @@ -172,8 +172,8 @@ test_that("arrays of BOOLEAN in struct column can be read", { df <- dbGetQuery(con, "FROM tbl") - a = c(10, 11, 12, 13) - b = matrix( c(T, F, T, F, F, T, T, F, T, T, F, F) , nrow = 4, ncol = 3 ) + a <- c(10, 11, 12, 13) + b <- matrix( c(T, F, T, F, F, T, T, F, T, T, F, F) , nrow = 4, ncol = 3 ) expect_equal(df$s$a, a) expect_equal(df$s$b, b) @@ -224,3 +224,170 @@ test_that("array errors with default convert option array", { dbGetQuery(con, "FROM tbl") }) }) + + +test_that("Single array of INTEGER can be written", { + skip_if_not_installed("dplyr") + + con <- dbConnect(duckdb(), array = "matrix") + on.exit(dbDisconnect(con, shutdown = TRUE)) + + a <- matrix(1:12, nrow = 4, ncol = 3) + dbWriteTable(con, "tbl", dplyr::tibble(a)) + + df <- dbGetQuery(con, "FROM tbl") + + expect_equal(df$a, a) +}) + + +test_that("arrays of INTEGER can be written", { + skip_if_not_installed("dplyr") + + con <- dbConnect(duckdb(), array = "matrix") + on.exit(dbDisconnect(con, shutdown = TRUE)) + + a <- c(10, 11, 12, 13) + b <- matrix(1:12, nrow = 4, ncol = 3) + dbWriteTable(con, "tbl", dplyr::tibble(a, b)) + + df <- dbGetQuery(con, "FROM tbl") + + expect_equal(df$a, a) + expect_equal(df$b, b) +}) + + +test_that("arrays of INTEGER with NULL can be written", { + skip_if_not_installed("dplyr") + + con <- dbConnect(duckdb(), array = "matrix") + on.exit(dbDisconnect(con, shutdown = TRUE)) + + a <- c(10, 11, 12, 13) + b <- matrix(c(1, NA, 3, 4, 5, 6, 7, 8, 9, 10, NA, 12), nrow = 4, ncol = 3) + dbWriteTable(con, "tbl", dplyr::tibble(a, b)) + + df <- dbGetQuery(con, "FROM tbl") + + expect_equal(df$a, a) + expect_equal(df$b, b) +}) + + +test_that("arrays of DOUBLE can be written", { + skip_if_not_installed("dplyr") + + con <- dbConnect(duckdb(), array = "matrix") + on.exit(dbDisconnect(con, shutdown = TRUE)) + + a <- c(10, 11, 12, 13) + b <- matrix(as.double(1:12), nrow = 4, ncol = 3) + dbWriteTable(con, "tbl", dplyr::tibble(a, b)) + + df <- dbGetQuery(con, "FROM tbl") + + expect_equal(df$a, a) + expect_equal(df$b, b) +}) + + +test_that("arrays of DOUBLE with NULL can be written", { + skip_if_not_installed("dplyr") + + con <- dbConnect(duckdb(), array = "matrix") + on.exit(dbDisconnect(con, shutdown = TRUE)) + + a <- c(10, 11, 12, 13) + b <- matrix(as.double(c(1, 2, 3, 4, 5, 6, 7, NA, 9, NA, 11, 12)), nrow = 4, ncol = 3) + dbWriteTable(con, "tbl", dplyr::tibble(a, b)) + + df <- dbGetQuery(con, "FROM tbl") + + expect_equal(df$a, a) + expect_equal(df$b, b) +}) + + +test_that("arrays of BOOLEAN can be written", { + skip_if_not_installed("dplyr") + + con <- dbConnect(duckdb(), array = "matrix") + on.exit(dbDisconnect(con, shutdown = TRUE)) + + a <- c(10, 11, 12, 13) + b <- matrix(c(T, F, T, F, F, T, T, F, T, T, F, F) , nrow = 4, ncol = 3) + dbWriteTable(con, "tbl", dplyr::tibble(a, b)) + + df <- dbGetQuery(con, "FROM tbl") + + expect_equal(df$a, a) + expect_equal(df$b, b) +}) + + +test_that("arrays of BOOLEAN with NULL can be written", { + skip_if_not_installed("dplyr") + + con <- dbConnect(duckdb(), array = "matrix") + on.exit(dbDisconnect(con, shutdown = TRUE)) + + a <- c(10, 11, 12, 13) + b <- matrix(c(T, F, NA, F, NA, T, T, F, T, T, F, F) , nrow = 4, ncol = 3) + dbWriteTable(con, "tbl", dplyr::tibble(a, b)) + + df <- dbGetQuery(con, "FROM tbl") + + expect_equal(df$a, a) + expect_equal(df$b, b) +}) + + +test_that("arrays of STRING can be written", { + skip_if_not_installed("dplyr") + + con <- dbConnect(duckdb(), array = "matrix") + on.exit(dbDisconnect(con, shutdown = TRUE)) + + a <- c(10, 11, 12, 13) + b <- matrix(c("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l") , nrow = 4, ncol = 3) + dbWriteTable(con, "tbl", dplyr::tibble(a, b)) + + df <- dbGetQuery(con, "FROM tbl") + + expect_equal(df$a, a) + expect_equal(df$b, b) +}) + + +test_that("arrays of STRING with NULL can be written", { + skip_if_not_installed("dplyr") + + con <- dbConnect(duckdb(), array = "matrix") + on.exit(dbDisconnect(con, shutdown = TRUE)) + + a <- c(10, 11, 12, 13) + b <- matrix(c("a", "b", "c", "d", "e", NA, "g", "h", NA, "j", "k", "l") , nrow = 4, ncol = 3) + dbWriteTable(con, "tbl", dplyr::tibble(a, b)) + + df <- dbGetQuery(con, "FROM tbl") + + expect_equal(df$a, a) + expect_equal(df$b, b) +}) + + +test_that("array errors when writing matrix of complex numbers", { + skip_if_not_installed("dplyr") + + con <- dbConnect(duckdb(), array = "matrix") + on.exit(dbDisconnect(con, shutdown = TRUE)) + + a <- c(10, 11, 12, 13) + b <- matrix(1+1i , nrow = 4, ncol = 3) + df <- dplyr::tibble(a, b) + + expect_snapshot( error = TRUE, { + dbWriteTable(con, "tbl", df) + }) +})