Skip to content

Commit f840e27

Browse files
committed
perf: rewrite normalize_latent_structure() in Rust
Implement `normalize_latent_structure()` in Rust instead to improve performance. This change is stacked on top of #268.
1 parent 4ef96ac commit f840e27

8 files changed

Lines changed: 414 additions & 103 deletions

File tree

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- Improved performance of all queries. Speedups are more significant on larger graphs,
1010
but even on small graphs, queries are roughly 5x faster.
1111
- `exogenize()` is now implemented in Rust for DAGs, which reduces overhead on larger graphs.
12+
- `normalize_latent_structure()` is now implemented in Rust for DAGs for faster latent normalization workflows.
1213

1314

1415
# caugi 1.1.0

R/extendr-wrappers.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ rs_latent_project <- function(session, latents) .Call(wrap__rs_latent_project, s
124124

125125
rs_exogenize <- function(session, nodes) .Call(wrap__rs_exogenize, session, nodes)
126126

127+
rs_normalize_latent_structure <- function(session, latents) .Call(wrap__rs_normalize_latent_structure, session, latents)
128+
127129
rs_induced_subgraph <- function(session, keep) .Call(wrap__rs_induced_subgraph, session, keep)
128130

129131
subgraph <- function(cg, nodes, index) .Call(wrap__subgraph, cg, nodes, index)

R/operations.R

Lines changed: 6 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -400,90 +400,12 @@ normalize_latent_structure <- function(cg, latents) {
400400
)
401401
}
402402

403-
cg <- exogenize(cg, nodes = latents)
404-
405-
changed <- TRUE
406-
407-
while (changed) {
408-
changed <- FALSE
409-
current_latents <- intersect(latents, nodes(cg)$name)
410-
411-
if (length(current_latents) == 0L) {
412-
break
413-
}
414-
415-
# Lemma 3: remove exogenous latents with <= 1 child
416-
child_counts <- vapply(
417-
current_latents,
418-
function(l) {
419-
ch <- children(cg, l)
420-
421-
if (is.null(ch)) {
422-
0L
423-
} else {
424-
length(ch)
425-
}
426-
},
427-
integer(1)
428-
)
429-
430-
to_drop <- current_latents[child_counts <= 1L]
431-
432-
if (length(to_drop) > 0L) {
433-
cg <- remove_nodes(cg, name = to_drop)
434-
changed <- TRUE
435-
next
436-
}
437-
438-
# Lemma 2: remove nested child sets among exogenous latents
439-
current_latents <- intersect(latents, nodes(cg)$name)
440-
441-
if (length(current_latents) < 2L) {
442-
break
443-
}
444-
445-
child_sets <- lapply(
446-
current_latents,
447-
function(l) {
448-
ch <- children(cg, l)
449-
if (is.null(ch)) {
450-
character(0)
451-
} else {
452-
sort(unique(ch))
453-
}
454-
}
455-
)
456-
457-
drop_one <- NULL
458-
459-
for (i in seq_len(length(current_latents) - 1L)) {
460-
for (j in (i + 1L):length(current_latents)) {
461-
ch_i <- child_sets[[i]]
462-
ch_j <- child_sets[[j]]
463-
464-
if (length(ch_i) < length(ch_j) && all(ch_i %in% ch_j)) {
465-
drop_one <- current_latents[i]
466-
break
467-
}
468-
469-
if (length(ch_j) < length(ch_i) && all(ch_j %in% ch_i)) {
470-
drop_one <- current_latents[j]
471-
break
472-
}
473-
}
474-
475-
if (!is.null(drop_one)) {
476-
break
477-
}
478-
}
479-
480-
if (!is.null(drop_one)) {
481-
cg <- remove_nodes(cg, name = drop_one)
482-
changed <- TRUE
483-
}
484-
}
485-
486-
cg
403+
latent_indices <- .nodes_to_indices(cg, latents)
404+
normalized_session <- rs_normalize_latent_structure(
405+
cg@session,
406+
latent_indices
407+
)
408+
.session_to_caugi(normalized_session)
487409
}
488410

489411

0 commit comments

Comments
 (0)