Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion app/src/main/java/shop/whitedns/client/MainActivity.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package shop.whitedns.client

import android.Manifest
import android.app.Activity
import android.app.AlertDialog
import android.content.Intent
import android.content.pm.PackageManager
import android.net.VpnService
Expand All @@ -26,11 +27,15 @@ import shop.whitedns.client.ui.WhiteDnsScreen
import shop.whitedns.client.ui.WhiteDnsTheme
import shop.whitedns.client.ui.WhiteDnsViewModel
import shop.whitedns.client.model.ConnectionStatus
import shop.whitedns.client.model.StormDnsProfileLinkPreview
import shop.whitedns.client.model.WhiteDnsOptions
import shop.whitedns.client.model.previewStormDnsProfileLink
import shop.whitedns.client.model.resolve

class MainActivity : ComponentActivity() {

private val viewModel by viewModels<WhiteDnsViewModel>()
private var profileImportDialog: AlertDialog? = null

override fun onResume() {
super.onResume()
Expand Down Expand Up @@ -147,6 +152,12 @@ class MainActivity : ComponentActivity() {
handleProfileLinkIntent(intent)
}

override fun onDestroy() {
profileImportDialog?.dismiss()
profileImportDialog = null
super.onDestroy()
}

private fun openNotificationSettings() {
val packageUri = Uri.parse("package:$packageName")
val settingsIntent = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
Expand Down Expand Up @@ -190,7 +201,36 @@ class MainActivity : ComponentActivity() {
}
val link = intent.dataString?.takeIf(String::isNotBlank) ?: return
intent.putExtra(ExtraProfileImportHandled, true)
viewModel.importProfileLink(link)
val preview = runCatching {
previewStormDnsProfileLink(link)
}.getOrElse {
viewModel.importProfileLink(link)
return
}
showProfileImportConfirmation(link, preview)
}

private fun showProfileImportConfirmation(
link: String,
preview: StormDnsProfileLinkPreview,
) {
profileImportDialog?.dismiss()
profileImportDialog = AlertDialog.Builder(this)
.setTitle("Import StormDNS profile?")
.setMessage(
buildString {
appendLine("Name: ${preview.name.ifBlank { "Imported profile" }}")
appendLine("Server: ${preview.domain}")
appendLine("Encryption: ${WhiteDnsOptions.encryptionMethodLabel(preview.encryptionMethod)}")
appendLine()
append("Only import profiles from sources you trust.")
},
)
.setNegativeButton(android.R.string.cancel, null)
.setPositiveButton("Import") { _, _ ->
viewModel.importProfileLink(link)
}
.show()
}

private companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ import org.json.JSONObject
private const val StormDnsProfileScheme = "stormdns"
private const val StormDnsProfileSchema = "whitedns.profile"
private const val StormDnsProfileVersion = 1
private val UnsafeProfileControlRegex = Regex("""[\u0000-\u001F\u007F]""")
private val UnsafeDomainWhitespaceRegex = Regex("""\s""")

data class StormDnsProfileLinkPreview(
val name: String,
val domain: String,
val encryptionMethod: Int,
)

fun WhiteDnsSettings.exportStormDnsProfileLink(profile: ConnectionProfile = selectedConnectionProfile()): String {
val normalizedProfile = profile.copy(
Expand Down Expand Up @@ -85,6 +93,40 @@ fun WhiteDnsSettings.importStormDnsProfileLink(
rawLink: String,
nowMillis: Long = System.currentTimeMillis(),
): WhiteDnsSettings {
val imported = parseStormDnsProfileLink(rawLink)
val profileId = uniqueImportedProfileId(normalizedConnectionProfiles(), nowMillis)

val importedProfile = ConnectionProfile(
id = profileId,
name = imported.name,
serverMode = "custom",
customServerDomain = imported.domain,
customServerEncryptionKey = imported.encryptionKey,
customServerEncryptionMethod = imported.encryptionMethod,
resolverProfileId = "",
connectionMode = connectionMode,
)

return copy(
selectedConnectionProfileId = profileId,
connectionProfiles = normalizedConnectionProfiles() + importedProfile,
serverMode = "custom",
customServerDomain = imported.domain,
customServerEncryptionKey = imported.encryptionKey,
customServerEncryptionMethod = importedProfile.customServerEncryptionMethod,
).syncSelectedConnectionProfileFields()
}

fun previewStormDnsProfileLink(rawLink: String): StormDnsProfileLinkPreview {
val imported = parseStormDnsProfileLink(rawLink)
return StormDnsProfileLinkPreview(
name = imported.name,
domain = imported.domain,
encryptionMethod = imported.encryptionMethod,
)
}

private fun parseStormDnsProfileLink(rawLink: String): ImportedStormDnsProfile {
val root = decodeProfilePayload(rawLink)
val schema = root.requiredString("schema")
if (schema != StormDnsProfileSchema) {
Expand All @@ -101,6 +143,11 @@ fun WhiteDnsSettings.importStormDnsProfileLink(
?: throw IllegalArgumentException("Missing server")
val domain = serverJson.requiredString("domain").trim().trimEnd('.')
val encryptionKey = serverJson.requiredString("encryption_key").trim()
rejectControlCharacters(domain, "Server domain")
rejectControlCharacters(encryptionKey, "Server encryption key")
if (UnsafeDomainWhitespaceRegex.containsMatchIn(domain)) {
throw IllegalArgumentException("Server domain cannot contain whitespace")
}
if (domain.isBlank()) {
throw IllegalArgumentException("Server domain is required")
}
Expand All @@ -109,30 +156,20 @@ fun WhiteDnsSettings.importStormDnsProfileLink(
}

val profileName = profileJson.requiredString("name").trim()
val profileId = uniqueImportedProfileId(normalizedConnectionProfiles(), nowMillis)
rejectControlCharacters(profileName, "Profile name")
if (profileName.isBlank()) {
throw IllegalArgumentException("Profile name is required")
}
val encryptionMethod = serverJson.requiredInt("encryption_method")
if (encryptionMethod !in 0..5) {
throw IllegalArgumentException("Server encryption method must be between 0 and 5")
}
val importedProfile = ConnectionProfile(
id = profileId,
return ImportedStormDnsProfile(
name = profileName,
serverMode = "custom",
customServerDomain = domain,
customServerEncryptionKey = encryptionKey,
customServerEncryptionMethod = encryptionMethod,
resolverProfileId = "",
connectionMode = connectionMode,
domain = domain,
encryptionKey = encryptionKey,
encryptionMethod = encryptionMethod,
)

return copy(
selectedConnectionProfileId = profileId,
connectionProfiles = normalizedConnectionProfiles() + importedProfile,
serverMode = "custom",
customServerDomain = domain,
customServerEncryptionKey = encryptionKey,
customServerEncryptionMethod = importedProfile.customServerEncryptionMethod,
).syncSelectedConnectionProfileFields()
}

private fun encodeProfilePayload(root: JSONObject): String {
Expand Down Expand Up @@ -209,3 +246,16 @@ private fun JSONObject.optionalInt(name: String): Int? {
else -> null
}
}

private fun rejectControlCharacters(value: String, label: String) {
if (UnsafeProfileControlRegex.containsMatchIn(value)) {
throw IllegalArgumentException("$label cannot contain control characters")
}
}

private data class ImportedStormDnsProfile(
val name: String,
val domain: String,
val encryptionKey: String,
val encryptionMethod: Int,
)
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,27 @@ object StormDnsConfigRenderer {
}

private fun escape(value: String): String {
return value
.replace("\\", "\\\\")
.replace("\"", "\\\"")
return buildString {
value.forEach { character ->
when (character) {
'\\' -> append("\\\\")
'"' -> append("\\\"")
'\b' -> append("\\b")
'\t' -> append("\\t")
'\n' -> append("\\n")
'\u000C' -> append("\\f")
'\r' -> append("\\r")
else -> {
if (character.code < 0x20 || character.code == 0x7F) {
append("\\u")
append(character.code.toString(16).uppercase().padStart(4, '0'))
} else {
append(character)
}
}
}
}
}
}

private fun ConnectionProfile.toStormDnsServerProfile(): StormDnsServerProfile {
Expand Down
75 changes: 75 additions & 0 deletions app/src/test/java/shop/whitedns/client/model/WhiteDnsModelsTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,81 @@ class WhiteDnsModelsTest {
assertEquals("proxy", importedSettings.connectionMode)
}

@Test
fun previewStormDnsProfileLinkReturnsSafeImportSummary() {
val payload = """
{
"schema": "whitedns.profile",
"version": 1,
"profile": {
"name": "Imported Profile",
"server": {
"domain": "server.example.com",
"encryption_key": "secret-key",
"encryption_method": 3
}
}
}
""".trimIndent()
val link = "stormdns://${Base64.getUrlEncoder().withoutPadding().encodeToString(payload.toByteArray())}"

val preview = previewStormDnsProfileLink(link)

assertEquals("Imported Profile", preview.name)
assertEquals("server.example.com", preview.domain)
assertEquals(3, preview.encryptionMethod)
}

@Test
fun importStormDnsProfileLinkRejectsControlCharacters() {
val payload = """
{
"schema": "whitedns.profile",
"version": 1,
"profile": {
"name": "Imported Profile",
"server": {
"domain": "server.example.com",
"encryption_key": "secret\nkey",
"encryption_method": 2
}
}
}
""".trimIndent()
val link = "stormdns://${Base64.getUrlEncoder().withoutPadding().encodeToString(payload.toByteArray())}"

val error = assertThrows(IllegalArgumentException::class.java) {
WhiteDnsSettings().importStormDnsProfileLink(link, nowMillis = 42L)
}

assertEquals("Server encryption key cannot contain control characters", error.message)
}

@Test
fun importStormDnsProfileLinkRejectsDomainWhitespace() {
val payload = """
{
"schema": "whitedns.profile",
"version": 1,
"profile": {
"name": "Imported Profile",
"server": {
"domain": "server example.com",
"encryption_key": "secret-key",
"encryption_method": 2
}
}
}
""".trimIndent()
val link = "stormdns://${Base64.getUrlEncoder().withoutPadding().encodeToString(payload.toByteArray())}"

val error = assertThrows(IllegalArgumentException::class.java) {
WhiteDnsSettings().importStormDnsProfileLink(link, nowMillis = 42L)
}

assertEquals("Server domain cannot contain whitespace", error.message)
}

@Test
fun exportAndImportStormDnsProfileLinkUsesOnlyRequiredProfileFields() {
val resolverProfile = ResolverProfile(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,26 @@ class StormDnsConfigRendererTest {
assertTrue(toml.contains("MTU_TEST_PARALLELISM_RESOLVERS = 1"))
assertTrue(toml.contains("STARTUP_MODE = \"resolvers\""))
}

@Test
fun renderClientTomlEscapesControlCharactersInStrings() {
val toml = StormDnsConfigRenderer.renderClientToml(
serverProfile = shop.whitedns.client.model.StormDnsServerProfile(
id = "server",
label = "Server",
domain = "server.example.com",
encryptionKey = "line\nkey\t\"\\${1.toChar()}",
encryptionMethod = 1,
),
settings = WhiteDnsSettings(
socks5Authentication = true,
socksUsername = "user\rname",
socksPassword = "pass\bword",
),
)

assertTrue(toml.contains("ENCRYPTION_KEY = \"line\\nkey\\t\\\"\\\\\\u0001\""))
assertTrue(toml.contains("SOCKS5_USER = \"user\\rname\""))
assertTrue(toml.contains("SOCKS5_PASS = \"pass\\bword\""))
}
}