Skip to content

Commit 411c171

Browse files
committed
reduce GC
1 parent 1d6fa9a commit 411c171

File tree

1 file changed

+46
-36
lines changed
  • src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata

1 file changed

+46
-36
lines changed

src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/GRE.kt

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ fun repairWithSparseGRE(brokenStr: List<Σᐩ>, cfg: CFG): GRE? {
539539
val upperBound = MAX_RADIUS * 3
540540
val timer = TimeSource.Monotonic.markNow()
541541
val startIdx = cfg.bindex[START_SYMBOL]
542+
val ladj = cfg.leftAdj
542543

543544
fun nonemptyLevIntSparse(levFSA: FSA): Int? {
544545
val n = levFSA.numStates
@@ -572,7 +573,7 @@ fun repairWithSparseGRE(brokenStr: List<Σᐩ>, cfg: CFG): GRE? {
572573
val rightBits = dp[r][q]
573574

574575
for (B in leftBits.iterator()) {
575-
val adj = cfg.leftAdj[B] ?: continue
576+
val adj = ladj[B] ?: continue
576577
adj.forEachIfIn(rightBits) { _, A ->
577578
if (!tgt[A]) {
578579
tgt.set(A); aCount[p][q]++
@@ -641,7 +642,7 @@ fun repairWithSparseGRE(brokenStr: List<Σᐩ>, cfg: CFG): GRE? {
641642
val rightMap = dp[r][q]
642643

643644
for (B in leftBits.iterator()) {
644-
val adj = cfg.leftAdj[B] ?: continue
645+
val adj = ladj[B] ?: continue
645646
adj.forEachIfIn(rightBits) { C, A ->
646647
val l = leftMap[B] ?: return@forEachIfIn
647648
val rgre = rightMap[C] ?: return@forEachIfIn
@@ -691,21 +692,20 @@ fun completeWithSparseGRE(template: List<Σᐩ>, cfg: CFG): PTree? {
691692
val tmm = cfg.tmMap // terminal -> terminal index
692693
val t2vs = cfg.tmToVidx // terminal index -> IntArray of NT indices
693694
val unitNTs = cfg.unitNonterminals // nonterminals that can appear over a HOLE
694-
val units = cfg.bimap.UNITS // Map<Σᐩ, List<Σᐩ>> of unit expansions
695+
val units = cfg.bimap.UNITS // Map<Σᐩ, List<Σᐩ>> of unit expansions
696+
val ladj = cfg.leftAdj
695697

696698
// -------------------------------
697699
// PASS 1: Boolean CYK chart
698700
// -------------------------------
699701

700-
// active[i][j]: KBitSet of nonterminals that derive template[i..j)
701702
val active: Array<Array<KBitSet>> = Array(nTok + 1) { Array(nTok + 1) { KBitSet(W) } }
702703

703704
// Base case: spans of length 1 (i, i+1)
704705
for (i in 0 until nTok) {
705706
val tok = template[i]
706707
val cellBits = active[i][i + 1]
707708

708-
// Treat "_" and HOLE_MARKER as holes
709709
if (tok == "_" || tok == HOLE_MARKER) {
710710
// Any NT that has at least one unit expansion can sit on a hole
711711
for (nt in unitNTs) {
@@ -717,9 +717,7 @@ fun completeWithSparseGRE(template: List<Σᐩ>, cfg: CFG): PTree? {
717717
} else {
718718
// Ordinary terminal: hook into the lexical index
719719
val tIdx = tmm[tok] ?: continue
720-
for (A in t2vs[tIdx]) {
721-
cellBits.set(A)
722-
}
720+
for (A in t2vs[tIdx]) cellBits.set(A)
723721
}
724722
}
725723

@@ -736,16 +734,8 @@ fun completeWithSparseGRE(template: List<Σᐩ>, cfg: CFG): PTree? {
736734
val rightBits = active[k][j]
737735
if (leftBits.isEmpty() || rightBits.isEmpty()) { k++; continue }
738736

739-
// For each B that can derive [i, k)
740-
for (B in leftBits.iterator()) {
741-
val adj = cfg.leftAdj[B] ?: continue
742-
// adj.forEachIfIn iterates over C in rightBits such that B C -> A exists
743-
adj.forEachIfIn(rightBits) { _, A ->
744-
if (!tgtBits[A]) {
745-
tgtBits.set(A)
746-
}
747-
}
748-
}
737+
for (B in leftBits.iterator())
738+
(ladj[B] ?: continue).forEachIfIn(rightBits) { _, A -> if (!tgtBits[A]) tgtBits.set(A) }
749739
k++
750740
}
751741

@@ -759,7 +749,7 @@ fun completeWithSparseGRE(template: List<Σᐩ>, cfg: CFG): PTree? {
759749
}
760750

761751
// -------------------------------
762-
// PASS 2: PTree chart
752+
// PASS 2: PTree chart (branch-accumulating)
763753
// -------------------------------
764754

765755
// trees[i][j][A]: PTree for nonterminal A deriving template[i..j), or null
@@ -771,29 +761,35 @@ fun completeWithSparseGRE(template: List<Σᐩ>, cfg: CFG): PTree? {
771761
val cellTrees = trees[i][i + 1]
772762
val cellBits = active[i][i + 1]
773763

764+
// Scratch: branchesAcc[A] = all branches we want for this (i, i+1, A)
765+
val branchesAcc = arrayOfNulls<MutableList<Π2A<PTree>>>(W)
766+
774767
if (tok == "_" || tok == HOLE_MARKER) {
775768
// For each NT that was marked active on this hole, attach its unit expansions
776769
for (nt in unitNTs) {
777-
val ntIdx = cfg.bindex[nt] ?: continue
770+
val ntIdx = cfg.bindex[nt]
778771
if (!cellBits[ntIdx]) continue
779772

780773
val exp = units[nt] ?: continue
781774
if (exp.isEmpty()) continue
782775

776+
val list = branchesAcc[ntIdx] ?: mutableListOf<Π2A<PTree>>().also { branchesAcc[ntIdx] = it }
777+
783778
// Each unit expansion contributes PSingleton(terminal)
784-
val branches = exp.flatMap { term -> PSingleton(term) }
785-
val newTree = PTree(nt, branches)
786-
val prev = cellTrees[ntIdx]
787-
cellTrees[ntIdx] = prev?.plus(newTree) ?: newTree
779+
for (term in exp) { list += PSingleton(term) }
788780
}
781+
782+
// Now materialize exactly one PTree per active ntIdx
783+
for (A in 0 until W) cellTrees[A] = PTree("", branchesAcc[A] ?: continue)
789784
} else {
790785
val tIdx = tmm[tok] ?: continue
791786
for (A in t2vs[tIdx]) {
792787
if (!cellBits[A]) continue
793-
val newTree = PTree("", PSingleton(tok))
794-
val prev = cellTrees[A]
795-
cellTrees[A] = prev?.plus(newTree) ?: newTree
788+
val list = branchesAcc[A] ?: mutableListOf<Π2A<PTree>>().also { branchesAcc[A] = it }
789+
list += PSingleton(tok)
796790
}
791+
792+
for (A in 0 until W) cellTrees[A] = PTree("", branchesAcc[A] ?: continue)
797793
}
798794
}
799795

@@ -808,6 +804,9 @@ fun completeWithSparseGRE(template: List<Σᐩ>, cfg: CFG): PTree? {
808804

809805
val cellTrees = trees[i][j]
810806

807+
// branchesAcc[A] will collect all (leftTree, rightTree) branches
808+
val branchesAcc = arrayOfNulls<MutableList<Π2A<PTree>>>(W)
809+
811810
var k = i + 1
812811
while (k < j) {
813812
val leftBits = active[i][k]
@@ -817,26 +816,37 @@ fun completeWithSparseGRE(template: List<Σᐩ>, cfg: CFG): PTree? {
817816
val leftRow = trees[i][k]
818817
val rightRow = trees[k][j]
819818

820-
// For each B in leftBits, try to combine with C in rightBits via B C -> A
821819
for (B in leftBits.iterator()) {
822820
val leftTree = leftRow[B] ?: continue
823-
val adj = cfg.leftAdj[B] ?: continue
821+
val adj = ladj[B] ?: continue
824822

825823
adj.forEachIfIn(rightBits) { C, A ->
826-
if (!tgtBits[A]) return@forEachIfIn // sanity: should already be true by pass 1
824+
if (!tgtBits[A]) return@forEachIfIn
827825
val rightTree = rightRow[C] ?: return@forEachIfIn
828-
829-
// Build one binary branch for A, then union into the forest
830-
val branch = listOf(leftTree to rightTree)
831-
val newTree = PTree("", branch)
832-
val prev = cellTrees[A]
833-
cellTrees[A] = prev?.plus(newTree) ?: newTree
826+
(branchesAcc[A] ?: mutableListOf<Π2A<PTree>>().also { branchesAcc[A] = it })
827+
.add(leftTree to rightTree)
834828
}
835829
}
836830

837831
k++
838832
}
839833

834+
// After all splits k, build one PTree per A (and merge with any existing ones once)
835+
for (A in 0 until W) {
836+
val newBranches = branchesAcc[A] ?: continue
837+
val prev = cellTrees[A]
838+
839+
// First time we see A in this cell
840+
if (prev == null) cellTrees[A] = PTree("", newBranches)
841+
else {
842+
// Merge previous and new branches in a single list
843+
val merged = ArrayList<Π2A<PTree>>(prev.branches.size + newBranches.size)
844+
merged.addAll(prev.branches)
845+
merged.addAll(newBranches)
846+
cellTrees[A] = PTree(prev.root, merged)
847+
}
848+
}
849+
840850
i++
841851
}
842852
}

0 commit comments

Comments
 (0)