Skip to content

Commit 5dbf32a

Browse files
committed
cleanup + compare GRE vs. PTree repair
1 parent e16fc66 commit 5dbf32a

File tree

4 files changed

+269
-19
lines changed

4 files changed

+269
-19
lines changed

src/commonMain/kotlin/ai/hypergraph/kaliningraph/CommonUtils.kt

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,8 @@ class KBitSet(val n: Int) {
162162

163163
fun clear() { data.fill(0L) }
164164

165-
infix fun or(other: KBitSet) {
166-
for (i in data.indices) data[i] = data[i] or other.data[i]
167-
}
168-
169-
infix fun and(other: KBitSet) {
170-
for (i in data.indices) data[i] = data[i] and other.data[i]
171-
}
165+
infix fun or(other: KBitSet) { for (i in data.indices) data[i] = data[i] or other.data[i] }
166+
infix fun and(other: KBitSet) { for (i in data.indices) data[i] = data[i] and other.data[i] }
172167

173168
fun merge(other: KBitSet): KBitSet = KBitSet(n).also { it or other }.also { it or this }
174169

@@ -252,4 +247,13 @@ class KBitSet(val n: Int) {
252247
for (i in data.indices) if ((data[i] and other.data[i]) != 0L) return true
253248
return false
254249
}
250+
251+
fun cardinality(): Int {
252+
var count = 0
253+
val last = data.lastIndex
254+
if (last < 0) return 0
255+
for (i in 0 until last) count += data[i].countOneBits()
256+
count += (data[last] and lastMask).countOneBits()
257+
return count
258+
}
255259
}

src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/GRE.kt

