Skip to content

Commit 349f3f1

Browse files
committed
feat: add MPDAG class
Add a new `MPDAG` class, which is a ergonomic way to represent a `PDAG` that fulfills Meek's rules. Right now, this is safe and always validates the graph. The intention is that we later on will be able to specialize behavior for this class, possibly leading to performance improvements since after validating we have stronger guarantees about the graph structure. We already have tests for `MPDAG`, and the functionality is already gated behind the `is_mpdag()` stuff we have already implemented, so hopefully this should be safe.
1 parent 280d9ee commit 349f3f1

11 files changed

Lines changed: 149 additions & 15 deletions

File tree

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
## New Features
44

55
- Add `list_caugi_edges()` function to list all available edge types.
6+
- Add first-class `"MPDAG"` graph class support across constructor, class
7+
mutation, and class resolution. `class = "AUTO"` now resolves Meek-closed
8+
PDAGs to `"MPDAG"`.
69

710
## Improvements
811

R/caugi.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838
#' @param simple Logical; if `TRUE` (default), the graph is a simple graph, and
3939
#' the function will throw an error if the input contains parallel edges or
4040
#' self-loops.
41-
#' @param class Character; one of `"AUTO"`, `"DAG"`, `"UG"`, `"PDAG"`, `"ADMG"`,
42-
#' `"AG"`, or `"UNKNOWN"`. `"AUTO"` will automatically pick the appropriate
43-
#' class based on the first match in the order of `"DAG"`, `"UG"`, `"PDAG"`,
44-
#' `"ADMG"`, and `"AG"`.
41+
#' @param class Character; one of `"AUTO"`, `"DAG"`, `"UG"`, `"PDAG"`, `"MPDAG"`,
42+
#' `"ADMG"`, `"AG"`, or `"UNKNOWN"`. `"AUTO"` will automatically pick the
43+
#' appropriate class based on the first match in the order of `"DAG"`, `"UG"`,
44+
#' `"MPDAG"`, `"PDAG"`, `"ADMG"`, and `"AG"`.
4545
#' It will default to `"UNKNOWN"` if no match is found.
4646
#' @param .session For internal use. Build a graph by supplying a
4747
#' pre-constructed session pointer from Rust.
@@ -244,7 +244,7 @@ caugi <- S7::new_class(
244244
edges_df = NULL,
245245
simple = TRUE,
246246
build = NULL, # deprecated
247-
class = c("AUTO", "DAG", "UG", "PDAG", "ADMG", "AG", "UNKNOWN"),
247+
class = c("AUTO", "DAG", "UG", "PDAG", "MPDAG", "ADMG", "AG", "UNKNOWN"),
248248
state = NULL, # deprecated
249249
.session = NULL
250250
) {

R/caugi_to.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
as_igraph <- function(x, ...) {
2020
is_caugi(x, throw_error = TRUE)
2121

22-
if (!(x@graph_class %in% c("DAG", "PDAG", "ADMG", "UG", "AG", "UNKNOWN"))) {
22+
if (!(x@graph_class %in% c("DAG", "PDAG", "MPDAG", "ADMG", "UG", "AG", "UNKNOWN"))) {
2323
stop(
2424
"caugi graphs of class '",
2525
x@graph_class,

R/operations.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ mutate_caugi <- function(cg, class) {
234234
class,
235235
"DAG" = is_dag(cg),
236236
"PDAG" = is_pdag(cg),
237+
"MPDAG" = is_mpdag(cg),
237238
"UG" = is_ug(cg),
238239
"ADMG" = is_admg(cg),
239240
"AG" = is_ag(cg),
@@ -801,8 +802,8 @@ are_connected <- function(cg, u, v) {
801802
#'
802803
#' @export
803804
dag_from_pdag <- function(PDAG) {
804-
if (PDAG@graph_class != "PDAG") {
805-
stop("Input must be a caugi PDAG graph")
805+
if (!(PDAG@graph_class %in% c("PDAG", "MPDAG"))) {
806+
stop("Input must be a caugi PDAG/MPDAG graph")
806807
}
807808

808809
output_graph <- PDAG

R/queries.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ is_dag <- function(cg, force_check = FALSE) {
315315
is_pdag <- function(cg, force_check = FALSE) {
316316
is_caugi(cg, throw_error = TRUE)
317317

318-
if (identical(cg@graph_class, "PDAG") && !force_check) {
318+
if (cg@graph_class %in% c("PDAG", "MPDAG") && !force_check) {
319319
is_it <- TRUE
320320
} else {
321321
# if we can't be sure from the class, we check

src/rust/src/graph/session.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ pub enum GraphClass {
2828
Dag,
2929
/// Partially Directed Acyclic Graph (`-->`, `---`)
3030
Pdag,
31+
/// Maximally Oriented Partially Directed Acyclic Graph (Meek-closed PDAG)
32+
Mpdag,
3133
/// Undirected Graph (only `---`)
3234
Ug,
3335
/// Acyclic Directed Mixed Graph (`-->`, `<->`)
@@ -47,6 +49,7 @@ impl std::str::FromStr for GraphClass {
4749
match s.to_lowercase().as_str() {
4850
"dag" => Ok(GraphClass::Dag),
4951
"pdag" | "cpdag" => Ok(GraphClass::Pdag),
52+
"mpdag" => Ok(GraphClass::Mpdag),
5053
"ug" => Ok(GraphClass::Ug),
5154
"admg" => Ok(GraphClass::Admg),
5255
"ag" | "mag" | "pag" => Ok(GraphClass::Ag),
@@ -62,6 +65,7 @@ impl GraphClass {
6265
match self {
6366
GraphClass::Dag => "DAG",
6467
GraphClass::Pdag => "PDAG",
68+
GraphClass::Mpdag => "MPDAG",
6569
GraphClass::Ug => "UG",
6670
GraphClass::Admg => "ADMG",
6771
GraphClass::Ag => "AG",
@@ -480,6 +484,14 @@ impl GraphSession {
480484
let pdag = Pdag::new(core).map_err(|e| self.map_error(e))?;
481485
Ok(GraphView::Pdag(Arc::new(pdag)))
482486
}
487+
GraphClass::Mpdag => {
488+
let pdag = Pdag::new(core).map_err(|e| self.map_error(e))?;
489+
if pdag.is_meek_closed() {
490+
Ok(GraphView::Pdag(Arc::new(pdag)))
491+
} else {
492+
Err("graph is not MPDAG (not closed under Meek rules)".to_string())
493+
}
494+
}
483495
GraphClass::Ug => {
484496
let ug = Ug::new(core).map_err(|e| self.map_error(e))?;
485497
Ok(GraphView::Ug(Arc::new(ug)))
@@ -694,6 +706,14 @@ impl GraphSession {
694706
Pdag::new(Arc::new(core.as_ref().clone())).map_err(|e| self.map_error(e))?;
695707
Ok(GraphClass::Pdag)
696708
}
709+
GraphClass::Mpdag => {
710+
let pdag = Pdag::new(Arc::new(core.as_ref().clone())).map_err(|e| self.map_error(e))?;
711+
if pdag.is_meek_closed() {
712+
Ok(GraphClass::Mpdag)
713+
} else {
714+
Err("graph is not MPDAG (not closed under Meek rules)".to_string())
715+
}
716+
}
697717
GraphClass::Ug => {
698718
Ug::new(Arc::new(core.as_ref().clone())).map_err(|e| self.map_error(e))?;
699719
Ok(GraphClass::Ug)
@@ -712,8 +732,12 @@ impl GraphSession {
712732
Ok(GraphClass::Dag)
713733
} else if Ug::new(Arc::new(core.as_ref().clone())).is_ok() {
714734
Ok(GraphClass::Ug)
715-
} else if Pdag::new(Arc::new(core.as_ref().clone())).is_ok() {
716-
Ok(GraphClass::Pdag)
735+
} else if let Ok(pdag) = Pdag::new(Arc::new(core.as_ref().clone())) {
736+
if pdag.is_meek_closed() {
737+
Ok(GraphClass::Mpdag)
738+
} else {
739+
Ok(GraphClass::Pdag)
740+
}
717741
} else if Admg::new(Arc::new(core.as_ref().clone())).is_ok() {
718742
Ok(GraphClass::Admg)
719743
} else if Ag::new(Arc::new(core.as_ref().clone())).is_ok() {
@@ -1212,6 +1236,7 @@ mod tests {
12121236
fn graph_class_from_str_and_as_str() {
12131237
assert_eq!("dag".parse::<GraphClass>().unwrap(), GraphClass::Dag);
12141238
assert_eq!("CPDAG".parse::<GraphClass>().unwrap(), GraphClass::Pdag);
1239+
assert_eq!("mpdag".parse::<GraphClass>().unwrap(), GraphClass::Mpdag);
12151240
assert_eq!("ug".parse::<GraphClass>().unwrap(), GraphClass::Ug);
12161241
assert_eq!("admg".parse::<GraphClass>().unwrap(), GraphClass::Admg);
12171242
assert_eq!("mag".parse::<GraphClass>().unwrap(), GraphClass::Ag);
@@ -1221,6 +1246,7 @@ mod tests {
12211246

12221247
assert_eq!(GraphClass::Dag.as_str(), "DAG");
12231248
assert_eq!(GraphClass::Pdag.as_str(), "PDAG");
1249+
assert_eq!(GraphClass::Mpdag.as_str(), "MPDAG");
12241250
assert_eq!(GraphClass::Ug.as_str(), "UG");
12251251
assert_eq!(GraphClass::Admg.as_str(), "ADMG");
12261252
assert_eq!(GraphClass::Ag.as_str(), "AG");
@@ -1516,6 +1542,22 @@ mod tests {
15161542
GraphClass::Pdag
15171543
);
15181544

1545+
let mut mpdag =
1546+
GraphSession::from_snapshot(Arc::clone(&snapshot), 3, true, GraphClass::Mpdag);
1547+
let mut mpdag_edges = EdgeBuffer::new();
1548+
mpdag_edges.push(0, 2, d);
1549+
mpdag_edges.push(1, 2, d);
1550+
mpdag.set_edges(mpdag_edges);
1551+
assert!(matches!(&*mpdag.view().unwrap(), GraphView::Pdag(_)));
1552+
assert_eq!(
1553+
mpdag.resolve_class(GraphClass::Mpdag).unwrap(),
1554+
GraphClass::Mpdag
1555+
);
1556+
assert_eq!(
1557+
mpdag.resolve_class(GraphClass::Auto).unwrap(),
1558+
GraphClass::Mpdag
1559+
);
1560+
15191561
let mut ug = GraphSession::from_snapshot(Arc::clone(&snapshot), 2, true, GraphClass::Ug);
15201562
let mut ug_edges = EdgeBuffer::new();
15211563
ug_edges.push(0, 1, u);

src/rust/src/lib.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,11 @@ fn graphview_new(core: ExternalPtr<CaugiGraph>, class: &str) -> ExternalPtr<Grap
553553
let dag = Dag::new(Arc::clone(&core_arc)).unwrap_or_else(|e| throw_r_error(e));
554554
ExternalPtr::new(GraphView::Dag(Arc::new(dag)))
555555
}
556-
"PDAG" | "CPDAG" => {
556+
"PDAG" | "CPDAG" | "MPDAG" => {
557557
let pdag = Pdag::new(Arc::clone(&core_arc)).unwrap_or_else(|e| throw_r_error(e));
558+
if class.trim().eq_ignore_ascii_case("MPDAG") && !pdag.is_meek_closed() {
559+
throw_r_error("graph is not MPDAG (not closed under Meek rules)");
560+
}
558561
ExternalPtr::new(GraphView::Pdag(Arc::new(pdag)))
559562
}
560563
"UG" => {
@@ -595,8 +598,19 @@ fn graph_builder_resolve_class(mut b: ExternalPtr<GraphBuilder>, class: &str) ->
595598
.as_mut()
596599
.finalize_in_place()
597600
.unwrap_or_else(|e| throw_r_error(e));
598-
let view = graphview_new(ExternalPtr::new(core), class);
599-
graph_class_label_from_view(view.as_ref()).to_string()
601+
let graph_class = GraphClass::from_str(class).unwrap_or_else(|e| throw_r_error(e));
602+
let mut session = GraphSession::from_snapshot(
603+
Arc::new(core.registry.clone()),
604+
core.n(),
605+
core.simple,
606+
GraphClass::Unknown,
607+
);
608+
session.set_edges(edge_buffer_from_core(&core));
609+
session
610+
.resolve_class(graph_class)
611+
.unwrap_or_else(|e| throw_r_error(e))
612+
.as_str()
613+
.to_string()
600614
}
601615

602616
// ── Metrics ────────────────────────────────────────────────────────────────

tests/testthat/test-caugi_graph.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,26 @@ test_that("building PDAG with bidirected edges results in error", {
265265
)
266266
})
267267

268+
test_that("building MPDAG validates Meek closure", {
269+
expect_s7_class(
270+
caugi(
271+
A %---% B,
272+
A %-->% C,
273+
B %-->% C,
274+
class = "MPDAG"
275+
),
276+
caugi
277+
)
278+
279+
expect_error(
280+
caugi(
281+
A %-->% B,
282+
B %---% C,
283+
class = "MPDAG"
284+
)
285+
)
286+
})
287+
268288
# ──────────────────────────────────────────────────────────────────────────────
269289
# ──────────────────────────────── AUTO tests ──────────────────────────────────
270290
# ──────────────────────────────────────────────────────────────────────────────
@@ -280,6 +300,8 @@ test_that("AUTO class picks the correct class", {
280300
expect_equal(cg@graph_class, "UNKNOWN")
281301
cg <- caugi(A %-->% B %---% C, class = "AUTO")
282302
expect_equal(cg@graph_class, "PDAG")
303+
cg <- caugi(A %---% B, A %-->% C, B %-->% C, class = "AUTO")
304+
expect_equal(cg@graph_class, "MPDAG")
283305
})
284306

285307
# ──────────────────────────────────────────────────────────────────────────────

tests/testthat/test-caugi_to.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,22 @@ test_that("mixed edges: directed kept, undirected duplicated as bidirected", {
100100
expect_equal(nrow(ed), 5L)
101101
})
102102

103+
test_that("MPDAG converts to igraph like PDAG", {
104+
cg <- caugi(
105+
A %---% B,
106+
A %-->% C,
107+
B %-->% C,
108+
class = "MPDAG"
109+
)
110+
ig <- as_igraph(cg)
111+
expect_true(igraph::is_directed(ig))
112+
ed <- igraph::as_data_frame(ig)
113+
expect_true(any(ed$from == "A" & ed$to == "C"))
114+
expect_true(any(ed$from == "B" & ed$to == "C"))
115+
expect_true(any(ed$from == "A" & ed$to == "B"))
116+
expect_true(any(ed$from == "B" & ed$to == "A"))
117+
})
118+
103119
test_that("conversion from UG --> igraph works", {
104120
cg <- caugi(A %---% B, class = "UG")
105121
ig <- as_igraph(cg)

tests/testthat/test-operations.R

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,27 @@ test_that("mutate_caugi works from DAG to PDAG", {
205205
expect_equal(nodes(cg_pdag), nodes(cg_dag))
206206
})
207207

208+
test_that("mutate_caugi supports MPDAG target when valid", {
209+
cg_pdag <- caugi(
210+
A %---% B,
211+
A %-->% C,
212+
B %-->% C,
213+
class = "PDAG"
214+
)
215+
expect_true(is_mpdag(cg_pdag))
216+
cg_mpdag <- mutate_caugi(cg_pdag, class = "MPDAG")
217+
expect_equal(cg_mpdag@graph_class, "MPDAG")
218+
expect_equal(edges(cg_mpdag), edges(cg_pdag))
219+
220+
cg_not_mpdag <- caugi(
221+
A %-->% B,
222+
B %---% C,
223+
class = "PDAG"
224+
)
225+
expect_false(is_mpdag(cg_not_mpdag))
226+
expect_error(mutate_caugi(cg_not_mpdag, class = "MPDAG"), "Cannot convert caugi")
227+
})
228+
208229
test_that("mutate_caugi works from PDAG to DAG if PDAG is a DAG", {
209230
cg_pdag <- caugi(
210231
A %-->% B,
@@ -228,6 +249,10 @@ test_that("mutate_caugi works on empty caugi", {
228249
cg_empty_ug <- mutate_caugi(cg_empty, class = "UG")
229250
expect_equal(length(cg_empty_ug), 0)
230251
expect_equal(cg_empty_ug@graph_class, "UG")
252+
253+
cg_empty_mpdag <- mutate_caugi(cg_empty, class = "MPDAG")
254+
expect_equal(length(cg_empty_mpdag), 0)
255+
expect_equal(cg_empty_mpdag@graph_class, "MPDAG")
231256
})
232257

233258
test_that("mutate_caugi doesn't change class if old class is equal to new class", {
@@ -637,7 +662,7 @@ test_that("dag_from_pdag errors on non-PDAG input", {
637662
B %-->% C,
638663
class = "DAG"
639664
)
640-
expect_error(dag_from_pdag(cg), "Input must be a caugi PDAG graph")
665+
expect_error(dag_from_pdag(cg), "Input must be a caugi PDAG/MPDAG graph")
641666
})
642667

643668
test_that("dag_from_pdag errors if PDAG cannot be extended to a DAG", {

0 commit comments

Comments
 (0)