Skip to content

Commit 482b90f

Browse files
Merge pull request #250 from jolars/pdag-meek-closure
fix: correctly implement meek closure for PDAGs
2 parents ab3340c + 0e72dc9 commit 482b90f

8 files changed

Lines changed: 338 additions & 326 deletions

File tree

src/rust/src/graph/alg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
pub mod acyclic;
55
pub mod bitset;
66
pub mod csr;
7+
pub mod meek;
78
pub mod moral;
89
pub mod reachability;
910
pub mod subsets;

src/rust/src/graph/alg/meek.rs

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
// SPDX-License-Identifier: MIT
2+
//! Shared Meek closure helpers for CPDAG orientation.
3+
4+
use std::collections::{HashSet, VecDeque};
5+
6+
#[inline]
7+
pub(crate) fn adjacent(
8+
a: usize,
9+
b: usize,
10+
und: &[HashSet<u32>],
11+
pa: &[HashSet<u32>],
12+
ch: &[HashSet<u32>],
13+
) -> bool {
14+
und[a].contains(&(b as u32))
15+
|| und[b].contains(&(a as u32))
16+
|| pa[a].contains(&(b as u32))
17+
|| ch[a].contains(&(b as u32))
18+
|| pa[b].contains(&(a as u32))
19+
|| ch[b].contains(&(a as u32))
20+
}
21+
22+
#[inline]
23+
pub(crate) fn orient(
24+
a: u32,
25+
b: u32,
26+
und: &mut [HashSet<u32>],
27+
pa: &mut [HashSet<u32>],
28+
ch: &mut [HashSet<u32>],
29+
) {
30+
let ai = a as usize;
31+
let bi = b as usize;
32+
und[ai].remove(&b);
33+
und[bi].remove(&a);
34+
ch[ai].insert(b);
35+
pa[bi].insert(a);
36+
}
37+
38+
#[inline]
39+
fn try_orient(
40+
a: u32,
41+
b: u32,
42+
und: &mut [HashSet<u32>],
43+
pa: &mut [HashSet<u32>],
44+
ch: &mut [HashSet<u32>],
45+
) -> bool {
46+
let ai = a as usize;
47+
if !und[ai].contains(&b) {
48+
return false;
49+
}
50+
if has_dir_path(ch, b, a) {
51+
return false;
52+
}
53+
orient(a, b, und, pa, ch);
54+
true
55+
}
56+
57+
fn has_dir_path(ch: &[HashSet<u32>], src: u32, tgt: u32) -> bool {
58+
if src == tgt {
59+
return true;
60+
}
61+
let n = ch.len();
62+
let mut seen = vec![false; n];
63+
let mut q = VecDeque::new();
64+
q.push_back(src);
65+
while let Some(u) = q.pop_front() {
66+
if u == tgt {
67+
return true;
68+
}
69+
if std::mem::replace(&mut seen[u as usize], true) {
70+
continue;
71+
}
72+
for &v in &ch[u as usize] {
73+
if !seen[v as usize] {
74+
q.push_back(v);
75+
}
76+
}
77+
}
78+
false
79+
}
80+
81+
#[inline]
82+
fn creates_new_unshielded_collider(
83+
u: usize,
84+
v: u32,
85+
und: &[HashSet<u32>],
86+
pa: &[HashSet<u32>],
87+
ch: &[HashSet<u32>],
88+
) -> bool {
89+
for &p in &pa[v as usize] {
90+
if p as usize != u && !adjacent(u, p as usize, und, pa, ch) {
91+
return true;
92+
}
93+
}
94+
false
95+
}
96+
97+
/// Apply iterative Meek closure (R1-R4) to a partially directed graph state.
98+
///
99+
/// `guard_new_colliders` enables an R1 safeguard that skips orientations
100+
/// creating new unshielded colliders (pgmpy-aligned behavior).
101+
pub(crate) fn apply_meek_closure(
102+
pa: &mut [HashSet<u32>],
103+
ch: &mut [HashSet<u32>],
104+
und: &mut [HashSet<u32>],
105+
guard_new_colliders: bool,
106+
) {
107+
let n = pa.len();
108+
109+
loop {
110+
let mut changed = false;
111+
112+
// R1: a->b, b--c, a !~ c => b->c
113+
for b in 0..n {
114+
if pa[b].is_empty() || und[b].is_empty() {
115+
continue;
116+
}
117+
let pb: Vec<u32> = pa[b].iter().copied().collect();
118+
let ubs: Vec<u32> = und[b].clone().into_iter().collect();
119+
'c_loop: for c in ubs {
120+
let ci = c as usize;
121+
for &a in &pb {
122+
if !adjacent(a as usize, ci, und, pa, ch)
123+
&& (!guard_new_colliders
124+
|| !creates_new_unshielded_collider(b, c, und, pa, ch))
125+
{
126+
if try_orient(b as u32, c, und, pa, ch) {
127+
changed = true;
128+
continue 'c_loop;
129+
}
130+
}
131+
}
132+
}
133+
}
134+
135+
// R2: a--b and ∃ w: a->w, w->b => a->b
136+
for a in 0..n {
137+
let uab: Vec<u32> = und[a].clone().into_iter().collect();
138+
for b_u in uab {
139+
let b = b_u as usize;
140+
if ch[a].iter().any(|w| pa[b].contains(w)) {
141+
if try_orient(a as u32, b_u, und, pa, ch) {
142+
changed = true;
143+
continue;
144+
}
145+
}
146+
if ch[b].iter().any(|w| pa[a].contains(w)) {
147+
if try_orient(b_u, a as u32, und, pa, ch) {
148+
changed = true;
149+
}
150+
}
151+
}
152+
}
153+
154+
// R3: a--b and ∃ c,d: c->b, d->b, c !~ d, a--c, a--d => a->b
155+
for a in 0..n {
156+
let uab: Vec<u32> = und[a].clone().into_iter().collect();
157+
for b_u in uab {
158+
let b = b_u as usize;
159+
let pb: Vec<u32> = pa[b].iter().copied().collect();
160+
'pairs: for i in 0..pb.len() {
161+
for j in (i + 1)..pb.len() {
162+
let c = pb[i] as usize;
163+
let d = pb[j] as usize;
164+
if !adjacent(c, d, und, pa, ch)
165+
&& und[a].contains(&pb[i])
166+
&& und[a].contains(&pb[j])
167+
{
168+
if try_orient(a as u32, b_u, und, pa, ch) {
169+
changed = true;
170+
break 'pairs;
171+
}
172+
}
173+
}
174+
}
175+
}
176+
}
177+
178+
// R4: a--b and (a ⇒ b or b ⇒ a) => orient along reachability
179+
for a in 0..n {
180+
let uab: Vec<u32> = und[a].clone().into_iter().collect();
181+
for b_u in uab {
182+
if has_dir_path(ch, a as u32, b_u) {
183+
if try_orient(a as u32, b_u, und, pa, ch) {
184+
changed = true;
185+
}
186+
} else if has_dir_path(ch, b_u, a as u32) {
187+
if try_orient(b_u, a as u32, und, pa, ch) {
188+
changed = true;
189+
}
190+
}
191+
}
192+
}
193+
194+
if !changed {
195+
break;
196+
}
197+
}
198+
}

