diff --git a/r/adbcdrivermanager/R/async.R b/r/adbcdrivermanager/R/async.R index ff54f5cb37..81eaa8b872 100644 --- a/r/adbcdrivermanager/R/async.R +++ b/r/adbcdrivermanager/R/async.R @@ -27,23 +27,35 @@ adbc_async_task_status <- function(task) { .Call(RAdbcAsyncTaskWaitFor, task, 0) } -adbc_async_task_set_callback <- function(task, callback, loop = later::current_loop()) { +adbc_async_task_set_callback <- function(task, resolve, reject = NULL, + loop = later::current_loop()) { # If the task is completed, run the callback (or else the callback # will not run) if (adbc_async_task_status(task) == "ready") { - result <- adbc_async_task_result(task) - callback(result) + adbc_async_task_run_callback(task, resolve, reject) } else { - .Call(RAdbcAsyncTaskSetCallback, task, callback, loop$id) + .Call(RAdbcAsyncTaskSetCallback, task, resolve, reject, loop$id) } invisible(task) } -adbc_async_task_run_callback <- function(task) { - callback <- task$callback - result <- adbc_async_task_result(task) - callback(result) +adbc_async_task_run_callback <- function(task, resolve = task$resolve, + reject = task$reject) { + tryCatch({ + result <- adbc_async_task_result(task) + resolve(result) + }, + error = function(e) { + if (is.null(reject)) { + stop(e) + } else { + reject(e) + } + } + ) + + invisible(task) } adbc_async_task_wait_non_cancellable <- function(task, resolution = 0.05) { @@ -63,36 +75,10 @@ adbc_async_task_wait <- function(task, resolution = 0.05) { adbc_async_task_result(task) } -later_loop_schedule_task_callback <- function(task, resolve, reject, - loop = later::current_loop(), - delay = 0) { - force(task) - force(resolve) - force(reject) - - later::later(function() { - status <- adbc_async_task_status(task) - if (status == "timeout") { - later_loop_schedule_task_callback( - task, - resolve, - reject, - loop = loop, - delay = delay - ) - } else { - tryCatch( - resolve(adbc_async_task_result(task)), - error = function(e) reject(e) - ) - } - }, delay = delay, loop = loop) -} - as.promise.adbc_async_task <- function(task) { force(task) promises::promise(function(resolve, reject) { - later_loop_schedule_task_callback(task, resolve, reject) + adbc_async_task_set_callback(task, resolve, reject) }) } diff --git a/r/adbcdrivermanager/src/async.cc b/r/adbcdrivermanager/src/async.cc index 19379d2b7e..c32afee4ef 100644 --- a/r/adbcdrivermanager/src/async.cc +++ b/r/adbcdrivermanager/src/async.cc @@ -99,7 +99,8 @@ static void error_for_started_task(RAdbcAsyncTask* task) { } extern "C" SEXP RAdbcAsyncTaskNew(SEXP error_xptr) { - const char* names[] = {"error_xptr", "return_code", "user_data", "callback", ""}; + const char* names[] = {"error_xptr", "return_code", "user_data", + "resolve", "reject", ""}; SEXP task_prot = PROTECT(Rf_mkNamed(VECSXP, names)); SET_VECTOR_ELT(task_prot, 0, error_xptr); @@ -125,13 +126,14 @@ extern "C" SEXP RAdbcAsyncTaskNew(SEXP error_xptr) { return task_xptr; } -extern "C" SEXP RAdbcAsyncTaskSetCallback(SEXP task_xptr, SEXP callback_sexp, - SEXP loop_id_sexp) { +extern "C" SEXP RAdbcAsyncTaskSetCallback(SEXP task_xptr, SEXP callback_resolve_sexp, + SEXP callback_reject_sexp, SEXP loop_id_sexp) { auto task = adbc_from_xptr(task_xptr); SEXP task_prot = R_ExternalPtrProtected(task_xptr); int loop_id = adbc_as_int(loop_id_sexp); - SET_VECTOR_ELT(task_prot, 3, callback_sexp); + SET_VECTOR_ELT(task_prot, 3, callback_resolve_sexp); + SET_VECTOR_ELT(task_prot, 4, callback_reject_sexp); task->SetCallback(task_xptr, loop_id); return R_NilValue; } diff --git a/r/adbcdrivermanager/src/init.c b/r/adbcdrivermanager/src/init.c index a2e5a59ac5..1ae9ab30d3 100644 --- a/r/adbcdrivermanager/src/init.c +++ b/r/adbcdrivermanager/src/init.c @@ -21,7 +21,8 @@ /* generated by tools/make-callentries.R */ SEXP RAdbcAsyncTaskNew(SEXP error_xptr); -SEXP RAdbcAsyncTaskSetCallback(SEXP task_xptr, SEXP callback_sexp, SEXP loop_id_sexp); +SEXP RAdbcAsyncTaskSetCallback(SEXP task_xptr, SEXP callback_resolve_sexp, + SEXP callback_reject_sexp, SEXP loop_id_sexp); SEXP RAdbcAsyncTaskData(SEXP task_xptr); SEXP RAdbcAsyncTaskWaitFor(SEXP task_xptr, SEXP duration_ms_sexp); SEXP RAdbcAsyncTaskLaunchSleep(SEXP task_xptr, SEXP duration_ms_sexp); @@ -111,7 +112,7 @@ SEXP RAdbcXptrSetProtected(SEXP xptr, SEXP prot); static const R_CallMethodDef CallEntries[] = { {"RAdbcAsyncTaskNew", (DL_FUNC)&RAdbcAsyncTaskNew, 1}, - {"RAdbcAsyncTaskSetCallback", (DL_FUNC)&RAdbcAsyncTaskSetCallback, 3}, + {"RAdbcAsyncTaskSetCallback", (DL_FUNC)&RAdbcAsyncTaskSetCallback, 4}, {"RAdbcAsyncTaskData", (DL_FUNC)&RAdbcAsyncTaskData, 1}, {"RAdbcAsyncTaskWaitFor", (DL_FUNC)&RAdbcAsyncTaskWaitFor, 2}, {"RAdbcAsyncTaskLaunchSleep", (DL_FUNC)&RAdbcAsyncTaskLaunchSleep, 2}, diff --git a/r/adbcdrivermanager/tests/testthat/test-async.R b/r/adbcdrivermanager/tests/testthat/test-async.R index 60b22f82db..1a48fe90bf 100644 --- a/r/adbcdrivermanager/tests/testthat/test-async.R +++ b/r/adbcdrivermanager/tests/testthat/test-async.R @@ -22,7 +22,7 @@ test_that("async tasks can be created and inspected", { expect_identical( names(task), - c("error_xptr", "return_code", "user_data", "callback") + c("error_xptr", "return_code", "user_data", "resolve", "reject") ) expect_s3_class(task$error_xptr, "adbc_error") @@ -90,6 +90,17 @@ test_that("async tasks can set an R callback", { adbc_async_task_set_callback(sleep_task, function(x) { async_called <<- TRUE }) Sys.sleep(0.1) expect_true(async_called) + + # Ensure this also works on error + async_called <- FALSE + sleep_task <- adbc_async_sleep(0, error_message = "some error") + adbc_async_task_set_callback( + sleep_task, + resolve = function(x) NULL, + reject = function(x) { async_called <<- TRUE } + ) + Sys.sleep(0.1) + expect_true(async_called) }) test_that("async task can be converted to a promise", {