Skip to content

Commit 70e76ba

Browse files
New subgraph implementation
1 parent 234939c commit 70e76ba

5 files changed

Lines changed: 154 additions & 15 deletions

File tree

R/extendr-wrappers.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ markov_blanket_of_ptr <- function(g, idxs) .Call(wrap__markov_blanket_of_ptr, g,
4646

4747
exogenous_nodes_of_ptr <- function(g, undirected_as_parents) .Call(wrap__exogenous_nodes_of_ptr, g, undirected_as_parents)
4848

49+
induced_subgraph_ptr <- function(g, keep) .Call(wrap__induced_subgraph_ptr, g, keep)
50+
4951
is_simple_ptr <- function(g) .Call(wrap__is_simple_ptr, g)
5052

5153
graph_class_ptr <- function(g) .Call(wrap__graph_class_ptr, g)

R/queries.R

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -615,52 +615,70 @@ subgraph <- function(cg, nodes = NULL, index = NULL) {
615615
stop("Supply one of `nodes` or `index`.", call. = FALSE)
616616
}
617617

618+
# resolve -> keep_ids0 (0-based), keep_names, preserving order
618619
if (index_supplied) {
620+
if (!is.numeric(index) || anyNA(index)) {
621+
stop("`index` must be numeric without NA.", call. = FALSE)
622+
}
619623
idx1 <- as.integer(index)
620624
if (any(idx1 < 1L) || any(idx1 > nrow(cg@nodes))) {
621625
stop("`index` out of range (1..n).", call. = FALSE)
622626
}
627+
keep_ids0 <- idx1 - 1L
623628
keep_names <- cg@nodes$name[idx1]
624629
} else {
625630
if (!is.character(nodes)) {
626631
stop("`nodes` must be a character vector.", call. = FALSE)
627632
}
633+
if (anyNA(nodes)) {
634+
stop("`nodes` contains NA.", call. = FALSE)
635+
}
628636
missing <- setdiff(nodes, cg@nodes$name)
629637
if (length(missing)) {
630638
stop("Unknown node(s): ", paste(missing, collapse = ", "), call. = FALSE)
631639
}
632640
keep_names <- nodes
641+
keep_ids0 <- vapply(nodes, cg@name_index_map$get, integer(1))
633642
}
634643

635-
if (any(duplicated(keep_names))) {
636-
dups <- unique(keep_names[duplicated(keep_names)])
637-
stop("`nodes`/`index` contains duplicates: ",
638-
paste(dups, collapse = ", "),
639-
call. = FALSE
640-
)
644+
# duplicates are an error
645+
if (any(duplicated(keep_ids0))) {
646+
dups <- unique(keep_names[duplicated(keep_ids0) | duplicated(keep_ids0, fromLast = TRUE)])
647+
stop("`nodes`/`index` contains duplicates: ", paste(dups, collapse = ", "), call. = FALSE)
641648
}
642649

643-
# filter edges to the kept nodes; keep constructor’s sort
650+
# call Rust (always reindexed)
651+
ptr_sub <- induced_subgraph_ptr(cg@ptr, as.integer(keep_ids0))
652+
653+
# nodes table in input order
654+
nodes_sub <- tibble::tibble(name = keep_names)
655+
656+
# filter edges to kept names and sort like constructor
644657
keep_set <- fastmap::fastmap()
645658
for (nm in keep_names) keep_set$set(nm, TRUE)
646659
edges_sub <- cg@edges |>
647660
dplyr::filter(keep_set$has(.data$from) & keep_set$has(.data$to)) |>
648661
dplyr::arrange(.data$from, .data$to, .data$edge)
649662

650-
# rebuild via constructor: declared nodes contain all edge nodes
651-
caugi_graph(
652-
from = edges_sub$from,
653-
edge = edges_sub$edge,
654-
to = edges_sub$to,
655-
nodes = keep_names,
663+
# rebuild name_index_map
664+
name_index_map_sub <- fastmap::fastmap()
665+
for (i in seq_len(nrow(nodes_sub))) name_index_map_sub$set(nodes_sub$name[i], i - 1L)
666+
667+
state_sub <- .cg_state(
668+
nodes = nodes_sub,
669+
edges = edges_sub,
670+
ptr = ptr_sub,
671+
built = TRUE,
656672
simple = cg@simple,
657-
build = TRUE,
658-
class = cg@graph_class
673+
class = cg@graph_class,
674+
name_index_map = name_index_map_sub
659675
)
676+
caugi_graph(state = state_sub)
660677
}
661678

662679

663680

681+
664682
# ──────────────────────────────────────────────────────────────────────────────
665683
# ──────────────────────────── Relations helpers ───────────────────────────────
666684
# ──────────────────────────────────────────────────────────────────────────────

src/rust/src/graph/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,83 @@ impl CaugiGraph {
8383
}
8484
}
8585

86+
impl CaugiGraph {
87+
/// Node-induced subgraph on `keep` (new ids are 0..k-1, in the SAME order as `keep`).
88+
/// Returns: (new_core, new_to_old, old_to_new).
89+
pub fn induced_subgraph(
90+
&self,
91+
keep: &[u32],
92+
) -> Result<(CaugiGraph, Vec<u32>, Vec<u32>), String> {
93+
let n = self.n() as usize;
94+
95+
// validate + deduplicate while preserving order
96+
let mut seen = vec![false; n];
97+
let mut new_to_old: Vec<u32> = Vec::with_capacity(keep.len());
98+
for &u in keep {
99+
if (u as usize) >= n {
100+
return Err("node id out of range".into());
101+
}
102+
if std::mem::replace(&mut seen[u as usize], true) {
103+
return Err("duplicate node id in `keep`".into());
104+
}
105+
new_to_old.push(u);
106+
}
107+
108+
// old -> new map
109+
let mut old_to_new = vec![u32::MAX; n];
110+
for (new, &old) in new_to_old.iter().enumerate() {
111+
old_to_new[old as usize] = new as u32;
112+
}
113+
114+
// row counts
115+
let k = new_to_old.len();
116+
let mut row_index: Vec<u32> = Vec::with_capacity(k + 1);
117+
row_index.push(0);
118+
for &old_u in &new_to_old {
119+
let mut cnt = 0u32;
120+
for kk in self.row_range(old_u) {
121+
let ov = self.col_index[kk] as usize;
122+
if ov < n && old_to_new[ov] != u32::MAX {
123+
cnt += 1;
124+
}
125+
}
126+
row_index.push(row_index.last().unwrap() + cnt);
127+
}
128+
129+
// allocate + scatter
130+
let nnz = *row_index.last().unwrap() as usize;
131+
let mut col_index = vec![0u32; nnz];
132+
let mut etype = vec![0u8; nnz];
133+
let mut side = vec![0u8; nnz];
134+
let mut cur = row_index[..k].to_vec();
135+
136+
for (new_u, &old_u) in new_to_old.iter().enumerate() {
137+
for kk in self.row_range(old_u) {
138+
let ov = self.col_index[kk] as usize;
139+
let nv = old_to_new[ov];
140+
if nv == u32::MAX {
141+
continue;
142+
}
143+
let p = cur[new_u] as usize;
144+
col_index[p] = nv;
145+
etype[p] = self.etype[kk];
146+
side[p] = self.side[kk];
147+
cur[new_u] += 1;
148+
}
149+
}
150+
151+
let out = CaugiGraph::from_csr(
152+
row_index,
153+
col_index,
154+
etype,
155+
side,
156+
self.simple,
157+
self.registry.clone(),
158+
)?;
159+
Ok((out, new_to_old, old_to_new))
160+
}
161+
}
162+
86163
#[cfg(test)]
87164
mod tests {
88165
use super::*;

src/rust/src/graph/view.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,25 @@ impl GraphView {
131131
_ => Err("all_backdoor_sets is only defined for DAGs".into()),
132132
}
133133
}
134+
135+
pub fn induced_subgraph(
136+
&self,
137+
keep: &[u32],
138+
) -> Result<GraphView, String> {
139+
let (core2, _new_to_old, _old_to_new) = self.core().induced_subgraph(keep)?;
140+
let gv = match self {
141+
GraphView::Dag(_) => {
142+
let d = super::dag::Dag::new(std::sync::Arc::new(core2))?;
143+
GraphView::Dag(std::sync::Arc::new(d))
144+
}
145+
GraphView::Pdag(_) => {
146+
let p = super::pdag::Pdag::new(std::sync::Arc::new(core2))?;
147+
GraphView::Pdag(std::sync::Arc::new(p))
148+
}
149+
GraphView::Raw(_) => GraphView::Raw(std::sync::Arc::new(core2)),
150+
};
151+
Ok(gv)
152+
}
134153
}
135154

