@@ -2,6 +2,7 @@ package com.alexdremov.notate.data
22
33import android.graphics.PointF
44import android.graphics.RectF
5+ import com.alexdremov.notate.data.region.RegionId
56import com.alexdremov.notate.model.InfiniteCanvasModel
67import com.alexdremov.notate.model.Stroke
78import com.alexdremov.notate.model.StrokeType
@@ -18,35 +19,37 @@ import kotlinx.coroutines.flow.launchIn
1819import kotlinx.coroutines.flow.onEach
1920import kotlinx.coroutines.isActive
2021import kotlinx.coroutines.launch
22+ import kotlinx.coroutines.withContext
2123
2224@OptIn(FlowPreview ::class )
2325class HandwritingRecognitionCoordinator (
2426 private val model : InfiniteCanvasModel ,
2527 private val recognitionManager : HandwritingRecognitionManager ,
2628 private val isEnabledProvider : () -> Boolean ,
27- private val onOcrUpdated : ((RectF ) -> Unit )? = null
29+ private val onOcrUpdated : ((RectF ) -> Unit )? = null ,
2830) {
2931 private val scope = CoroutineScope (Dispatchers .Default + SupervisorJob ())
3032 private val pendingStrokes = ArrayList <Stroke >()
3133 private val strokeUpdateFlow = MutableSharedFlow <Unit >(extraBufferCapacity = 1 )
34+ private val processingRegions = HashSet <RegionId >()
3235
3336 init {
3437 // Observe model events for new strokes
3538 model.events
3639 .onEach { event ->
3740 if (event is InfiniteCanvasModel .ModelEvent .ItemsAdded ) {
38- val newStrokes = event.items.filterIsInstance<Stroke >().filter {
39- it.style != StrokeType .DASH // Exclude erasers/selection tools
40- }
41+ val newStrokes =
42+ event.items.filterIsInstance<Stroke >().filter {
43+ it.style != StrokeType .DASH && it.style != StrokeType .HIGHLIGHTER // Exclude erasers/selection tools and highlighters
44+ }
4145 if (newStrokes.isNotEmpty()) {
4246 synchronized(pendingStrokes) {
4347 pendingStrokes.addAll(newStrokes)
4448 }
4549 strokeUpdateFlow.emit(Unit )
4650 }
4751 }
48- }
49- .launchIn(scope)
52+ }.launchIn(scope)
5053
5154 // Debounced processing
5255 strokeUpdateFlow
@@ -59,190 +62,163 @@ class HandwritingRecognitionCoordinator(
5962 pendingStrokes.clear()
6063 }
6164 }
62- }
63- .launchIn(scope)
65+ }.launchIn(scope)
6466
6567 // Periodic sweep for unrecognized strokes (e.g. from older documents or erasures)
6668 scope.launch {
69+ // Trigger immediate sweep after document open
70+ if (isEnabledProvider()) {
71+ sweepUnrecognizedStrokes()
72+ }
6773 while (isActive) {
68- delay(10000 ) // Sweep every 10 seconds
74+ delay(15000 ) // Sweep every 15 seconds
6975 if (isEnabledProvider()) {
7076 sweepUnrecognizedStrokes()
7177 }
7278 }
7379 }
7480 }
7581
76- private fun clusterStrokesSpatially (strokes : List <Stroke >, maxDistance : Float ): List <List <Stroke >> {
77- val clusters = ArrayList <MutableList <Stroke >>()
78- for (stroke in strokes) {
79- val strokeBounds = RectF (stroke.bounds).apply { inset(- maxDistance, - maxDistance) }
80-
81- // Find all clusters that intersect
82- val intersectingClusters = clusters.filter { cluster ->
83- cluster.any { s ->
84- val cb = RectF (s.bounds).apply { inset(- maxDistance, - maxDistance) }
85- RectF .intersects(strokeBounds, cb)
86- }
87- }
88-
89- if (intersectingClusters.isEmpty()) {
90- clusters.add(mutableListOf (stroke))
91- } else {
92- val firstCluster = intersectingClusters.first()
93- firstCluster.add(stroke)
94- for (i in 1 until intersectingClusters.size) {
95- firstCluster.addAll(intersectingClusters[i])
96- clusters.remove(intersectingClusters[i])
97- }
98- }
99- }
100- return clusters
101- }
102-
103- private suspend fun sweepUnrecognizedStrokes () {
82+ suspend fun sweepUnrecognizedStrokes () {
10483 val rm = model.getRegionManager() ? : return
10584 val activeIds = rm.getActiveRegionIds()
106-
85+
10786 for (rId in activeIds) {
87+ // Coordination: Skip if this region is currently being processed by real-time logic
88+ val skip =
89+ synchronized(processingRegions) {
90+ processingRegions.contains(rId)
91+ }
92+ if (skip) continue
93+
10894 val region = rm.getRegionReadOnly(rId) ? : continue
109- val unrecognizedInRegion = ArrayList <Stroke >()
110-
111- for (item in region.items) {
112- if (item is Stroke && item.style != StrokeType .DASH ) {
113- val strokeCenter = PointF (item.bounds.centerX(), item.bounds.centerY())
114- var isRecognized = false
115- for (ocr in region.recognizedTexts) {
116- val ocrBounds = RectF (ocr.x, ocr.y, ocr.x + ocr.width, ocr.y + ocr.height)
117- // Heuristic: if stroke center is within OCR bounds, it's recognized
118- if (ocrBounds.contains(strokeCenter.x, strokeCenter.y) || RectF .intersects(ocrBounds, item.bounds)) {
119- isRecognized = true
120- break
121- }
122- }
123- if (! isRecognized) {
124- unrecognizedInRegion.add(item)
125- }
95+
96+ // 1. Gather all "recognizable" strokes in the region
97+ val strokesInRegion =
98+ region.items.filterIsInstance<Stroke >().filter {
99+ it.style != StrokeType .DASH && it.style != StrokeType .HIGHLIGHTER
100+ }
101+ if (strokesInRegion.isEmpty()) continue
102+
103+ // 2. Perform algorithmic line detection on ALL strokes in the region
104+ // This establishes the "ground truth" for how strokes SHOULD be grouped.
105+ val highLevelClusters = StrokeClusteringManager .clusterStrokes(strokesInRegion)
106+ val detectedLines = highLevelClusters.flatMap { StrokeClusteringManager .segmentIntoLines(it) }
107+
108+ // 3. Build a fast lookup for existing OCR blocks by their stroke sets
109+ // We use a Set<Long> of strokeOrders as the stable identity of an OCR block.
110+ val existingOcrByStrokes = region.recognizedTexts.associateBy { it.strokeOrders.toSet() }
111+
112+ val unrecognizedLines = ArrayList <List <Stroke >>()
113+
114+ for (line in detectedLines) {
115+ val lineStrokeOrders = line.map { it.strokeOrder }.toSet()
116+
117+ // INVARIANT CHECK: Does an OCR block exist that matches this exact line?
118+ if (! existingOcrByStrokes.containsKey(lineStrokeOrders)) {
119+ // If not, this line is either new, modified, or fragmented.
120+ unrecognizedLines.add(line)
126121 }
127122 }
128-
129- if (unrecognizedInRegion.isNotEmpty()) {
130- Logger .d(" OCRCoordinator" , " Sweep found ${unrecognizedInRegion.size} unrecognized strokes in region $rId " )
131- // Cluster the strokes spatially to avoid sending the whole region at once
132- val clusters = clusterStrokesSpatially(unrecognizedInRegion, maxDistance = 150f )
133- for (cluster in clusters) {
134- processStrokeCluster(cluster)
135- // Yield to avoid freezing the background thread
123+
124+ if (unrecognizedLines.isNotEmpty()) {
125+ Logger .d(" OCRCoordinator" , " Sweep found ${unrecognizedLines.size} lines in region $rId violating OCR invariant." )
126+ for (line in unrecognizedLines) {
127+ processStrokesInternal(line)
136128 delay(200 )
137129 }
138130 }
139131 }
140132 }
141133
142134 private suspend fun processPendingStrokes () {
143- val strokesToProcess = synchronized(pendingStrokes) {
144- val copy = ArrayList (pendingStrokes)
145- pendingStrokes.clear()
146- copy
147- }
148-
149- if (strokesToProcess.isEmpty()) return
150-
151- val clusters = clusterStrokesSpatially(strokesToProcess, maxDistance = 150f )
152- for (cluster in clusters) {
153- processStrokeCluster(cluster)
154- delay(100 )
155- }
135+ val strokesToProcess =
136+ synchronized(pendingStrokes) {
137+ val copy = ArrayList (pendingStrokes)
138+ pendingStrokes.clear()
139+ copy
140+ }
141+ processStrokesInternal(strokesToProcess)
156142 }
157143
158- private suspend fun processStrokeCluster (strokesToProcess : List <Stroke >) {
159- if (strokesToProcess.isEmpty()) return
160-
161- Logger .d(" OCRCoordinator" , " Processing cluster of ${strokesToProcess.size} strokes" )
162-
163- // Group strokes by spatial proximity and temporal overlap.
164- val totalBounds = RectF ()
165- var isFirst = true
166- for (s in strokesToProcess) {
167- if (isFirst) {
168- totalBounds.set(s.bounds)
169- isFirst = false
144+ private suspend fun processStrokesInternal (initialStrokes : List <Stroke >) {
145+ if (initialStrokes.isEmpty()) return
146+
147+ // 1. Recursive spatial expansion to find all connected strokes and intersecting OCR blocks
148+ val fullClusterSet = HashSet <Stroke >(initialStrokes)
149+ val expandedSearchArea = RectF ()
150+ initialStrokes.forEach {
151+ if (expandedSearchArea.isEmpty) {
152+ expandedSearchArea.set(
153+ it.bounds,
154+ )
170155 } else {
171- totalBounds .union(s .bounds)
156+ expandedSearchArea .union(it .bounds)
172157 }
173158 }
159+ expandedSearchArea.inset(- 150f , - 100f )
174160
175- // Expand search area to catch adjacent letters/words and existing OCR blocks
176- // Generous vertical padding (-100f) to allow newlines/paragraph grouping.
177- val searchArea = RectF (totalBounds).apply { inset(- 150f , - 100f ) }
178-
161+ // Identify all existing OCR blocks that intersect our current search area
179162 val intersectingOcrStrokeOrders = HashSet <Long >()
180- val finalArea = RectF (totalBounds)
181-
182- // Find existing OCR blocks that intersect the search area
183- model.getRegionManager()?.getRegionIdsInRect(searchArea)?.forEach { rId ->
163+ val totalInvalidateArea = RectF (expandedSearchArea)
164+
165+ model.getRegionManager()?.getRegionIdsInRect(expandedSearchArea)?.forEach { rId ->
184166 val region = model.getRegionManager()?.getRegionReadOnly(rId)
185167 region?.recognizedTexts?.forEach { ocr ->
186168 val ocrRect = RectF (ocr.x, ocr.y, ocr.x + ocr.width, ocr.y + ocr.height)
187- if (RectF .intersects(ocrRect, searchArea )) {
169+ if (RectF .intersects(ocrRect, expandedSearchArea )) {
188170 intersectingOcrStrokeOrders.addAll(ocr.strokeOrders)
189- finalArea .union(ocrRect)
171+ totalInvalidateArea .union(ocrRect)
190172 }
191173 }
192174 }
193-
194- val finalStrokesSet = HashSet <Stroke >()
195- finalStrokesSet.addAll(strokesToProcess)
196-
197- // If we touched existing OCR blocks, grab all their strokes so we don't truncate words
198- if (intersectingOcrStrokeOrders.isNotEmpty()) {
199- model.getRegionManager()?.visitItemsInRect(finalArea) { item ->
175+
176+ // Lock regions affected by the expanded area
177+ val affectedRegions = model.getRegionManager()?.getRegionIdsInRect(totalInvalidateArea) ? : emptyList()
178+ synchronized(processingRegions) {
179+ processingRegions.addAll(affectedRegions)
180+ }
181+
182+ try {
183+ // Find ALL strokes in the final expanded area (including those from intersected OCR blocks)
184+ model.getRegionManager()?.visitItemsInRect(totalInvalidateArea) { item ->
200185 if (item is Stroke && item.style != StrokeType .DASH && item.style != StrokeType .HIGHLIGHTER ) {
201- if (intersectingOcrStrokeOrders.contains(item.strokeOrder)) {
202- finalStrokesSet.add(item)
186+ // Include if spatially inside OR part of an invalidated OCR block
187+ if (RectF .intersects(totalInvalidateArea, item.bounds) ||
188+ intersectingOcrStrokeOrders.contains(item.strokeOrder)
189+ ) {
190+ fullClusterSet.add(item)
203191 }
204192 }
205193 }
206- }
207194
208- // Spatially sort strokes to ensure correct word order (top-to-bottom, left-to-right)
209- val sortedByY = finalStrokesSet.sortedBy { it.bounds.centerY() }
210- val lines = ArrayList <MutableList <Stroke >>()
211- for (stroke in sortedByY) {
212- val lastLine = lines.lastOrNull()
213- if (lastLine != null ) {
214- val avgY = lastLine.map { it.bounds.centerY() }.average().toFloat()
215- // If center Y is within 60px of the line's average center, consider it the same line
216- if (kotlin.math.abs(stroke.bounds.centerY() - avgY) < 60f ) {
217- lastLine.add(stroke)
218- continue
195+ // 2. High-level Clustering (Group into paragraphs/sections)
196+ val clusters = StrokeClusteringManager .clusterStrokes(fullClusterSet.toList())
197+
198+ // 3. Clear existing OCR for the entire affected area ONCE to prevent inter-line conflicts
199+ model.removeRecognizedTextInRect(totalInvalidateArea)
200+
201+ for (cluster in clusters) {
202+ // 4. Line Segmentation (Split into individual horizontal lines)
203+ val lines = StrokeClusteringManager .segmentIntoLines(cluster)
204+
205+ // 5. Individual Recognition and Persistence
206+ for (line in lines) {
207+ if (line.isEmpty()) continue
208+ val result = recognitionManager.recognizeStrokes(line)
209+ if (result != null ) {
210+ model.addRecognizedText(result)
211+ }
219212 }
220213 }
221- lines.add(mutableListOf (stroke))
222- }
223- val finalStrokes = lines.flatMap { line -> line.sortedBy { it.bounds.left } }
224214
225- // Invalidate old OCR in the area of ALL strokes we are recognizing
226- val invalidateArea = RectF ()
227- isFirst = true
228- for (s in finalStrokes) {
229- if (isFirst) {
230- invalidateArea.set(s.bounds)
231- isFirst = false
232- } else {
233- invalidateArea.union(s.bounds)
215+ withContext(Dispatchers .Main ) {
216+ onOcrUpdated?.invoke(totalInvalidateArea)
217+ }
218+ } finally {
219+ synchronized(processingRegions) {
220+ processingRegions.removeAll(affectedRegions.toSet())
234221 }
235- }
236-
237- // Safety check, ensure invalidateArea isn't wildly wrong
238- if (! invalidateArea.isEmpty || finalStrokes.isNotEmpty()) {
239- model.removeRecognizedTextInRect(invalidateArea)
240- }
241-
242- val result = recognitionManager.recognizeStrokes(finalStrokes)
243- if (result != null ) {
244- model.addRecognizedText(result)
245- onOcrUpdated?.invoke(invalidateArea)
246222 }
247223 }
248224
@@ -267,7 +243,4 @@ class HandwritingRecognitionCoordinator(
267243 strokeUpdateFlow.emit(Unit )
268244 }
269245 }
270- }
271- }
272- }
273246}
0 commit comments