Skip to content

Commit 344592a

Browse files
committed
gpu-accelerated code completion
1 parent ffc8ce9 commit 344592a

File tree

5 files changed

+55
-15
lines changed

5 files changed

+55
-15
lines changed

src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/FSA.kt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,9 @@ open class FSA constructor(open val Q: TSA, open val init: Set<Σᐩ>, open val
146146

147147
open val midpoints: List<List<List<Int>>> by lazy { TODO() }
148148

149-
val finalIdxs by lazy { final.map { stateMap[it]!! }.filter { 0 < idsToCoords[it]!!.second }.toTypedArray() }
150-
val isFinal by lazy { BooleanArray(numStates).also { fm -> for (f in finalIdxs) fm[f] = true } }
149+
val finalIdxsq by lazy { final.map { stateMap[it]!! }.toTypedArray() }
150+
val levFinalIdxs by lazy { finalIdxsq.filter { idsToCoords[it]!!.second > 0 }.toTypedArray() }
151+
val isFinal by lazy { BooleanArray(numStates).also { fm -> for (f in levFinalIdxs) fm[f] = true } }
151152

152153
// TODO: Implement Lev state pairing function to avoid this pain
153154
val idsToCoords by lazy { stateLst.mapIndexed { i, it -> i to it.coords() }.toMap() }
@@ -235,7 +236,7 @@ open class FSA constructor(open val Q: TSA, open val init: Set<Σᐩ>, open val
235236
}
236237
}
237238

238-
if (p == 0 && A == startIdx && q in levFSA.finalIdxs && dp[p][q][A]) return true
239+
if (p == 0 && A == startIdx && q in levFSA.levFinalIdxs && dp[p][q][A]) return true
239240
}
240241
}
241242
}
@@ -304,7 +305,7 @@ open class FSA constructor(open val Q: TSA, open val init: Set<Σᐩ>, open val
304305
println("Completed parse matrix in: ${timer.elapsedNow()}")
305306

306307
// 4) Gather final parse trees from dp[0][f][startIdx], for all final states f
307-
val allParses = levFSA.finalIdxs.mapNotNull { q -> dp[0][q][startIdx] }
308+
val allParses = levFSA.levFinalIdxs.mapNotNull { q -> dp[0][q][startIdx] }
308309

309310
return PTree(START_SYMBOL, allParses.flatMap { forest -> forest.branches })
310311
}

src/commonMain/kotlin/ai/hypergraph/kaliningraph/automata/GRE.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ fun repairWithGREAtDist(brokenStr: List<Σᐩ>, cfg: CFG, d: Int): Pair<GRE.CUP,
285285
}
286286
}
287287

