diff --git a/NEWS.md b/NEWS.md index 1ae3a429..6d19ae09 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,11 @@ # caugi (development version) +## Improvements + +- Improved performance of all queries. Speedups are more significant on larger graphs, +but even on small graphs, queries are roughly 5x faster. + + # caugi 1.1.0 ## New Features diff --git a/R/extendr-wrappers.R b/R/extendr-wrappers.R index bd0ba930..560ccf87 100644 --- a/R/extendr-wrappers.R +++ b/R/extendr-wrappers.R @@ -58,6 +58,8 @@ rs_graph_class <- function(session) .Call(wrap__rs_graph_class, session) rs_names <- function(session) .Call(wrap__rs_names, session) +rs_names_subset <- function(session, indices) .Call(wrap__rs_names_subset, session, indices) + rs_index_of <- function(session, name) .Call(wrap__rs_index_of, session, name) rs_indices_of <- function(session, names) .Call(wrap__rs_indices_of, session, names) @@ -68,33 +70,29 @@ rs_is_valid <- function(session) .Call(wrap__rs_is_valid, session) rs_build <- function(session) invisible(.Call(wrap__rs_build, session)) -rs_topological_sort <- function(session) .Call(wrap__rs_topological_sort, session) - -rs_parents_of <- function(session, idxs) .Call(wrap__rs_parents_of, session, idxs) +children <- function(cg, nodes, index) .Call(wrap__children, cg, nodes, index) -rs_children_of <- function(session, idxs) .Call(wrap__rs_children_of, session, idxs) +neighbors <- function(cg, nodes, index, mode) .Call(wrap__neighbors, cg, nodes, index, mode) -rs_undirected_of <- function(session, idxs) .Call(wrap__rs_undirected_of, session, idxs) +ancestors <- function(cg, nodes, index, open) .Call(wrap__ancestors, cg, nodes, index, open) -rs_neighbors_of <- function(session, idxs, mode) .Call(wrap__rs_neighbors_of, session, idxs, mode) +descendants <- function(cg, nodes, index, open) .Call(wrap__descendants, cg, nodes, index, open) -rs_ancestors_of <- function(session, node) .Call(wrap__rs_ancestors_of, session, node) +anteriors <- function(cg, nodes, index, open) .Call(wrap__anteriors, cg, nodes, index, open) -rs_descendants_of <- function(session, node) .Call(wrap__rs_descendants_of, session, node) +posteriors <- function(cg, nodes, index, open) .Call(wrap__posteriors, cg, nodes, index, open) -rs_anteriors_of <- function(session, node) .Call(wrap__rs_anteriors_of, session, node) +markov_blanket <- function(cg, nodes, index) .Call(wrap__markov_blanket, cg, nodes, index) -rs_posteriors_of <- function(session, node) .Call(wrap__rs_posteriors_of, session, node) +exogenous <- function(cg, undirected_as_parents) .Call(wrap__exogenous, cg, undirected_as_parents) -rs_markov_blanket_of <- function(session, node) .Call(wrap__rs_markov_blanket_of, session, node) +topological_sort <- function(cg) .Call(wrap__topological_sort, cg) -rs_spouses_of <- function(session, idxs) .Call(wrap__rs_spouses_of, session, idxs) +spouses <- function(cg, nodes, index) .Call(wrap__spouses, cg, nodes, index) -rs_exogenous_nodes <- function(session, undirected_as_parents) .Call(wrap__rs_exogenous_nodes, session, undirected_as_parents) +districts <- function(cg, nodes, index, all) .Call(wrap__districts, cg, nodes, index, all) -rs_districts <- function(session) .Call(wrap__rs_districts, session) - -rs_district_of <- function(session, idx) .Call(wrap__rs_district_of, session, idx) +parents <- function(cg, nodes, index) .Call(wrap__parents, cg, nodes, index) rs_is_acyclic <- function(session) .Call(wrap__rs_is_acyclic, session) @@ -126,6 +124,8 @@ rs_latent_project <- function(session, latents) .Call(wrap__rs_latent_project, s rs_induced_subgraph <- function(session, keep) .Call(wrap__rs_induced_subgraph, session, keep) +subgraph <- function(cg, nodes, index) .Call(wrap__subgraph, cg, nodes, index) + rs_d_separated <- function(session, xs, ys, z) .Call(wrap__rs_d_separated, session, xs, ys, z) rs_minimal_d_separator <- function(session, xs, ys, include, restrict) .Call(wrap__rs_minimal_d_separator, session, xs, ys, include, restrict) diff --git a/R/queries.R b/R/queries.R index 7fabec85..a62fa069 100644 --- a/R/queries.R +++ b/R/queries.R @@ -719,24 +719,9 @@ edge_types <- function(cg) { #' @concept queries #' #' @export -parents <- function(cg, nodes = NULL, index = NULL) { - check <- .validate_nodes_and_index(nodes, index) - - if (check$index_supplied) { - return(.getter_output( - cg, - rs_parents_of(cg@session, as.integer(index - 1L)), - cg@nodes$name[index] - )) - } - - index <- .nodes_to_indices(cg, nodes) - .getter_output( - cg, - rs_parents_of(cg@session, as.integer(index)), - nodes - ) -} +# Implemented directly in Rust via `parents()` in `src/rust/src/lib.rs`. +parents <- parents +formals(parents) <- alist(cg = , nodes = NULL, index = NULL) #' @title Get children of nodes in a `caugi` #' @@ -770,25 +755,9 @@ parents <- function(cg, nodes = NULL, index = NULL) { #' @concept queries #' #' @export -children <- function(cg, nodes = NULL, index = NULL) { - check <- .validate_nodes_and_index(nodes, index) - - if (check$index_supplied) { - return(.getter_output( - cg, - rs_children_of(cg@session, as.integer(index - 1L)), - cg@nodes$name[index] - )) - } - - index <- .nodes_to_indices(cg, nodes) - - .getter_output( - cg, - rs_children_of(cg@session, as.integer(index)), - nodes - ) -} +# Implemented directly in Rust via `children()` in `src/rust/src/lib.rs`. +children <- children +formals(children) <- alist(cg = , nodes = NULL, index = NULL) #' @title Get neighbors of nodes in a `caugi` #' @@ -858,40 +827,9 @@ children <- function(cg, nodes = NULL, index = NULL) { #' @concept queries #' #' @export -neighbors <- function( - cg, - nodes = NULL, - index = NULL, - mode = c( - "all", - "in", - "out", - "undirected", - "bidirected", - "partial" - ) -) { - check <- .validate_nodes_and_index(nodes, index) - - mode <- match.arg(mode) - - if (check$index_supplied) { - idx <- as.integer(index - 1L) - return(.getter_output( - cg, - rs_neighbors_of(cg@session, idx, mode), - cg@nodes$name[index] - )) - } - - index <- rs_indices_of(cg@session, nodes) - - .getter_output( - cg, - rs_neighbors_of(cg@session, as.integer(index), mode), - nodes - ) -} +# Implemented directly in Rust via `neighbors()` in `src/rust/src/lib.rs`. +neighbors <- neighbors +formals(neighbors) <- alist(cg = , nodes = NULL, index = NULL, mode = "all") #' @rdname neighbors #' @export @@ -928,50 +866,14 @@ neighbours <- neighbors #' @concept queries #' #' @export -ancestors <- function( - cg, +# Implemented directly in Rust via `ancestors()` in `src/rust/src/lib.rs`. +ancestors <- ancestors +formals(ancestors) <- alist( + cg = , nodes = NULL, index = NULL, open = caugi_options("use_open_graph_definition") -) { - if (!is.logical(open) || length(open) != 1L) { - stop("`open` must be a single TRUE or FALSE.", call. = FALSE) - } - - check <- .validate_nodes_and_index(nodes, index) - - if (check$index_supplied) { - idx0_list <- lapply( - as.integer(index - 1L), - function(ix) { - anc <- rs_ancestors_of(cg@session, ix) - if (!open) { - anc <- c(ix, anc) - } - anc - } - ) - return(.getter_output( - cg, - idx0_list, - cg@nodes$name[index] - )) - } - - index <- rs_indices_of(cg@session, nodes) - - idx0_list <- lapply( - as.integer(index), - function(ix) { - anc <- rs_ancestors_of(cg@session, ix) - if (!open) { - anc <- c(ix, anc) - } - anc - } - ) - .getter_output(cg, idx0_list, nodes) -} +) #' @title Get descendants of nodes in a `caugi` #' @@ -1004,50 +906,14 @@ ancestors <- function( #' @concept queries #' #' @export -descendants <- function( - cg, +# Implemented directly in Rust via `descendants()` in `src/rust/src/lib.rs`. +descendants <- descendants +formals(descendants) <- alist( + cg = , nodes = NULL, index = NULL, open = caugi_options("use_open_graph_definition") -) { - if (!is.logical(open) || length(open) != 1L) { - stop("`open` must be a single TRUE or FALSE.", call. = FALSE) - } - - check <- .validate_nodes_and_index(nodes, index) - - if (check$index_supplied) { - idx0_list <- lapply( - as.integer(index - 1L), - function(ix) { - anc <- rs_descendants_of(cg@session, ix) - if (!open) { - anc <- c(ix, anc) - } - anc - } - ) - return(.getter_output( - cg, - idx0_list, - cg@nodes$name[index] - )) - } - - index <- rs_indices_of(cg@session, nodes) - - idx0_list <- lapply( - as.integer(index), - function(ix) { - anc <- rs_descendants_of(cg@session, ix) - if (!open) { - anc <- c(ix, anc) - } - anc - } - ) - .getter_output(cg, idx0_list, nodes) -} +) #' @title Get anteriors of nodes in a `caugi` #' @@ -1096,49 +962,14 @@ descendants <- function( #' @concept queries #' #' @export -anteriors <- function( - cg, +# Implemented directly in Rust via `anteriors()` in `src/rust/src/lib.rs`. +anteriors <- anteriors +formals(anteriors) <- alist( + cg = , nodes = NULL, index = NULL, open = caugi_options("use_open_graph_definition") -) { - if (!is.logical(open) || length(open) != 1L) { - stop("`open` must be a single TRUE or FALSE.", call. = FALSE) - } - check <- .validate_nodes_and_index(nodes, index) - - if (check$index_supplied) { - idx0_list <- lapply( - as.integer(index - 1L), - function(ix) { - anc <- rs_anteriors_of(cg@session, ix) - if (!open) { - anc <- c(ix, anc) - } - anc - } - ) - return(.getter_output( - cg, - idx0_list, - cg@nodes$name[index] - )) - } - - index <- rs_indices_of(cg@session, nodes) - - idx0_list <- lapply( - as.integer(index), - function(ix) { - anc <- rs_anteriors_of(cg@session, ix) - if (!open) { - anc <- c(ix, anc) - } - anc - } - ) - .getter_output(cg, idx0_list, nodes) -} +) #' @title Get posteriors of nodes in a `caugi` #' @@ -1185,63 +1016,14 @@ anteriors <- function( #' @concept queries #' #' @export -posteriors <- function( - cg, +# Implemented directly in Rust via `posteriors()` in `src/rust/src/lib.rs`. +posteriors <- posteriors +formals(posteriors) <- alist( + cg = , nodes = NULL, index = NULL, open = caugi_options("use_open_graph_definition") -) { - if (!is.logical(open) || length(open) != 1L) { - stop("`open` must be a single TRUE or FALSE.", call. = FALSE) - } - nodes_supplied <- !is.null(nodes) - index_supplied <- !is.null(index) - - if (nodes_supplied && index_supplied) { - stop("Supply either `nodes` or `index`, not both.", call. = FALSE) - } - - if (index_supplied) { - idx0_list <- lapply( - as.integer(index - 1L), - function(ix) { - anc <- rs_posteriors_of(cg@session, ix) - if (!open) { - anc <- c(ix, anc) - } - anc - } - ) - return(.getter_output( - cg, - idx0_list, - cg@nodes$name[index] - )) - } - - if (!nodes_supplied) { - stop("Supply one of `nodes` or `index`.", call. = FALSE) - } - - if (!is.character(nodes)) { - stop("`nodes` must be a character vector of node names.", call. = FALSE) - } - - index <- rs_indices_of(cg@session, nodes) - - idx0_list <- lapply( - as.integer(index), - function(ix) { - anc <- rs_posteriors_of(cg@session, ix) - if (!open) { - anc <- c(ix, anc) - } - anc - } - ) - - .getter_output(cg, idx0_list, nodes) -} +) #' @title Get Markov blanket of nodes in a `caugi` #' @@ -1270,29 +1052,9 @@ posteriors <- function( #' @concept queries #' #' @export -markov_blanket <- function(cg, nodes = NULL, index = NULL) { - check <- .validate_nodes_and_index(nodes, index) - - if (check$index_supplied) { - idx0_list <- lapply( - as.integer(index - 1L), - function(ix) rs_markov_blanket_of(cg@session, ix) - ) - return(.getter_output( - cg, - idx0_list, - cg@nodes$name[index] - )) - } - - index <- rs_indices_of(cg@session, nodes) - - idx0_list <- lapply( - as.integer(index), - function(ix) rs_markov_blanket_of(cg@session, ix) - ) - .getter_output(cg, idx0_list, nodes) -} +# Implemented directly in Rust via `markov_blanket()` in `src/rust/src/lib.rs`. +markov_blanket <- markov_blanket +formals(markov_blanket) <- alist(cg = , nodes = NULL, index = NULL) #' @title Get all exogenous nodes in a `caugi` #' @@ -1318,12 +1080,9 @@ markov_blanket <- function(cg, nodes = NULL, index = NULL) { #' @concept queries #' #' @export -exogenous <- function(cg, undirected_as_parents = FALSE) { - is_caugi(cg, throw_error = TRUE) - - idx0 <- rs_exogenous_nodes(cg@session, undirected_as_parents) - cg@nodes$name[idx0 + 1L] -} +# Implemented directly in Rust via `exogenous()` in `src/rust/src/lib.rs`. +exogenous <- exogenous +formals(exogenous) <- alist(cg = , undirected_as_parents = FALSE) #' @title Get a topological ordering of a DAG #' @@ -1357,12 +1116,8 @@ exogenous <- function(cg, undirected_as_parents = FALSE) { #' @concept queries #' #' @export -topological_sort <- function(cg) { - is_caugi(cg, throw_error = TRUE) - - idx0 <- rs_topological_sort(cg@session) - cg@nodes$name[idx0 + 1L] -} +# Implemented directly in Rust via `topological_sort()` in `src/rust/src/lib.rs`. +topological_sort <- topological_sort # ────────────────────────────────────────────────────────────────────────────── # ─────────────────────────── ADMG-specific queries ──────────────────────────── @@ -1392,25 +1147,9 @@ topological_sort <- function(cg) { #' @concept queries #' #' @export -spouses <- function(cg, nodes = NULL, index = NULL) { - check <- .validate_nodes_and_index(nodes, index) - - if (check$index_supplied) { - return(.getter_output( - cg, - rs_spouses_of(cg@session, as.integer(index - 1L)), - cg@nodes$name[index] - )) - } - - index <- rs_indices_of(cg@session, nodes) - - .getter_output( - cg, - rs_spouses_of(cg@session, as.integer(index)), - nodes - ) -} +# Implemented directly in Rust via `spouses()` in `src/rust/src/lib.rs`. +spouses <- spouses +formals(spouses) <- alist(cg = , nodes = NULL, index = NULL) #' @title Get districts (c-components) of an ADMG or AG #' @@ -1446,91 +1185,18 @@ spouses <- function(cg, nodes = NULL, index = NULL) { #' @concept queries #' #' @export -districts <- function(cg, nodes = NULL, index = NULL, all) { - is_caugi(cg, throw_error = TRUE) - +# Implemented directly in Rust via `districts()` in `src/rust/src/lib.rs`, +# with deprecation-warning compatibility handled in R. +districts <- function(cg, nodes = NULL, index = NULL, all = NULL) { if (!missing(all)) { - # TODO: Remove in a future major release warning( "`all` argument is deprecated and will be removed in a future version. ", "To get all districts, simply call `districts(cg)` without `nodes` or `index`.", call. = FALSE ) - - if ( - !is.null(all) && (!is.logical(all) || length(all) != 1L || is.na(all)) - ) { - stop("`all` must be TRUE, FALSE, or NULL.", call. = FALSE) - } - } else { - all <- is.null(nodes) && is.null(index) } - nodes_supplied <- !is.null(nodes) - index_supplied <- !is.null(index) - - if (nodes_supplied && index_supplied) { - stop("Supply either `nodes` or `index`, not both.", call. = FALSE) - } - - if (isTRUE(all) && (nodes_supplied || index_supplied)) { - stop( - "`all = TRUE` cannot be combined with `nodes` or `index`.", - call. = FALSE - ) - } - - if (identical(all, FALSE) && !nodes_supplied && !index_supplied) { - stop( - "`all = FALSE` requires `nodes` or `index` to be supplied.", - call. = FALSE - ) - } - - all_requested <- if (is.null(all)) { - !nodes_supplied && !index_supplied - } else { - isTRUE(all) - } - - if (all_requested) { - idx0_list <- rs_districts(cg@session) - return(lapply(idx0_list, function(idx0) cg@nodes$name[idx0 + 1L])) - } - - if (index_supplied) { - if (!is.numeric(index) || anyNA(index)) { - stop("`index` must be numeric without NA.", call. = FALSE) - } - idx1 <- as.integer(index) - n <- nrow(cg@nodes) - if (any(idx1 < 1L) || any(idx1 > n)) { - stop("`index` out of range (1..n).", call. = FALSE) - } - - idx0_list <- lapply( - as.integer(idx1 - 1L), - function(ix) rs_district_of(cg@session, ix) - ) - return(.getter_output(cg, idx0_list, cg@nodes$name[idx1])) - } - - if (!nodes_supplied) { - stop( - "Supply one of `nodes` or `index`, or set `all = TRUE`.", - call. = FALSE - ) - } - - if (!is.character(nodes) || anyNA(nodes)) { - stop("`nodes` must be a character vector without NA.", call. = FALSE) - } - - idx0 <- rs_indices_of(cg@session, nodes) - idx0_list <- lapply(as.integer(idx0), function(ix) { - rs_district_of(cg@session, ix) - }) - .getter_output(cg, idx0_list, nodes) + .Call(wrap__districts, cg, nodes, index, all) } #' @title M-separation test for AGs and ADMGs @@ -1607,82 +1273,6 @@ m_separated <- function( #' @concept queries #' #' @export -subgraph <- function(cg, nodes = NULL, index = NULL) { - is_caugi(cg, throw_error = TRUE) - session_names <- rs_names(cg@session) - - check <- .validate_nodes_and_index(nodes, index) - - if (check$index_supplied) { - idx1 <- as.integer(index) - n <- length(session_names) - if (any(idx1 < 1L) || any(idx1 > n)) { - stop("`index` out of range (1..n).", call. = FALSE) - } - keep_idx0 <- idx1 - 1L - keep_names <- session_names[idx1] - } else { - pos <- match(nodes, session_names) - if (anyNA(pos)) { - miss <- nodes[is.na(pos)] - stop( - "Unknown node(s): ", - paste(unique(miss), collapse = ", "), - call. = FALSE - ) - } - keep_names <- nodes - keep_idx0 <- pos - 1L - } - - if (anyDuplicated(keep_idx0)) { - dpos <- duplicated(keep_idx0) | duplicated(keep_idx0, fromLast = TRUE) - stop( - "`nodes`/`index` contains duplicates: ", - paste(unique(keep_names[dpos]), collapse = ", "), - call. = FALSE - ) - } - - sub_session <- rs_induced_subgraph( - cg@session, - as.integer(keep_idx0) - ) - .session_to_caugi(sub_session, node_names = keep_names) -} - -# ────────────────────────────────────────────────────────────────────────────── -# ──────────────────────────── Relations helpers ─────────────────────────────── -# ────────────────────────────────────────────────────────────────────────────── - -#' @title Output object of getter queries -#' -#' @description Helper to format the output of getter queries. -#' -#' @param cg A `caugi` object. -#' @param idx0 A vector of zero-based node indices. -#' @param nodes A vector of node names. -#' -#' @returns A list of character vectors, each a set of node names. -#' If only one node is requested, returns a character vector. -#' -#' @keywords internal -.getter_output <- function(cg, idx0, nodes) { - nm <- cg@nodes$name - to_names <- function(ix0) { - if (length(ix0) == 0L) { - return(NULL) - } - nm[ix0 + 1L] - } - - # faster check than doing is.null and length == 1, since length(NULL) == 0 - if (length(nodes) <= 1L && length(idx0) == 1L) { - ix <- idx0[[1L]] - return(to_names(ix)) - } - - out <- lapply(idx0, to_names) - names(out) <- nodes - out -} +# Implemented directly in Rust via `subgraph()` in `src/rust/src/lib.rs`. +subgraph <- subgraph +formals(subgraph) <- alist(cg = , nodes = NULL, index = NULL) diff --git a/man/districts.Rd b/man/districts.Rd index dca26915..47f9fa1c 100644 --- a/man/districts.Rd +++ b/man/districts.Rd @@ -4,7 +4,7 @@ \alias{districts} \title{Get districts (c-components) of an ADMG or AG} \usage{ -districts(cg, nodes = NULL, index = NULL, all) +districts(cg, nodes = NULL, index = NULL, all = NULL) } \arguments{ \item{cg}{A \code{caugi} object of class ADMG or AG.} diff --git a/man/dot-getter_output.Rd b/man/dot-getter_output.Rd deleted file mode 100644 index 154e90e1..00000000 --- a/man/dot-getter_output.Rd +++ /dev/null @@ -1,23 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/queries.R -\name{.getter_output} -\alias{.getter_output} -\title{Output object of getter queries} -\usage{ -.getter_output(cg, idx0, nodes) -} -\arguments{ -\item{cg}{A \code{caugi} object.} - -\item{idx0}{A vector of zero-based node indices.} - -\item{nodes}{A vector of node names.} -} -\value{ -A list of character vectors, each a set of node names. -If only one node is requested, returns a character vector. -} -\description{ -Helper to format the output of getter queries. -} -\keyword{internal} diff --git a/man/neighbors.Rd b/man/neighbors.Rd index d243b148..f34fabd2 100644 --- a/man/neighbors.Rd +++ b/man/neighbors.Rd @@ -5,19 +5,9 @@ \alias{neighbours} \title{Get neighbors of nodes in a \code{caugi}} \usage{ -neighbors( - cg, - nodes = NULL, - index = NULL, - mode = c("all", "in", "out", "undirected", "bidirected", "partial") -) +neighbors(cg, nodes = NULL, index = NULL, mode = "all") -neighbours( - cg, - nodes = NULL, - index = NULL, - mode = c("all", "in", "out", "undirected", "bidirected", "partial") -) +neighbours(cg, nodes = NULL, index = NULL, mode = "all") } \arguments{ \item{cg}{A \code{caugi} object.} diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 3b9e4ef3..956bf9cf 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -48,6 +48,279 @@ fn rbool_to_bool(x: Rbool, field: &str) -> bool { x.is_true() } +fn indices_to_names(indices: &[u32], names: &[String]) -> Robj { + indices + .iter() + .map(|&idx| names[idx as usize].as_str()) + .collect_robj() +} + +fn indices_to_names_or_null(indices: &[u32], names: &[String]) -> Robj { + if indices.is_empty() { + ().into_robj() + } else { + indices_to_names(indices, names) + } +} + +fn named_list(values: Vec, names: Vec) -> Robj { + let mut out = extendr_api::prelude::List::from_values(values); + out.set_names(names.iter().map(|s| s.as_str())) + .unwrap_or_else(|e| throw_r_error(e.to_string())); + out.into_robj() +} + +fn session_ptr_from_cg(cg: &Robj) -> ExternalPtr { + let session = cg + .get_attrib(sym!(session)) + .unwrap_or_else(|| throw_r_error("Input must be a caugi")); + if session.is_null() { + throw_r_error("Cannot look up indices for empty graph."); + } + session + .try_into() + .unwrap_or_else(|_| throw_r_error("Input must be a caugi")) +} + +fn parse_parent_nodes(nodes: Robj) -> Vec { + let node_strings: Strings = nodes.try_into().unwrap_or_else(|_| { + throw_r_error("`nodes` must be a character vector of node names.") + }); + + let mut out = Vec::with_capacity(node_strings.len()); + for i in 0..node_strings.len() { + let s = node_strings.elt(i); + if s.is_na() { + throw_r_error("`nodes` cannot contain NA values."); + } + out.push(s.to_string()); + } + out +} + +fn parse_parent_index0(index: Robj) -> Vec { + if index.is_integer() { + let idx_int: Integers = index + .try_into() + .unwrap_or_else(|_| throw_r_error("`index` must be numeric.")); + let mut out = Vec::with_capacity(idx_int.len()); + for x in idx_int.iter() { + if x.is_na() { + throw_r_error("`index` cannot contain NA values."); + } + out.push(x.inner() - 1); + } + return out; + } + + if index.is_real() { + let idx_num: Doubles = index + .try_into() + .unwrap_or_else(|_| throw_r_error("`index` must be numeric.")); + let mut out = Vec::with_capacity(idx_num.len()); + for x in idx_num.iter() { + if x.is_na() { + throw_r_error("`index` cannot contain NA values."); + } + out.push((x.inner() - 1.0).trunc() as i32); + } + return out; + } + + throw_r_error("`index` must be numeric."); +} + +fn parse_subgraph_index1(index: Robj) -> Vec { + if index.is_integer() { + let idx_int: Integers = index + .try_into() + .unwrap_or_else(|_| throw_r_error("`index` must be numeric.")); + let mut out = Vec::with_capacity(idx_int.len()); + for x in idx_int.iter() { + if x.is_na() { + throw_r_error("`index` cannot contain NA values."); + } + out.push(x.inner()); + } + return out; + } + + if index.is_real() { + let idx_num: Doubles = index + .try_into() + .unwrap_or_else(|_| throw_r_error("`index` must be numeric.")); + let mut out = Vec::with_capacity(idx_num.len()); + for x in idx_num.iter() { + if x.is_na() { + throw_r_error("`index` cannot contain NA values."); + } + out.push(x.inner().trunc() as i32); + } + return out; + } + + throw_r_error("`index` must be numeric."); +} + +fn resolve_query_idx0( + session: &ExternalPtr, + nodes: Robj, + index: Robj, + missing_msg: &str, +) -> Vec { + let nodes_supplied = !nodes.is_null(); + let index_supplied = !index.is_null(); + + if nodes_supplied && index_supplied { + throw_r_error("Supply either `nodes` or `index`, not both."); + } + if !nodes_supplied && !index_supplied { + throw_r_error(missing_msg); + } + + if nodes_supplied { + let node_names = parse_parent_nodes(nodes); + session + .as_ref() + .indices_of(&node_names) + .unwrap_or_else(|e| throw_r_error(e)) + .into_iter() + .map(|x| x as i32) + .collect() + } else { + parse_parent_index0(index) + } +} + +fn parse_single_logical(value: Robj, err_msg: &str) -> bool { + let vals: Logicals = value.try_into().unwrap_or_else(|_| throw_r_error(err_msg)); + if vals.len() != 1 { + throw_r_error(err_msg); + } + let out = vals.elt(0); + if out.is_na() { + throw_r_error(err_msg); + } + out.is_true() +} + +fn parse_open_arg(open: Robj) -> bool { + parse_single_logical(open, "`open` must be a single TRUE or FALSE.") +} + +fn parse_neighbors_mode(mode: Robj) -> String { + let mode_vals: Strings = mode + .try_into() + .unwrap_or_else(|_| throw_r_error("`mode` must be a character vector.")); + if mode_vals.len() != 1 { + throw_r_error("`mode` must be length 1."); + } + + let raw = mode_vals.elt(0); + if raw.is_na() { + throw_r_error("`mode` cannot contain NA values."); + } + let mode_lc = raw.as_str().to_ascii_lowercase(); + + const CHOICES: [&str; 6] = ["all", "in", "out", "undirected", "bidirected", "partial"]; + if CHOICES.iter().any(|&m| m == mode_lc) { + return mode_lc; + } + + let matches: Vec<&str> = CHOICES + .iter() + .copied() + .filter(|m| m.starts_with(&mode_lc)) + .collect(); + match matches.len() { + 1 => matches[0].to_string(), + 0 => throw_r_error("`mode` must be one of: all, in, out, undirected, bidirected, partial."), + _ => throw_r_error("`mode` is ambiguous. Use one of: all, in, out, undirected, bidirected, partial."), + } +} + +fn parse_district_index1(index: Robj) -> Vec { + let err_msg = "`index` must be numeric without NA."; + + if index.is_integer() { + let vals: Integers = index.try_into().unwrap_or_else(|_| throw_r_error(err_msg)); + let mut out = Vec::with_capacity(vals.len()); + for x in vals.iter() { + if x.is_na() { + throw_r_error(err_msg); + } + out.push(x.inner()); + } + return out; + } + + if index.is_real() { + let vals: Doubles = index.try_into().unwrap_or_else(|_| throw_r_error(err_msg)); + let mut out = Vec::with_capacity(vals.len()); + for x in vals.iter() { + if x.is_na() { + throw_r_error(err_msg); + } + out.push(x.inner().trunc() as i32); + } + return out; + } + + throw_r_error(err_msg); +} + +fn parse_optional_single_logical(value: Robj, err_msg: &str) -> Option { + if value.is_null() { + return None; + } + let vals: Logicals = value.try_into().unwrap_or_else(|_| throw_r_error(err_msg)); + if vals.len() != 1 { + throw_r_error(err_msg); + } + let out = vals.elt(0); + if out.is_na() { + throw_r_error(err_msg); + } + Some(out.is_true()) +} + +fn run_relation_query( + session: &mut ExternalPtr, + idx0: Vec, + scalar_field: &str, + vector_field: &str, + mut query: F, +) -> Robj +where + F: FnMut(&mut ExternalPtr, u32) -> std::result::Result, String>, +{ + if idx0.len() <= 1 { + if idx0.is_empty() { + throw_r_error("Expected non zero length"); + } + let i = rint_to_u32(Rint::from(idx0[0]), scalar_field); + if i >= session.as_ref().n() { + throw_r_error(format!("Index {} is out of bounds", i)); + } + let v = query(session, i).unwrap_or_else(|e| throw_r_error(e)); + return indices_to_names_or_null(&v, session.as_ref().names()); + } + + let mut out: Vec = Vec::with_capacity(idx0.len()); + let mut out_names: Vec = Vec::with_capacity(idx0.len()); + for ii in idx0 { + let i = rint_to_u32(Rint::from(ii), vector_field); + if i >= session.as_ref().n() { + throw_r_error(format!("Index {} is out of bounds", i)); + } + out_names.push(session.as_ref().names()[i as usize].clone()); + let v = query(session, i).unwrap_or_else(|e| throw_r_error(e)); + out.push(indices_to_names_or_null(&v, session.as_ref().names())); + } + + named_list(out, out_names) +} + /// Convert coordinate pairs to R list with x and y vectors. fn coords_to_list(coords: Vec<(f64, f64)>) -> Robj { let mut x: Vec = Vec::with_capacity(coords.len()); @@ -123,6 +396,13 @@ fn session_from_view(view: GraphView, node_names: Vec) -> GraphSession { session } +fn caugi_from_session_ptr(template: &Robj, session: ExternalPtr) -> Robj { + let mut out = template.duplicate(); + out.set_attrib(sym!(session), Robj::from(session)) + .unwrap_or_else(|e| throw_r_error(e.to_string())); + out +} + // ── Edge Registry ──────────────────────────────────────────────────────────────── #[extendr] @@ -793,6 +1073,16 @@ fn rs_names(session: ExternalPtr) -> Strings { .collect() } +#[extendr] +fn rs_names_subset(session: ExternalPtr, indices: Vec) -> Strings { + let names = session.as_ref().names(); + + indices + .iter() + .map(|&i| names[i as usize].as_str()) + .collect() +} + #[extendr] fn rs_index_of(session: ExternalPtr, name: &str) -> Robj { match session.as_ref().index_of(name) { @@ -865,211 +1155,307 @@ fn rs_build(mut session: ExternalPtr) { session.as_mut().view().unwrap_or_else(|e| throw_r_error(e)); } -// Query accessors #[extendr] -fn rs_topological_sort(mut session: ExternalPtr) -> Robj { - let result = session - .as_mut() - .topological_sort() - .unwrap_or_else(|e| throw_r_error(e)); - result.iter().map(|&x| x as i32).collect_robj() +fn parents( + cg: Robj, + nodes: Robj, + index: Robj, +) -> Robj { + let mut session = session_ptr_from_cg(&cg); + let idx0 = resolve_query_idx0(&session, nodes, index, "Must supply either `nodes` or `index`."); + run_relation_query(&mut session, idx0, "idx", "idxs", |s, i| { + s.as_mut().parents_of(i).map_err(|e| e) + }) } #[extendr] -fn rs_parents_of(mut session: ExternalPtr, idxs: Integers) -> Robj { - let mut out: Vec = Vec::with_capacity(idxs.len()); - for ri in idxs.iter() { - let i = rint_to_u32(ri, "idxs"); - if i >= session.as_ref().n() { - throw_r_error(format!("Index {} is out of bounds", i)); - } - let v = session - .as_mut() - .parents_of(i) - .unwrap_or_else(|e| throw_r_error(e)); - out.push(v.iter().map(|&x| x as i32).collect_robj()); - } - extendr_api::prelude::List::from_values(out).into_robj() +fn children( + cg: Robj, + nodes: Robj, + index: Robj, +) -> Robj { + let mut session = session_ptr_from_cg(&cg); + let idx0 = resolve_query_idx0(&session, nodes, index, "Must supply either `nodes` or `index`."); + run_relation_query(&mut session, idx0, "idx", "idxs", |s, i| { + s.as_mut().children_of(i).map_err(|e| e) + }) } #[extendr] -fn rs_children_of(mut session: ExternalPtr, idxs: Integers) -> Robj { - let mut out: Vec = Vec::with_capacity(idxs.len()); - for ri in idxs.iter() { - let i = rint_to_u32(ri, "idxs"); - if i >= session.as_ref().n() { - throw_r_error(format!("Index {} is out of bounds", i)); - } - let v = session - .as_mut() - .children_of(i) - .unwrap_or_else(|e| throw_r_error(e)); - out.push(v.iter().map(|&x| x as i32).collect_robj()); - } - extendr_api::prelude::List::from_values(out).into_robj() +fn neighbors( + cg: Robj, + nodes: Robj, + index: Robj, + mode: Robj, +) -> Robj { + use graph::NeighborMode; + + let mut session = session_ptr_from_cg(&cg); + let idx0 = resolve_query_idx0(&session, nodes, index, "Must supply either `nodes` or `index`."); + let mode_norm = parse_neighbors_mode(mode); + let neighbor_mode = NeighborMode::from_str(mode_norm.as_str()).unwrap_or_else(|e| throw_r_error(e)); + + run_relation_query(&mut session, idx0, "idx", "idxs", move |s, i| { + s.as_mut().neighbors_of(i, neighbor_mode).map_err(|e| e) + }) } #[extendr] -fn rs_undirected_of(mut session: ExternalPtr, idxs: Integers) -> Robj { - let mut out: Vec = Vec::with_capacity(idxs.len()); - for ri in idxs.iter() { - let i = rint_to_u32(ri, "idxs"); - if i >= session.as_ref().n() { - throw_r_error(format!("Index {} is out of bounds", i)); +fn ancestors( + cg: Robj, + nodes: Robj, + index: Robj, + open: Robj, +) -> Robj { + let mut session = session_ptr_from_cg(&cg); + let idx0 = resolve_query_idx0(&session, nodes, index, "Must supply either `nodes` or `index`."); + let open_flag = parse_open_arg(open); + + run_relation_query(&mut session, idx0, "node", "node", move |s, i| { + let mut v = s.as_mut().ancestors_of(i).map_err(|e| e)?; + if !open_flag { + v.insert(0, i); } - let v = session - .as_mut() - .undirected_of(i) - .unwrap_or_else(|e| throw_r_error(e)); - out.push(v.iter().map(|&x| x as i32).collect_robj()); - } - extendr_api::prelude::List::from_values(out).into_robj() + Ok(v) + }) } #[extendr] -fn rs_neighbors_of(mut session: ExternalPtr, idxs: Integers, mode: Strings) -> Robj { - use graph::NeighborMode; - // Allow single mode to apply to all indices, or one mode per index - if mode.len() != 1 && mode.len() != idxs.len() { - throw_r_error("mode must be length 1 or match index length"); - } - let single_mode = mode.len() == 1; - let first_mode = if single_mode { - Some( - NeighborMode::from_str(mode.iter().next().unwrap().as_str()) - .unwrap_or_else(|e| throw_r_error(e)), - ) - } else { - None - }; - - let mut out: Vec = Vec::with_capacity(idxs.len()); - for (idx, ri) in idxs.iter().enumerate() { - let i = rint_to_u32(ri, "idxs"); - if i >= session.as_ref().n() { - throw_r_error(format!("Index {} is out of bounds", i)); +fn descendants( + cg: Robj, + nodes: Robj, + index: Robj, + open: Robj, +) -> Robj { + let mut session = session_ptr_from_cg(&cg); + let idx0 = resolve_query_idx0(&session, nodes, index, "Must supply either `nodes` or `index`."); + let open_flag = parse_open_arg(open); + + run_relation_query(&mut session, idx0, "node", "node", move |s, i| { + let mut v = s.as_mut().descendants_of(i).map_err(|e| e)?; + if !open_flag { + v.insert(0, i); } - let neighbor_mode = if single_mode { - first_mode.unwrap() - } else { - NeighborMode::from_str(mode.elt(idx).as_str()).unwrap_or_else(|e| throw_r_error(e)) - }; - let v = session - .as_mut() - .neighbors_of(i, neighbor_mode) - .unwrap_or_else(|e| throw_r_error(e)); - out.push(v.iter().map(|&x| x as i32).collect_robj()); - } - extendr_api::prelude::List::from_values(out).into_robj() + Ok(v) + }) } #[extendr] -fn rs_ancestors_of(mut session: ExternalPtr, node: i32) -> Robj { - let idx = rint_to_u32(Rint::from(node), "node"); - let result = session - .as_mut() - .ancestors_of(idx) - .unwrap_or_else(|e| throw_r_error(e)); - result.iter().map(|&x| x as i32).collect_robj() +fn anteriors( + cg: Robj, + nodes: Robj, + index: Robj, + open: Robj, +) -> Robj { + let mut session = session_ptr_from_cg(&cg); + let idx0 = resolve_query_idx0(&session, nodes, index, "Must supply either `nodes` or `index`."); + let open_flag = parse_open_arg(open); + + run_relation_query(&mut session, idx0, "node", "node", move |s, i| { + let mut v = s.as_mut().anteriors_of(i).map_err(|e| e)?; + if !open_flag { + v.insert(0, i); + } + Ok(v) + }) } #[extendr] -fn rs_descendants_of(mut session: ExternalPtr, node: i32) -> Robj { - let idx = rint_to_u32(Rint::from(node), "node"); - let result = session - .as_mut() - .descendants_of(idx) - .unwrap_or_else(|e| throw_r_error(e)); - result.iter().map(|&x| x as i32).collect_robj() +fn posteriors( + cg: Robj, + nodes: Robj, + index: Robj, + open: Robj, +) -> Robj { + let mut session = session_ptr_from_cg(&cg); + let idx0 = resolve_query_idx0(&session, nodes, index, "Supply one of `nodes` or `index`."); + let open_flag = parse_open_arg(open); + + run_relation_query(&mut session, idx0, "node", "node", move |s, i| { + let mut v = s.as_mut().posteriors_of(i).map_err(|e| e)?; + if !open_flag { + v.insert(0, i); + } + Ok(v) + }) } #[extendr] -fn rs_anteriors_of(mut session: ExternalPtr, node: i32) -> Robj { - let idx = rint_to_u32(Rint::from(node), "node"); - let result = session - .as_mut() - .anteriors_of(idx) - .unwrap_or_else(|e| throw_r_error(e)); - result.iter().map(|&x| x as i32).collect_robj() +fn markov_blanket( + cg: Robj, + nodes: Robj, + index: Robj, +) -> Robj { + let mut session = session_ptr_from_cg(&cg); + let idx0 = resolve_query_idx0(&session, nodes, index, "Must supply either `nodes` or `index`."); + run_relation_query(&mut session, idx0, "node", "node", |s, i| { + s.as_mut().markov_blanket_of(i).map_err(|e| e) + }) } #[extendr] -fn rs_posteriors_of(mut session: ExternalPtr, node: i32) -> Robj { - let idx = rint_to_u32(Rint::from(node), "node"); - let result = session - .as_mut() - .posteriors_of(idx) - .unwrap_or_else(|e| throw_r_error(e)); - result.iter().map(|&x| x as i32).collect_robj() +fn spouses( + cg: Robj, + nodes: Robj, + index: Robj, +) -> Robj { + let mut session = session_ptr_from_cg(&cg); + let idx0 = resolve_query_idx0(&session, nodes, index, "Must supply either `nodes` or `index`."); + run_relation_query(&mut session, idx0, "idx", "idxs", |s, i| { + s.as_mut().spouses_of(i).map_err(|e| e) + }) } #[extendr] -fn rs_markov_blanket_of(mut session: ExternalPtr, node: i32) -> Robj { - let idx = rint_to_u32(Rint::from(node), "node"); - let result = session - .as_mut() - .markov_blanket_of(idx) - .unwrap_or_else(|e| throw_r_error(e)); - result.iter().map(|&x| x as i32).collect_robj() -} +fn districts( + cg: Robj, + nodes: Robj, + index: Robj, + all: Robj, +) -> Robj { + let mut session = session_ptr_from_cg(&cg); -#[extendr] -fn rs_spouses_of(mut session: ExternalPtr, idxs: Integers) -> Robj { - let mut out: Vec = Vec::with_capacity(idxs.len()); - for ri in idxs.iter() { - let i = rint_to_u32(ri, "idxs"); - if i >= session.as_ref().n() { - throw_r_error(format!("Index {} is out of bounds", i)); + let nodes_supplied = !nodes.is_null(); + let index_supplied = !index.is_null(); + + if nodes_supplied && index_supplied { + throw_r_error("Supply either `nodes` or `index`, not both."); + } + + let all_opt = parse_optional_single_logical(all, "`all` must be TRUE, FALSE, or NULL."); + if all_opt == Some(true) && (nodes_supplied || index_supplied) { + throw_r_error("`all = TRUE` cannot be combined with `nodes` or `index`."); + } + + if all_opt == Some(false) && !nodes_supplied && !index_supplied { + throw_r_error("`all = FALSE` requires `nodes` or `index` to be supplied."); + } + + let all_requested = match all_opt { + None => !nodes_supplied && !index_supplied, + Some(v) => v, + }; + + if all_requested { + let result = session + .as_mut() + .districts() + .unwrap_or_else(|e| throw_r_error(e)); + + let out: Vec = result + .iter() + .map(|d| indices_to_names(d, session.as_ref().names())) + .collect(); + return extendr_api::prelude::List::from_values(out).into_robj(); + } + + if index_supplied { + let idx1 = parse_district_index1(index); + let n = session.as_ref().n() as i32; + if idx1.iter().any(|&i| i < 1 || i > n) { + throw_r_error("`index` out of range (1..n)."); + } + + if idx1.len() <= 1 { + if idx1.is_empty() { + throw_r_error("Expected non zero length"); + } + let i0 = idx1[0] - 1; + let i = rint_to_u32(Rint::from(i0), "idx"); + let v = session + .as_mut() + .district_of(i) + .unwrap_or_else(|e| throw_r_error(e)); + return indices_to_names(&v, session.as_ref().names()); + } + + let mut out: Vec = Vec::with_capacity(idx1.len()); + let mut out_names: Vec = Vec::with_capacity(idx1.len()); + for i1 in idx1 { + let i0 = i1 - 1; + let i = rint_to_u32(Rint::from(i0), "idxs"); + out_names.push(session.as_ref().names()[i as usize].clone()); + let v = session + .as_mut() + .district_of(i) + .unwrap_or_else(|e| throw_r_error(e)); + out.push(indices_to_names(&v, session.as_ref().names())); + } + return named_list(out, out_names); + } + + if !nodes_supplied { + throw_r_error("Supply one of `nodes` or `index`, or set `all = TRUE`."); + } + + let node_vals: Strings = nodes + .try_into() + .unwrap_or_else(|_| throw_r_error("`nodes` must be a character vector without NA.")); + let mut node_names = Vec::with_capacity(node_vals.len()); + for i in 0..node_vals.len() { + let s = node_vals.elt(i); + if s.is_na() { + throw_r_error("`nodes` must be a character vector without NA."); } + node_names.push(s.to_string()); + } + let idx0: Vec = session + .as_ref() + .indices_of(&node_names) + .unwrap_or_else(|e| throw_r_error(e)) + .into_iter() + .map(|x| x as i32) + .collect(); + + if idx0.len() <= 1 { + if idx0.is_empty() { + throw_r_error("Expected non zero length"); + } + let i = rint_to_u32(Rint::from(idx0[0]), "idx"); let v = session .as_mut() - .spouses_of(i) + .district_of(i) .unwrap_or_else(|e| throw_r_error(e)); - out.push(v.iter().map(|&x| x as i32).collect_robj()); + return indices_to_names(&v, session.as_ref().names()); } - extendr_api::prelude::List::from_values(out).into_robj() -} -#[extendr] -fn rs_exogenous_nodes( - mut session: ExternalPtr, - undirected_as_parents: Rbool, -) -> Robj { - let result = session - .as_mut() - .exogenous_nodes(undirected_as_parents.is_true()) - .unwrap_or_else(|e| throw_r_error(e)); - result.iter().map(|&x| x as i32).collect_robj() + let mut out: Vec = Vec::with_capacity(idx0.len()); + for &ii in idx0.iter() { + let i = rint_to_u32(Rint::from(ii), "idxs"); + let v = session + .as_mut() + .district_of(i) + .unwrap_or_else(|e| throw_r_error(e)); + out.push(indices_to_names(&v, session.as_ref().names())); + } + named_list(out, node_names) } #[extendr] -fn rs_districts(mut session: ExternalPtr) -> Robj { +fn topological_sort(cg: Robj) -> Robj { + let mut session = session_ptr_from_cg(&cg); let result = session .as_mut() - .districts() + .topological_sort() .unwrap_or_else(|e| throw_r_error(e)); - - let out: Vec = result - .iter() - .map(|d| d.iter().map(|&x| x as i32).collect_robj()) - .collect(); - extendr_api::prelude::List::from_values(out).into_robj() + indices_to_names(&result, session.as_ref().names()) } #[extendr] -fn rs_district_of(mut session: ExternalPtr, idx: i32) -> Robj { - if idx < 0 { - throw_r_error("idx must be >= 0"); - } - let i = idx as u32; - if i >= session.as_ref().n() { - throw_r_error(format!("Index {} is out of bounds", i)); - } - let v = session +fn exogenous( + cg: Robj, + undirected_as_parents: Robj, +) -> Robj { + let mut session = session_ptr_from_cg(&cg); + let undirected = parse_single_logical( + undirected_as_parents, + "`undirected_as_parents` must be a single TRUE or FALSE.", + ); + let result = session .as_mut() - .district_of(i) + .exogenous_nodes(undirected) .unwrap_or_else(|e| throw_r_error(e)); - v.into_iter().map(|x| x as i32).collect_robj() + indices_to_names(&result, session.as_ref().names()) } // ── Session validation / class checks ──────────────────────────────────────── @@ -1222,13 +1608,11 @@ fn rs_latent_project( ExternalPtr::new(session_from_view(view, names)) } -#[extendr] -fn rs_induced_subgraph( - mut session: ExternalPtr, - keep: Integers, +fn induced_subgraph_session_from_keep( + session: &mut ExternalPtr, + keep_u: &[u32], ) -> ExternalPtr { - let keep_u: Vec = keep.iter().map(|ri| rint_to_u32(ri, "keep")).collect(); - for &i in &keep_u { + for &i in keep_u { if i >= session.as_ref().n() { throw_r_error(format!("Index {} is out of bounds", i)); } @@ -1237,13 +1621,12 @@ fn rs_induced_subgraph( let view = session.as_mut().view().unwrap_or_else(|e| throw_r_error(e)); let sub_view = view .as_ref() - .induced_subgraph(&keep_u) + .induced_subgraph(keep_u) .unwrap_or_else(|e| throw_r_error(e)); - let all_names: Vec = session.as_ref().names().to_vec(); let names: Vec = keep_u .iter() - .map(|&i| all_names[i as usize].clone()) + .map(|&i| session.as_ref().names()[i as usize].clone()) .collect(); // Preserve original input edge orientation/order by filtering the source @@ -1270,11 +1653,94 @@ fn rs_induced_subgraph( } } - let mut out = session_from_view(sub_view, names); + // Build output session directly to avoid materializing edge buffer from the + // temporary subgraph core (we already have the desired filtered edge order). + let sub_core = sub_view.core(); + let mut out = GraphSession::from_snapshot( + Arc::new(sub_core.registry.clone()), + sub_core.n(), + sub_core.simple, + graph_class_from_view(&sub_view), + ); + out.set_names(names); out.set_edges(kept_edges); ExternalPtr::new(out) } +#[extendr] +fn rs_induced_subgraph( + mut session: ExternalPtr, + keep: Integers, +) -> ExternalPtr { + let keep_u: Vec = keep.iter().map(|ri| rint_to_u32(ri, "keep")).collect(); + induced_subgraph_session_from_keep(&mut session, &keep_u) +} + +#[extendr] +fn subgraph(cg: Robj, nodes: Robj, index: Robj) -> Robj { + use std::collections::HashSet; + + let mut session = session_ptr_from_cg(&cg); + + let nodes_supplied = !nodes.is_null(); + let index_supplied = !index.is_null(); + + if nodes_supplied && index_supplied { + throw_r_error("Supply either `nodes` or `index`, not both."); + } + if !nodes_supplied && !index_supplied { + throw_r_error("Must supply either `nodes` or `index`."); + } + + let keep_u: Vec = if index_supplied { + let idx1 = parse_subgraph_index1(index); + let n = session.as_ref().n() as i32; + if idx1.iter().any(|&i| i < 1 || i > n) { + throw_r_error("`index` out of range (1..n)."); + } + idx1.into_iter().map(|i| (i - 1) as u32).collect() + } else { + let node_strings: Strings = nodes.try_into().unwrap_or_else(|_| { + throw_r_error("`nodes` must be a character vector of node names.") + }); + + let mut keep = Vec::with_capacity(node_strings.len()); + let mut miss_seen: HashSet = HashSet::new(); + let mut miss: Vec = Vec::new(); + + for i in 0..node_strings.len() { + let s = node_strings.elt(i); + if s.is_na() { + throw_r_error("`nodes` cannot contain NA values."); + } + let name = s.as_str(); + match session.as_ref().index_of(name) { + Some(idx) => keep.push(idx), + None => { + if miss_seen.insert(name.to_string()) { + miss.push(name.to_string()); + } + } + } + } + + if !miss.is_empty() { + throw_r_error(format!("Unknown node(s): {}", miss.join(", "))); + } + keep + }; + + let mut seen: HashSet = HashSet::with_capacity(keep_u.len()); + for &idx in &keep_u { + if !seen.insert(idx) { + throw_r_error("`nodes`/`index` contains duplicates."); + } + } + + let sub_session = induced_subgraph_session_from_keep(&mut session, &keep_u); + caugi_from_session_ptr(&cg, sub_session) +} + // ── Session causal queries ─────────────────────────────────────────────────── #[extendr] fn rs_d_separated( @@ -1503,25 +1969,24 @@ extendr_module! { fn rs_class; fn rs_graph_class; fn rs_names; + fn rs_names_subset; fn rs_index_of; fn rs_indices_of; fn rs_edges_df; fn rs_is_valid; fn rs_build; - fn rs_topological_sort; - fn rs_parents_of; - fn rs_children_of; - fn rs_undirected_of; - fn rs_neighbors_of; - fn rs_ancestors_of; - fn rs_descendants_of; - fn rs_anteriors_of; - fn rs_posteriors_of; - fn rs_markov_blanket_of; - fn rs_spouses_of; - fn rs_exogenous_nodes; - fn rs_districts; - fn rs_district_of; + fn children; + fn neighbors; + fn ancestors; + fn descendants; + fn anteriors; + fn posteriors; + fn markov_blanket; + fn exogenous; + fn topological_sort; + fn spouses; + fn districts; + fn parents; fn rs_is_acyclic; fn rs_is_dag_type; fn rs_is_pdag_type; @@ -1537,6 +2002,7 @@ extendr_module! { fn rs_moralize; fn rs_latent_project; fn rs_induced_subgraph; + fn subgraph; fn rs_d_separated; fn rs_minimal_d_separator; fn rs_m_separated; diff --git a/tests/testthat/test-builder-rust.R b/tests/testthat/test-builder-rust.R index 80d4a54b..2f91c01f 100644 --- a/tests/testthat/test-builder-rust.R +++ b/tests/testthat/test-builder-rust.R @@ -33,89 +33,38 @@ test_that("graph builder works for directed edge in reverse direction", { reset_caugi_registry() }) -test_that("queries work for DAGs and PDAGs via session", { +test_that("queries work for DAGs and PDAGs", { # DAG EXAMPLE cg <- caugi(A %-->% B, class = "DAG") - expect_identical( - rs_parents_of(cg@session, 0L), - list(integer(0)) - ) - expect_identical( - rs_children_of(cg@session, 0L), - list(1L) - ) - expect_identical( - rs_parents_of(cg@session, 1L), - list(0L) - ) - expect_identical( - rs_children_of(cg@session, 1L), - list(integer(0)) - ) + expect_null(parents(cg, index = 1L)) + expect_identical(children(cg, index = 1L), "B") + expect_identical(parents(cg, index = 2L), "A") + expect_null(children(cg, index = 2L)) cg <- add_edges(cg, B %-->% C) # Session syncs automatically - expect_identical( - rs_parents_of(cg@session, 0L), - list(integer(0)) - ) - expect_identical( - rs_children_of(cg@session, 0L), - list(1L) - ) - expect_identical( - rs_parents_of(cg@session, 1L), - list(0L) - ) - expect_identical( - rs_children_of(cg@session, 1L), - list(2L) - ) - expect_identical( - rs_parents_of(cg@session, 2L), - list(1L) - ) - expect_identical( - rs_children_of(cg@session, 2L), - list(integer(0)) - ) + expect_null(parents(cg, index = 1L)) + expect_identical(children(cg, index = 1L), "B") + expect_identical(parents(cg, index = 2L), "A") + expect_identical(children(cg, index = 2L), "C") + expect_identical(parents(cg, index = 3L), "B") + expect_null(children(cg, index = 3L)) # PDAG EXAMPLE cg <- caugi(A %-->% B, class = "PDAG") - expect_identical( - rs_parents_of(cg@session, 0L), - list(integer(0)) - ) - expect_identical( - rs_children_of(cg@session, 0L), - list(1L) - ) - expect_identical( - rs_parents_of(cg@session, 1L), - list(0L) - ) - expect_identical( - rs_children_of(cg@session, 1L), - list(integer(0)) - ) + expect_null(parents(cg, index = 1L)) + expect_identical(children(cg, index = 1L), "B") + expect_identical(parents(cg, index = 2L), "A") + expect_null(children(cg, index = 2L)) cg <- add_edges(cg, B %---% C) # Session syncs automatically - expect_identical( - rs_parents_of(cg@session, 0L), - list(integer(0)) - ) - expect_identical( - rs_children_of(cg@session, 0L), - list(1L) - ) - expect_identical( - rs_undirected_of(cg@session, 1L), - list(2L) - ) + expect_null(parents(cg, index = 1L)) + expect_identical(children(cg, index = 1L), "B") + expect_identical(neighbors(cg, index = 2L, mode = "undirected"), "C") }) test_that("rs_build compiles core and view lazily", { diff --git a/tests/testthat/test-queries.R b/tests/testthat/test-queries.R index 49d116e9..57e61e3d 100644 --- a/tests/testthat/test-queries.R +++ b/tests/testthat/test-queries.R @@ -681,21 +681,6 @@ test_that("getter queries builds", { } }) -# ────────────────────────────────────────────────────────────────────────────── -# ────────────────────────────── Getter helpers ──────────────────────────────── -# ────────────────────────────────────────────────────────────────────────────── - -test_that(".getter_output returns data frame with name column", { - cg <- caugi(A %-->% B, B %-->% C, class = "DAG") - out <- caugi:::.getter_output(cg, c(0L, 2L), c("A", "C")) - expect_identical(out[["A"]], "A") - expect_identical(out[["C"]], "C") - - out_null <- caugi:::.getter_output(cg, 0L, NULL) - expect_equal(out_null, "A") -}) - - # ────────────────────────────────────────────────────────────────────────────── # ───────────────────────────────── Subgraph ─────────────────────────────────── # ────────────────────────────────────────────────────────────────────────────── diff --git a/tests/testthat/test-registry.R b/tests/testthat/test-registry.R index 2ddae7fc..02efd2bf 100644 --- a/tests/testthat/test-registry.R +++ b/tests/testthat/test-registry.R @@ -134,30 +134,12 @@ test_that("reverse edge, <--, behaves correctly, when initalized in a cg", { reset_caugi_registry() register_caugi_edge("<--", "arrow", "tail", "directed", FALSE) cg <- caugi(A %-->% B, B %<--% C, class = "DAG") - expect_identical( - rs_parents_of(cg@session, 0L), - list(integer(0)) - ) - expect_identical( - rs_children_of(cg@session, 0L), - list(1L) - ) - expect_identical( - rs_parents_of(cg@session, 1L), - list(c(0L, 2L)) - ) - expect_identical( - rs_children_of(cg@session, 1L), - list(integer(0)) - ) - expect_identical( - rs_parents_of(cg@session, 2L), - list(integer(0)) - ) - expect_identical( - rs_children_of(cg@session, 2L), - list(1L) - ) + expect_null(parents(cg, index = 1L)) + expect_identical(children(cg, index = 1L), "B") + expect_setequal(parents(cg, index = 2L), c("A", "C")) + expect_null(children(cg, index = 2L)) + expect_null(parents(cg, index = 3L)) + expect_identical(children(cg, index = 3L), "B") reset_caugi_registry() }) diff --git a/vignettes/articles/performance.Rmd b/vignettes/articles/performance.Rmd index 425db30e..86f1e01d 100644 --- a/vignettes/articles/performance.Rmd +++ b/vignettes/articles/performance.Rmd @@ -109,7 +109,7 @@ bm_parents_children <- bench::mark( plot(bm_parents_children) ``` -As you can see, `bnlearn` performs best for this particular example. +As you can see, `caugi` followed by `bnlearn` performs the best for this particular example. In our next experiment, however, we will examine if this extends to different graph sizes and densities, by parameterizing our benchmark over `n` and `p`. Note that we adjust `p` as a function @@ -192,7 +192,8 @@ plot_parameterized_benchmark <- function(bm) { plot_parameterized_benchmark(bm_parents_children_np) ``` -For ancestors and descendants, we see that `caugi` outperforms all other packages by a several magnitudes, except for `igraph`, which it still beats, but by a smaller margin: +For ancestors and descendants, we see that `caugi` outperforms all other packages by a several magnitudes, +except for `igraph`, which it still beats, but by a smaller margin: ```{r benchmark-an-de} #| fig-cap: "Benchmarking ancestors/descendants queries for different packages." @@ -244,9 +245,12 @@ bm_dsep <- bench::mark( plot(bm_dsep) ``` +We see that `caugi` again outperforms the other packages by a large margin. + #### Subgraph (building) -Here we see an example of where the frontloading hurts performance. When we build a subgraph, we have to rebuild the entire `caugi` graph object. Here, we see that while `caugi` outperforms other packages for queries (except for parents/children for `bnlearn`), it is slower for building the graph objects themselves, which shows in the subgraph benchmark: +Here we see an example of where the frontloading hurts performance. When we build a subgraph, we have to rebuild the entire `caugi` graph object. Here, we see that while `caugi` outperforms other packages for queries, it is slower for building the graph +object itself compared to `igraph`, as seen below: ```{r benchmark-subgraph} #| fig-cap: "Benchmarking subgraph extraction for different packages."