Skip to content

Commit 383f8a3

Browse files
committed
initial ocr
1 parent 6fde712 commit 383f8a3

22 files changed

Lines changed: 1051 additions & 6 deletions

app/build.gradle.kts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ dependencies {
154154
// Serialization
155155
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.10.0")
156156
implementation("org.jetbrains.kotlinx:kotlinx-serialization-protobuf:1.10.0")
157+
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-play-services:1.9.0")
157158

158159
// Color Picker
159160
implementation("com.github.skydoves:colorpickerview:2.4.0")
@@ -198,6 +199,9 @@ dependencies {
198199
implementation("com.google.api-client:google-api-client-android:2.8.1")
199200
implementation("com.google.apis:google-api-services-drive:v3-rev20251210-2.0.0")
200201

202+
// ML Kit Handwriting Recognition
203+
implementation("com.google.mlkit:digital-ink-recognition:18.1.0")
204+
201205
// Markwon (Markdown Rendering & Editing)
202206
implementation("io.noties.markwon:core:4.6.2")
203207
implementation("io.noties.markwon:editor:4.6.2")

app/src/main/java/com/alexdremov/notate/CanvasActivity.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ class CanvasActivity : AppCompatActivity() {
264264
binding = ActivityMainBinding.inflate(layoutInflater)
265265
setContentView(binding.root)
266266

267+
viewModel.setControllerProvider { binding.canvasView.getController() }
268+
267269
currentCanvasPath = intent.getStringExtra("CANVAS_PATH")
268270

269271
enableImmersiveMode()
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
package com.alexdremov.notate.data
2+
3+
import android.graphics.PointF
4+
import android.graphics.RectF
5+
import com.alexdremov.notate.model.InfiniteCanvasModel
6+
import com.alexdremov.notate.model.Stroke
7+
import com.alexdremov.notate.model.StrokeType
8+
import com.alexdremov.notate.util.Logger
9+
import kotlinx.coroutines.CoroutineScope
10+
import kotlinx.coroutines.Dispatchers
11+
import kotlinx.coroutines.FlowPreview
12+
import kotlinx.coroutines.SupervisorJob
13+
import kotlinx.coroutines.delay
14+
import kotlinx.coroutines.flow.MutableSharedFlow
15+
import kotlinx.coroutines.flow.debounce
16+
import kotlinx.coroutines.flow.filter
17+
import kotlinx.coroutines.flow.launchIn
18+
import kotlinx.coroutines.flow.onEach
19+
import kotlinx.coroutines.isActive
20+
import kotlinx.coroutines.launch
21+
22+
@OptIn(FlowPreview::class)
23+
class HandwritingRecognitionCoordinator(
24+
private val model: InfiniteCanvasModel,
25+
private val recognitionManager: HandwritingRecognitionManager,
26+
private val isEnabledProvider: () -> Boolean,
27+
private val onOcrUpdated: ((RectF) -> Unit)? = null
28+
) {
29+
private val scope = CoroutineScope(Dispatchers.Default + SupervisorJob())
30+
private val pendingStrokes = ArrayList<Stroke>()
31+
private val strokeUpdateFlow = MutableSharedFlow<Unit>(extraBufferCapacity = 1)
32+
33+
init {
34+
// Observe model events for new strokes
35+
model.events
36+
.onEach { event ->
37+
if (event is InfiniteCanvasModel.ModelEvent.ItemsAdded) {
38+
val newStrokes = event.items.filterIsInstance<Stroke>().filter {
39+
it.style != StrokeType.DASH // Exclude erasers/selection tools
40+
}
41+
if (newStrokes.isNotEmpty()) {
42+
synchronized(pendingStrokes) {
43+
pendingStrokes.addAll(newStrokes)
44+
}
45+
strokeUpdateFlow.emit(Unit)
46+
}
47+
}
48+
}
49+
.launchIn(scope)
50+
51+
// Debounced processing
52+
strokeUpdateFlow
53+
.debounce(2000)
54+
.onEach {
55+
if (isEnabledProvider()) {
56+
processPendingStrokes()
57+
} else {
58+
synchronized(pendingStrokes) {
59+
pendingStrokes.clear()
60+
}
61+
}
62+
}
63+
.launchIn(scope)
64+
65+
// Periodic sweep for unrecognized strokes (e.g. from older documents or erasures)
66+
scope.launch {
67+
while (isActive) {
68+
delay(10000) // Sweep every 10 seconds
69+
if (isEnabledProvider()) {
70+
sweepUnrecognizedStrokes()
71+
}
72+
}
73+
}
74+
}
75+
76+
private fun clusterStrokesSpatially(strokes: List<Stroke>, maxDistance: Float): List<List<Stroke>> {
77+
val clusters = ArrayList<MutableList<Stroke>>()
78+
for (stroke in strokes) {
79+
val strokeBounds = RectF(stroke.bounds).apply { inset(-maxDistance, -maxDistance) }
80+
81+
// Find all clusters that intersect
82+
val intersectingClusters = clusters.filter { cluster ->
83+
cluster.any { s ->
84+
val cb = RectF(s.bounds).apply { inset(-maxDistance, -maxDistance) }
85+
RectF.intersects(strokeBounds, cb)
86+
}
87+
}
88+
89+
if (intersectingClusters.isEmpty()) {
90+
clusters.add(mutableListOf(stroke))
91+
} else {
92+
val firstCluster = intersectingClusters.first()
93+
firstCluster.add(stroke)
94+
for (i in 1 until intersectingClusters.size) {
95+
firstCluster.addAll(intersectingClusters[i])
96+
clusters.remove(intersectingClusters[i])
97+
}
98+
}
99+
}
100+
return clusters
101+
}
102+
103+
private suspend fun sweepUnrecognizedStrokes() {
104+
val rm = model.getRegionManager() ?: return
105+
val activeIds = rm.getActiveRegionIds()
106+
107+
for (rId in activeIds) {
108+
val region = rm.getRegionReadOnly(rId) ?: continue
109+
val unrecognizedInRegion = ArrayList<Stroke>()
110+
111+
for (item in region.items) {
112+
if (item is Stroke && item.style != StrokeType.DASH) {
113+
val strokeCenter = PointF(item.bounds.centerX(), item.bounds.centerY())
114+
var isRecognized = false
115+
for (ocr in region.recognizedTexts) {
116+
val ocrBounds = RectF(ocr.x, ocr.y, ocr.x + ocr.width, ocr.y + ocr.height)
117+
// Heuristic: if stroke center is within OCR bounds, it's recognized
118+
if (ocrBounds.contains(strokeCenter.x, strokeCenter.y) || RectF.intersects(ocrBounds, item.bounds)) {
119+
isRecognized = true
120+
break
121+
}
122+
}
123+
if (!isRecognized) {
124+
unrecognizedInRegion.add(item)
125+
}
126+
}
127+
}
128+
129+
if (unrecognizedInRegion.isNotEmpty()) {
130+
Logger.d("OCRCoordinator", "Sweep found ${unrecognizedInRegion.size} unrecognized strokes in region $rId")
131+
// Cluster the strokes spatially to avoid sending the whole region at once
132+
val clusters = clusterStrokesSpatially(unrecognizedInRegion, maxDistance = 150f)
133+
for (cluster in clusters) {
134+
processStrokeCluster(cluster)
135+
// Yield to avoid freezing the background thread
136+
delay(200)
137+
}
138+
}
139+
}
140+
}
141+
142+
private suspend fun processPendingStrokes() {
143+
val strokesToProcess = synchronized(pendingStrokes) {
144+
val copy = ArrayList(pendingStrokes)
145+
pendingStrokes.clear()
146+
copy
147+
}
148+
149+
if (strokesToProcess.isEmpty()) return
150+
151+
val clusters = clusterStrokesSpatially(strokesToProcess, maxDistance = 150f)
152+
for (cluster in clusters) {
153+
processStrokeCluster(cluster)
154+
delay(100)
155+
}
156+
}
157+
158+
private suspend fun processStrokeCluster(strokesToProcess: List<Stroke>) {
159+
if (strokesToProcess.isEmpty()) return
160+
161+
Logger.d("OCRCoordinator", "Processing cluster of ${strokesToProcess.size} strokes")
162+
163+
// Group strokes by spatial proximity and temporal overlap.
164+
val totalBounds = RectF()
165+
var isFirst = true
166+
for (s in strokesToProcess) {
167+
if (isFirst) {
168+
totalBounds.set(s.bounds)
169+
isFirst = false
170+
} else {
171+
totalBounds.union(s.bounds)
172+
}
173+
}
174+
175+
// Expand search area to catch adjacent letters/words and existing OCR blocks
176+
// Generous vertical padding (-100f) to allow newlines/paragraph grouping.
177+
val searchArea = RectF(totalBounds).apply { inset(-150f, -100f) }
178+
179+
val intersectingOcrStrokeOrders = HashSet<Long>()
180+
val finalArea = RectF(totalBounds)
181+
182+
// Find existing OCR blocks that intersect the search area
183+
model.getRegionManager()?.getRegionIdsInRect(searchArea)?.forEach { rId ->
184+
val region = model.getRegionManager()?.getRegionReadOnly(rId)
185+
region?.recognizedTexts?.forEach { ocr ->
186+
val ocrRect = RectF(ocr.x, ocr.y, ocr.x + ocr.width, ocr.y + ocr.height)
187+
if (RectF.intersects(ocrRect, searchArea)) {
188+
intersectingOcrStrokeOrders.addAll(ocr.strokeOrders)
189+
finalArea.union(ocrRect)
190+
}
191+
}
192+
}
193+
194+
val finalStrokesSet = HashSet<Stroke>()
195+
finalStrokesSet.addAll(strokesToProcess)
196+
197+
// If we touched existing OCR blocks, grab all their strokes so we don't truncate words
198+
if (intersectingOcrStrokeOrders.isNotEmpty()) {
199+
model.getRegionManager()?.visitItemsInRect(finalArea) { item ->
200+
if (item is Stroke && item.style != StrokeType.DASH && item.style != StrokeType.HIGHLIGHTER) {
201+
if (intersectingOcrStrokeOrders.contains(item.strokeOrder)) {
202+
finalStrokesSet.add(item)
203+
}
204+
}
205+
}
206+
}
207+
208+
// Spatially sort strokes to ensure correct word order (top-to-bottom, left-to-right)
209+
val sortedByY = finalStrokesSet.sortedBy { it.bounds.centerY() }
210+
val lines = ArrayList<MutableList<Stroke>>()
211+
for (stroke in sortedByY) {
212+
val lastLine = lines.lastOrNull()
213+
if (lastLine != null) {
214+
val avgY = lastLine.map { it.bounds.centerY() }.average().toFloat()
215+
// If center Y is within 60px of the line's average center, consider it the same line
216+
if (kotlin.math.abs(stroke.bounds.centerY() - avgY) < 60f) {
217+
lastLine.add(stroke)
218+
continue
219+
}
220+
}
221+
lines.add(mutableListOf(stroke))
222+
}
223+
val finalStrokes = lines.flatMap { line -> line.sortedBy { it.bounds.left } }
224+
225+
// Invalidate old OCR in the area of ALL strokes we are recognizing
226+
val invalidateArea = RectF()
227+
isFirst = true
228+
for (s in finalStrokes) {
229+
if (isFirst) {
230+
invalidateArea.set(s.bounds)
231+
isFirst = false
232+
} else {
233+
invalidateArea.union(s.bounds)
234+
}
235+
}
236+
237+
// Safety check, ensure invalidateArea isn't wildly wrong
238+
if (!invalidateArea.isEmpty || finalStrokes.isNotEmpty()) {
239+
model.removeRecognizedTextInRect(invalidateArea)
240+
}
241+
242+
val result = recognitionManager.recognizeStrokes(finalStrokes)
243+
if (result != null) {
244+
model.addRecognizedText(result)
245+
onOcrUpdated?.invoke(invalidateArea)
246+
}
247+
}
248+
249+
fun stop() {
250+
// RecognitionManager is usually managed externally, but we stop our scope
251+
scope.launch {
252+
synchronized(pendingStrokes) {
253+
pendingStrokes.clear()
254+
}
255+
}
256+
}
257+
258+
/**
259+
* Manually triggers recognition for a set of strokes (e.g. after movement).
260+
*/
261+
fun triggerManualRecognition(strokes: List<Stroke>) {
262+
if (strokes.isEmpty()) return
263+
synchronized(pendingStrokes) {
264+
pendingStrokes.addAll(strokes)
265+
}
266+
scope.launch {
267+
strokeUpdateFlow.emit(Unit)
268+
}
269+
}
270+
}
271+
}
272+
}
273+
}

0 commit comments

Comments
 (0)