Skip to content

Commit c876736

Browse files
committed
perf: rewrite normalize_latent_structure() in Rust
Implement `normalize_latent_structure()` in Rust instead to improve performance. This change is stacked on top of #268.
1 parent 48aa538 commit c876736

8 files changed

Lines changed: 414 additions & 103 deletions

File tree

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- Improved performance of all queries. Speedups are more significant on larger graphs,
1313
but even on small graphs, queries are roughly 5x faster.
1414
- `exogenize()` is now implemented in Rust for DAGs, which reduces overhead on larger graphs.
15+
- `normalize_latent_structure()` is now implemented in Rust for DAGs for faster latent normalization workflows.
1516

1617

1718
# caugi 1.1.0

R/extendr-wrappers.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ rs_latent_project <- function(session, latents) .Call(wrap__rs_latent_project, s
124124

125125
rs_exogenize <- function(session, nodes) .Call(wrap__rs_exogenize, session, nodes)
126126

127+
rs_normalize_latent_structure <- function(session, latents) .Call(wrap__rs_normalize_latent_structure, session, latents)
128+
127129
rs_induced_subgraph <- function(session, keep) .Call(wrap__rs_induced_subgraph, session, keep)
128130

129131
subgraph <- function(cg, nodes, index) .Call(wrap__subgraph, cg, nodes, index)

R/operations.R

Lines changed: 6 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -401,90 +401,12 @@ normalize_latent_structure <- function(cg, latents) {
401401
)
402402
}
403403

404-
cg <- exogenize(cg, nodes = latents)
405-
406-
changed <- TRUE
407-
408-
while (changed) {
409-
changed <- FALSE
410-
current_latents <- intersect(latents, nodes(cg)$name)
411-
412-
if (length(current_latents) == 0L) {
413-
break
414-
}
415-
416-
# Lemma 3: remove exogenous latents with <= 1 child
417-
child_counts <- vapply(
418-
current_latents,
419-
function(l) {
420-
ch <- children(cg, l)
421-
422-
if (is.null(ch)) {
423-
0L
424-
} else {
425-
length(ch)
426-
}
427-
},
428-
integer(1)
429-
)
430-
431-
to_drop <- current_latents[child_counts <= 1L]
432-
433-
if (length(to_drop) > 0L) {
434-
cg <- remove_nodes(cg, name = to_drop)
435-
changed <- TRUE
436-
next
437-
}
438-
439-
# Lemma 2: remove nested child sets among exogenous latents
440-
current_latents <- intersect(latents, nodes(cg)$name)
441-
442-
if (length(current_latents) < 2L) {
443-
break
444-
}
445-
446-
child_sets <- lapply(
447-
current_latents,
448-
function(l) {
449-
ch <- children(cg, l)
450-
if (is.null(ch)) {
451-
character(0)
452-
} else {
453-
sort(unique(ch))
454-
}
455-
}
456-
)
457-
458-
drop_one <- NULL
459-
460-
for (i in seq_len(length(current_latents) - 1L)) {
461-
for (j in (i + 1L):length(current_latents)) {
462-
ch_i <- child_sets[[i]]
463-
ch_j <- child_sets[[j]]
464-
465-
if (length(ch_i) < length(ch_j) && all(ch_i %in% ch_j)) {
466-
drop_one <- current_latents[i]
467-
break
468-
}
469-
470-
if (length(ch_j) < length(ch_i) && all(ch_j %in% ch_i)) {
471-
drop_one <- current_latents[j]
472-
break
473-
}
474-
}
475-
476-
if (!is.null(drop_one)) {
477-
break
478-
}
479-
}
480-
481-
if (!is.null(drop_one)) {
482-
cg <- remove_nodes(cg, name = drop_one)
483-
changed <- TRUE
484-
}
485-
}
486-
487-
cg
404+
latent_indices <- .nodes_to_indices(cg, latents)
405+
normalized_session <- rs_normalize_latent_structure(
406+
cg@session,
407+
latent_indices
408+
)
409+
.session_to_caugi(normalized_session)
488410
}
489411

490412

0 commit comments

Comments
 (0)