Skip to content

Commit c3bebbb

Browse files
committed
perf: rewrite not_m_separated_for_all_subsets() in Rustq
1 parent 48aa538 commit c3bebbb

4 files changed

Lines changed: 114 additions & 38 deletions

File tree

R/extendr-wrappers.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ rs_minimal_d_separator <- function(session, xs, ys, include, restrict) .Call(wra
134134

135135
rs_m_separated <- function(session, xs, ys, z) .Call(wrap__rs_m_separated, session, xs, ys, z)
136136

137+
rs_not_m_separated_for_all_subsets <- function(session, node_a, node_b, other_nodes, cond_vars) .Call(wrap__rs_not_m_separated_for_all_subsets, session, node_a, node_b, other_nodes, cond_vars)
138+
137139
rs_adjustment_set_parents <- function(session, xs, ys) .Call(wrap__rs_adjustment_set_parents, session, xs, ys)
138140

139141
rs_adjustment_set_backdoor <- function(session, xs, ys) .Call(wrap__rs_adjustment_set_backdoor, session, xs, ys)

R/operations.R

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -643,45 +643,25 @@ condition_marginalize <- function(cg, cond_vars = NULL, marg_vars = NULL) {
643643
other_nodes,
644644
cond_vars
645645
) {
646-
n_other <- length(other_nodes)
647-
648-
# Generate all subsets of other_nodes
649-
subsets <- if (n_other == 0L) {
650-
list(NULL)
646+
node_idx <- .nodes_to_indices(cg, c(node_a, node_b))
647+
other_idx <- if (length(other_nodes) == 0L) {
648+
integer(0)
651649
} else {
652-
# Build subsets from largest to smallest (often finds separation faster)
653-
c(
654-
list(other_nodes),
655-
if (n_other > 1L) {
656-
unlist(
657-
lapply(
658-
seq_len(n_other - 1L),
659-
function(k) combn(other_nodes, n_other - k, simplify = FALSE)
660-
),
661-
recursive = FALSE
662-
)
663-
},
664-
list(NULL)
665-
)
650+
.nodes_to_indices(cg, other_nodes)
666651
}
667-
668-
# Check each conditioning set
669-
670-
for (subset in subsets) {
671-
conditioning_set <- c(cond_vars, subset)
672-
if (length(conditioning_set) == 0L) {
673-
conditioning_set <- NULL
674-
}
675-
676-
if (m_separated(cg, X = node_a, Y = node_b, Z = conditioning_set)) {
677-
# Found a set that m-separates them: no edge needed
678-
return(FALSE)
679-
}
652+
cond_idx <- if (length(cond_vars) == 0L) {
653+
integer(0)
654+
} else {
655+
.nodes_to_indices(cg, cond_vars)
680656
}
681657

682-
# Not m-separated for any conditioning set: edge is required
683-
684-
TRUE
658+
rs_not_m_separated_for_all_subsets(
659+
cg@session,
660+
node_idx[[1]],
661+
node_idx[[2]],
662+
other_idx,
663+
cond_idx
664+
)
685665
}
686666

687667
#' @title

src/rust/src/lib.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,90 @@ fn rs_m_separated(
18981898
.unwrap_or_else(|e| throw_r_error(e))
18991899
}
19001900

1901+
#[extendr]
1902+
fn rs_not_m_separated_for_all_subsets(
1903+
mut session: ExternalPtr<GraphSession>,
1904+
node_a: i32,
1905+
node_b: i32,
1906+
other_nodes: Integers,
1907+
cond_vars: Integers,
1908+
) -> bool {
1909+
let a = rint_to_u32(Rint::from(node_a), "node_a");
1910+
let b = rint_to_u32(Rint::from(node_b), "node_b");
1911+
1912+
if a >= session.as_ref().n() {
1913+
throw_r_error(format!("Index {} is out of bounds", a));
1914+
}
1915+
if b >= session.as_ref().n() {
1916+
throw_r_error(format!("Index {} is out of bounds", b));
1917+
}
1918+
1919+
let mut z_base: Vec<u32> = cond_vars
1920+
.iter()
1921+
.map(|ri| rint_to_u32(ri, "cond_vars"))
1922+
.collect();
1923+
let other_u: Vec<u32> = other_nodes
1924+
.iter()
1925+
.map(|ri| rint_to_u32(ri, "other_nodes"))
1926+
.collect();
1927+
1928+
for &i in &z_base {
1929+
if i >= session.as_ref().n() {
1930+
throw_r_error(format!("Index {} is out of bounds", i));
1931+
}
1932+
}
1933+
for &i in &other_u {
1934+
if i >= session.as_ref().n() {
1935+
throw_r_error(format!("Index {} is out of bounds", i));
1936+
}
1937+
}
1938+
1939+
if session
1940+
.as_mut()
1941+
.m_separated(&[a], &[b], &z_base)
1942+
.unwrap_or_else(|e| throw_r_error(e))
1943+
{
1944+
return false;
1945+
}
1946+
1947+
let m = other_u.len();
1948+
if m == 0 {
1949+
return true;
1950+
}
1951+
1952+
for k in (1..=m).rev() {
1953+
let mut idx: Vec<usize> = (0..k).collect();
1954+
loop {
1955+
z_base.truncate(cond_vars.len());
1956+
for &ii in &idx {
1957+
z_base.push(other_u[ii]);
1958+
}
1959+
1960+
if session
1961+
.as_mut()
1962+
.m_separated(&[a], &[b], &z_base)
1963+
.unwrap_or_else(|e| throw_r_error(e))
1964+
{
1965+
return false;
1966+
}
1967+
1968+
let mut i = k;
1969+
while i > 0 && idx[i - 1] == i - 1 + (m - k) {
1970+
i -= 1;
1971+
}
1972+
if i == 0 {
1973+
break;
1974+
}
1975+
idx[i - 1] += 1;
1976+
for j in i..k {
1977+
idx[j] = idx[j - 1] + 1;
1978+
}
1979+
}
1980+
}
1981+
1982+
true
1983+
}
1984+
19011985
#[extendr]
19021986
fn rs_adjustment_set_parents(
19031987
mut session: ExternalPtr<GraphSession>,
@@ -2086,6 +2170,7 @@ extendr_module! {
20862170
fn rs_d_separated;
20872171
fn rs_minimal_d_separator;
20882172
fn rs_m_separated;
2173+
fn rs_not_m_separated_for_all_subsets;
20892174
fn rs_adjustment_set_parents;
20902175
fn rs_adjustment_set_backdoor;
20912176
fn rs_adjustment_set_optimal;

tests/testthat/test-operations.R

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -970,15 +970,24 @@ test_that("condition_marginalize and helper branches are covered", {
970970
"must be the same length"
971971
)
972972

973-
expect_type(
973+
expect_true(
974974
caugi:::.not_m_separated_for_all_subsets(
975975
cg = cg2,
976976
node_a = "A",
977977
node_b = "C",
978978
other_nodes = character(0),
979979
cond_vars = character(0)
980-
),
981-
"logical"
980+
)
981+
)
982+
983+
expect_false(
984+
caugi:::.not_m_separated_for_all_subsets(
985+
cg = cg2,
986+
node_a = "A",
987+
node_b = "C",
988+
other_nodes = "B",
989+
cond_vars = character(0)
990+
)
982991
)
983992

984993
edge_rev <- caugi:::.edge_type_from_anteriors(

0 commit comments

Comments
 (0)