Skip to content

Commit 9071d6b

Browse files
committed
fixup! perf: rewrite not_m_separated_for_all_subsets() in Rust
1 parent 4da87b7 commit 9071d6b

1 file changed

Lines changed: 135 additions & 0 deletions

File tree

tests/testthat/test-operations.R

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,3 +1000,138 @@ test_that("condition_marginalize and helper branches are covered", {
10001000
expect_identical(edge_rev$edge, "-->")
10011001
expect_identical(edge_rev$to, "A")
10021002
})
1003+
1004+
test_that("not_m_separated helper matches R reference across subset scenarios", {
1005+
helper_reference <- function(cg, node_a, node_b, other_nodes, cond_vars) {
1006+
n_other <- length(other_nodes)
1007+
subsets <- if (n_other == 0L) {
1008+
list(NULL)
1009+
} else {
1010+
c(
1011+
list(other_nodes),
1012+
if (n_other > 1L) {
1013+
unlist(
1014+
lapply(
1015+
seq_len(n_other - 1L),
1016+
function(k) combn(other_nodes, n_other - k, simplify = FALSE)
1017+
),
1018+
recursive = FALSE
1019+
)
1020+
},
1021+
list(NULL)
1022+
)
1023+
}
1024+
1025+
for (subset in subsets) {
1026+
z <- c(cond_vars, subset)
1027+
if (length(z) == 0L) {
1028+
z <- NULL
1029+
}
1030+
if (m_separated(cg, X = node_a, Y = node_b, Z = z)) {
1031+
return(FALSE)
1032+
}
1033+
}
1034+
TRUE
1035+
}
1036+
1037+
# 1) Separation found only for full subset.
1038+
# A <- B -> C and A <- D -> C: both B and D are required to block all paths.
1039+
g_full <- caugi(B %-->% A + C, D %-->% A + C, class = "DAG")
1040+
expect_true(m_separated(g_full, X = "A", Y = "C", Z = c("B", "D")))
1041+
expect_false(m_separated(g_full, X = "A", Y = "C", Z = "B"))
1042+
expect_false(m_separated(g_full, X = "A", Y = "C", Z = "D"))
1043+
1044+
rust_full <- caugi:::.not_m_separated_for_all_subsets(
1045+
g_full,
1046+
node_a = "A",
1047+
node_b = "C",
1048+
other_nodes = c("B", "D"),
1049+
cond_vars = character(0)
1050+
)
1051+
ref_full <- helper_reference(
1052+
g_full,
1053+
node_a = "A",
1054+
node_b = "C",
1055+
other_nodes = c("B", "D"),
1056+
cond_vars = character(0)
1057+
)
1058+
expect_identical(rust_full, ref_full)
1059+
1060+
# 2) Separation found only for an intermediate subset size.
1061+
# A <- B -> C is blocked by conditioning on B, but conditioning on D opens A->E<-C via E->D.
1062+
g_mid <- caugi(B %-->% A + C, A %-->% E, C %-->% E, E %-->% D, class = "DAG")
1063+
expect_true(m_separated(g_mid, X = "A", Y = "C", Z = "B"))
1064+
expect_false(m_separated(g_mid, X = "A", Y = "C", Z = c("B", "D")))
1065+
1066+
rust_mid <- caugi:::.not_m_separated_for_all_subsets(
1067+
g_mid,
1068+
node_a = "A",
1069+
node_b = "C",
1070+
other_nodes = c("B", "D"),
1071+
cond_vars = character(0)
1072+
)
1073+
ref_mid <- helper_reference(
1074+
g_mid,
1075+
node_a = "A",
1076+
node_b = "C",
1077+
other_nodes = c("B", "D"),
1078+
cond_vars = character(0)
1079+
)
1080+
expect_identical(rust_mid, ref_mid)
1081+
1082+
# 3) Separation found for cond_vars alone.
1083+
g_cond <- caugi(S %-->% A + C, T %-->% A, class = "DAG")
1084+
expect_true(m_separated(g_cond, X = "A", Y = "C", Z = "S"))
1085+
1086+
rust_cond <- caugi:::.not_m_separated_for_all_subsets(
1087+
g_cond,
1088+
node_a = "A",
1089+
node_b = "C",
1090+
other_nodes = "T",
1091+
cond_vars = "S"
1092+
)
1093+
ref_cond <- helper_reference(
1094+
g_cond,
1095+
node_a = "A",
1096+
node_b = "C",
1097+
other_nodes = "T",
1098+
cond_vars = "S"
1099+
)
1100+
expect_identical(rust_cond, ref_cond)
1101+
1102+
# Additional parity checks on small random DAGs (all subsets feasible).
1103+
set.seed(123)
1104+
for (i in seq_len(20L)) {
1105+
g <- generate_graph(n = 8, m = 8, class = "DAG")
1106+
nn <- nodes(g)$name
1107+
pair <- sample(nn, size = 2, replace = FALSE)
1108+
rest <- setdiff(nn, pair)
1109+
other <- sample(rest, size = sample.int(3L, 1), replace = FALSE)
1110+
cond_pool <- setdiff(rest, other)
1111+
cond <- if (length(cond_pool) > 0L) {
1112+
sample(
1113+
cond_pool,
1114+
size = sample.int(min(2L, length(cond_pool)), 1) - 1L,
1115+
replace = FALSE
1116+
)
1117+
} else {
1118+
character(0)
1119+
}
1120+
1121+
rust <- caugi:::.not_m_separated_for_all_subsets(
1122+
g,
1123+
node_a = pair[[1]],
1124+
node_b = pair[[2]],
1125+
other_nodes = other,
1126+
cond_vars = cond
1127+
)
1128+
ref <- helper_reference(
1129+
g,
1130+
node_a = pair[[1]],
1131+
node_b = pair[[2]],
1132+
other_nodes = other,
1133+
cond_vars = cond
1134+
)
1135+
expect_identical(rust, ref)
1136+
}
1137+
})

0 commit comments

Comments
 (0)