136155
#[cfg(test)]

src/rust/src/lib.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,28 @@ fn all_backdoor_sets_ptr(
611611
extendr_api::prelude::List::from_values(robjs).into_robj()
612612
}
613613

614+
// ── Subgraph ────────────────────────────────────────────────────────────────
615+
616+
#[extendr]
617+
fn induced_subgraph_ptr(g: ExternalPtr<GraphView>, keep: Integers) -> Robj {
618+
let mut ks: Vec<u32> = Vec::with_capacity(keep.len());
619+
for ri in keep.iter() {
620+
let u = rint_to_u32(ri, "keep");
621+
if u >= g.as_ref().n() {
622+
throw_r_error(format!("node id {} out of bounds", u));
623+
}
624+
ks.push(u);
625+
}
626+
627+
let sub= g
628+
.as_ref()
629+
.induced_subgraph(&ks)
630+
.unwrap_or_else(|e| throw_r_error(e));
631+
632+
let sub_ptr = ExternalPtr::new(sub);
633+
sub_ptr.into_robj()
634+
}
635+
614636
extendr_module! {
615637
mod caugi;
616638
// registry
@@ -638,6 +660,7 @@ extendr_module! {
638660
fn descendants_of_ptr;
639661
fn markov_blanket_of_ptr;
640662
fn exogenous_nodes_of_ptr;
663+
fn induced_subgraph_ptr;
641664

642665
// graph properties
643666
fn is_simple_ptr;

0 commit comments

Comments
 (0)