diff --git a/r/adbcdrivermanager/DESCRIPTION b/r/adbcdrivermanager/DESCRIPTION index cbd53577ac..4810c28c6a 100644 --- a/r/adbcdrivermanager/DESCRIPTION +++ b/r/adbcdrivermanager/DESCRIPTION @@ -17,6 +17,8 @@ Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.2.3 Suggests: + later, + promises, testthat (>= 3.0.0), withr Config/testthat/edition: 3 diff --git a/r/adbcdrivermanager/NAMESPACE b/r/adbcdrivermanager/NAMESPACE index 5c671a2b3d..a37824acb9 100644 --- a/r/adbcdrivermanager/NAMESPACE +++ b/r/adbcdrivermanager/NAMESPACE @@ -1,11 +1,20 @@ # Generated by roxygen2: do not edit by hand +S3method("$",adbc_async_task) S3method("$",adbc_error) S3method("$",adbc_xptr) S3method("$<-",adbc_xptr) +S3method("[[",adbc_async_task) S3method("[[",adbc_error) S3method("[[",adbc_xptr) S3method("[[<-",adbc_xptr) +S3method(adbc_async_task_cancel,adbc_async_statement_cancellable) +S3method(adbc_async_task_cancel,default) +S3method(adbc_async_task_result,adbc_async_execute_query) +S3method(adbc_async_task_result,adbc_async_prepare) +S3method(adbc_async_task_result,adbc_async_sleep) +S3method(adbc_async_task_result,adbc_async_statement_stream_get_next) +S3method(adbc_async_task_result,adbc_async_statement_stream_schema) S3method(adbc_connection_init,adbc_database_log) S3method(adbc_connection_init,adbc_database_monkey) S3method(adbc_connection_init,default) @@ -21,6 +30,7 @@ S3method(execute_adbc,default) S3method(format,adbc_xptr) S3method(length,adbc_error) S3method(length,adbc_xptr) +S3method(names,adbc_async_task) S3method(names,adbc_error) S3method(names,adbc_xptr) S3method(print,adbc_driver) diff --git a/r/adbcdrivermanager/R/async.R b/r/adbcdrivermanager/R/async.R new file mode 100644 index 0000000000..78bbe60eb5 --- /dev/null +++ b/r/adbcdrivermanager/R/async.R @@ -0,0 +1,261 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +adbc_async_task <- function(subclass = character()) { + structure( + .Call(RAdbcAsyncTaskNew, adbc_allocate_error()), + class = union(subclass, "adbc_async_task") + ) +} + +adbc_async_task_status <- function(task) { + .Call(RAdbcAsyncTaskWaitFor, task, 0) +} + +adbc_async_task_set_callback <- function(task, resolve, reject = NULL, + loop = later::current_loop()) { + # If the task is completed, run the callback (or else the callback + # will not run) + if (adbc_async_task_status(task) == "ready") { + adbc_async_task_run_callback(task, resolve, reject) + } else { + .Call(RAdbcAsyncTaskSetCallback, task, resolve, reject, loop$id) + } + + invisible(task) +} + +adbc_async_task_run_callback <- function(task, resolve = task$resolve, + reject = task$reject) { + tryCatch({ + result <- adbc_async_task_result(task) + resolve(result) + }, + error = function(e) { + if (is.null(reject)) { + stop(e) + } else { + reject(e) + } + } + ) + + invisible(task) +} + +adbc_async_task_wait_non_cancellable <- function(task, resolution = 0.1) { + .Call(RAdbcAsyncTaskWaitFor, task, round(resolution * 1000)) +} + +adbc_async_task_wait <- function(task, resolution = 0.1) { + withCallingHandlers( + status <- .Call(RAdbcAsyncTaskWait, task, round(resolution * 1000)), + interrupt = function(e) { + adbc_async_task_cancel(task) + } + ) + + if (status != "ready") { + stop(sprintf("Expected status ready but got %s", status)) + } + + adbc_async_task_result(task) +} + +as.promise.adbc_async_task <- function(task) { + force(task) + promises::promise(function(resolve, reject) { + adbc_async_task_set_callback(task, resolve, reject) + }) +} + +adbc_async_task_cancel <- function(task) { + UseMethod("adbc_async_task_cancel") +} + +#' @export +adbc_async_task_cancel.default <- function(task) { + FALSE +} + +adbc_async_task_result <- function(task) { + UseMethod("adbc_async_task_result") +} + +#' @export +names.adbc_async_task <- function(x) { + names(.Call(RAdbcAsyncTaskData, x)) +} + +#' @export +`[[.adbc_async_task` <- function(x, i) { + .Call(RAdbcAsyncTaskData, x)[[i]] +} + +#' @export +`$.adbc_async_task` <- function(x, name) { + .Call(RAdbcAsyncTaskData, x)[[name]] +} + +adbc_async_sleep <- function(duration_ms, error_message = NULL) { + task <- adbc_async_task("adbc_async_sleep") + .Call(RAdbcAsyncTaskLaunchSleep, task, duration_ms) + + user_data <- task$user_data + user_data$duration_ms <- duration_ms + user_data$error_message <- error_message + + task +} + +#' @export +adbc_async_task_result.adbc_async_sleep <- function(task) { + if (!is.null(task$user_data$error_message)) { + cnd <- simpleError(task$user_data$error_message) + class(cnd) <- c("adbc_async_sleep_error", class(cnd)) + stop(cnd) + } + + task$user_data$duration_ms +} + +#' @export +adbc_async_task_cancel.adbc_async_statement_cancellable <- function(task) { + adbc_statement_cancel(task$user_data$statement) + TRUE +} + +adbc_statement_prepare_async <- function(statement) { + task <- adbc_async_task( + c("adbc_async_prepare", "adbc_async_statement_cancellable") + ) + + user_data <- task$user_data + user_data$statement <- statement + .Call(RAdbcAsyncTaskLaunchPrepare, task, statement) + + task +} + +#' @export +adbc_async_task_result.adbc_async_prepare <- function(task) { + if (!identical(task$return_code, 0L)) { + stop_for_error(task$return_code, task$error_xptr) + } + + task$user_data$statement +} + +adbc_statement_execute_query_async <- function(statement, stream = NULL) { + task <- adbc_async_task( + c("adbc_async_execute_query", "adbc_async_statement_cancellable") + ) + + user_data <- task$user_data + user_data$statement <- statement + user_data$stream <- stream + + user_data$rows_affected <- .Call( + RAdbcAsyncTaskLaunchExecuteQuery, + task, + statement, + stream + ) + + task +} + +#' @export +adbc_async_task_result.adbc_async_execute_query <- function(task) { + if (!identical(task$return_code, 0L)) { + stop_for_error(task$return_code, task$error_xptr) + } + + list( + statement = task$user_data$statement, + stream = task$user_data$stream, + rows_affected = task$user_data$rows_affected + ) +} + +adbc_statement_stream_get_schema_async <- function(statement, stream) { + task <- adbc_async_task( + c("adbc_async_statement_stream_get_next", "adbc_async_statement_cancellable") + ) + + user_data <- task$user_data + user_data$statement <- statement + user_data$stream <- stream + user_data$schema <- nanoarrow::nanoarrow_allocate_schema() + + user_data$rows_affected <- .Call( + RAdbcAsyncTaskLaunchStreamGetSchema, + task, + stream, + user_data$schema + ) + + task +} + + +#' @export +adbc_async_task_result.adbc_async_statement_stream_schema <- function(task) { + if (!identical(task$return_code, 0L)) { + adbc_statement_release(task$user_data$statement) + stop(task$user_data$stream$get_last_error()) + } + + list( + statement = task$user_data$statement, + array = task$user_data$schema + ) +} + +adbc_statement_stream_get_next_async <- function(statement, stream) { + task <- adbc_async_task( + c("adbc_async_statement_stream_get_next", "adbc_async_statement_cancellable") + ) + + user_data <- task$user_data + user_data$statement <- statement + user_data$stream <- stream + user_data$array <- nanoarrow::nanoarrow_allocate_array() + + user_data$rows_affected <- .Call( + RAdbcAsyncTaskLaunchStreamGetNext, + task, + stream, + user_data$array + ) + + task +} + +#' @export +adbc_async_task_result.adbc_async_statement_stream_get_next <- function(task) { + if (!identical(task$return_code, 0L)) { + adbc_statement_release(task$user_data$statement) + stop(task$user_data$stream$get_last_error()) + } + + list( + statement = task$user_data$statement, + array = task$user_data$array + ) +} diff --git a/r/adbcdrivermanager/R/error.R b/r/adbcdrivermanager/R/error.R index a7c3eeb8d8..397008edf3 100644 --- a/r/adbcdrivermanager/R/error.R +++ b/r/adbcdrivermanager/R/error.R @@ -52,7 +52,12 @@ adbc_allocate_error <- function(shelter = NULL, use_legacy_error = NULL) { stop_for_error <- function(status, error) { if (!identical(status, 0L)) { - error <- .Call(RAdbcErrorProxy, error) + if (inherits(error, "adbc_error")) { + error <- .Call(RAdbcErrorProxy, error) + } else { + error <- list() + } + error$status <- status error$status_code_message <- .Call(RAdbcStatusCodeMessage, status) if (!is.null(error$message)) { @@ -79,6 +84,22 @@ stop_for_error <- function(status, error) { } } +adbc_error_message <- function(status, error) { + if (!identical(status, 0L)) { + if (inherits(error, "adbc_error")) { + error <- .Call(RAdbcErrorProxy, error) + } else { + error <- list() + } + + error$status <- status + error$status_code_message <- .Call(RAdbcStatusCodeMessage, status) + if (!is.null(error$message)) error$message else error$status_code_message + } else { + "OK" + } +} + #' @export print.adbc_error <- function(x, ...) { str(x, ...) diff --git a/r/adbcdrivermanager/R/zzz.R b/r/adbcdrivermanager/R/zzz.R new file mode 100644 index 0000000000..e5ba51804b --- /dev/null +++ b/r/adbcdrivermanager/R/zzz.R @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +.onLoad <- function(...) { + s3_register("promises::as.promise", "adbc_async_task") +} + +# From the `vctrs` package (this function is intended to be copied +# without attribution or license requirements to avoid a hard dependency on +# vctrs: +# https://github.com/r-lib/vctrs/blob/c2a7710fe55e3a2249c4fdfe75bbccbafcf38804/R/register-s3.R#L25-L31 +s3_register <- function(generic, class, method = NULL) { + stopifnot(is.character(generic), length(generic) == 1) + stopifnot(is.character(class), length(class) == 1) + + pieces <- strsplit(generic, "::")[[1]] + stopifnot(length(pieces) == 2) + package <- pieces[[1]] + generic <- pieces[[2]] + + caller <- parent.frame() + + get_method_env <- function() { + top <- topenv(caller) + if (isNamespace(top)) { + asNamespace(environmentName(top)) + } else { + caller + } + } + get_method <- function(method, env) { + if (is.null(method)) { + get(paste0(generic, ".", class), envir = get_method_env()) + } else { + method + } + } + + register <- function(...) { + envir <- asNamespace(package) + + # Refresh the method each time, it might have been updated by + # `devtools::load_all()` + method_fn <- get_method(method) + stopifnot(is.function(method_fn)) + + + # Only register if generic can be accessed + if (exists(generic, envir)) { + registerS3method(generic, class, method_fn, envir = envir) + } else if (identical(Sys.getenv("NOT_CRAN"), "true")) { + warning(sprintf( + "Can't find generic `%s` in package %s to register S3 method.", + generic, + package + )) + } + } + + # Always register hook in case package is later unloaded & reloaded + setHook(packageEvent(package, "onLoad"), register) + + # Avoid registration failures during loading (pkgload or regular) + if (isNamespaceLoaded(package)) { + register() + } + + invisible() +} diff --git a/r/adbcdrivermanager/src/async.cc b/r/adbcdrivermanager/src/async.cc new file mode 100644 index 0000000000..6ef130192b --- /dev/null +++ b/r/adbcdrivermanager/src/async.cc @@ -0,0 +1,308 @@ + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#define R_NO_REMAP +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "radbc.h" + +typedef void (*ExcecLaterNativeFn)(void (*func)(void*), void*, double, int); + +static ExcecLaterNativeFn later_execLaterNative2 = NULL; + +static inline void later_ensure_initialized() { + later_execLaterNative2 = + (ExcecLaterNativeFn)R_GetCCallable("later", "execLaterNative2"); +} + +static void later_task_callback_wrapper(void* data); + +struct ArrowArrayCustomDeleter { + void operator()(ArrowArray* array) const { + if (array->release != nullptr) { + array->release(array); + } + delete array; + } +}; + +using UniqueArrowArrayPtr = std::unique_ptr; + +enum class RAdbcAsyncTaskStatus { NOT_STARTED, STARTED, READY }; + +struct RAdbcAsyncTask { + RAdbcAsyncTask() : callback_data_sexp(R_NilValue) {} + + void SetCallback(SEXP data, int loop_id) { + callback_data_sexp = data; + later_loop_id = loop_id; + later_ensure_initialized(); + } + + void ScheduleCallbackIfSet() { + if (callback_data_sexp != R_NilValue) { + later_execLaterNative2(&later_task_callback_wrapper, this, 0, later_loop_id); + } + } + + AdbcError* return_error{nullptr}; + int* return_code{nullptr}; + + SEXP callback_data_sexp; + int later_loop_id{-1}; + + RAdbcAsyncTaskStatus status{RAdbcAsyncTaskStatus::NOT_STARTED}; + std::future result; +}; + +static void later_task_callback_wrapper(void* data) { + auto task = reinterpret_cast(data); + + SEXP func_sym = PROTECT(Rf_install("adbc_async_task_run_callback")); + SEXP func_call = PROTECT(Rf_lang2(func_sym, task->callback_data_sexp)); + SEXP pkg_chr = PROTECT(Rf_mkString("adbcdrivermanager")); + SEXP pkg_ns = PROTECT(R_FindNamespace(pkg_chr)); + Rf_eval(func_call, pkg_ns); + UNPROTECT(4); +} + +template <> +inline const char* adbc_xptr_class() { + return "adbc_async_task"; +} + +static void FinalizeTaskXptr(SEXP task_xptr) { + auto task = reinterpret_cast(R_ExternalPtrAddr(task_xptr)); + if (task != nullptr) { + delete task; + } +} + +static void error_for_started_task(RAdbcAsyncTask* task) { + if (task->result.valid()) { + Rf_error("adbc_async_task is already in use"); + } +} + +extern "C" SEXP RAdbcAsyncTaskNew(SEXP error_xptr) { + const char* names[] = {"error_xptr", "return_code", "user_data", + "resolve", "reject", ""}; + SEXP task_prot = PROTECT(Rf_mkNamed(VECSXP, names)); + + SET_VECTOR_ELT(task_prot, 0, error_xptr); + + SEXP return_code_sexp = PROTECT(Rf_allocVector(INTSXP, 1)); + SET_VECTOR_ELT(task_prot, 1, return_code_sexp); + UNPROTECT(1); + + SEXP new_env = PROTECT(adbc_new_env()); + SET_VECTOR_ELT(task_prot, 2, new_env); + UNPROTECT(1); + + auto task = new RAdbcAsyncTask(); + SEXP task_xptr = PROTECT(R_MakeExternalPtr(task, R_NilValue, task_prot)); + R_RegisterCFinalizer(task_xptr, &FinalizeTaskXptr); + + task->return_error = adbc_from_xptr(error_xptr); + task->return_code = INTEGER(VECTOR_ELT(task_prot, 1)); + + *(task->return_code) = NA_INTEGER; + + UNPROTECT(2); + return task_xptr; +} + +extern "C" SEXP RAdbcAsyncTaskSetCallback(SEXP task_xptr, SEXP callback_resolve_sexp, + SEXP callback_reject_sexp, SEXP loop_id_sexp) { + auto task = adbc_from_xptr(task_xptr); + SEXP task_prot = R_ExternalPtrProtected(task_xptr); + int loop_id = adbc_as_int(loop_id_sexp); + + SET_VECTOR_ELT(task_prot, 3, callback_resolve_sexp); + SET_VECTOR_ELT(task_prot, 4, callback_reject_sexp); + task->SetCallback(task_xptr, loop_id); + return R_NilValue; +} + +extern "C" SEXP RAdbcAsyncTaskData(SEXP task_xptr) { + adbc_from_xptr(task_xptr); + return R_ExternalPtrProtected(task_xptr); +} + +extern "C" SEXP RAdbcAsyncTaskWaitFor(SEXP task_xptr, SEXP duration_ms_sexp) { + auto task = adbc_from_xptr(task_xptr); + int duration_ms = adbc_as_int(duration_ms_sexp); + + if (duration_ms < 0) { + Rf_error("duration_ms must be >= 0"); + } + + switch (task->status) { + case RAdbcAsyncTaskStatus::NOT_STARTED: + return Rf_mkString("not_started"); + case RAdbcAsyncTaskStatus::READY: + return Rf_mkString("ready"); + default: + break; + } + + std::future_status status = + task->result.wait_for(std::chrono::milliseconds(duration_ms)); + switch (status) { + case std::future_status::timeout: + return Rf_mkString("started"); + case std::future_status::ready: + task->status = RAdbcAsyncTaskStatus::READY; + return Rf_mkString("ready"); + default: + Rf_error("Unknown status returned from future::wait_for()"); + } +} + +extern "C" SEXP RAdbcAsyncTaskWait(SEXP task_xptr, SEXP resolution_ms_sexp) { + auto task = adbc_from_xptr(task_xptr); + int resolution_ms = adbc_as_int(resolution_ms_sexp); + + switch (task->status) { + case RAdbcAsyncTaskStatus::NOT_STARTED: + return Rf_mkString("not_started"); + case RAdbcAsyncTaskStatus::READY: + return Rf_mkString("ready"); + default: + break; + } + + std::future_status status; + do { + status = task->result.wait_for(std::chrono::milliseconds(resolution_ms)); + R_CheckUserInterrupt(); + } while (status == std::future_status::timeout); + + switch (status) { + case std::future_status::ready: + task->status = RAdbcAsyncTaskStatus::READY; + return Rf_mkString("ready"); + default: + Rf_error("Unknown status returned from future::wait_for()"); + } +} + +extern "C" SEXP RAdbcAsyncTaskLaunchSleep(SEXP task_xptr, SEXP duration_ms_sexp) { + auto task = adbc_from_xptr(task_xptr); + error_for_started_task(task); + + int duration_ms = adbc_as_int(duration_ms_sexp); + + task->result = std::async(std::launch::async, [task, duration_ms] { + std::this_thread::sleep_for(std::chrono::milliseconds(duration_ms)); + *(task->return_code) = ADBC_STATUS_OK; + task->ScheduleCallbackIfSet(); + }); + + task->status = RAdbcAsyncTaskStatus::STARTED; + return R_NilValue; +} + +extern "C" SEXP RAdbcAsyncTaskLaunchPrepare(SEXP task_xptr, SEXP statement_xptr) { + auto task = adbc_from_xptr(task_xptr); + error_for_started_task(task); + + auto statement = adbc_from_xptr(statement_xptr); + + task->result = std::async(std::launch::async, [task, statement] { + *(task->return_code) = AdbcStatementPrepare(statement, task->return_error); + task->ScheduleCallbackIfSet(); + }); + + task->status = RAdbcAsyncTaskStatus::STARTED; + UNPROTECT(1); + return R_NilValue; +} + +extern "C" SEXP RAdbcAsyncTaskLaunchExecuteQuery(SEXP task_xptr, SEXP statement_xptr, + SEXP stream_xptr) { + auto task = adbc_from_xptr(task_xptr); + error_for_started_task(task); + + auto statement = adbc_from_xptr(statement_xptr); + ArrowArrayStream* stream = nullptr; + if (stream_xptr != R_NilValue) { + stream = adbc_from_xptr(stream_xptr); + } + + SEXP rows_affected_sexp = PROTECT(Rf_allocVector(REALSXP, 1)); + double* rows_affected_dbl = REAL(rows_affected_sexp); + + task->result = + std::async(std::launch::async, [task, statement, stream, rows_affected_dbl] { + int64_t rows_affected = -1; + *(task->return_code) = AdbcStatementExecuteQuery( + statement, stream, &rows_affected, task->return_error); + *rows_affected_dbl = static_cast(rows_affected); + task->ScheduleCallbackIfSet(); + }); + + task->status = RAdbcAsyncTaskStatus::STARTED; + UNPROTECT(1); + return rows_affected_sexp; +} + +extern "C" SEXP RAdbcAsyncTaskLaunchStreamGetSchema(SEXP task_xptr, SEXP stream_xptr, + SEXP schema_xptr) { + auto task = adbc_from_xptr(task_xptr); + error_for_started_task(task); + + auto stream = adbc_from_xptr(stream_xptr); + auto schema = adbc_from_xptr(schema_xptr); + + task->result = std::async(std::launch::async, [task, stream, schema] { + *(task->return_code) = stream->get_schema(stream, schema); + task->ScheduleCallbackIfSet(); + }); + + task->status = RAdbcAsyncTaskStatus::STARTED; + return R_NilValue; +} + +extern "C" SEXP RAdbcAsyncTaskLaunchStreamGetNext(SEXP task_xptr, SEXP stream_xptr, + SEXP array_xptr) { + auto task = adbc_from_xptr(task_xptr); + error_for_started_task(task); + + auto stream = adbc_from_xptr(stream_xptr); + auto array = adbc_from_xptr(array_xptr); + + task->result = std::async(std::launch::async, [task, stream, array] { + *(task->return_code) = stream->get_next(stream, array); + task->ScheduleCallbackIfSet(); + }); + + task->status = RAdbcAsyncTaskStatus::STARTED; + return R_NilValue; +} diff --git a/r/adbcdrivermanager/src/init.c b/r/adbcdrivermanager/src/init.c index ad7ff6dcb9..2ea4c7dde3 100644 --- a/r/adbcdrivermanager/src/init.c +++ b/r/adbcdrivermanager/src/init.c @@ -20,6 +20,19 @@ #include /* generated by tools/make-callentries.R */ +SEXP RAdbcAsyncTaskNew(SEXP error_xptr); +SEXP RAdbcAsyncTaskSetCallback(SEXP task_xptr, SEXP callback_resolve_sexp, + SEXP callback_reject_sexp, SEXP loop_id_sexp); +SEXP RAdbcAsyncTaskData(SEXP task_xptr); +SEXP RAdbcAsyncTaskWaitFor(SEXP task_xptr, SEXP duration_ms_sexp); +SEXP RAdbcAsyncTaskWait(SEXP task_xptr, SEXP resolution_ms_sexp); +SEXP RAdbcAsyncTaskLaunchSleep(SEXP task_xptr, SEXP duration_ms_sexp); +SEXP RAdbcAsyncTaskLaunchPrepare(SEXP task_xptr, SEXP statement_xptr); +SEXP RAdbcAsyncTaskLaunchExecuteQuery(SEXP task_xptr, SEXP statement_xptr, + SEXP stream_xptr); +SEXP RAdbcAsyncTaskLaunchStreamGetSchema(SEXP task_xptr, SEXP stream_xptr, + SEXP schema_xptr); +SEXP RAdbcAsyncTaskLaunchStreamGetNext(SEXP task_xptr, SEXP stream_xptr, SEXP array_xptr); SEXP RAdbcVoidDriverInitFunc(void); SEXP RAdbcMonkeyDriverInitFunc(void); SEXP RAdbcLogDriverInitFunc(void); @@ -102,6 +115,17 @@ SEXP RAdbcXptrEnv(SEXP xptr); SEXP RAdbcXptrSetProtected(SEXP xptr, SEXP prot); static const R_CallMethodDef CallEntries[] = { + {"RAdbcAsyncTaskNew", (DL_FUNC)&RAdbcAsyncTaskNew, 1}, + {"RAdbcAsyncTaskSetCallback", (DL_FUNC)&RAdbcAsyncTaskSetCallback, 4}, + {"RAdbcAsyncTaskData", (DL_FUNC)&RAdbcAsyncTaskData, 1}, + {"RAdbcAsyncTaskWaitFor", (DL_FUNC)&RAdbcAsyncTaskWaitFor, 2}, + {"RAdbcAsyncTaskWait", (DL_FUNC)&RAdbcAsyncTaskWait, 2}, + {"RAdbcAsyncTaskLaunchSleep", (DL_FUNC)&RAdbcAsyncTaskLaunchSleep, 2}, + {"RAdbcAsyncTaskLaunchPrepare", (DL_FUNC)&RAdbcAsyncTaskLaunchPrepare, 2}, + {"RAdbcAsyncTaskLaunchExecuteQuery", (DL_FUNC)&RAdbcAsyncTaskLaunchExecuteQuery, 3}, + {"RAdbcAsyncTaskLaunchStreamGetSchema", (DL_FUNC)&RAdbcAsyncTaskLaunchStreamGetSchema, + 3}, + {"RAdbcAsyncTaskLaunchStreamGetNext", (DL_FUNC)&RAdbcAsyncTaskLaunchStreamGetNext, 3}, {"RAdbcVoidDriverInitFunc", (DL_FUNC)&RAdbcVoidDriverInitFunc, 0}, {"RAdbcMonkeyDriverInitFunc", (DL_FUNC)&RAdbcMonkeyDriverInitFunc, 0}, {"RAdbcLogDriverInitFunc", (DL_FUNC)&RAdbcLogDriverInitFunc, 0}, diff --git a/r/adbcdrivermanager/src/radbc.h b/r/adbcdrivermanager/src/radbc.h index 4f1ec28317..72cf10c1a0 100644 --- a/r/adbcdrivermanager/src/radbc.h +++ b/r/adbcdrivermanager/src/radbc.h @@ -24,6 +24,17 @@ #include +static inline SEXP adbc_new_env() { + SEXP new_env_sym = PROTECT(Rf_install("new_env")); + SEXP new_env_call = PROTECT(Rf_lang1(new_env_sym)); + SEXP pkg_chr = PROTECT(Rf_mkString("adbcdrivermanager")); + SEXP pkg_ns = PROTECT(R_FindNamespace(pkg_chr)); + SEXP new_env = PROTECT(Rf_eval(new_env_call, pkg_ns)); + UNPROTECT(5); + + return new_env; +} + template static inline const char* adbc_xptr_class(); @@ -89,13 +100,9 @@ static inline SEXP adbc_borrow_xptr(T* ptr, SEXP shelter_sexp = R_NilValue) { Rf_setAttrib(xptr, R_ClassSymbol, xptr_class); UNPROTECT(1); - SEXP new_env_sym = PROTECT(Rf_install("new_env")); - SEXP new_env_call = PROTECT(Rf_lang1(new_env_sym)); - SEXP pkg_chr = PROTECT(Rf_mkString("adbcdrivermanager")); - SEXP pkg_ns = PROTECT(R_FindNamespace(pkg_chr)); - SEXP new_env = PROTECT(Rf_eval(new_env_call, pkg_ns)); + SEXP new_env = PROTECT(adbc_new_env()); R_SetExternalPtrTag(xptr, new_env); - UNPROTECT(5); + UNPROTECT(1); UNPROTECT(1); return xptr; diff --git a/r/adbcdrivermanager/tests/testthat/test-async.R b/r/adbcdrivermanager/tests/testthat/test-async.R new file mode 100644 index 0000000000..517b51a5ff --- /dev/null +++ b/r/adbcdrivermanager/tests/testthat/test-async.R @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("async tasks can be created and inspected", { + task <- adbc_async_task(subclass = "specific_cls") + expect_s3_class(task, "adbc_async_task") + expect_s3_class(task, "specific_cls") + + expect_identical( + names(task), + c("error_xptr", "return_code", "user_data", "resolve", "reject") + ) + + expect_s3_class(task$error_xptr, "adbc_error") + expect_identical(task$return_code, NA_integer_) + + expect_identical(adbc_async_task_status(task), "not_started") +}) + +test_that("async tasks can update R-level user data", { + task <- adbc_async_task() + expect_identical(as.list(task$user_data), list()) + + user_data <- task$user_data + user_data$some_field <- "some_value" + expect_identical(task$user_data$some_field, "some_value") +}) + +test_that("async task methods error for invalid input", { + task <- unserialize(serialize(adbc_async_task(), NULL)) + expect_error( + names(task), + "Can't convert external pointer to NULL" + ) + + expect_error( + adbc_async_task_wait_non_cancellable(adbc_async_task(), -1), + "duration_ms must be >= 0" + ) +}) + +test_that("async sleeper test works", { + sleep_task <- adbc_async_sleep(500) + expect_identical(adbc_async_task_status(sleep_task), "started") + expect_identical(adbc_async_task_wait_non_cancellable(sleep_task, 1000), "ready") + expect_identical(adbc_async_task_status(sleep_task), "ready") + expect_identical(sleep_task$return_code, 0L) + expect_identical(adbc_async_task_result(sleep_task), 500) +}) + +test_that("async task waiter works", { + sleep_task <- adbc_async_sleep(500) + expect_identical(adbc_async_task_wait(sleep_task), 500) + + erroring_sleep_task <- adbc_async_sleep(500, error_message = "some error") + expect_error( + adbc_async_task_wait(erroring_sleep_task), + "some error", + class = "adbc_async_sleep_error" + ) +}) + +test_that("async tasks can set an R callback", { + skip_if_not_installed("later") + + async_called <- FALSE + sleep_task <- adbc_async_sleep(200) + adbc_async_task_set_callback(sleep_task, function(x) { async_called <<- TRUE }) + Sys.sleep(0.4) + later::run_now() + expect_true(async_called) + + # Ensure the callback runs even if the task is already finished + async_called <- FALSE + sleep_task <- adbc_async_sleep(0) + Sys.sleep(0.1) + adbc_async_task_set_callback(sleep_task, function(x) { async_called <<- TRUE }) + Sys.sleep(0.1) + expect_true(async_called) + + # Ensure this also works on error + async_called <- FALSE + sleep_task <- adbc_async_sleep(0, error_message = "some error") + Sys.sleep(0.1) + adbc_async_task_set_callback( + sleep_task, + resolve = function(x) NULL, + reject = function(x) { async_called <<- TRUE } + ) + Sys.sleep(0.1) + expect_true(async_called) +}) + +test_that("async task can be converted to a promise", { + skip_if_not_installed("promises") + + # Enough time for most CI runners to handle this + max_wait_s <- 5 + + # Check successful call + async_called <- FALSE + adbc_async_sleep(100) %>% + promises::as.promise() %>% + promises::then( + onFulfilled = function(duration_ms) { + expect_identical(duration_ms, 100) + async_called <<- TRUE + } + ) + + # Only wait for so long before bailing on this test + for (i in seq_len(max_wait_s * 100)) { + later::run_now() + if (async_called) { + break + } + + Sys.sleep(max_wait_s / 100) + } + + expect_true(async_called) + + # Check erroring call + async_called <- FALSE + adbc_async_sleep(100, error_message = "errored after 100 ms") %>% + promises::as.promise() %>% + promises::then( + onRejected = function(reason) { + expect_s3_class(reason, "adbc_async_sleep_error") + async_called <<- TRUE + } + ) + + for (i in seq_len(max_wait_s * 100)) { + later::run_now() + if (async_called) { + break + } + Sys.sleep(max_wait_s / 100) + } + + expect_true(async_called) + +})