diff --git a/app/src/main/java/org/mydomain/myscan/DocumentDetection.kt b/app/src/main/java/org/mydomain/myscan/DocumentDetection.kt index edb83df..fb9cdca 100644 --- a/app/src/main/java/org/mydomain/myscan/DocumentDetection.kt +++ b/app/src/main/java/org/mydomain/myscan/DocumentDetection.kt @@ -68,7 +68,7 @@ fun detectDocumentQuad(mask: Bitmap, minQuadAreaRatio: Double = 0.02): Quad? { } val vertices = biggest?.toList()?.map { Point(it.x.toInt(), it.y.toInt()) } - return createQuad(vertices) + return if (vertices?.size == 4) createQuad(vertices) else null } /** diff --git a/app/src/main/java/org/mydomain/myscan/Geometry.kt b/app/src/main/java/org/mydomain/myscan/Geometry.kt index d3ffce0..e8767ee 100644 --- a/app/src/main/java/org/mydomain/myscan/Geometry.kt +++ b/app/src/main/java/org/mydomain/myscan/Geometry.kt @@ -44,10 +44,28 @@ data class Quad( Line(bottomRight, bottomLeft), Line(bottomLeft, topLeft)) } + + fun rotate90(iterations: Int, imageWidth: Int, imageHeight: Int): Quad { + val rotatedPoints = listOf( + rotate90(topLeft, imageWidth, imageHeight, iterations), + rotate90(topRight, imageWidth, imageHeight, iterations), + rotate90(bottomRight, imageWidth, imageHeight, iterations), + rotate90(bottomLeft, imageWidth, imageHeight, iterations) + ) + return createQuad(rotatedPoints) + } + private fun rotate90(p: Point, width: Int, height: Int, iterations: Int): Point { + return when (iterations % 4) { + 1 -> Point(height - p.y, p.x) // 90° + 2 -> Point(width - p.x, height - p.y) // 180° + 3 -> Point(p.y, width - p.x) // 270° + else -> p // 0° + } + } } -fun createQuad(vertices: List?): Quad? { - if (vertices == null || vertices.size != 4) return null +fun createQuad(vertices: List): Quad { + require(vertices.size == 4) // Centroid of the points val cx = vertices.map { it.x }.average() diff --git a/app/src/main/java/org/mydomain/myscan/LiveAnalysisState.kt b/app/src/main/java/org/mydomain/myscan/LiveAnalysisState.kt index b7037f9..6f67a26 100644 --- a/app/src/main/java/org/mydomain/myscan/LiveAnalysisState.kt +++ b/app/src/main/java/org/mydomain/myscan/LiveAnalysisState.kt @@ -22,4 +22,5 @@ data class LiveAnalysisState( val inferenceTime: Long = 0L, val binaryMask: Bitmap? = null, val documentQuad: Quad? = null, + val timestamp: Long = System.currentTimeMillis(), ) diff --git a/app/src/main/java/org/mydomain/myscan/MainViewModel.kt b/app/src/main/java/org/mydomain/myscan/MainViewModel.kt index d8ce352..c9df5b8 100644 --- a/app/src/main/java/org/mydomain/myscan/MainViewModel.kt +++ b/app/src/main/java/org/mydomain/myscan/MainViewModel.kt @@ -65,6 +65,7 @@ class MainViewModel( private var _liveAnalysisState = MutableStateFlow(LiveAnalysisState()) val liveAnalysisState: StateFlow = _liveAnalysisState.asStateFlow() + private var lastSuccessfulLiveAnalysisState: LiveAnalysisState? = null private val _screenStack = MutableStateFlow>(listOf(Screen.Camera)) val currentScreen: StateFlow = _screenStack.map { it.last() } @@ -86,11 +87,15 @@ class MainViewModel( LiveAnalysisState( inferenceTime = it.inferenceTime, binaryMask = binaryMask, - documentQuad = detectDocumentQuad(binaryMask) + documentQuad = detectDocumentQuad(binaryMask), + timestamp = System.currentTimeMillis(), ) } .collect { _liveAnalysisState.value = it + if (it.documentQuad != null) { + lastSuccessfulLiveAnalysisState = it + } } } } @@ -164,7 +169,22 @@ class MainViewModel( val segmentation = imageSegmentationService.runSegmentationAndReturn(bitmap, 0) if (segmentation != null) { val mask = segmentation.segmentation.toBinaryMask() - val quad = detectDocumentQuad(mask) + var quad = detectDocumentQuad(mask) + if (quad == null) { + val now = System.currentTimeMillis() + lastSuccessfulLiveAnalysisState?.timestamp?.let { + val offset = now - it + Log.i("Quad", "Last successful live analysis was $offset ms ago") + } + val recentLive = lastSuccessfulLiveAnalysisState?.takeIf { + now - it.timestamp <= 1500 + } + val rotations = (-imageProxy.imageInfo.rotationDegrees / 90) + 4 + quad = recentLive?.documentQuad?.rotate90(rotations, mask.width, mask.height) + if (quad != null) { + Log.i("Quad", "Using quad taken in live analysis; rotations=$rotations") + } + } if (quad != null) { val resizedQuad = quad.scaledTo(mask.width, mask.height, bitmap.width, bitmap.height) corrected = extractDocument(bitmap, resizedQuad, imageProxy.imageInfo.rotationDegrees) diff --git a/app/src/test/java/org/mydomain/myscan/GeometryTest.kt b/app/src/test/java/org/mydomain/myscan/GeometryTest.kt index 47b55be..08bc791 100644 --- a/app/src/test/java/org/mydomain/myscan/GeometryTest.kt +++ b/app/src/test/java/org/mydomain/myscan/GeometryTest.kt @@ -15,6 +15,7 @@ package org.mydomain.myscan import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Test class GeometryTest { @@ -24,4 +25,36 @@ class GeometryTest { assertThat(Line(Point(0, 0), Point(10, 0)).norm()).isEqualTo(10.0) assertThat(Line(Point(1, 2), Point(4, 6)).norm()).isEqualTo(5.0) } + + @Test + fun createQuad() { + val quad = createQuad(listOf( + Point(3, 9), Point(1,2), Point(11,12), Point(10, 3))) + assertThat(quad).isEqualTo( + Quad(Point(1,2), Point(10, 3), Point(11,12), Point(3, 9))) + assertThatThrownBy { createQuad(listOf()) } + .isInstanceOf(IllegalArgumentException::class.java) + } + + @Test + fun rotateQuad() { + val quad = createQuad(listOf( + Point(1,2), Point(10, 3), Point(11,12), Point(3, 9))) + assertThat(quad.rotate90(1, 100, 50)).isEqualTo( + createQuad(listOf( + Point(48,1), Point(47, 10), Point(38,11), Point(41, 3) + ))) + assertThat(quad.rotate90(2, 100, 50)).isEqualTo( + createQuad(listOf( + Point(99,48), Point(90, 47), Point(89,38), Point(97, 41) + ))) + assertThat(quad.rotate90(3, 100, 50)).isEqualTo( + createQuad(listOf( + Point(2,99), Point(3, 90), Point(12,89), Point(9, 97) + ))) + assertThat(quad.rotate90(4, 100, 50)).isEqualTo(quad) + assertThat(quad.rotate90(5, 100, 50)).isEqualTo( + quad.rotate90(1, 100, 50) + ) + } }