Lines changed: 247 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package ai.hypergraph.kaliningraph.automata
33
import ai.hypergraph.kaliningraph.KBitSet
44
import ai.hypergraph.kaliningraph.parsing.*
55
import ai.hypergraph.kaliningraph.repair.*
6+
import ai.hypergraph.kaliningraph.sampling.bigLFSRSequence
67
import ai.hypergraph.kaliningraph.tensor.UTMatrix
78
import ai.hypergraph.kaliningraph.types.*
9+
import com.ionspin.kotlin.bignum.integer.BigInteger
810
import kotlinx.coroutines.delay
911
import kotlin.math.*
1012
import kotlin.time.Duration.Companion.nanoseconds
@@ -19,7 +21,7 @@ sealed class GRE(open vararg val args: GRE) {
1921
class CAT(val l: GRE, val r: GRE): GRE(l, r)
2022

2123
fun words(terminals: List<Σᐩ>, shouldContinue: () -> Boolean = { true }): Sequence<Σᐩ> =
22-
enumerate(shouldContinue).takeWhile { shouldContinue() }.distinct()
24+
enumerate(shouldContinue).takeWhile { shouldContinue() }
2325
.map { it.mapNotNull { terminals[it].let { if (it == "ε") null else it } }.joinToString(" ") }
2426

2527
fun wordsOrdered(
@@ -137,6 +139,80 @@ sealed class GRE(open vararg val args: GRE) {
137139
// is UNI -> "( ${args.joinToString(" ∪ "){ "$it" }} )"
138140
// is CAT -> "$l $r"
139141
// }
142+
143+
/** Like [words], but sampled pseudorandomly from the space of all derivations */
144+
fun sampleStrWithoutReplacement(terminals: List<Σᐩ>, stride: Int = 1, offset: Int = 0): Sequence<Σᐩ> {
145+
val memo = hashMapOf<GRE, GreMemo>()
146+
val total = this.sizeMemo(memo).size
147+
if (total.isZero()) return emptySequence()
148+
149+
val idxSeq =
150+
if (6 < total.bitLength()) bigLFSRSequence(total)
151+
else sequence { var i = BigInteger.ZERO; while (i < total) { yield(i); i++ } }
152+
153+
return idxSeq.mapIndexedNotNull { ix, bi ->
154+
if (ix % stride != offset) return@mapIndexedNotNull null
155+
val buf = ArrayList<Int>(32)
156+
unrank(bi, memo, buf)
157+
buf.mapNotNull { terminals[it].let { t -> if (t == "ε") null else t } }.joinToString(" ")
158+
}
159+
}
160+
161+
private data class GreMemo(val size: BigInteger, val ranges: List<Pair<BigInteger, BigInteger>>? = null)
162+
163+
// Counts derivations (like PTree.totalTrees), not unique strings.
164+
private fun sizeMemo(m: MutableMap<GRE, GreMemo>): GreMemo = m[this] ?: run {
165+
val memo = when (this) {
166+
is EPS -> GreMemo(BigInteger.ONE)
167+
is SET -> GreMemo(BigInteger.fromInt(s.cardinality()))
168+
is CAT -> {
169+
val lm = l.sizeMemo(m); val rm = r.sizeMemo(m)
170+
GreMemo(lm.size * rm.size)
171+
}
172+
is CUP -> {
173+
val child = args.map { it.sizeMemo(m).size }
174+
val total = child.fold(BigInteger.ZERO) { a, b -> a + b }
175+
val ranges =
176+
child.fold(listOf(BigInteger.ZERO)) { acc, it -> acc + (acc.last() + it) }
177+
.windowed(2) { (a, b) -> a to (b - BigInteger.ONE) }
178+
GreMemo(total, ranges)
179+
}
180+
}
181+
m[this] = memo
182+
memo
183+
}
184+
185+
private fun unrank(i: BigInteger, m: MutableMap<GRE, GreMemo>, out: MutableList<Int>) {
186+
when (this) {
187+
is EPS -> return
188+
189+
is SET -> {
190+
// pick the i-th element of the set in iteration order
191+
val idx = i.intValue(true)
192+
var k = 0
193+
for (t in s.iterator()) {
194+
if (k++ == idx) { out.add(t); return }
195+
}
196+
return
197+
}
198+
199+
is CAT -> {
200+
val rm = r.sizeMemo(m).size
201+
val (iLeft, iRight) = i.divrem(rm)
202+
l.unrank(iLeft, m, out)
203+
r.unrank(iRight, m, out)
204+
}
205+
206+
is CUP -> {
207+
val memo = sizeMemo(m)
208+
val ranges = memo.ranges!!
209+
val t = ranges.indexOfFirst { (lo, hi) -> i in lo..hi }
210+
val child = args[t]
211+
val offset = i - ranges[t].first
212+
child.unrank(offset, m, out)
213+
}
214+
}
215+
}
140216
}
141217

142218
fun CFG.initGREListMat(tokens: List<Σᐩ>): UTMatrix<List<GRE?>> =
@@ -426,7 +502,7 @@ suspend fun initiateSuspendableRepair(brokenStr: List<Σᐩ>, cfg: CFG): GRE? {
426502
val maxBranch = vindex.maxOf { it.size }
427503
val startIdx = bindex[START_SYMBOL]
428504

429-
var i = 0; suspend fun pause(freq: Int = 300_000) { if (i++ % freq == 0) { delay(50.nanoseconds) }}
505+
var spin = 0; suspend fun pause(mask: Int = (1 shl 18) - 1) { if ((++spin and mask) == 0) delay(50.nanoseconds) }
430506

431507
suspend fun nonemptyLevInt(levFSA: FSA): Int? {
432508
val ap: List<List<List<Int>?>> = levFSA.allPairs
@@ -685,7 +761,7 @@ fun repairWithSparseGRE(brokenStr: List<Σᐩ>, cfg: CFG): GRE? {
685761
return if (allParses.isEmpty()) null else GRE.CUP(*allParses.toTypedArray()).flatunion()
686762
}
687763

688-
fun completeWithSparseGRE(template: List<Σᐩ>, cfg: CFG): PTree? {
764+
fun completeWithSparsePTree(template: List<Σᐩ>, cfg: CFG): PTree? {
689765
val timer = TimeSource.Monotonic.markNow()
690766
val startIdx = cfg.bindex[START_SYMBOL]
691767
val W = cfg.nonterminals.size
@@ -856,4 +932,172 @@ fun completeWithSparseGRE(template: List<Σᐩ>, cfg: CFG): PTree? {
856932
val result = trees[0][nTok][startIdx]
857933
println("Completed sparse completion chart in ${timer.elapsedNow()} |w|=$nTok, |V|=$W")
858934
return result
935+
}
936+
937+
fun completeWithSparseGRE(template: List<Σᐩ>, cfg: CFG): GRE? {
938+
val timer = TimeSource.Monotonic.markNow()
939+
val startIdx = cfg.bindex[START_SYMBOL]
940+
val W = cfg.nonterminals.size
941+
val nTok = template.size
942+
943+
val tmm = cfg.tmMap // terminal -> terminal index
944+
val t2vs = cfg.tmToVidx // terminal index -> IntArray of NT indices
945+
val unitNTs = cfg.unitNonterminals // nonterminals that can appear over a HOLE
946+
val units = cfg.bimap.UNITS // Map<NT, List<Terminal>> of unit expansions
947+
val ladj = cfg.leftAdj // B -> (C -> A) adjacency for A -> B C
948+
val tms = cfg.tmLst.size
949+
950+
// -------------------------------
951+
// PASS 1: Boolean CYK chart
952+
// -------------------------------
953+
val active: Array<Array<KBitSet>> = Array(nTok + 1) { Array(nTok + 1) { KBitSet(W) } }
954+
955+
// Base case: spans of length 1 (i, i+1)
956+
for (i in 0 until nTok) {
957+
val tok = template[i]
958+
val cellBits = active[i][i + 1]
959+
960+
if (tok == "_" || tok == HOLE_MARKER) {
961+
for (nt in unitNTs) {
962+
val exp = units[nt] ?: continue
963+
if (exp.isEmpty()) continue
964+
cellBits.set(cfg.bindex[nt])
965+
}
966+
} else {
967+
val tIdx = tmm[tok] ?: continue
968+
for (A in t2vs[tIdx]) cellBits.set(A)
969+
}
970+
}
971+
972+
// CYK-style DP for spans of length ≥ 2
973+
for (len in 2..nTok) {
974+
var i = 0
975+
while (i + len <= nTok) {
976+
val j = i + len
977+
val tgtBits = active[i][j]
978+
979+
var k = i + 1
980+
while (k < j) {
981+
val leftBits = active[i][k]
982+
val rightBits = active[k][j]
983+
if (leftBits.isEmpty() || rightBits.isEmpty()) { k++; continue }
984+
985+
for (B in leftBits.iterator())
986+
(ladj[B] ?: continue).forEachIfIn(rightBits) { _, A ->
987+
if (!tgtBits[A]) tgtBits.set(A)
988+
}
989+
990+
k++
991+
}
992+
i++
993+
}
994+
}
995+
996+
if (!active[0][nTok][startIdx]) {
997+
println("No completion: START does not derive the template under the hole semantics.")
998+
return null
999+
}
1000+
1001+
// -------------------------------
1002+
// PASS 2: Sparse GRE chart
1003+
// -------------------------------
1004+
1005+
// Precompute GRE.SET for hole-lexical expansions (NT -> SET(terminals))
1006+
val holeGre: Array<GRE?> = arrayOfNulls(W)
1007+
for (nt in unitNTs) {
1008+
val ntIdx = cfg.bindex[nt]
1009+
val exp = units[nt] ?: continue
1010+
if (exp.isEmpty()) continue
1011+
1012+
val bs = KBitSet(tms)
1013+
for (term in exp) {
1014+
val tid = tmm[term] ?: continue
1015+
bs.set(tid)
1016+
}
1017+
if (!bs.isEmpty()) holeGre[ntIdx] = GRE.SET(bs)
1018+
}
1019+
1020+
// dp[i][j]: map NT-index -> GRE for template[i..j)
1021+
val dp: Array<Array<MutableMap<Int, GRE>>> =
1022+
Array(nTok + 1) { Array(nTok + 1) { mutableMapOf() } }
1023+
1024+
// Base case: spans of length 1
1025+
for (i in 0 until nTok) {
1026+
val tok = template[i]
1027+
val cellBits = active[i][i + 1]
1028+
val cellMap = dp[i][i + 1]
1029+
1030+
if (tok == "_" || tok == HOLE_MARKER) {
1031+
for (nt in unitNTs) {
1032+
val A = cfg.bindex[nt]
1033+
if (!cellBits[A]) continue
1034+
val g = holeGre[A] ?: continue
1035+
cellMap[A] = g
1036+
}
1037+
} else {
1038+
val tIdx = tmm[tok] ?: continue
1039+
for (A in t2vs[tIdx]) {
1040+
if (!cellBits[A]) continue
1041+
// (Usually unique, but allow merging just in case)
1042+
val prev = cellMap[A] as? GRE.SET
1043+
cellMap[A] = (prev ?: GRE.SET(tms)).apply { s.set(tIdx) }
1044+
}
1045+
}
1046+
}
1047+
1048+
var maxChildren = 0
1049+
val acc: MutableMap<Int, MutableList<GRE>> = hashMapOf()
1050+
1051+
// Internal spans: 2..nTok
1052+
for (len in 2..nTok) {
1053+
var i = 0
1054+
while (i + len <= nTok) {
1055+
val j = i + len
1056+
val tgtBits = active[i][j]
1057+
if (tgtBits.isEmpty()) { i++; continue }
1058+
1059+
acc.clear()
1060+
1061+
var k = i + 1
1062+
while (k < j) {
1063+
val leftBits = active[i][k]
1064+
val rightBits = active[k][j]
1065+
if (leftBits.isEmpty() || rightBits.isEmpty()) { k++; continue }
1066+
1067+
val leftMap = dp[i][k]
1068+
val rightMap = dp[k][j]
1069+
1070+
for (B in leftBits.iterator()) {
1071+
val adj = ladj[B] ?: continue
1072+
val lgre = leftMap[B] ?: continue
1073+
1074+
adj.forEachIfIn(rightBits) { C, A ->
1075+
if (!tgtBits[A]) return@forEachIfIn
1076+
val rgre = rightMap[C] ?: return@forEachIfIn
1077+
acc.getOrPut(A) { mutableListOf() }.add(lgre * rgre)
1078+
}
1079+
}
1080+
1081+
k++
1082+
}
1083+
1084+
if (acc.isNotEmpty()) {
1085+
val cellMap = dp[i][j]
1086+
for ((A, parts) in acc) {
1087+
val combined = if (parts.size == 1) parts[0] else GRE.CUP(*parts.toTypedArray()).flatunion()
1088+
1089+
val prev = cellMap[A]
1090+
cellMap[A] = if (prev == null) combined else (prev + combined).flatunion()
1091+
1092+
if (parts.size > maxChildren) maxChildren = parts.size
1093+
}
1094+
}
1095+
1096+
i++
1097+
}
1098+
}
1099+
1100+
val result = dp[0][nTok][startIdx]
1101+
println("Completed sparse completion chart in ${timer.elapsedNow()} |w|=$nTok, |V|=$W, maxChildren=$maxChildren")
1102+
return result?.flatunion()
8591103
}

src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class PTree constructor(val root: String = ".ε", val branches: List<Π2A<PTree>
2828

2929
val allTerminals: Set<Σᐩ> by lazy {
3030
if (branches.isEmpty()) setOf(root)
31-
else branches.map { (l, r) -> l.allTerminals + r.allTerminals }.flatten().toSet()
31+
else branches.flatMap { (l, r) -> l.allTerminals + r.allTerminals }.toSet()
3232
}
3333

3434
val termDict by lazy { TermDict(allTerminals) }
@@ -42,10 +42,10 @@ class PTree constructor(val root: String = ".ε", val branches: List<Π2A<PTree>
4242
// TODO: Use weighted choice mechanism
4343
val shuffledBranches by lazy { branches.shuffled().sortedBy { "ε" !in it.first.root + it.second.root } }
4444
val toCFG: CFG by lazy {
45-
branches.map { (x, z) ->
45+
branches.flatMap { (x, z) ->
4646
if ("" == z.root) setOf(root to listOf(x.root))
4747
else setOf(root to listOf(x.root, z.root)) + x.toCFG + z.toCFG
48-
}.flatten().toSet()
48+
}.toSet()
4949
}
5050

5151
val totalTreesStr by lazy { totalTrees.toString() }
@@ -311,8 +311,7 @@ fun CFG.initPTreeListMat(tokens: List<String>): UTMatrix<List<PTree?>> =
311311
(if (token != HOLE_MARKER) bimap[listOf(token)] else unitNonterminals)
312312
.associateWith { nt ->
313313
if (token != HOLE_MARKER) PSingleton(token)
314-
else bimap.UNITS[nt]?.map {
315-
PSingleton(it) }?.flatten() ?: listOf()
314+
else bimap.UNITS[nt]?.flatMap { PSingleton(it) } ?: listOf()
316315
}.forEach { (k, v) -> ptreeList[bindex[k]] = PTree(k, v) }
317316
ptreeList
318317
}.toTypedArray(),
@@ -325,7 +324,7 @@ fun CFG.initPForestMat(tokens: List<String>): UTMatrix<PForest> =
325324
(if (token != HOLE_MARKER) bimap[listOf(token)] else unitNonterminals)
326325
.associateWith { nt ->
327326
if (token != HOLE_MARKER) PSingleton(token)
328-
else bimap.UNITS[nt]?.map { PSingleton(it) }?.flatten() ?: listOf()
327+
else bimap.UNITS[nt]?.flatMap { PSingleton(it) } ?: listOf()
329328
}.map { (k, v) -> k to PTree(k, v) }.toMap()
330329
}.toTypedArray(),
331330
algebra = Ring.of(

src/jvmTest/kotlin/ai/hypergraph/kaliningraph/repair/ProbabilisticLBH.kt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -747,9 +747,12 @@ class ProbabilisticLBH {
747747
// @Test // This test requires a lot of memory (more than GHCI can provide)
748748
fun testSLPs() {
749749
// val pt = k2.startPTree(List(35) { "_" })!!
750-
val pt = completeWithSparseGRE(List(20) { "_" }, k3)!!
751-
pt.sampleStrWithoutReplacement().take(1000)//.filter { it.length <= 80 }
752-
.forEach { println("\\texttt{ " + it.replace("{", "\\{")
750+
val template = List(40) { "_" }
751+
val pt = completeWithSparseGRE(template, k3)!!.sampleStrWithoutReplacement(k3.tmLst).take(1000)
752+
// val pt = completeWithSparsePTree(template, k3)!!.sampleStrWithoutReplacement().take(1000)
753+
// pt.sampleStrWithoutReplacement().take(1000)//.filter { it.length <= 80 }
754+
755+
pt.forEach { println("\\texttt{ " + it.replace("{", "\\{")
753756
.replace("}", "\\}") + "}\\\\") }
754757

755758
// val broke = "fn f0 ( p1 : T , p2 : T ) -> T { let mut p3 = add ( p1 , p1 ) ; p3 = mul ( p1 , p1 ) ; p3 }"

0 commit comments

Comments
 (0)