diff --git a/app/src/main/kotlin/io/homeassistant/companion/android/onboarding/connection/ConnectionViewModel.kt b/app/src/main/kotlin/io/homeassistant/companion/android/onboarding/connection/ConnectionViewModel.kt index cf8b5aa0159..4bd686ecde6 100644 --- a/app/src/main/kotlin/io/homeassistant/companion/android/onboarding/connection/ConnectionViewModel.kt +++ b/app/src/main/kotlin/io/homeassistant/companion/android/onboarding/connection/ConnectionViewModel.kt @@ -116,6 +116,14 @@ internal class ConnectionViewModel @VisibleForTesting constructor( init { viewModelScope.launch { + // Pre-set the mTLS flag before emitting the auth URL to handle TLS session + // resumption. See preInitializeTLSClientAuthState for details. + try { + webViewClient.preInitializeTLSClientAuthState(rawUrl.toHttpUrl().host) + } catch (_: IllegalArgumentException) { + // Malformed URL: this is a best-effort pre-initialisation; + // buildAuthUrl below will surface the error to the user. + } buildAuthUrl(rawUrl) } } diff --git a/app/src/main/kotlin/io/homeassistant/companion/android/util/TLSWebViewClient.kt b/app/src/main/kotlin/io/homeassistant/companion/android/util/TLSWebViewClient.kt index e540e1977ea..d919c1f83a8 100644 --- a/app/src/main/kotlin/io/homeassistant/companion/android/util/TLSWebViewClient.kt +++ b/app/src/main/kotlin/io/homeassistant/companion/android/util/TLSWebViewClient.kt @@ -11,15 +11,23 @@ import android.webkit.WebViewClient import androidx.annotation.VisibleForTesting import io.homeassistant.companion.android.common.data.keychain.KeyChainRepository import java.lang.ref.WeakReference +import java.net.InetAddress +import java.net.UnknownHostException import java.security.PrivateKey import java.security.cert.CertificateException +import java.security.cert.CertificateParsingException import java.security.cert.X509Certificate +import javax.security.auth.x500.X500Principal import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.launch import timber.log.Timber +// SAN (Subject Alternative Name) type codes per RFC 5280 section 4.2.1.6 +private const val SAN_TYPE_DNS_NAME = 2 +private const val SAN_TYPE_IP_ADDRESS = 7 + /* * [TLSWebViewClient] is on the onboarding module for convenience, since we don't have yet * a place to share components between app modules. Common is shared with wear and @@ -39,6 +47,117 @@ open class TLSWebViewClient(private var keyChainRepository: KeyChainRepository) private var key: PrivateKey? = null private var chain: Array? = null + /** + * Pre-initializes [isTLSClientAuthNeeded] by verifying whether the currently loaded + * certificate chain covers [targetHost], to handle TLS session resumption. + * + * Normally [isTLSClientAuthNeeded] is set when [onReceivedClientCertRequest] fires during + * a full TLS handshake. However, when TLS session resumption occurs (the WebView reuses an + * existing session from the same process), the server does not issue a new + * `CertificateRequest`, so [onReceivedClientCertRequest] is never called — even if the + * server requires a client certificate. + * + * This is the root cause of the Wear OS onboarding mTLS failure: the main app WebView + * establishes a TLS session while the user is connected; the onboarding WebView immediately + * resumes it, bypassing the callback that would reveal the mTLS requirement to the + * navigation layer. + * + * The fix inspects the in-memory certificate chain (if any) and checks whether it covers + * [targetHost] via its Subject Alternative Names (SANs), or its Common Name (CN) as a + * fallback. This avoids a false positive when the user has multiple servers where only one + * requires mTLS: the loaded cert will not match the non-mTLS server's hostname. + * + * If the app was force-stopped first (clearing in-memory state) no TLS session can be + * resumed either, so [onReceivedClientCertRequest] will fire naturally on the fresh handshake. + * + * Must be called **before** the WebView starts loading (i.e. before the URL is emitted). + * Idempotent: if the flag is already `true` (set by a real handshake) this is a no-op. + * + * @param targetHost the hostname of the server being connected to (e.g. "myha.example.com") + */ + fun preInitializeTLSClientAuthState(targetHost: String) { + if (isTLSClientAuthNeeded) return + val cert = keyChainRepository.getCertificateChain()?.firstOrNull() ?: return + isTLSClientAuthNeeded = certCoversHost(cert, targetHost) + } + + /** + * Returns `true` if [cert] is valid for [host]. + * + * Checks Subject Alternative Names (SANs) first — both DNS names (with wildcard support) + * and IP addresses. Falls back to the Common Name (CN) in the Subject DN if no SANs are + * present, matching the behaviour of legacy TLS stacks. + */ + private fun certCoversHost(cert: X509Certificate, host: String): Boolean { + val sans: Collection>? = try { + cert.subjectAlternativeNames + } catch (_: CertificateParsingException) { + null + } + + return if (!sans.isNullOrEmpty()) { + sans.any { san -> + if (san.size < 2) return@any false + val type = san[0] as? Int ?: return@any false + when (type) { + SAN_TYPE_DNS_NAME -> { // dNSName — returned as String + val value = san[1] as? String ?: return@any false + hostMatchesSan(host, value) + } + SAN_TYPE_IP_ADDRESS -> { + // iPAddress — the standard Java X.509 API returns this as a String + // (dotted-quad or colon-hex), but some providers (e.g. BouncyCastle) + // return a ByteArray; handle both defensively. + // Normalize both sides through InetAddress so that different textual + // representations of the same address compare equal (e.g. "::1" vs + // "0:0:0:0:0:0:0:1"). + val sanAddress = try { + when (val ipEntry = san[1]) { + is ByteArray -> InetAddress.getByAddress(ipEntry) + is String -> InetAddress.getByName(ipEntry) + else -> return@any false + } + } catch (_: UnknownHostException) { + return@any false + } + val hostAddress = try { + InetAddress.getByName(host) + } catch (_: UnknownHostException) { + return@any false + } + hostAddress == sanAddress + } + else -> false + } + } + } else { + // Fallback: extract CN from the Subject DN. + // getName(RFC2253) uses comma as AVA separator; commas inside values are escaped + // as \, which we don't need to handle because hostnames never contain commas. + val dn = cert.subjectX500Principal.getName(X500Principal.RFC2253) + val cn = dn.splitToSequence(",") + .map { it.trim() } + .firstOrNull { it.startsWith("CN=", ignoreCase = true) } + ?.let { it.substring(it.indexOf('=') + 1).trim() } + ?.takeIf { it.isNotEmpty() } + cn != null && hostMatchesSan(host, cn) + } + } + + /** + * Matches [host] against a SAN value that may contain a leading wildcard. + * + * A wildcard (`*.example.com`) covers any single label: `foo.example.com` matches but + * `foo.bar.example.com` and `example.com` do not (per RFC 2818 §3.1). + */ + private fun hostMatchesSan(host: String, san: String): Boolean { + if (!san.startsWith("*.")) return host.equals(san, ignoreCase = true) + val suffix = san.substring(1) // ".example.com" + if (!host.endsWith(suffix, ignoreCase = true)) return false + val wildcardLabel = host.substring(0, host.length - suffix.length) + return wildcardLabel.isNotEmpty() && !wildcardLabel.contains('.') + } + private fun getActivity(context: Context?): Activity? { if (context == null) { return null diff --git a/app/src/test/kotlin/io/homeassistant/companion/android/onboarding/connection/ConnectionViewModelTest.kt b/app/src/test/kotlin/io/homeassistant/companion/android/onboarding/connection/ConnectionViewModelTest.kt index d1770a67311..c78b951afb6 100644 --- a/app/src/test/kotlin/io/homeassistant/companion/android/onboarding/connection/ConnectionViewModelTest.kt +++ b/app/src/test/kotlin/io/homeassistant/companion/android/onboarding/connection/ConnectionViewModelTest.kt @@ -22,7 +22,10 @@ import io.mockk.mockk import io.mockk.mockkStatic import io.mockk.slot import io.mockk.verify +import java.net.InetAddress import java.net.URL +import java.security.cert.CertificateParsingException +import java.security.cert.X509Certificate import kotlin.reflect.KClass import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.MutableSharedFlow @@ -415,4 +418,241 @@ class ConnectionViewModelTest { assertEquals("class java.lang.UnsatisfiedLinkError", error.rawErrorType) verify(exactly = 1) { connectivityCheckRepository.runChecks(rawUrl) } } + + // --- preInitializeTLSClientAuthState / cert-host matching tests --- + + @Test + fun `Given cert with exact DNS SAN matching target host when initializing then isTLSClientAuthNeeded is pre-set to true`() = runTest { + val cert = mockk { + every { subjectAlternativeNames } returns listOf(listOf(2, "homeassistant.local")) + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "http://homeassistant.local:8123", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given cert with wildcard DNS SAN matching target host when initializing then isTLSClientAuthNeeded is pre-set to true`() = runTest { + val cert = mockk { + every { subjectAlternativeNames } returns listOf(listOf(2, "*.example.com")) + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "https://ha.example.com", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given cert with DNS SAN for a different host when initializing then isTLSClientAuthNeeded remains false`() = runTest { + val cert = mockk { + every { subjectAlternativeNames } returns listOf(listOf(2, "other-server.example.com")) + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "http://homeassistant.local:8123", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertFalse(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given no certificate chain in memory when initializing then isTLSClientAuthNeeded remains false`() = runTest { + every { keyChainRepository.getCertificateChain() } returns null + + val viewModel = ConnectionViewModel( + "http://homeassistant.local:8123", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertFalse(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given cert with CN matching target host and no SANs when initializing then isTLSClientAuthNeeded is pre-set to true`() = runTest { + val cert = mockk { + every { subjectAlternativeNames } returns null + every { subjectX500Principal } returns mockk { + every { getName("RFC2253") } returns "CN=homeassistant.local,O=Home Assistant" + } + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "http://homeassistant.local:8123", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given cert with wildcard SAN that does not cover a multi-label subdomain when initializing then isTLSClientAuthNeeded remains false`() = runTest { + val cert = mockk { + // *.example.com covers foo.example.com but not foo.bar.example.com + every { subjectAlternativeNames } returns listOf(listOf(2, "*.example.com")) + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "https://foo.bar.example.com", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertFalse(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given cert with wildcard SAN that does not cover apex domain when initializing then isTLSClientAuthNeeded remains false`() = runTest { + val cert = mockk { + // *.example.com covers foo.example.com but not example.com itself (RFC 2818 §3.1) + every { subjectAlternativeNames } returns listOf(listOf(2, "*.example.com")) + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "https://example.com", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertFalse(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given cert with non-matching SANs and matching CN when initializing then isTLSClientAuthNeeded remains false`() = runTest { + // When SANs are present, the CN must not be used as fallback even if it would match — + // this is the standard behaviour defined in RFC 2818 §3.1. + val cert = mockk { + every { subjectAlternativeNames } returns listOf(listOf(2, "other-server.example.com")) + every { subjectX500Principal } returns mockk { + every { getName("RFC2253") } returns "CN=homeassistant.local,O=Home Assistant" + } + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "http://homeassistant.local:8123", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertFalse(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given cert with IP address SAN as ByteArray matching target host when initializing then isTLSClientAuthNeeded is pre-set to true`() = runTest { + // Some providers (e.g. BouncyCastle) return iPAddress (type 7) as a ByteArray. + val ipBytes = InetAddress.getByName("192.168.1.100").address + val cert = mockk { + every { subjectAlternativeNames } returns listOf(listOf(7, ipBytes)) + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "https://192.168.1.100", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given cert with IP address SAN as String matching target host when initializing then isTLSClientAuthNeeded is pre-set to true`() = runTest { + // The standard Java X.509 API returns iPAddress (type 7) as a String (dotted-quad or colon-hex). + val cert = mockk { + every { subjectAlternativeNames } returns listOf(listOf(7, "192.168.1.100")) + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "https://192.168.1.100", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given cert with IPv6 SAN in expanded form matching target host in compressed form when initializing then isTLSClientAuthNeeded is pre-set to true`() = runTest { + // InetAddress equality normalizes different textual forms of the same IPv6 address. + val cert = mockk { + every { subjectAlternativeNames } returns listOf(listOf(7, "0:0:0:0:0:0:0:1")) + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "https://[::1]", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given cert with DNS SAN in mixed case when initializing then isTLSClientAuthNeeded is pre-set to true`() = runTest { + val cert = mockk { + every { subjectAlternativeNames } returns listOf(listOf(2, "HomeAssistant.Local")) + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "http://homeassistant.local:8123", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded) + } + + @Test + fun `Given cert whose subjectAlternativeNames throws CertificateParsingException when initializing then falls back to CN matching`() = runTest { + val cert = mockk { + every { subjectAlternativeNames } throws CertificateParsingException("bad extension") + every { subjectX500Principal } returns mockk { + every { getName("RFC2253") } returns "CN=homeassistant.local,O=Home Assistant" + } + } + every { keyChainRepository.getCertificateChain() } returns arrayOf(cert) + + val viewModel = ConnectionViewModel( + "http://homeassistant.local:8123", + webViewClientFactory, + connectivityCheckRepository, + ) + advanceUntilIdle() + + assertTrue(viewModel.webViewClient.isTLSClientAuthNeeded) + } }