From 53976d38b1bd6917b8fa4d1dd4f009728ece3adb Mon Sep 17 00:00:00 2001 From: Levi Broderick Date: Tue, 14 Jul 2020 23:12:10 -0700 Subject: [PATCH] [release/5.0-preview7] Disallow unrestricted polymorphic deserialization in DataSet (#39314) Fixes CVE-2020-1147 https://portal.msrc.microsoft.com/en-us/security-guidance/advisory/CVE-2020-1147 See also https://go.microsoft.com/fwlink/?linkid=2132227. --- .../src/Resources/Strings.resx | 1 + .../src/System.Data.Common.csproj | 25 +- .../src/System/Data/Common/ObjectStorage.cs | 3 + .../src/System/Data/DataColumn.cs | 1 + .../src/System/Data/DataException.cs | 1 + .../src/System/Data/DataSet.cs | 6 + .../src/System/Data/DataTable.cs | 6 + .../src/System/Data/Filter/FunctionNode.cs | 24 +- .../System/Data/LocalAppContextSwitches.cs | 18 + .../src/System/Data/TypeLimiter.cs | 305 ++++++++++++ .../tests/System.Data.Common.Tests.csproj | 1 + .../Data/RestrictedTypeHandlingTests.cs | 446 ++++++++++++++++++ 12 files changed, 824 insertions(+), 13 deletions(-) create mode 100644 src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs create mode 100644 src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs create mode 100644 src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs diff --git a/src/libraries/System.Data.Common/src/Resources/Strings.resx b/src/libraries/System.Data.Common/src/Resources/Strings.resx index 4535a6944bffc1..671bf490bc1351 100644 --- a/src/libraries/System.Data.Common/src/Resources/Strings.resx +++ b/src/libraries/System.Data.Common/src/Resources/Strings.resx @@ -165,6 +165,7 @@ '{0}' argument is out of range. '{0}' argument cannot be null. '{0}' argument contains null value. + Type '{0}' is not allowed here. See https://go.microsoft.com/fwlink/?linkid=2132227 for more details. Cannot find column {0}. Column '{0}' already belongs to this DataTable. Column '{0}' already belongs to another DataTable. diff --git a/src/libraries/System.Data.Common/src/System.Data.Common.csproj b/src/libraries/System.Data.Common/src/System.Data.Common.csproj index 20091047d6b9b5..7beb59ac3a8bb6 100644 --- a/src/libraries/System.Data.Common/src/System.Data.Common.csproj +++ b/src/libraries/System.Data.Common/src/System.Data.Common.csproj @@ -2,7 +2,7 @@ System.Data.Common true - $(NetCoreAppCurrent) + $(NetCoreAppCurrent)-Windows_NT;$(NetCoreAppCurrent)-Unix @@ -123,6 +123,10 @@ + + + Common\System\LocalAppContextSwitches.Common.cs + @@ -156,6 +160,7 @@ + @@ -295,25 +300,23 @@ - + + + + + + + - - - + - - - - - - diff --git a/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs b/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs index ffe4960cce7b87..dc3fe0533569d4 100644 --- a/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs +++ b/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs @@ -406,6 +406,9 @@ public override object ConvertXmlToObject(XmlReader xmlReader, XmlRootAttribute if (type == typeof(object)) throw ExceptionBuilder.CanNotDeserializeObjectType(); + + TypeLimiter.EnsureTypeIsAllowed(type); + if (!isBaseCLRType) { retValue = System.Activator.CreateInstance(type, true); diff --git a/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs b/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs index d44ec1e0b88751..560f6ecd7e7148 100644 --- a/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs +++ b/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs @@ -143,6 +143,7 @@ public DataColumn(string columnName, Type dataType, string expr, MappingType typ private void UpdateColumnType(Type type, StorageType typeCode) { + TypeLimiter.EnsureTypeIsAllowed(type); _dataType = type; _storageType = typeCode; if (StorageType.DateTime != typeCode) diff --git a/src/libraries/System.Data.Common/src/System/Data/DataException.cs b/src/libraries/System.Data.Common/src/System/Data/DataException.cs index 7aa84800747b2e..aede6f282f193f 100644 --- a/src/libraries/System.Data.Common/src/System/Data/DataException.cs +++ b/src/libraries/System.Data.Common/src/System/Data/DataException.cs @@ -350,6 +350,7 @@ private static void ThrowDataException(string error, Exception innerException) public static Exception ArgumentOutOfRange(string paramName) => _ArgumentOutOfRange(paramName, SR.Format(SR.Data_ArgumentOutOfRange, paramName)); public static Exception BadObjectPropertyAccess(string error) => _InvalidOperation(SR.Format(SR.DataConstraint_BadObjectPropertyAccess, error)); public static Exception ArgumentContainsNull(string paramName) => _Argument(paramName, SR.Format(SR.Data_ArgumentContainsNull, paramName)); + public static Exception TypeNotAllowed(Type type) => _InvalidOperation(SR.Format(SR.Data_TypeNotAllowed, type.AssemblyQualifiedName)); // diff --git a/src/libraries/System.Data.Common/src/System/Data/DataSet.cs b/src/libraries/System.Data.Common/src/System/Data/DataSet.cs index f4d6ad887d5ce1..c09b35203a9c5e 100644 --- a/src/libraries/System.Data.Common/src/System/Data/DataSet.cs +++ b/src/libraries/System.Data.Common/src/System/Data/DataSet.cs @@ -1961,9 +1961,11 @@ private void WriteXmlSchema(XmlWriter writer, SchemaFormat schemaFormat, Convert internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving) { + IDisposable restrictedScope = null; long logScopeId = DataCommonEventSource.Log.EnterScope(" {0}, denyResolving={1}", ObjectID, denyResolving); try { + restrictedScope = TypeLimiter.EnterRestrictedScope(this); DataTable.DSRowDiffIdUsageSection rowDiffIdUsage = default; try { @@ -2231,6 +2233,7 @@ internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving) } finally { + restrictedScope?.Dispose(); DataCommonEventSource.Log.ExitScope(logScopeId); } } @@ -2467,9 +2470,11 @@ private void ReadXmlDiffgram(XmlReader reader) internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolving) { + IDisposable restictedScope = null; long logScopeId = DataCommonEventSource.Log.EnterScope(" {0}, mode={1}, denyResolving={2}", ObjectID, mode, denyResolving); try { + restictedScope = TypeLimiter.EnterRestrictedScope(this); XmlReadMode ret = mode; if (reader == null) @@ -2711,6 +2716,7 @@ internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolv } finally { + restictedScope?.Dispose(); DataCommonEventSource.Log.ExitScope(logScopeId); } } diff --git a/src/libraries/System.Data.Common/src/System/Data/DataTable.cs b/src/libraries/System.Data.Common/src/System/Data/DataTable.cs index 917908ec2a451f..5b2fa268c94c60 100644 --- a/src/libraries/System.Data.Common/src/System/Data/DataTable.cs +++ b/src/libraries/System.Data.Common/src/System/Data/DataTable.cs @@ -5659,9 +5659,11 @@ private bool IsEmptyXml(XmlReader reader) internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving) { + IDisposable restrictedScope = null; long logScopeId = DataCommonEventSource.Log.EnterScope(" {0}, denyResolving={1}", ObjectID, denyResolving); try { + restrictedScope = TypeLimiter.EnterRestrictedScope(this); RowDiffIdUsageSection rowDiffIdUsage = default; try { @@ -5896,15 +5898,18 @@ internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving) } finally { + restrictedScope?.Dispose(); DataCommonEventSource.Log.ExitScope(logScopeId); } } internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolving) { + IDisposable restrictedScope = null; RowDiffIdUsageSection rowDiffIdUsage = default; try { + restrictedScope = TypeLimiter.EnterRestrictedScope(this); bool fSchemaFound = false; bool fDataFound = false; bool fIsXdr = false; @@ -6190,6 +6195,7 @@ internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolv } finally { + restrictedScope?.Dispose(); // prepare and cleanup rowDiffId hashtable rowDiffIdUsage.Cleanup(); } diff --git a/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs b/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs index 33aa941e3034c1..472cd9be2b7e13 100644 --- a/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs +++ b/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs @@ -6,6 +6,7 @@ using System.Data.Common; using System.Data.SqlTypes; using System.Diagnostics; +using System.Runtime.Serialization; namespace System.Data { @@ -16,6 +17,7 @@ internal sealed class FunctionNode : ExpressionNode internal int _argumentCount = 0; internal const int initialCapacity = 1; internal ExpressionNode[] _arguments; + private readonly TypeLimiter _capturedLimiter = null; private static readonly Function[] s_funcs = new Function[] { new Function("Abs", FunctionId.Abs, typeof(object), true, false, 1, typeof(object), null, null), @@ -40,6 +42,12 @@ internal sealed class FunctionNode : ExpressionNode internal FunctionNode(DataTable table, string name) : base(table) { + // Because FunctionNode instances are created eagerly but evaluated lazily, + // we need to capture the deserialization scope here. The scope could be + // null if no deserialization is in progress. + + _capturedLimiter = TypeLimiter.Capture(); + _name = name; for (int i = 0; i < s_funcs.Length; i++) { @@ -289,6 +297,11 @@ private Type GetDataType(ExpressionNode node) throw ExprException.InvalidType(typeName); } + // ReadXml might not be on the current call stack. So we'll use the TypeLimiter + // that was captured when this FunctionNode instance was created. + + TypeLimiter.EnsureTypeIsAllowed(dataType, _capturedLimiter); + return dataType; } @@ -494,10 +507,17 @@ private object EvalFunction(FunctionId id, object[] argumentValues, DataRow row, { return SqlConvert.ChangeType2((decimal)SqlConvert.ChangeType2(argumentValues[0], StorageType.Decimal, typeof(decimal), FormatProvider), mytype, type, FormatProvider); } - return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider); } - return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider); + // The Convert function can be called lazily, outside of a previous Serialization Guard scope. + // If there was a type limiter scope on the stack at the time this Convert function was created, + // we must manually re-enter the Serialization Guard scope. + + DeserializationToken deserializationToken = (_capturedLimiter != null) ? SerializationInfo.StartDeserialization() : default; + using (deserializationToken) + { + return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider); + } } return argumentValues[0]; diff --git a/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs b/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs new file mode 100644 index 00000000000000..42afa2b8fedbe9 --- /dev/null +++ b/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; + +namespace System +{ + internal static partial class LocalAppContextSwitches + { + private static int s_allowArbitraryTypeInstantiation; + public static bool AllowArbitraryTypeInstantiation + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => GetCachedSwitchValue("Switch.System.Data.AllowArbitraryDataSetTypeInstantiation", ref s_allowArbitraryTypeInstantiation); + } + } +} diff --git a/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs b/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs new file mode 100644 index 00000000000000..1ff77cb99527b6 --- /dev/null +++ b/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs @@ -0,0 +1,305 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.Drawing; +using System.Linq; +using System.Numerics; +using System.Runtime.Serialization; + +namespace System.Data +{ + internal sealed class TypeLimiter + { + [ThreadStatic] + private static Scope s_activeScope; + + private Scope m_instanceScope; + + private const string AppDomainDataSetDefaultAllowedTypesKey = "System.Data.DataSetDefaultAllowedTypes"; + + private TypeLimiter(Scope scope) + { + Debug.Assert(scope != null); + m_instanceScope = scope; + } + + private static bool IsTypeLimitingDisabled + => LocalAppContextSwitches.AllowArbitraryTypeInstantiation; + + /// + /// Captures the current instance so that future + /// type checks can be performed against the allow list that was active during + /// the current deserialization scope. + /// + /// + /// Returns null if no limiter is active. + /// + public static TypeLimiter Capture() + { + Scope activeScope = s_activeScope; + return (activeScope != null) ? new TypeLimiter(activeScope) : null; + } + + /// + /// Ensures the requested type is allowed by the rules of the active + /// deserialization scope. If a captured scope is provided, we'll use + /// that previously captured scope rather than the thread-static active + /// scope. + /// + /// + /// If is not allowed. + /// + public static void EnsureTypeIsAllowed(Type type, TypeLimiter capturedLimiter = null) + { + if (type is null) + { + return; // nothing to check + } + + Scope capturedScope = capturedLimiter?.m_instanceScope ?? s_activeScope; + if (capturedScope is null) + { + return; // we're not in a restricted scope + } + + if (capturedScope.IsAllowedType(type)) + { + return; // type was explicitly allowed + } + + // We encountered a type that wasn't in the allow list. + // Throw an exception to fail the current operation. + + throw ExceptionBuilder.TypeNotAllowed(type); + } + + public static IDisposable EnterRestrictedScope(DataSet dataSet) + { + if (IsTypeLimitingDisabled) + { + return null; // protections aren't enabled + } + + Scope newScope = new Scope(s_activeScope, GetPreviouslyDeclaredDataTypes(dataSet)); + s_activeScope = newScope; + return newScope; + } + + public static IDisposable EnterRestrictedScope(DataTable dataTable) + { + if (IsTypeLimitingDisabled) + { + return null; // protections aren't enabled + } + + Scope newScope = new Scope(s_activeScope, GetPreviouslyDeclaredDataTypes(dataTable)); + s_activeScope = newScope; + return newScope; + } + + /// + /// Given a , returns all of the + /// values declared on the instance. + /// + private static IEnumerable GetPreviouslyDeclaredDataTypes(DataTable dataTable) + { + return (dataTable != null) + ? dataTable.Columns.Cast().Select(column => column.DataType) + : Enumerable.Empty(); + } + + /// + /// Given a , returns all of the + /// values declared on the instance. + /// + private static IEnumerable GetPreviouslyDeclaredDataTypes(DataSet dataSet) + { + return (dataSet != null) + ? dataSet.Tables.Cast().SelectMany(table => GetPreviouslyDeclaredDataTypes(table)) + : Enumerable.Empty(); + } + + private sealed class Scope : IDisposable + { + /// + /// Types which are always allowed, unconditionally. + /// + private static readonly HashSet s_allowedTypes = new HashSet() + { + /* primitives */ + typeof(bool), + typeof(char), + typeof(sbyte), + typeof(byte), + typeof(short), + typeof(ushort), + typeof(int), + typeof(uint), + typeof(long), + typeof(ulong), + typeof(float), + typeof(double), + typeof(decimal), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(TimeSpan), + typeof(string), + typeof(Guid), + typeof(SqlBinary), + typeof(SqlBoolean), + typeof(SqlByte), + typeof(SqlBytes), + typeof(SqlChars), + typeof(SqlDateTime), + typeof(SqlDecimal), + typeof(SqlDouble), + typeof(SqlGuid), + typeof(SqlInt16), + typeof(SqlInt32), + typeof(SqlInt64), + typeof(SqlMoney), + typeof(SqlSingle), + typeof(SqlString), + + /* non-primitives, but common */ + typeof(object), + typeof(Type), + typeof(BigInteger), + typeof(Uri), + + /* frequently used System.Drawing types */ + typeof(Color), + typeof(Point), + typeof(PointF), + typeof(Rectangle), + typeof(RectangleF), + typeof(Size), + typeof(SizeF), + }; + + /// + /// Types which are allowed within the context of this scope. + /// + private HashSet m_allowedTypes; + + /// + /// This thread's previous scope. + /// + private readonly Scope m_previousScope; + + /// + /// The Serialization Guard token associated with this scope. + /// + private readonly DeserializationToken m_deserializationToken; + + internal Scope(Scope previousScope, IEnumerable allowedTypes) + { + Debug.Assert(allowedTypes != null); + + m_previousScope = previousScope; + m_allowedTypes = new HashSet(allowedTypes.Where(type => type != null)); + m_deserializationToken = SerializationInfo.StartDeserialization(); + } + + public void Dispose() + { + if (this != s_activeScope) + { + // Stacks should never be popped out of order. + // We want to trap this condition in production. + Debug.Fail("Scope was popped out of order."); + throw new ObjectDisposedException(GetType().FullName); + } + + m_deserializationToken.Dispose(); // it's a readonly struct, but Dispose still works properly + s_activeScope = m_previousScope; // could be null + } + + public bool IsAllowedType(Type type) + { + Debug.Assert(type != null); + + // Is the incoming type unconditionally allowed? + + if (IsTypeUnconditionallyAllowed(type)) + { + return true; + } + + // The incoming type is allowed if the current scope or any nested inner + // scope allowed it. + + for (Scope currentScope = this; currentScope != null; currentScope = currentScope.m_previousScope) + { + if (currentScope.m_allowedTypes.Contains(type)) + { + return true; + } + } + + // Did the application programmatically allow this type to be deserialized? + + Type[] appDomainAllowedTypes = (Type[])AppDomain.CurrentDomain.GetData(AppDomainDataSetDefaultAllowedTypesKey); + if (appDomainAllowedTypes != null) + { + for (int i = 0; i < appDomainAllowedTypes.Length; i++) + { + if (type == appDomainAllowedTypes[i]) + { + return true; + } + } + } + + // All checks failed + + return false; + } + + private static bool IsTypeUnconditionallyAllowed(Type type) + { + TryAgain: + Debug.Assert(type != null); + + // Check the list of unconditionally allowed types. + + if (s_allowedTypes.Contains(type)) + { + return true; + } + + // Enums are also always allowed, as we optimistically assume the app + // developer didn't define a dangerous enum type. + + if (type.IsEnum) + { + return true; + } + + // Allow single-dimensional arrays of any unconditionally allowed type. + + if (type.IsSZArray) + { + type = type.GetElementType(); + goto TryAgain; + } + + // Allow generic lists of any unconditionally allowed type. + + if (type.IsGenericType && !type.IsGenericTypeDefinition && type.GetGenericTypeDefinition() == typeof(List<>)) + { + type = type.GetGenericArguments()[0]; + goto TryAgain; + } + + // All checks failed. + + return false; + } + } + } +} diff --git a/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj b/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj index 4d0933db75b09a..9742209ccf9f3b 100644 --- a/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj +++ b/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj @@ -112,6 +112,7 @@ + diff --git a/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs b/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs new file mode 100644 index 00000000000000..3f76cebc8053df --- /dev/null +++ b/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs @@ -0,0 +1,446 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Drawing; +using System.IO; +using System.Numerics; +using System.Runtime.Serialization; +using System.Text; +using System.Xml; +using System.Xml.Schema; +using System.Xml.Serialization; +using Xunit; +using Xunit.Sdk; + +namespace System.Data.Tests +{ + // !! Important !! + // These tests manipulate global state, so they cannot be run in parallel with one another. + // We rely on xunit's default behavior of not parallelizing unit tests declared on the same + // test class: see https://xunit.net/docs/running-tests-in-parallel.html. + public class RestrictedTypeHandlingTests + { + private const string AppDomainDataSetDefaultAllowedTypesKey = "System.Data.DataSetDefaultAllowedTypes"; + + private static readonly Type[] _alwaysAllowedTypes = new Type[] + { + /* primitives */ + typeof(bool), + typeof(char), + typeof(sbyte), + typeof(byte), + typeof(short), + typeof(ushort), + typeof(int), + typeof(uint), + typeof(long), + typeof(ulong), + typeof(float), + typeof(double), + typeof(decimal), + typeof(DateTime), + typeof(DateTimeOffset), + typeof(TimeSpan), + typeof(string), + typeof(Guid), + typeof(SqlBinary), + typeof(SqlBoolean), + typeof(SqlByte), + typeof(SqlBytes), + typeof(SqlChars), + typeof(SqlDateTime), + typeof(SqlDecimal), + typeof(SqlDouble), + typeof(SqlGuid), + typeof(SqlInt16), + typeof(SqlInt32), + typeof(SqlInt64), + typeof(SqlMoney), + typeof(SqlSingle), + typeof(SqlString), + + /* non-primitives, but common */ + typeof(object), + typeof(Type), + typeof(BigInteger), + typeof(Uri), + + /* frequently used System.Drawing types */ + typeof(Color), + typeof(Point), + typeof(PointF), + typeof(Rectangle), + typeof(RectangleF), + typeof(Size), + typeof(SizeF), + + /* to test that enums are allowed */ + typeof(StringComparison), + }; + + public static IEnumerable AllowedTypes() + { + foreach (Type type in _alwaysAllowedTypes) + { + yield return new object[] { type }; // T + yield return new object[] { type.MakeArrayType() }; // T[] (SZArray) + yield return new object[] { type.MakeArrayType().MakeArrayType() }; // T[][] (jagged array) + yield return new object[] { typeof(List<>).MakeGenericType(type) }; // List + } + } + + public static IEnumerable ForbiddenTypes() + { + // StringBuilder isn't in the allow list + + yield return new object[] { typeof(StringBuilder) }; + yield return new object[] { typeof(StringBuilder[]) }; + + // multi-dim arrays and non-sz arrays are forbidden + + yield return new object[] { typeof(int[,]) }; + yield return new object[] { Array.CreateInstance(typeof(int), new[] { 1 }, new[] { 1 }).GetType() }; + + // HashSet isn't in the allow list + + yield return new object[] { typeof(HashSet) }; + + // DataSet / DataTable / SqlXml aren't in the allow list + + yield return new object[] { typeof(DataSet) }; + yield return new object[] { typeof(DataTable) }; + yield return new object[] { typeof(SqlXml) }; + + // Enum, Array, and other base types aren't allowed + + yield return new object[] { typeof(Enum) }; + yield return new object[] { typeof(Array) }; + yield return new object[] { typeof(ValueType) }; + yield return new object[] { typeof(void) }; + } + + [Theory] + [MemberData(nameof(AllowedTypes))] + public void DataTable_ReadXml_AllowsKnownTypes(Type type) + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act + + table = ReadXml(asXml); + + // Assert + + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(type, table.Columns[0].DataType); + } + + [Theory] + [MemberData(nameof(ForbiddenTypes))] + public void DataTable_ReadXml_ForbidsUnknownTypes(Type type) + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act & assert + + Assert.Throws(() => ReadXml(asXml)); + } + + [Fact] + public void DataTable_ReadXml_HandlesXmlSerializableTypes() + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", typeof(object)); + table.Rows.Add(new MyXmlSerializableClass()); + + string asXml = WriteXmlWithSchema(table.WriteXml, XmlWriteMode.IgnoreSchema); + + // Act & assert + // MyXmlSerializableClass shouldn't be allowed as a member for a column + // typed as 'object'. + + table.Rows.Clear(); + Assert.Throws(() => table.ReadXml(new StringReader(asXml))); + } + + [Theory] + [MemberData(nameof(ForbiddenTypes))] + public void DataTable_ReadXmlSchema_AllowsUnknownTypes(Type type) + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act + + table = new DataTable(); + table.ReadXmlSchema(new StringReader(asXml)); + + // Assert + + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(type, table.Columns[0].DataType); + } + + [Fact] + public void DataTable_HonorsGloballyDefinedAllowList() + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", typeof(MyCustomClass)); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act & assert 1 + // First call should fail since MyCustomClass not allowed + + Assert.Throws(() => ReadXml(asXml)); + + // Act & assert 2 + // Deserialization should succeed since it's now in the allow list + + try + { + AppDomain.CurrentDomain.SetData(AppDomainDataSetDefaultAllowedTypesKey, new Type[] + { + typeof(MyCustomClass) + }); + + table = ReadXml(asXml); + + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(typeof(MyCustomClass), table.Columns[0].DataType); + } + finally + { + AppDomain.CurrentDomain.SetData(AppDomainDataSetDefaultAllowedTypesKey, null); + } + } + + [Fact] + public void DataColumn_ConvertExpression_SubjectToAllowList_Success() + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", typeof(object), "CONVERT('42', 'System.Int32')"); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act + + table = ReadXml(asXml); + + // Assert + + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(typeof(object), table.Columns[0].DataType); + Assert.Equal("CONVERT('42', 'System.Int32')", table.Columns[0].Expression); + } + + [Fact] + public void DataColumn_ConvertExpression_SubjectToAllowList_Failure() + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("ColumnA", typeof(object)); + table.Columns.Add("ColumnB", typeof(object), "CONVERT(ColumnA, 'System.Text.StringBuilder')"); + + string asXml = WriteXmlWithSchema(table.WriteXml); + + // Act + // 'StringBuilder' isn't in the allow list, but we're not yet hydrating the Type + // object so we won't check it just yet. + + table = ReadXml(asXml); + + // Assert - the CONVERT function node should have captured the active allow list + // at construction and should apply it now. + + Assert.Throws(() => table.Rows.Add(new StringBuilder())); + } + + [Theory] + [MemberData(nameof(AllowedTypes))] + public void DataSet_ReadXml_AllowsKnownTypes(Type type) + { + // Arrange + + DataSet set = new DataSet("MySet"); + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + set.Tables.Add(table); + + string asXml = WriteXmlWithSchema(set.WriteXml); + + // Act + + table = null; + set = ReadXml(asXml); + + // Assert + + Assert.Equal("MySet", set.DataSetName); + Assert.Equal(1, set.Tables.Count); + + table = set.Tables[0]; + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(type, table.Columns[0].DataType); + } + + [Theory] + [MemberData(nameof(ForbiddenTypes))] + public void DataSet_ReadXml_ForbidsUnknownTypes(Type type) + { + // Arrange + + DataSet set = new DataSet("MySet"); + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + set.Tables.Add(table); + + string asXml = WriteXmlWithSchema(set.WriteXml); + + // Act & assert + + Assert.Throws(() => ReadXml(asXml)); + } + + [Theory] + [MemberData(nameof(ForbiddenTypes))] + public void DataSet_ReadXmlSchema_AllowsUnknownTypes(Type type) + { + // Arrange + + DataSet set = new DataSet("MySet"); + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", type); + set.Tables.Add(table); + + string asXml = WriteXmlWithSchema(set.WriteXml); + + // Act + + set = new DataSet(); + set.ReadXmlSchema(new StringReader(asXml)); + + // Assert + + Assert.Equal("MySet", set.DataSetName); + Assert.Equal(1, set.Tables.Count); + + table = set.Tables[0]; + Assert.Equal("MyTable", table.TableName); + Assert.Equal(1, table.Columns.Count); + Assert.Equal("MyColumn", table.Columns[0].ColumnName); + Assert.Equal(type, table.Columns[0].DataType); + } + + [Fact] + public void SerializationGuard_BlocksFileAccessOnDeserialize() + { + // Arrange + + DataTable table = new DataTable("MyTable"); + table.Columns.Add("MyColumn", typeof(MyCustomClassThatWritesToAFile)); + table.Rows.Add(new MyCustomClassThatWritesToAFile()); + + string asXml = WriteXmlWithSchema(table.WriteXml); + table.Rows.Clear(); + + // Act & assert + + Assert.Throws(() => table.ReadXml(new StringReader(asXml))); + } + + private static string WriteXmlWithSchema(Action writeMethod, XmlWriteMode xmlWriteMode = XmlWriteMode.WriteSchema) + { + StringWriter writer = new StringWriter(); + writeMethod(writer, xmlWriteMode); + return writer.ToString(); + } + + private static T ReadXml(string xml) where T : IXmlSerializable, new() + { + T newObj = new T(); + newObj.ReadXml(new XmlTextReader(new StringReader(xml)) { XmlResolver = null }); // suppress DTDs, same as runtime code + return newObj; + } + + private sealed class MyCustomClass + { + } + + public sealed class MyXmlSerializableClass : IXmlSerializable + { + public XmlSchema GetSchema() + { + return null; + } + + public void ReadXml(XmlReader reader) + { + return; // no-op + } + + public void WriteXml(XmlWriter writer) + { + writer.WriteElementString("MyElement", "MyValue"); + } + } + + private sealed class MyCustomClassThatWritesToAFile : IXmlSerializable + { + public XmlSchema GetSchema() + { + return null; + } + + public void ReadXml(XmlReader reader) + { + // This should be called within a Serialization Guard scope, so the file write + // should fail. + + string tempPath = Path.GetTempFileName(); + File.WriteAllText(tempPath, "This better not be written..."); + File.Delete(tempPath); + throw new XunitException("Unreachable code (SerializationGuard should have kicked in)"); + } + + public void WriteXml(XmlWriter writer) + { + writer.WriteElementString("MyElement", "MyValue"); + } + } + } +}