Skip to content

Commit a7b90ed

Browse files
committed
improve text
1 parent 383f8a3 commit a7b90ed

14 files changed

Lines changed: 541 additions & 179 deletions

app/build.gradle.kts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ dependencies {
169169

170170
debugImplementation("androidx.compose.ui:ui-tooling")
171171
debugImplementation("androidx.compose.ui:ui-test-manifest")
172-
implementation(kotlin("stdlib-jdk8"))
173172

174173
// Testing
175174
testImplementation("junit:junit:4.13.2")
@@ -200,7 +199,7 @@ dependencies {
200199
implementation("com.google.apis:google-api-services-drive:v3-rev20251210-2.0.0")
201200

202201
// ML Kit Handwriting Recognition
203-
implementation("com.google.mlkit:digital-ink-recognition:18.1.0")
202+
implementation("com.google.mlkit:digital-ink-recognition:19.0.0")
204203

205204
// Markwon (Markdown Rendering & Editing)
206205
implementation("io.noties.markwon:core:4.6.2")

app/src/main/java/com/alexdremov/notate/data/HandwritingRecognitionCoordinator.kt

Lines changed: 119 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package com.alexdremov.notate.data
22

33
import android.graphics.PointF
44
import android.graphics.RectF
5+
import com.alexdremov.notate.data.region.RegionId
56
import com.alexdremov.notate.model.InfiniteCanvasModel
67
import com.alexdremov.notate.model.Stroke
78
import com.alexdremov.notate.model.StrokeType
@@ -18,35 +19,37 @@ import kotlinx.coroutines.flow.launchIn
1819
import kotlinx.coroutines.flow.onEach
1920
import kotlinx.coroutines.isActive
2021
import kotlinx.coroutines.launch
22+
import kotlinx.coroutines.withContext
2123

2224
@OptIn(FlowPreview::class)
2325
class HandwritingRecognitionCoordinator(
2426
private val model: InfiniteCanvasModel,
2527
private val recognitionManager: HandwritingRecognitionManager,
2628
private val isEnabledProvider: () -> Boolean,
27-
private val onOcrUpdated: ((RectF) -> Unit)? = null
29+
private val onOcrUpdated: ((RectF) -> Unit)? = null,
2830
) {
2931
private val scope = CoroutineScope(Dispatchers.Default + SupervisorJob())
3032
private val pendingStrokes = ArrayList<Stroke>()
3133
private val strokeUpdateFlow = MutableSharedFlow<Unit>(extraBufferCapacity = 1)
34+
private val processingRegions = HashSet<RegionId>()
3235

3336
init {
3437
// Observe model events for new strokes
3538
model.events
3639
.onEach { event ->
3740
if (event is InfiniteCanvasModel.ModelEvent.ItemsAdded) {
38-
val newStrokes = event.items.filterIsInstance<Stroke>().filter {
39-
it.style != StrokeType.DASH // Exclude erasers/selection tools
40-
}
41+
val newStrokes =
42+
event.items.filterIsInstance<Stroke>().filter {
43+
it.style != StrokeType.DASH && it.style != StrokeType.HIGHLIGHTER // Exclude erasers/selection tools and highlighters
44+
}
4145
if (newStrokes.isNotEmpty()) {
4246
synchronized(pendingStrokes) {
4347
pendingStrokes.addAll(newStrokes)
4448
}
4549
strokeUpdateFlow.emit(Unit)
4650
}
4751
}
48-
}
49-
.launchIn(scope)
52+
}.launchIn(scope)
5053

5154
// Debounced processing
5255
strokeUpdateFlow
@@ -59,190 +62,163 @@ class HandwritingRecognitionCoordinator(
5962
pendingStrokes.clear()
6063
}
6164
}
62-
}
63-
.launchIn(scope)
65+
}.launchIn(scope)
6466

6567
// Periodic sweep for unrecognized strokes (e.g. from older documents or erasures)
6668
scope.launch {
69+
// Trigger immediate sweep after document open
70+
if (isEnabledProvider()) {
71+
sweepUnrecognizedStrokes()
72+
}
6773
while (isActive) {
68-
delay(10000) // Sweep every 10 seconds
74+
delay(15000) // Sweep every 15 seconds
6975
if (isEnabledProvider()) {
7076
sweepUnrecognizedStrokes()
7177
}
7278
}
7379
}
7480
}
7581

