Skip to content

Commit 0d3f31b

Browse files
committed
now I understand
1 parent 21ddc74 commit 0d3f31b

File tree

3 files changed

+37
-25
lines changed

3 files changed

+37
-25
lines changed

src/commonMain/kotlin/ai/hypergraph/kaliningraph/CommonUtils.kt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,8 @@ class KBitSet(val n: Int) {
133133
constructor(n: Int, v: Collection<Int>) : this(n) { v.forEach { set(it) } }
134134
// Each element of 'data' holds 64 bits, covering up to n bits total.
135135
private val data = LongArray((n + 63) ushr 6)
136-
// var set = mutableSetOf<Int>()
137-
// var modified = false
138136

139137
fun set(index: Int) {
140-
// modified = true
141-
// set += index
142138
val word = index ushr 6
143139
val bit = index and 63
144140
data[word] = data[word] or (1L shl bit)

src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SetValiant.kt

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -77,25 +77,28 @@ fun CFG.isValid(str: List<Σᐩ>): Bln =
7777
dp[0][str.size][bindex[START_SYMBOL]]
7878
}
7979

80-
//fun CFG.isValidAlt(str: List<Σᐩ>): Bln =
81-
// if (str.size == 1) checkUnitWord(str.first()).isNotEmpty()
82-
// else {
83-
// val dp = Array(str.size + 1) { Array(str.size + 1) { KBitSet(nonterminals.size) } }
84-
// str.map {
85-
// if (it == "_" || tmMap[it] == null) (0..<nonterminals.size).toList()
86-
// else tmToVidx[tmMap[it]!!] }.forEachIndexed { i, it -> it.forEach { vidx -> dp[i][i+1].set(vidx) } }
87-
//
88-
// for (dist: Int in 0 until dp.size) {
89-
// for (iP: Int in 0 until dp.size - dist) {
90-
// val p = iP
91-
// val q = iP + dist
92-
// val appq = p..q
93-
// for (r in appq) for (lt in dp[p][r].set) for (rt in dp[r][q].set)
94-
// bimap.R2LHSI[lt][rt].forEach { dp[p][q].set(it) }
95-
// }
96-
// }
97-
// dp[0][str.size][bindex[START_SYMBOL]]
98-
// }
80+
// Differs only by the JOIN\otimes operation.
81+
// This strategy only wins over child-enumeration under low ambiguity.
82+
// If the number of child pairs is high, better to just loop over grammar
83+
fun CFG.isValidAlt(str: List<Σᐩ>): Bln =
84+
if (str.size == 1) checkUnitWord(str.first()).isNotEmpty()
85+
else {
86+
val dp = Array(str.size + 1) { Array(str.size + 1) { KBitSet(nonterminals.size) } }
87+
str.map {
88+
if (it == "_" || tmMap[it] == null) (0..<nonterminals.size).toList()
89+
else tmToVidx[tmMap[it]!!] }.forEachIndexed { i, it -> it.forEach { vidx -> dp[i][i+1].set(vidx) } }
90+
91+
for (dist: Int in 0 until dp.size) {
92+
for (iP: Int in 0 until dp.size - dist) {
93+
val p = iP
94+
val q = iP + dist
95+
val appq = p..q
96+
for (r in appq) for (lt in dp[p][r].toList()) for (rt in dp[r][q].toList())
97+
bimap.R2LHSI[lt][rt].forEach { dp[p][q].set(it) }
98+
}
99+
}
100+
dp[0][str.size][bindex[START_SYMBOL]]
101+
}
99102

100103
fun CFG.corner(str: Σᐩ) =
101104
solveFixedpoint(str.tokenizeByWhitespace())[0].last().map { it.root }.toSet()
@@ -298,9 +301,18 @@ fun Σᐩ.isNonterminalStubIn(CJL: CJL): Bln = CJL.cfgs.map { isNonterminalStubI
298301
fun Σᐩ.containsNonterminal(): Bln = Regex("<[^\\s>]*>") in this
299302

300303
// Converts tokens to UT matrix via constructor: σ_i = { A | (A -> w[i]) ∈ P }
301-
fun CFG.initialMatrix(str: List<Σᐩ>): TreeMatrix =
304+
fun CFG.initialMatrix(
305+
str: List<Σᐩ>,
306+
bmp: BiMap = bimap,
307+
unitReach: Map<Σᐩ, Set<Σᐩ>> = originalForm.unitReachability
308+
): TreeMatrix =
302309
FreeMatrix(makeForestAlgebra(), str.size + 1) { i, j ->
303310
if (i + 1 != j) emptySet()
311+
else if (str[j - 1] == HOLE_MARKER)
312+
unitReach.values.flatten().toSet().map { root ->
313+
bmp[root].filter { it.size == 1 }.map { it.first() }.filter { it in terminals }
314+
.map { Tree(root = root, terminal = it, span = i until (i + 1)) }
315+
}.flatten().toSet()
304316
else bimap[listOf(str[j - 1])].map {
305317
Tree(root = it, terminal = str[j - 1], span = (j - 1) until j)
306318
}.toSet()

src/commonTest/kotlin/ai/hypergraph/kaliningraph/parsing/SetValiantTest.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ class SetValiantTest {
418418
fun testUTMRepresentationEquivalence() {
419419
with(vanillaS2PCFG) {
420420
println("SIZE: ${nonterminals.size}")
421-
val str = "NAME = NAME ( NAME , NUMBER ) . NAME ( ) NEWLINE NAME = NAME . NAME NEWLINE NAME = STRING . NAME ( NAME for NAME in NAME if NAME . NAME ( ) ) NEWLINE NAME ( NAME ) NEWLINE".tokenizeByWhitespace()
421+
val str = "NAME = NAME ( NAME , NUMBER ) . NAME ( ) NEWLINE _ _ NAME . NAME NEWLINE NAME = STRING . NAME ( NAME for NAME in NAME if NAME . NAME ( ) ) NEWLINE NAME ( NAME ) NEWLINE".tokenizeByWhitespace()
422422
// with("""P -> ( P ) | P P | ε""".parseCFG()) {
423423
// val str = "( ( ) ( ) ) ( ) ( ( ( ) ) ( ) ) ( ( ( ) ) ) ( ) ( ) ( ) ( ( ) ( ) ) ( ) ( ( ) ( ) ) ( ) ( ( ) ( ) ) ( )".tokenizeByWhitespace()
424424
val iter = ceil(log2(str.size.toDouble())).toInt() + 9
@@ -433,13 +433,17 @@ class SetValiantTest {
433433
)[0, str.size]
434434
}.also { println("Slow transition: ${it.duration.inWholeMilliseconds}ms") }.value
435435

436+
// println(slowTransitionFP)
437+
436438
i = 0
437439
val fastTransitionFP = measureTimedValue {
438440
initialUTMatrix(str).seekFixpoint(
439441
// debug = { println("ITER=${i++}"); it.toFullMatrix().prettyPrint(); println() }
440442
).toFullMatrix()[0, str.size]
441443
}.also { println("Fast transition: ${it.duration.inWholeMilliseconds}ms") }.value
442444

445+
// println(fastTransitionFP)
446+
443447
measureTimedValue {
444448
println(isValidAlt(str))
445449
}.also { println("DP transition: ${it.duration.inWholeMilliseconds}ms") }.value

0 commit comments

Comments
 (0)