src/rust/src/graph/dag/transforms.rs

Lines changed: 6 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
use super::Dag;
55
use crate::edges::EdgeClass;
66
use crate::graph::admg::Admg;
7-
use crate::graph::alg::csr;
7+
use crate::graph::alg::{csr, meek};
88
use crate::graph::pdag::Pdag;
99
use crate::graph::ug::Ug;
1010
use crate::graph::CaugiGraph;
11-
use std::collections::{BTreeSet, HashSet, VecDeque};
11+
use std::collections::{BTreeSet, HashSet};
1212
use std::sync::Arc;
1313

1414
impl Dag {
@@ -335,62 +335,6 @@ impl Dag {
335335
let mut ch: Vec<HashSet<u32>> = vec![HashSet::new(); n];
336336
let mut und: Vec<HashSet<u32>> = vec![HashSet::new(); n];
337337

338-
#[inline]
339-
fn adjacent(
340-
a: usize,
341-
b: usize,
342-
und: &[HashSet<u32>],
343-
pa: &[HashSet<u32>],
344-
ch: &[HashSet<u32>],
345-
) -> bool {
346-
und[a].contains(&(b as u32))
347-
|| und[b].contains(&(a as u32))
348-
|| pa[a].contains(&(b as u32))
349-
|| ch[a].contains(&(b as u32))
350-
|| pa[b].contains(&(a as u32))
351-
|| ch[b].contains(&(a as u32))
352-
}
353-
354-
#[inline]
355-
fn orient(
356-
a: u32,
357-
b: u32,
358-
und: &mut [HashSet<u32>],
359-
pa: &mut [HashSet<u32>],
360-
ch: &mut [HashSet<u32>],
361-
) {
362-
let ai = a as usize;
363-
let bi = b as usize;
364-
und[ai].remove(&b);
365-
und[bi].remove(&a);
366-
ch[ai].insert(b);
367-
pa[bi].insert(a);
368-
}
369-
370-
fn has_dir_path(ch: &[HashSet<u32>], src: u32, tgt: u32) -> bool {
371-
if src == tgt {
372-
return true;
373-
}
374-
let n = ch.len();
375-
let mut seen = vec![false; n];
376-
let mut q = VecDeque::new();
377-
q.push_back(src);
378-
while let Some(u) = q.pop_front() {
379-
if u == tgt {
380-
return true;
381-
}
382-
if std::mem::replace(&mut seen[u as usize], true) {
383-
continue;
384-
}
385-
for &v in &ch[u as usize] {
386-
if !seen[v as usize] {
387-
q.push_back(v);
388-
}
389-
}
390-
}
391-
false
392-
}
393-
394338
// Skeleton from DAG (undirected)
395339
for u in 0..self.n() {
396340
for &v in self.children_of(u) {
@@ -406,95 +350,15 @@ impl Dag {
406350
for j in (i + 1)..parents.len() {
407351
let a = parents[i] as usize;
408352
let c = parents[j] as usize;
409-
if !adjacent(a, c, &und, &pa, &ch) {
410-
orient(parents[i], b, &mut und, &mut pa, &mut ch);
411-
orient(parents[j], b, &mut und, &mut pa, &mut ch);
353+
if !meek::adjacent(a, c, &und, &pa, &ch) {
354+
meek::orient(parents[i], b, &mut und, &mut pa, &mut ch);
355+
meek::orient(parents[j], b, &mut und, &mut pa, &mut ch);
412356
}
413357
}
414358
}
415359
}
416360

417-
// Meek closure (R1–R4)
418-
loop {
419-
let mut changed = false;
420-
421-
// R1: a->b, b--c, a !~ c ⇒ b->c
422-
for b in 0..n {
423-
if pa[b].is_empty() || und[b].is_empty() {
424-
continue;
425-
}
426-
let pb: Vec<u32> = pa[b].iter().copied().collect();
427-
let ubs: Vec<u32> = und[b].clone().into_iter().collect();
428-
'c_loop: for c in ubs {
429-
let ci = c as usize;
430-
for &a in &pb {
431-
if !adjacent(a as usize, ci, &und, &pa, &ch) {
432-
orient(b as u32, c, &mut und, &mut pa, &mut ch);
433-
changed = true;
434-
continue 'c_loop;
435-
}
436-
}
437-
}
438-
}
439-
440-
// R2: a--b and ∃ w: a->w, w->b ⇒ a->b
441-
for a in 0..n {
442-
let uab: Vec<u32> = und[a].clone().into_iter().collect();
443-
for b_u in uab {
444-
let b = b_u as usize;
445-
if ch[a].iter().any(|w| pa[b].contains(w)) {
446-
orient(a as u32, b_u, &mut und, &mut pa, &mut ch);
447-
changed = true;
448-
continue;
449-
}
450-
if ch[b].iter().any(|w| pa[a].contains(w)) {
451-
orient(b_u, a as u32, &mut und, &mut pa, &mut ch);
452-
changed = true;
453-
}
454-
}
455-
}
456-
457-
// R3: a--b and ∃ c,d: c->b, d->b, c !~ d, a--c, a--d ⇒ a->b
458-
for a in 0..n {
459-
let uab: Vec<u32> = und[a].clone().into_iter().collect();
460-
for b_u in uab {
461-
let b = b_u as usize;
462-
let pb: Vec<u32> = pa[b].iter().copied().collect();
463-
'pairs: for i in 0..pb.len() {
464-
for j in (i + 1)..pb.len() {
465-
let c = pb[i] as usize;
466-
let d = pb[j] as usize;
467-
if !adjacent(c, d, &und, &pa, &ch)
468-
&& und[a].contains(&pb[i])
469-
&& und[a].contains(&pb[j])
470-
{
471-
orient(a as u32, b_u, &mut und, &mut pa, &mut ch);
472-
changed = true;
473-
break 'pairs;
474-
}
475-
}
476-
}
477-
}
478-
}
479-
480-
// R4: a--b and (a ⇒ b or b ⇒ a) ⇒ orient along reachability
481-
for a in 0..n {
482-
let uab: Vec<u32> = und[a].clone().into_iter().collect();
483-
for b_u in uab {
484-
if has_dir_path(&ch, a as u32, b_u) {
485-
orient(a as u32, b_u, &mut und, &mut pa, &mut ch);
486-
changed = true;
487-
} else if has_dir_path(&ch, b_u, a as u32) {
488-
orient(b_u, a as u32, &mut und, &mut pa, &mut ch);
489-
changed = true;
490-
}
491-
}
492-
}
493-
494-
if !changed {
495-
break;
496-
}
497-
}
361+
meek::apply_meek_closure(&mut pa, &mut ch, &mut und, false);
498362

499363
// Build CSR core (parents | undirected | children)
500364
let specs = &self.core_ref().registry.specs;

0 commit comments

Comments
 (0)