76-
private fun clusterStrokesSpatially(strokes: List<Stroke>, maxDistance: Float): List<List<Stroke>> {
77-
val clusters = ArrayList<MutableList<Stroke>>()
78-
for (stroke in strokes) {
79-
val strokeBounds = RectF(stroke.bounds).apply { inset(-maxDistance, -maxDistance) }
80-
81-
// Find all clusters that intersect
82-
val intersectingClusters = clusters.filter { cluster ->
83-
cluster.any { s ->
84-
val cb = RectF(s.bounds).apply { inset(-maxDistance, -maxDistance) }
85-
RectF.intersects(strokeBounds, cb)
86-
}
87-
}
88-
89-
if (intersectingClusters.isEmpty()) {
90-
clusters.add(mutableListOf(stroke))
91-
} else {
92-
val firstCluster = intersectingClusters.first()
93-
firstCluster.add(stroke)
94-
for (i in 1 until intersectingClusters.size) {
95-
firstCluster.addAll(intersectingClusters[i])
96-
clusters.remove(intersectingClusters[i])
97-
}
98-
}
99-
}
100-
return clusters
101-
}
102-
103-
private suspend fun sweepUnrecognizedStrokes() {
82+
suspend fun sweepUnrecognizedStrokes() {
10483
val rm = model.getRegionManager() ?: return
10584
val activeIds = rm.getActiveRegionIds()
106-
85+
10786
for (rId in activeIds) {
87+
// Coordination: Skip if this region is currently being processed by real-time logic
88+
val skip =
89+
synchronized(processingRegions) {
90+
processingRegions.contains(rId)
91+
}
92+
if (skip) continue
93+
10894
val region = rm.getRegionReadOnly(rId) ?: continue
109-
val unrecognizedInRegion = ArrayList<Stroke>()
110-
111-
for (item in region.items) {
112-
if (item is Stroke && item.style != StrokeType.DASH) {
113-
val strokeCenter = PointF(item.bounds.centerX(), item.bounds.centerY())
114-
var isRecognized = false
115-
for (ocr in region.recognizedTexts) {
116-
val ocrBounds = RectF(ocr.x, ocr.y, ocr.x + ocr.width, ocr.y + ocr.height)
117-
// Heuristic: if stroke center is within OCR bounds, it's recognized
118-
if (ocrBounds.contains(strokeCenter.x, strokeCenter.y) || RectF.intersects(ocrBounds, item.bounds)) {
119-
isRecognized = true
120-
break
121-
}
122-
}
123-
if (!isRecognized) {
124-
unrecognizedInRegion.add(item)
125-
}
95+
96+
// 1. Gather all "recognizable" strokes in the region
97+
val strokesInRegion =
98+
region.items.filterIsInstance<Stroke>().filter {
99+
it.style != StrokeType.DASH && it.style != StrokeType.HIGHLIGHTER
100+
}
101+
if (strokesInRegion.isEmpty()) continue
102+
103+
// 2. Perform algorithmic line detection on ALL strokes in the region
104+
// This establishes the "ground truth" for how strokes SHOULD be grouped.
105+
val highLevelClusters = StrokeClusteringManager.clusterStrokes(strokesInRegion)
106+
val detectedLines = highLevelClusters.flatMap { StrokeClusteringManager.segmentIntoLines(it) }
107+
108+
// 3. Build a fast lookup for existing OCR blocks by their stroke sets
109+
// We use a Set<Long> of strokeOrders as the stable identity of an OCR block.
110+
val existingOcrByStrokes = region.recognizedTexts.associateBy { it.strokeOrders.toSet() }
111+
112+
val unrecognizedLines = ArrayList<List<Stroke>>()
113+
114+
for (line in detectedLines) {
115+
val lineStrokeOrders = line.map { it.strokeOrder }.toSet()
116+
117+
// INVARIANT CHECK: Does an OCR block exist that matches this exact line?
118+
if (!existingOcrByStrokes.containsKey(lineStrokeOrders)) {
119+
// If not, this line is either new, modified, or fragmented.
120+
unrecognizedLines.add(line)
126121
}
127122
}
128-
129-
if (unrecognizedInRegion.isNotEmpty()) {
130-
Logger.d("OCRCoordinator", "Sweep found ${unrecognizedInRegion.size} unrecognized strokes in region $rId")
131-
// Cluster the strokes spatially to avoid sending the whole region at once
132-
val clusters = clusterStrokesSpatially(unrecognizedInRegion, maxDistance = 150f)
133-
for (cluster in clusters) {
134-
processStrokeCluster(cluster)
135-
// Yield to avoid freezing the background thread
123+
124+
if (unrecognizedLines.isNotEmpty()) {
125+
Logger.d("OCRCoordinator", "Sweep found ${unrecognizedLines.size} lines in region $rId violating OCR invariant.")
126+
for (line in unrecognizedLines) {
127+
processStrokesInternal(line)
136128
delay(200)
137129
}
138130
}
139131
}
140132
}
141133

142134
private suspend fun processPendingStrokes() {
143-
val strokesToProcess = synchronized(pendingStrokes) {
144-
val copy = ArrayList(pendingStrokes)
145-
pendingStrokes.clear()
146-
copy
147-
}
148-
149-
if (strokesToProcess.isEmpty()) return
150-
151-
val clusters = clusterStrokesSpatially(strokesToProcess, maxDistance = 150f)
152-
for (cluster in clusters) {
153-
processStrokeCluster(cluster)
154-
delay(100)
155-
}
135+
val strokesToProcess =
136+
synchronized(pendingStrokes) {
137+
val copy = ArrayList(pendingStrokes)
138+
pendingStrokes.clear()
139+
copy
140+
}
141+
processStrokesInternal(strokesToProcess)
156142
}
157143

158-
private suspend fun processStrokeCluster(strokesToProcess: List<Stroke>) {
159-
if (strokesToProcess.isEmpty()) return
160-
161-
Logger.d("OCRCoordinator", "Processing cluster of ${strokesToProcess.size} strokes")
162-
163-
// Group strokes by spatial proximity and temporal overlap.
164-
val totalBounds = RectF()
165-
var isFirst = true
166-
for (s in strokesToProcess) {
167-
if (isFirst) {
168-
totalBounds.set(s.bounds)
169-
isFirst = false
144+
private suspend fun processStrokesInternal(initialStrokes: List<Stroke>) {
145+
if (initialStrokes.isEmpty()) return
146+
147+
// 1. Recursive spatial expansion to find all connected strokes and intersecting OCR blocks
148+
val fullClusterSet = HashSet<Stroke>(initialStrokes)
149+
val expandedSearchArea = RectF()
150+
initialStrokes.forEach {
151+
if (expandedSearchArea.isEmpty) {
152+
expandedSearchArea.set(
153+
it.bounds,
154+
)
170155
} else {
171-
totalBounds.union(s.bounds)
156+
expandedSearchArea.union(it.bounds)
172157
}
173158
}
159+
expandedSearchArea.inset(-150f, -100f)
174160

175-
// Expand search area to catch adjacent letters/words and existing OCR blocks
176-
// Generous vertical padding (-100f) to allow newlines/paragraph grouping.
177-
val searchArea = RectF(totalBounds).apply { inset(-150f, -100f) }
178-
161+
// Identify all existing OCR blocks that intersect our current search area
179162
val intersectingOcrStrokeOrders = HashSet<Long>()
180-
val finalArea = RectF(totalBounds)
181-
182-
// Find existing OCR blocks that intersect the search area
183-
model.getRegionManager()?.getRegionIdsInRect(searchArea)?.forEach { rId ->
163+
val totalInvalidateArea = RectF(expandedSearchArea)
164+
165+
model.getRegionManager()?.getRegionIdsInRect(expandedSearchArea)?.forEach { rId ->
184166
val region = model.getRegionManager()?.getRegionReadOnly(rId)
185167
region?.recognizedTexts?.forEach { ocr ->
186168
val ocrRect = RectF(ocr.x, ocr.y, ocr.x + ocr.width, ocr.y + ocr.height)
187-
if (RectF.intersects(ocrRect, searchArea)) {
169+
if (RectF.intersects(ocrRect, expandedSearchArea)) {
188170
intersectingOcrStrokeOrders.addAll(ocr.strokeOrders)
189-
finalArea.union(ocrRect)
171+
totalInvalidateArea.union(ocrRect)
190172
}
191173
}
192174
}
193-
194-
val finalStrokesSet = HashSet<Stroke>()
195-
finalStrokesSet.addAll(strokesToProcess)
196-
197-
// If we touched existing OCR blocks, grab all their strokes so we don't truncate words
198-
if (intersectingOcrStrokeOrders.isNotEmpty()) {
199-
model.getRegionManager()?.visitItemsInRect(finalArea) { item ->
175+
176+
// Lock regions affected by the expanded area
177+
val affectedRegions = model.getRegionManager()?.getRegionIdsInRect(totalInvalidateArea) ?: emptyList()
178+
synchronized(processingRegions) {
179+
processingRegions.addAll(affectedRegions)
180+
}
181+
182+
try {
183+
// Find ALL strokes in the final expanded area (including those from intersected OCR blocks)
184+
model.getRegionManager()?.visitItemsInRect(totalInvalidateArea) { item ->
200185
if (item is Stroke && item.style != StrokeType.DASH && item.style != StrokeType.HIGHLIGHTER) {
201-
if (intersectingOcrStrokeOrders.contains(item.strokeOrder)) {
202-
finalStrokesSet.add(item)
186+
// Include if spatially inside OR part of an invalidated OCR block
187+
if (RectF.intersects(totalInvalidateArea, item.bounds) ||
188+
intersectingOcrStrokeOrders.contains(item.strokeOrder)
189+
) {
190+
fullClusterSet.add(item)
203191
}
204192
}
205193
}
206-
}
207194

208-
// Spatially sort strokes to ensure correct word order (top-to-bottom, left-to-right)
209-
val sortedByY = finalStrokesSet.sortedBy { it.bounds.centerY() }
210-
val lines = ArrayList<MutableList<Stroke>>()
211-
for (stroke in sortedByY) {
212-
val lastLine = lines.lastOrNull()
213-
if (lastLine != null) {
214-
val avgY = lastLine.map { it.bounds.centerY() }.average().toFloat()
215-
// If center Y is within 60px of the line's average center, consider it the same line
216-
if (kotlin.math.abs(stroke.bounds.centerY() - avgY) < 60f) {
217-
lastLine.add(stroke)
218-
continue
195+
// 2. High-level Clustering (Group into paragraphs/sections)
196+
val clusters = StrokeClusteringManager.clusterStrokes(fullClusterSet.toList())
197+
198+
// 3. Clear existing OCR for the entire affected area ONCE to prevent inter-line conflicts
199+
model.removeRecognizedTextInRect(totalInvalidateArea)
200+
201+
for (cluster in clusters) {
202+
// 4. Line Segmentation (Split into individual horizontal lines)
203+
val lines = StrokeClusteringManager.segmentIntoLines(cluster)
204+
205+
// 5. Individual Recognition and Persistence
206+
for (line in lines) {
207+
if (line.isEmpty()) continue
208+
val result = recognitionManager.recognizeStrokes(line)
209+
if (result != null) {
210+
model.addRecognizedText(result)
211+
}
219212
}
220213
}
221-
lines.add(mutableListOf(stroke))
222-
}
223-
val finalStrokes = lines.flatMap { line -> line.sortedBy { it.bounds.left } }
224214

225-
// Invalidate old OCR in the area of ALL strokes we are recognizing
226-
val invalidateArea = RectF()
227-
isFirst = true
228-
for (s in finalStrokes) {
229-
if (isFirst) {
230-
invalidateArea.set(s.bounds)
231-
isFirst = false
232-
} else {
233-
invalidateArea.union(s.bounds)
215+
withContext(Dispatchers.Main) {
216+
onOcrUpdated?.invoke(totalInvalidateArea)
217+
}
218+
} finally {
219+
synchronized(processingRegions) {
220+
processingRegions.removeAll(affectedRegions.toSet())
234221
}
235-
}
236-
237-
// Safety check, ensure invalidateArea isn't wildly wrong
238-
if (!invalidateArea.isEmpty || finalStrokes.isNotEmpty()) {
239-
model.removeRecognizedTextInRect(invalidateArea)
240-
}
241-
242-
val result = recognitionManager.recognizeStrokes(finalStrokes)
243-
if (result != null) {
244-
model.addRecognizedText(result)
245-
onOcrUpdated?.invoke(invalidateArea)
246222
}
247223
}
248224

@@ -267,7 +243,4 @@ class HandwritingRecognitionCoordinator(
267243
strokeUpdateFlow.emit(Unit)
268244
}
269245
}
270-
}
271-
}
272-
}
273246
}

app/src/main/java/com/alexdremov/notate/data/HandwritingRecognitionManager.kt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ import com.alexdremov.notate.util.Logger
77
import com.google.mlkit.common.MlKitException
88
import com.google.mlkit.common.model.DownloadConditions
99
import com.google.mlkit.common.model.RemoteModelManager
10-
import com.google.mlkit.vision.digitalink.DigitalInkRecognition
11-
import com.google.mlkit.vision.digitalink.DigitalInkRecognitionModel
12-
import com.google.mlkit.vision.digitalink.DigitalInkRecognitionModelIdentifier
13-
import com.google.mlkit.vision.digitalink.DigitalInkRecognizer
14-
import com.google.mlkit.vision.digitalink.DigitalInkRecognizerOptions
15-
import com.google.mlkit.vision.digitalink.Ink
10+
import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognition
11+
import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognitionModel
12+
import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognitionModelIdentifier
13+
import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognizer
14+
import com.google.mlkit.vision.digitalink.recognition.DigitalInkRecognizerOptions
15+
import com.google.mlkit.vision.digitalink.recognition.Ink
1616
import kotlinx.coroutines.tasks.await
1717
import java.util.concurrent.atomic.AtomicReference
1818

0 commit comments

Comments
 (0)