288-
if (p == 0 && A == startIdx && q in levFSA.finalIdxs && dp[p][q][A]) {
288+
if (p == 0 && A == startIdx && q in levFSA.levFinalIdxs && dp[p][q][A]) {
289289
val (x, y) = levFSA.idsToCoords[q]!!
290290
/** See final state conditions for [makeExactLevCFL] */
291291
// The minimum radius such that this final state is included in the L-FSA
@@ -356,7 +356,7 @@ fun repairWithGREAtDist(brokenStr: List<Σᐩ>, cfg: CFG, d: Int): Pair<GRE.CUP,
356356
}
357357

358358
// 4) Gather final parse trees from dp[0][f][startIdx], for all final states f
359-
val allParses = levFSA.finalIdxs.mapNotNull { q -> dp[0][q][startIdx] }
359+
val allParses = levFSA.levFinalIdxs.mapNotNull { q -> dp[0][q][startIdx] }
360360

361361
// 5) Combine under a single GRE
362362
return (if (allParses.isEmpty()) null else GRE.CUP(*allParses.toTypedArray()) to diff)
@@ -476,7 +476,7 @@ fun repairWithGRE(brokenStr: List<Σᐩ>, cfg: CFG): GRE? {
476476
println("Completed parse matrix in: ${timer.elapsedNow()}")
477477

478478
// 4) Gather final parse trees from dp[0][f][startIdx], for all final states f
479-
val allParses = levFSA.finalIdxs.mapNotNull { q -> dp[0][q][startIdx] }
479+
val allParses = levFSA.levFinalIdxs.mapNotNull { q -> dp[0][q][startIdx] }
480480

481481
// 5) Combine under a single GRE
482482
return if (allParses.isEmpty()) null else GRE.CUP(*allParses.toTypedArray())
@@ -601,7 +601,7 @@ suspend fun initiateSuspendableRepair(brokenStr: List<Σᐩ>, cfg: CFG): GRE? {
601601
println("Completed parse matrix in: ${timer.elapsedNow()}")
602602

603603
// 4) Gather final parse trees from dp[0][f][startIdx], for all final states f
604-
val allParses = levFSA.finalIdxs.mapNotNull { q -> dp[0][q][startIdx] }
604+
val allParses = levFSA.levFinalIdxs.mapNotNull { q -> dp[0][q][startIdx] }
605605

606606
println("Parsing took ${timer.elapsedNow()} with |σ|=${brokenStr.size}, " +
607607
// "|Q|=$nStates, |G|=${cfg.size}, maxBranch=$maxBranch, |V|=$width, |Σ|=$tms, maxChildren=$maxChildren@$location")
@@ -757,7 +757,7 @@ fun repairWithSparseGRE(brokenStr: List<Σᐩ>, cfg: CFG): GRE? {
757757

758758
println("Completed sparse parse matrix in: ${timer.elapsedNow()} |Q|=$n, |G|=${cfg.size}, |V|=$W, |Σ|=$tms, maxChildren=$maxChildren")
759759

760-
val allParses = levFSA.finalIdxs.mapNotNull { q -> dp[0][q][startIdx] }
760+
val allParses = levFSA.levFinalIdxs.mapNotNull { q -> dp[0][q][startIdx] }
761761
return if (allParses.isEmpty()) null else GRE.CUP(*allParses.toTypedArray()).flatunion()
762762
}
763763

src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,14 @@ val CFG.noNonterminalStubs: CFG by cache {
383383
.also { it.blocked.addAll(blocked) }
384384
}
385385

386+
val CFG.noEpsilon: CFG by cache {
387+
// try { throw Exception() } catch (e: Exception) { e.printStackTrace() }
388+
println("Disabling ε!")
389+
filter { "ε" !in it.toString() }.toSet().freeze()
390+
.also { rewriteHistory.put(it, freeze().let { rewriteHistory[it]!! + listOf(it)}) }
391+
.also { it.blocked.addAll(blocked) }
392+
}
393+
386394
val CFG.noEpsilonOrNonterminalStubs: CFG by cache {
387395
// try { throw Exception() } catch (e: Exception) { e.printStackTrace() }
388396
println("Disabling nonterminal stubs!")

src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Levenshtein.kt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ fun makeExactLevCFL(
6565
): FSA =
6666
(upArcs(str, radius, digits) +
6767
diagArcs(str, radius, digits) +
68-
str.mapIndexed { i, it -> rightArcs(i, radius, it, digits) }.flatten() +
69-
str.mapIndexed { i, it -> knightArcs(i, radius, it, digits, str) }.flatten())
68+
str.flatMapIndexed { i, it -> rightArcs(i, radius, it, digits) } +
69+
str.flatMapIndexed { i, it -> knightArcs(i, radius, it, digits, str) })
7070
.let { Q ->
7171
val initialStates = setOf("q_" + pd(0, digits).let { "$it/$it" })
7272
val finalStates = Q.states().filter { it.unpackCoordinates().let { (i, j) -> ((str.size - i + j).absoluteValue == radius) } }
@@ -86,8 +86,8 @@ fun makeLevFSA(
8686
var initSize = 0
8787
val fsa = (upArcs(str, maxRad, digits) +
8888
diagArcs(str, maxRad, digits) +
89-
str.mapIndexed { i, it -> rightArcs(i, maxRad, it, digits) }.flatten() +
90-
str.mapIndexed { i, it -> knightArcs(i, maxRad, it, digits, str) }.flatten())
89+
str.flatMapIndexed { i, it -> rightArcs(i, maxRad, it, digits) } +
90+
str.flatMapIndexed { i, it -> knightArcs(i, maxRad, it, digits, str) })
9191
.also { initSize = it.size }
9292
.let { Q ->
9393
val initialStates = setOf("q_" + pd(0, digits).let { "$it/$it" })
@@ -131,8 +131,8 @@ fun makeLevFSA(
131131
var initSize = 0
132132
val fsa = (upArcs(str, maxRad, digits) +
133133
diagArcs(str, maxRad, digits) +
134-
str.mapIndexed { i, it -> rightArcs(i, maxRad, it, digits) }.flatten() +
135-
str.mapIndexed { i, it -> knightArcs(i, maxRad, it, digits, str) }.flatten())
134+
str.flatMapIndexed { i, it -> rightArcs(i, maxRad, it, digits) } +
135+
str.flatMapIndexed { i, it -> knightArcs(i, maxRad, it, digits, str) })
136136
.also { initSize = it.size }
137137
.filter { arc ->
138138
listOf(arc.first.unpackCoordinates(), arc.third.unpackCoordinates())
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package ai.hypergraph.kaliningraph.parsing
2+
3+
import ai.hypergraph.kaliningraph.automata.*
4+
5+
fun makePorousFSA(tokens: List<String>): FSA {
6+
val n = tokens.size
7+
val digits = (n + 1).toString().length
8+
9+
fun pd(i: Int) = i.toString().padStart(digits, '0')
10+
fun st(i: Int) = "q_${pd(i)}/${pd(0)}"
11+
12+
val arcs: TSA = (0 until n).map { i ->
13+
val lbl = tokens[i]
14+
Triple(st(i), lbl, st(i + 1))
15+
}.toSet()
16+
17+
val initialStates = setOf(st(0))
18+
val finalStates = setOf(st(n))
19+
20+
return AFSA(arcs, initialStates, finalStates)
21+
.also { it.width = n; it.height = 0; it.levString = tokens }
22+
}
23+
24+
private const val HOLE_SENTINEL_INT: Int = -1 // 0xFFFF_FFFFu on GPU
25+
26+
fun porousToCodePoints(cfg: CFG, porous: List<String>): IntArray =
27+
IntArray(porous.size) { i ->
28+
val t = porous[i]
29+
if (t == "_") HOLE_SENTINEL_INT
30+
else cfg.tmMap[t] ?: error("Unknown token '$t' (not in cfg.tmMap)")
31+
}

0 commit comments

Comments
 (0)