Skip to content

Commit 3384ffa

Browse files
committed
simplify CNF renormalization step
1 parent 5a139fd commit 3384ffa

File tree

4 files changed

+30
-41
lines changed

4 files changed

+30
-41
lines changed

src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SetValiant.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,11 @@ fun ptreeUnion(left: List<PTree?>, right: List<PTree?>): List<PTree?> =
201201
}
202202

203203
val CFG.bitwiseAlgebra: Ring<Blns> by cache {
204-
vindex.let {
204+
vindex.let { vi ->
205205
Ring.of(
206206
nil = BooleanArray(nonterminals.size) { false },
207207
plus = { x, y -> union(x, y) },
208-
times = { x, y -> fastJoin(it, x, y) }
208+
times = { x, y -> fastJoin(vi, x, y) },
209209
)
210210
}
211211
}

src/commonMain/kotlin/ai/hypergraph/kaliningraph/tensor/Tensor.kt

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,13 +403,11 @@ open class UTMatrix<T> constructor(
403403
algebra = algebra
404404
)
405405
else carry.windowed(2, 1).map { window ->
406-
window[0].second.zip(window[1].third)
407-
.map { (l, r) -> with(algebra) { l * r } }
408-
.fold(algebra.nil) { t, acc -> with(algebra) { acc + t } }
409-
.let { it to (window[0].second + it) to (listOf(it) + window[1].third) }
406+
with(algebra) { dot(window[0].π2, window[1].π3) }
407+
.let { it to (window[0].π2 + it) to (listOf(it) + window[1].π3) }
410408
}.let { next ->
411409
UTMatrix(
412-
diagonals = diagonals + listOf(next.map { it.first }),
410+
diagonals = diagonals + listOf(next.map { it.π1 }),
413411
algebra = algebra
414412
).seekFixpoint(next, iteration + 1, maxIterations)
415413
}

src/commonMain/kotlin/ai/hypergraph/kaliningraph/types/Types.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ interface Group<T>: Nat<T> {
8585
interface Ring<T>: Group<T> {
8686
override fun T.plus(t: T): T
8787
override fun T.times(t: T): T
88+
fun dot(l1: List<T>, l2: List<T>): T = l1.zip(l2).map { (l, r) -> l * r }.reduce { acc, t -> acc + t }
8889

8990
open class of<T>(
9091
override val nil: T, override val one: T = nil,

src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import kotlin.streams.*
1010
import kotlin.time.Duration.Companion.minutes
1111
import kotlin.time.TimeSource
1212
import java.util.concurrent.ConcurrentHashMap
13+
import kotlin.collections.asSequence
1314

1415
fun CFG.parallelEnumSeqMinimalWOR(
1516
prompt: List<String>,
@@ -284,35 +285,24 @@ fun CFG.jvmIntersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CF
284285
// .jdvpNew()
285286
}
286287

287-
// Parallel streaming doesn't seem to be that much faster (yet)?
288-
289288
fun CFG.jvmPostProcess(clock: TimeSource.Monotonic.ValueTimeMark) =
290-
jvmElimVarUnitProds(
291-
cfg = jvmDropVestigialProductions(clock)
292-
).also { println("Normalization eliminated ${size - it.size} productions in ${clock.elapsedNow()}") }
293-
.freeze()
294-
295-
tailrec fun jvmElimVarUnitProds(
296-
cfg: CFG,
297-
toVisit: Set<Σᐩ> = cfg.nonterminals,
298-
vars: Set<Σᐩ> = cfg.nonterminals,
299-
toElim: Σᐩ? = toVisit.firstOrNull()
300-
): CFG {
301-
fun Production.isVariableUnitProd() = RHS.size == 1 && RHS[0] in vars
302-
if (toElim == null) return cfg.filter { !it.isVariableUnitProd() }
303-
val varsThatMapToMe =
304-
cfg.asSequence().asStream().parallel()
305-
.filter { it.RHS.size == 1 && it.RHS[0] == toElim }
306-
.map { it.LHS }.collect(Collectors.toSet())
307-
val thingsIMapTo =
308-
cfg.asSequence().asStream().parallel()
309-
.filter { it.LHS == toElim }.map { it.RHS }
310-
.collect(Collectors.toSet())
311-
return jvmElimVarUnitProds(
312-
(varsThatMapToMe * thingsIMapTo).fold(cfg) { g, p -> g + p },
313-
toVisit.drop(1).toSet(),
314-
vars
315-
)
289+
jvmDropVestigialProductions(clock)
290+
.also { println("Normalization eliminated ${size - it.size} productions in ${clock.elapsedNow()}") }
291+
292+
// Eliminates unit productions whose RHS is not a terminal. For Bar-Hillel intersections, we know the only
293+
// examples of this are the (S -> *) rules, so elimination is much simpler than the full CNF normalization.
294+
fun jvmElimVarUnitProds(cfg: CFG): CFG {
295+
val scfg = cfg.asSequence()
296+
val vars = scfg.asStream().parallel().map { it.first }.collect(Collectors.toSet())
297+
val toElim = scfg.asStream().parallel()
298+
.filter { it.RHS.size == 1 && it.LHS == "START" && it.RHS[0] in vars }
299+
.map { it.RHS[0] }
300+
.collect(Collectors.toSet())
301+
val newCFG = scfg.asStream().parallel()
302+
.filter { it.RHS.size > 1 || it.RHS[0] !in toElim }
303+
.map { if (it.LHS in toElim) "START" to it.RHS else it }
304+
.collect(Collectors.toSet())
305+
return newCFG
316306
}
317307

318308
// TODO: Incomplete / untested
@@ -394,9 +384,9 @@ tailrec fun jvmElimVarUnitProds(
394384
fun CFG.jvmDropVestigialProductions(clock: TimeSource.Monotonic.ValueTimeMark): CFG {
395385
val start = clock.elapsedNow()
396386
var counter = 0
397-
val nts: Set<Σᐩ> = asSequence().asStream().parallel().map { it.first }.collect(Collectors.toSet())
398-
val rw = asSequence().asStream().parallel()
399-
.filter { prod ->
387+
val scfg = asSequence()
388+
val nts: Set<Σᐩ> = scfg.asStream().parallel().map { it.first }.collect(Collectors.toSet())
389+
val rw = scfg.asStream().parallel().filter { prod ->
400390
if (counter++ % 1000 == 0 && BH_TIMEOUT < clock.elapsedNow()) throw Exception("Timeout! ${clock.elapsedNow()}")
401391
// Only keep productions whose RHS symbols are not synthetic or are in the set of NTs
402392
prod.RHS.all { !(it.first() == '[' && 1 < it.length) || it in nts }
@@ -405,11 +395,11 @@ fun CFG.jvmDropVestigialProductions(clock: TimeSource.Monotonic.ValueTimeMark):
405395
.also { println("Removed ${size - it.size} invalid productions in ${clock.elapsedNow() - start}") }
406396
.freeze()
407397
.jvmRemoveUselessSymbols(nts)
408-
// .jdvpNew()
409398

410399
println("Removed ${size - rw.size} vestigial productions, resulting in ${rw.size} productions.")
411400

412-
return if (rw.size == size) rw else rw.jvmDropVestigialProductions(clock)
401+
return if (rw.size == size) jvmElimVarUnitProds(rw).freeze()
402+
else rw.jvmDropVestigialProductions(clock)
413403
}
414404

415405
/**
@@ -472,7 +462,7 @@ private fun CFG.jvmGenSym(
472462
val nextGenerating: MutableSet<Σᐩ> = from.toMutableSet()
473463
val TDEPS =
474464
ConcurrentHashMap<Σᐩ, MutableSet<Σᐩ>>(size).apply {
475-
this@jvmGenSym.asSequence().asStream().parallel()
465+
this@jvmGenSym.toHashSet().asSequence().asStream().parallel()
476466
.forEach { (l, r) -> r.forEach { getOrPut(it) { ConcurrentHashMap.newKeySet() }.add(l) } }
477467
}
478468
// [email protected]().asStream().parallel()

0 commit comments

Comments
 (0)