diff --git a/core/common/src/Exceptions.kt b/core/common/src/Exceptions.kt index 1f5e93ed..0f43109c 100644 --- a/core/common/src/Exceptions.kt +++ b/core/common/src/Exceptions.kt @@ -13,6 +13,10 @@ public class DateTimeArithmeticException: RuntimeException { public constructor(message: String): super(message) public constructor(cause: Throwable): super(cause) public constructor(message: String, cause: Throwable): super(message, cause) + + private companion object { + private const val serialVersionUID: Long = -3207806170214997982L + } } /** @@ -23,6 +27,10 @@ public class IllegalTimeZoneException: IllegalArgumentException { public constructor(message: String): super(message) public constructor(cause: Throwable): super(cause) public constructor(message: String, cause: Throwable): super(message, cause) + + private companion object { + private const val serialVersionUID: Long = 1159315966274264801L + } } internal class DateTimeFormatException: IllegalArgumentException { @@ -30,4 +38,8 @@ internal class DateTimeFormatException: IllegalArgumentException { constructor(message: String): super(message) constructor(cause: Throwable): super(cause) constructor(message: String, cause: Throwable): super(message, cause) + + private companion object { + private const val serialVersionUID: Long = 4231196759387994100L + } } diff --git a/core/common/src/internal/format/parser/Parser.kt b/core/common/src/internal/format/parser/Parser.kt index 9958e3fb..5e1c3f84 100644 --- a/core/common/src/internal/format/parser/Parser.kt +++ b/core/common/src/internal/format/parser/Parser.kt @@ -209,7 +209,13 @@ internal value class Parser>( ) } -internal class ParseException(errors: List) : Exception(formatError(errors)) +// note that the message of this exception could be anything (even null) after deserialization of a manually constructed +// or corrupted stream (via Java Object Serialization) +internal class ParseException(errors: List) : Exception(formatError(errors)) { + private companion object { + private const val serialVersionUID: Long = 5691186997393344103L + } +} private fun formatError(errors: List): String { if (errors.size == 1) { diff --git a/core/jvm/src/LocalDate.kt b/core/jvm/src/LocalDate.kt index 7e7cda70..9983e2f8 100644 --- a/core/jvm/src/LocalDate.kt +++ b/core/jvm/src/LocalDate.kt @@ -52,6 +52,11 @@ public actual class LocalDate internal constructor( @Suppress("FunctionName") public actual fun Format(block: DateTimeFormatBuilder.WithDate.() -> Unit): DateTimeFormat = LocalDateFormat.build(block) + + // even though this class uses writeReplace (so serialVersionUID is not needed for a stable serialized form), a + // stable serialVersionUID means exceptions caused by deserialization of malicious streams will be consistent + // (InvalidClassException vs. InvalidObjectException, see MaliciousJvmSerializationTest) + private const val serialVersionUID = 7026816023079564263L } public actual object Formats { @@ -103,6 +108,9 @@ public actual class LocalDate internal constructor( @JvmName("toEpochDays") internal fun toEpochDaysJvm(): Int = value.toEpochDay().clampToInt() + private fun readObject(ois: java.io.ObjectInputStream): Unit = + throw java.io.InvalidObjectException("kotlinx.datetime.LocalDate must be deserialized via kotlinx.datetime.Ser") + private fun writeReplace(): Any = Ser(Ser.DATE_TAG, this) } diff --git a/core/jvm/src/LocalDateTimeJvm.kt b/core/jvm/src/LocalDateTimeJvm.kt index 235be907..0053c47e 100644 --- a/core/jvm/src/LocalDateTimeJvm.kt +++ b/core/jvm/src/LocalDateTimeJvm.kt @@ -106,12 +106,21 @@ public actual class LocalDateTime internal constructor( @Suppress("FunctionName") public actual fun Format(builder: DateTimeFormatBuilder.WithDateTime.() -> Unit): DateTimeFormat = LocalDateTimeFormat.build(builder) + + // even though this class uses writeReplace (so serialVersionUID is not needed for a stable serialized form), a + // stable serialVersionUID means exceptions caused by deserialization of malicious streams will be consistent + // (InvalidClassException vs. InvalidObjectException, see MaliciousJvmSerializationTest) + private const val serialVersionUID: Long = -4261744960416354711L } public actual object Formats { public actual val ISO: DateTimeFormat = ISO_DATETIME } + private fun readObject(ois: java.io.ObjectInputStream): Unit = throw java.io.InvalidObjectException( + "kotlinx.datetime.LocalDateTime must be deserialized via kotlinx.datetime.Ser" + ) + private fun writeReplace(): Any = Ser(Ser.DATE_TIME_TAG, this) } diff --git a/core/jvm/src/LocalTimeJvm.kt b/core/jvm/src/LocalTimeJvm.kt index 98f42011..18f08263 100644 --- a/core/jvm/src/LocalTimeJvm.kt +++ b/core/jvm/src/LocalTimeJvm.kt @@ -85,6 +85,11 @@ public actual class LocalTime internal constructor( @Suppress("FunctionName") public actual fun Format(builder: DateTimeFormatBuilder.WithTime.() -> Unit): DateTimeFormat = LocalTimeFormat.build(builder) + + // even though this class uses writeReplace (so serialVersionUID is not needed for a stable serialized form), a + // stable serialVersionUID means exceptions caused by deserialization of malicious streams will be consistent + // (InvalidClassException vs. InvalidObjectException, see MaliciousJvmSerializationTest) + private const val serialVersionUID: Long = -352249606036216323L } public actual object Formats { @@ -92,6 +97,9 @@ public actual class LocalTime internal constructor( } + private fun readObject(ois: java.io.ObjectInputStream): Unit = + throw java.io.InvalidObjectException("kotlinx.datetime.LocalTime must be deserialized via kotlinx.datetime.Ser") + private fun writeReplace(): Any = Ser(Ser.TIME_TAG, this) } diff --git a/core/jvm/src/UtcOffsetJvm.kt b/core/jvm/src/UtcOffsetJvm.kt index 7f9ed703..3228b3f5 100644 --- a/core/jvm/src/UtcOffsetJvm.kt +++ b/core/jvm/src/UtcOffsetJvm.kt @@ -39,6 +39,11 @@ public actual class UtcOffset( @Suppress("FunctionName") public actual fun Format(block: DateTimeFormatBuilder.WithUtcOffset.() -> Unit): DateTimeFormat = UtcOffsetFormat.build(block) + + // even though this class uses writeReplace (so serialVersionUID is not needed for a stable serialized form), a + // stable serialVersionUID means exceptions caused by deserialization of malicious streams will be consistent + // (InvalidClassException vs. InvalidObjectException, see MaliciousJvmSerializationTest) + private const val serialVersionUID: Long = -6636773355667981618L } public actual object Formats { @@ -47,6 +52,9 @@ public actual class UtcOffset( public actual val FOUR_DIGITS: DateTimeFormat get() = FOUR_DIGIT_OFFSET } + private fun readObject(ois: java.io.ObjectInputStream): Unit = + throw java.io.InvalidObjectException("kotlinx.datetime.UtcOffset must be deserialized via kotlinx.datetime.Ser") + private fun writeReplace(): Any = Ser(Ser.UTC_OFFSET_TAG, this) } diff --git a/core/jvm/test/MaliciousJvmSerializationTest.kt b/core/jvm/test/MaliciousJvmSerializationTest.kt new file mode 100644 index 00000000..a45d9dbd --- /dev/null +++ b/core/jvm/test/MaliciousJvmSerializationTest.kt @@ -0,0 +1,207 @@ +/* + * Copyright 2019-2025 JetBrains s.r.o. and contributors. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package kotlinx.datetime + +import kotlinx.datetime.MaliciousJvmSerializationTest.TestCase.Streams +import java.io.ByteArrayInputStream +import java.io.ObjectInputStream +import java.io.Serializable +import java.lang.reflect.Modifier +import kotlin.reflect.KClass +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.fail + +class MaliciousJvmSerializationTest { + + /** + * This data was generated by running the following Java code (`X` was replaced with [clazz]`.simpleName`, `Y` with + * [delegate]`::class.qualifiedName` and `z` with [delegateFieldName]): + * ```java + * package kotlinx.datetime; + * + * import java.io.*; + * import java.util.*; + * + * public class X implements Serializable { + * private final Y z = ...; + * + * @Serial + * private static final long serialVersionUID = ...; + * + * public static void main(String[] args) throws IOException { + * var bos = new ByteArrayOutputStream(); + * try (var oos = new ObjectOutputStream(bos)) { + * oos.writeObject(new X()); + * } + * System.out.println(HexFormat.of().formatHex(bos.toByteArray())); + * } + * } + * ``` + */ + private class TestCase( + val clazz: KClass, + val serialVersionUID: Long, + val delegateFieldName: String, + val delegate: Serializable, + /** serialVersionUID had the correct value ([serialVersionUID]) in the Java code. */ + val withCorrectSVUID: Streams, + /** serialVersionUID had an incorrect value (42) in the Java code. */ + val withSVUID42: Streams, + ) { + class Streams( + /** `z` was set to `null` in the Java code. */ + val delegateNull: String, + /** `z` was set to [delegate] in the Java code. */ + val delegateValid: String, + ) + } + + @Suppress("RemoveRedundantQualifierName") + private val testCases = listOf( + TestCase( + kotlinx.datetime.LocalDate::class, + serialVersionUID = 7026816023079564263L, + delegateFieldName = "value", + delegate = java.time.LocalDate.of(2025, 4, 26), + withCorrectSVUID = Streams( + delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c44617465618443f17dae33e70200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770703000007e9041a78", + delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c44617465618443f17dae33e70200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b787070", + ), + withSVUID42 = Streams( + delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c44617465000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770703000007e9041a78", + delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c44617465000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b787070", + ), + ), + TestCase( + kotlinx.datetime.LocalDateTime::class, + serialVersionUID = -4261744960416354711L, + delegateFieldName = "value", + delegate = java.time.LocalDateTime.of(2025, 4, 26, 11, 18), + withCorrectSVUID = Streams( + delegateValid = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d65c4db3d89c7126e690200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770905000007e9041a0bed78", + delegateNull = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d65c4db3d89c7126e690200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b787070", + ), + withSVUID42 = Streams( + delegateValid = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d65000000000000002a0200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770905000007e9041a0bed78", + delegateNull = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d65000000000000002a0200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b787070", + ), + ), + TestCase( + kotlinx.datetime.LocalTime::class, + serialVersionUID = -352249606036216323L, + delegateFieldName = "value", + delegate = java.time.LocalTime.of(11, 18), + withCorrectSVUID = Streams( + delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d65fb1c8ed97ff0a5fd0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707703040bed78", + delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d65fb1c8ed97ff0a5fd0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b787070", + ), + withSVUID42 = Streams( + delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d65000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707703040bed78", + delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d65000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b787070", + ), + ), + TestCase( + kotlinx.datetime.UtcOffset::class, + serialVersionUID = -6636773355667981618L, + delegateFieldName = "zoneOffset", + delegate = java.time.ZoneOffset.UTC, + withCorrectSVUID = Streams( + delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f6666736574a3e571cbd0a1face0200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707702080078", + delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f6666736574a3e571cbd0a1face0200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b787070", + ), + withSVUID42 = Streams( + delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f6666736574000000000000002a0200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707702080078", + delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f6666736574000000000000002a0200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b787070", + ), + ), + ) + + @OptIn(ExperimentalStdlibApi::class) + private fun deserialize(stream: String): Any? { + val bis = ByteArrayInputStream(stream.hexToByteArray()) + return ObjectInputStream(bis).use { ois -> + ois.readObject() + } + } + + @Test + fun deserializeMaliciousStreams() { + for (testCase in testCases) { + testCase.ensureAssumptionsHold() + val className = testCase.clazz.qualifiedName!! + testStreamsWithCorrectSVUID(className, testCase.withCorrectSVUID) + testStreamsWithSVUID42(testCase.serialVersionUID, className, testCase.withSVUID42) + } + } + + private fun TestCase.ensureAssumptionsHold() { + val className = clazz.qualifiedName!! + + val actualSerialVersionUID = clazz.java + .getDeclaredField("serialVersionUID") + .apply { isAccessible = true } + .get(null) as Long + if (actualSerialVersionUID == 42L) { + fail("This test assumes that the tested classes don't have a serialVersionUID of 42 but $className does.") + } + if (actualSerialVersionUID != serialVersionUID) { + fail( + "This test assumes that the serialVersionUID of $className is $serialVersionUID but it was " + + "$actualSerialVersionUID." + ) + } + + val field = clazz.java.declaredFields.singleOrNull { !Modifier.isStatic(it.modifiers) } + if (field == null || field.name != delegateFieldName || field.type != delegate.javaClass) { + fail( + "This test assumes that $className has a single instance field named $delegateFieldName of type " + + "${delegate::class.qualifiedName}. The test case for $className should be updated with new " + + "malicious serial streams that represent the changes to $className." + ) + } + } + + private fun testStreamsWithCorrectSVUID(className: String, streams: Streams) { + val testFailureMessage = "Deserialization of a serial stream that tries to bypass kotlinx.datetime.Ser and " + + "has the correct serialVersionUID for $className should fail" + + val expectedIOEMessage = "$className must be deserialized via kotlinx.datetime.Ser" + + // this would actually create a valid instance, but serialization should always go through the proxy + val ioe1 = assertFailsWith(testFailureMessage) { + deserialize(streams.delegateValid) + } + assertEquals(expectedIOEMessage, ioe1.message) + + // this would create an instance that has null in a non-nullable field (e.g., the field + // kotlinx.datetime.LocalDate.value) + // see https://github.com/Kotlin/kotlinx-datetime/pull/373#discussion_r2008922681 + val ioe2 = assertFailsWith(testFailureMessage) { + deserialize(streams.delegateNull) + } + assertEquals(expectedIOEMessage, ioe2.message) + } + + private fun testStreamsWithSVUID42(serialVersionUID: Long, className: String, streams: Streams) { + val testFailureMessage = "Deserialization of a serial stream that tries to bypass kotlinx.datetime.Ser but " + + "has a wrong serialVersionUID for $className should fail" + + val expectedICEMessage = "$className; local class incompatible: stream classdesc serialVersionUID = 42, " + + "local class serialVersionUID = $serialVersionUID" + + val ice1 = assertFailsWith(testFailureMessage) { + deserialize(streams.delegateValid) + } + assertEquals(expectedICEMessage, ice1.message) + + val ice2 = assertFailsWith(testFailureMessage) { + deserialize(streams.delegateNull) + } + assertEquals(expectedICEMessage, ice2.message) + } +}