diff --git a/src/libraries/System.Data.Common/src/Resources/Strings.resx b/src/libraries/System.Data.Common/src/Resources/Strings.resx
index 4535a6944bffc1..671bf490bc1351 100644
--- a/src/libraries/System.Data.Common/src/Resources/Strings.resx
+++ b/src/libraries/System.Data.Common/src/Resources/Strings.resx
@@ -165,6 +165,7 @@
'{0}' argument is out of range.
'{0}' argument cannot be null.
'{0}' argument contains null value.
+ Type '{0}' is not allowed here. See https://go.microsoft.com/fwlink/?linkid=2132227 for more details.
Cannot find column {0}.
Column '{0}' already belongs to this DataTable.
Column '{0}' already belongs to another DataTable.
diff --git a/src/libraries/System.Data.Common/src/System.Data.Common.csproj b/src/libraries/System.Data.Common/src/System.Data.Common.csproj
index 20091047d6b9b5..7beb59ac3a8bb6 100644
--- a/src/libraries/System.Data.Common/src/System.Data.Common.csproj
+++ b/src/libraries/System.Data.Common/src/System.Data.Common.csproj
@@ -2,7 +2,7 @@
System.Data.Common
true
- $(NetCoreAppCurrent)
+ $(NetCoreAppCurrent)-Windows_NT;$(NetCoreAppCurrent)-Unix
@@ -123,6 +123,10 @@
+
+
+ Common\System\LocalAppContextSwitches.Common.cs
+
@@ -156,6 +160,7 @@
+
@@ -295,25 +300,23 @@
-
+
+
+
+
+
+
+
-
-
-
+
-
-
-
-
-
-
diff --git a/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs b/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs
index ffe4960cce7b87..dc3fe0533569d4 100644
--- a/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/Common/ObjectStorage.cs
@@ -406,6 +406,9 @@ public override object ConvertXmlToObject(XmlReader xmlReader, XmlRootAttribute
if (type == typeof(object))
throw ExceptionBuilder.CanNotDeserializeObjectType();
+
+ TypeLimiter.EnsureTypeIsAllowed(type);
+
if (!isBaseCLRType)
{
retValue = System.Activator.CreateInstance(type, true);
diff --git a/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs b/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs
index d44ec1e0b88751..560f6ecd7e7148 100644
--- a/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/DataColumn.cs
@@ -143,6 +143,7 @@ public DataColumn(string columnName, Type dataType, string expr, MappingType typ
private void UpdateColumnType(Type type, StorageType typeCode)
{
+ TypeLimiter.EnsureTypeIsAllowed(type);
_dataType = type;
_storageType = typeCode;
if (StorageType.DateTime != typeCode)
diff --git a/src/libraries/System.Data.Common/src/System/Data/DataException.cs b/src/libraries/System.Data.Common/src/System/Data/DataException.cs
index 7aa84800747b2e..aede6f282f193f 100644
--- a/src/libraries/System.Data.Common/src/System/Data/DataException.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/DataException.cs
@@ -350,6 +350,7 @@ private static void ThrowDataException(string error, Exception innerException)
public static Exception ArgumentOutOfRange(string paramName) => _ArgumentOutOfRange(paramName, SR.Format(SR.Data_ArgumentOutOfRange, paramName));
public static Exception BadObjectPropertyAccess(string error) => _InvalidOperation(SR.Format(SR.DataConstraint_BadObjectPropertyAccess, error));
public static Exception ArgumentContainsNull(string paramName) => _Argument(paramName, SR.Format(SR.Data_ArgumentContainsNull, paramName));
+ public static Exception TypeNotAllowed(Type type) => _InvalidOperation(SR.Format(SR.Data_TypeNotAllowed, type.AssemblyQualifiedName));
//
diff --git a/src/libraries/System.Data.Common/src/System/Data/DataSet.cs b/src/libraries/System.Data.Common/src/System/Data/DataSet.cs
index f4d6ad887d5ce1..c09b35203a9c5e 100644
--- a/src/libraries/System.Data.Common/src/System/Data/DataSet.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/DataSet.cs
@@ -1961,9 +1961,11 @@ private void WriteXmlSchema(XmlWriter writer, SchemaFormat schemaFormat, Convert
internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving)
{
+ IDisposable restrictedScope = null;
long logScopeId = DataCommonEventSource.Log.EnterScope(" {0}, denyResolving={1}", ObjectID, denyResolving);
try
{
+ restrictedScope = TypeLimiter.EnterRestrictedScope(this);
DataTable.DSRowDiffIdUsageSection rowDiffIdUsage = default;
try
{
@@ -2231,6 +2233,7 @@ internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving)
}
finally
{
+ restrictedScope?.Dispose();
DataCommonEventSource.Log.ExitScope(logScopeId);
}
}
@@ -2467,9 +2470,11 @@ private void ReadXmlDiffgram(XmlReader reader)
internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolving)
{
+ IDisposable restictedScope = null;
long logScopeId = DataCommonEventSource.Log.EnterScope(" {0}, mode={1}, denyResolving={2}", ObjectID, mode, denyResolving);
try
{
+ restictedScope = TypeLimiter.EnterRestrictedScope(this);
XmlReadMode ret = mode;
if (reader == null)
@@ -2711,6 +2716,7 @@ internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolv
}
finally
{
+ restictedScope?.Dispose();
DataCommonEventSource.Log.ExitScope(logScopeId);
}
}
diff --git a/src/libraries/System.Data.Common/src/System/Data/DataTable.cs b/src/libraries/System.Data.Common/src/System/Data/DataTable.cs
index 917908ec2a451f..5b2fa268c94c60 100644
--- a/src/libraries/System.Data.Common/src/System/Data/DataTable.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/DataTable.cs
@@ -5659,9 +5659,11 @@ private bool IsEmptyXml(XmlReader reader)
internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving)
{
+ IDisposable restrictedScope = null;
long logScopeId = DataCommonEventSource.Log.EnterScope(" {0}, denyResolving={1}", ObjectID, denyResolving);
try
{
+ restrictedScope = TypeLimiter.EnterRestrictedScope(this);
RowDiffIdUsageSection rowDiffIdUsage = default;
try
{
@@ -5896,15 +5898,18 @@ internal XmlReadMode ReadXml(XmlReader reader, bool denyResolving)
}
finally
{
+ restrictedScope?.Dispose();
DataCommonEventSource.Log.ExitScope(logScopeId);
}
}
internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolving)
{
+ IDisposable restrictedScope = null;
RowDiffIdUsageSection rowDiffIdUsage = default;
try
{
+ restrictedScope = TypeLimiter.EnterRestrictedScope(this);
bool fSchemaFound = false;
bool fDataFound = false;
bool fIsXdr = false;
@@ -6190,6 +6195,7 @@ internal XmlReadMode ReadXml(XmlReader reader, XmlReadMode mode, bool denyResolv
}
finally
{
+ restrictedScope?.Dispose();
// prepare and cleanup rowDiffId hashtable
rowDiffIdUsage.Cleanup();
}
diff --git a/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs b/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs
index 33aa941e3034c1..472cd9be2b7e13 100644
--- a/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs
+++ b/src/libraries/System.Data.Common/src/System/Data/Filter/FunctionNode.cs
@@ -6,6 +6,7 @@
using System.Data.Common;
using System.Data.SqlTypes;
using System.Diagnostics;
+using System.Runtime.Serialization;
namespace System.Data
{
@@ -16,6 +17,7 @@ internal sealed class FunctionNode : ExpressionNode
internal int _argumentCount = 0;
internal const int initialCapacity = 1;
internal ExpressionNode[] _arguments;
+ private readonly TypeLimiter _capturedLimiter = null;
private static readonly Function[] s_funcs = new Function[] {
new Function("Abs", FunctionId.Abs, typeof(object), true, false, 1, typeof(object), null, null),
@@ -40,6 +42,12 @@ internal sealed class FunctionNode : ExpressionNode
internal FunctionNode(DataTable table, string name) : base(table)
{
+ // Because FunctionNode instances are created eagerly but evaluated lazily,
+ // we need to capture the deserialization scope here. The scope could be
+ // null if no deserialization is in progress.
+
+ _capturedLimiter = TypeLimiter.Capture();
+
_name = name;
for (int i = 0; i < s_funcs.Length; i++)
{
@@ -289,6 +297,11 @@ private Type GetDataType(ExpressionNode node)
throw ExprException.InvalidType(typeName);
}
+ // ReadXml might not be on the current call stack. So we'll use the TypeLimiter
+ // that was captured when this FunctionNode instance was created.
+
+ TypeLimiter.EnsureTypeIsAllowed(dataType, _capturedLimiter);
+
return dataType;
}
@@ -494,10 +507,17 @@ private object EvalFunction(FunctionId id, object[] argumentValues, DataRow row,
{
return SqlConvert.ChangeType2((decimal)SqlConvert.ChangeType2(argumentValues[0], StorageType.Decimal, typeof(decimal), FormatProvider), mytype, type, FormatProvider);
}
- return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider);
}
- return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider);
+ // The Convert function can be called lazily, outside of a previous Serialization Guard scope.
+ // If there was a type limiter scope on the stack at the time this Convert function was created,
+ // we must manually re-enter the Serialization Guard scope.
+
+ DeserializationToken deserializationToken = (_capturedLimiter != null) ? SerializationInfo.StartDeserialization() : default;
+ using (deserializationToken)
+ {
+ return SqlConvert.ChangeType2(argumentValues[0], mytype, type, FormatProvider);
+ }
}
return argumentValues[0];
diff --git a/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs b/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs
new file mode 100644
index 00000000000000..42afa2b8fedbe9
--- /dev/null
+++ b/src/libraries/System.Data.Common/src/System/Data/LocalAppContextSwitches.cs
@@ -0,0 +1,18 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.CompilerServices;
+
+namespace System
+{
+ internal static partial class LocalAppContextSwitches
+ {
+ private static int s_allowArbitraryTypeInstantiation;
+ public static bool AllowArbitraryTypeInstantiation
+ {
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ get => GetCachedSwitchValue("Switch.System.Data.AllowArbitraryDataSetTypeInstantiation", ref s_allowArbitraryTypeInstantiation);
+ }
+ }
+}
diff --git a/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs b/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs
new file mode 100644
index 00000000000000..1ff77cb99527b6
--- /dev/null
+++ b/src/libraries/System.Data.Common/src/System/Data/TypeLimiter.cs
@@ -0,0 +1,305 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Data.SqlTypes;
+using System.Diagnostics;
+using System.Drawing;
+using System.Linq;
+using System.Numerics;
+using System.Runtime.Serialization;
+
+namespace System.Data
+{
+ internal sealed class TypeLimiter
+ {
+ [ThreadStatic]
+ private static Scope s_activeScope;
+
+ private Scope m_instanceScope;
+
+ private const string AppDomainDataSetDefaultAllowedTypesKey = "System.Data.DataSetDefaultAllowedTypes";
+
+ private TypeLimiter(Scope scope)
+ {
+ Debug.Assert(scope != null);
+ m_instanceScope = scope;
+ }
+
+ private static bool IsTypeLimitingDisabled
+ => LocalAppContextSwitches.AllowArbitraryTypeInstantiation;
+
+ ///
+ /// Captures the current instance so that future
+ /// type checks can be performed against the allow list that was active during
+ /// the current deserialization scope.
+ ///
+ ///
+ /// Returns null if no limiter is active.
+ ///
+ public static TypeLimiter Capture()
+ {
+ Scope activeScope = s_activeScope;
+ return (activeScope != null) ? new TypeLimiter(activeScope) : null;
+ }
+
+ ///
+ /// Ensures the requested type is allowed by the rules of the active
+ /// deserialization scope. If a captured scope is provided, we'll use
+ /// that previously captured scope rather than the thread-static active
+ /// scope.
+ ///
+ ///
+ /// If is not allowed.
+ ///
+ public static void EnsureTypeIsAllowed(Type type, TypeLimiter capturedLimiter = null)
+ {
+ if (type is null)
+ {
+ return; // nothing to check
+ }
+
+ Scope capturedScope = capturedLimiter?.m_instanceScope ?? s_activeScope;
+ if (capturedScope is null)
+ {
+ return; // we're not in a restricted scope
+ }
+
+ if (capturedScope.IsAllowedType(type))
+ {
+ return; // type was explicitly allowed
+ }
+
+ // We encountered a type that wasn't in the allow list.
+ // Throw an exception to fail the current operation.
+
+ throw ExceptionBuilder.TypeNotAllowed(type);
+ }
+
+ public static IDisposable EnterRestrictedScope(DataSet dataSet)
+ {
+ if (IsTypeLimitingDisabled)
+ {
+ return null; // protections aren't enabled
+ }
+
+ Scope newScope = new Scope(s_activeScope, GetPreviouslyDeclaredDataTypes(dataSet));
+ s_activeScope = newScope;
+ return newScope;
+ }
+
+ public static IDisposable EnterRestrictedScope(DataTable dataTable)
+ {
+ if (IsTypeLimitingDisabled)
+ {
+ return null; // protections aren't enabled
+ }
+
+ Scope newScope = new Scope(s_activeScope, GetPreviouslyDeclaredDataTypes(dataTable));
+ s_activeScope = newScope;
+ return newScope;
+ }
+
+ ///
+ /// Given a , returns all of the
+ /// values declared on the instance.
+ ///
+ private static IEnumerable GetPreviouslyDeclaredDataTypes(DataTable dataTable)
+ {
+ return (dataTable != null)
+ ? dataTable.Columns.Cast().Select(column => column.DataType)
+ : Enumerable.Empty();
+ }
+
+ ///
+ /// Given a , returns all of the
+ /// values declared on the instance.
+ ///
+ private static IEnumerable GetPreviouslyDeclaredDataTypes(DataSet dataSet)
+ {
+ return (dataSet != null)
+ ? dataSet.Tables.Cast().SelectMany(table => GetPreviouslyDeclaredDataTypes(table))
+ : Enumerable.Empty();
+ }
+
+ private sealed class Scope : IDisposable
+ {
+ ///
+ /// Types which are always allowed, unconditionally.
+ ///
+ private static readonly HashSet s_allowedTypes = new HashSet()
+ {
+ /* primitives */
+ typeof(bool),
+ typeof(char),
+ typeof(sbyte),
+ typeof(byte),
+ typeof(short),
+ typeof(ushort),
+ typeof(int),
+ typeof(uint),
+ typeof(long),
+ typeof(ulong),
+ typeof(float),
+ typeof(double),
+ typeof(decimal),
+ typeof(DateTime),
+ typeof(DateTimeOffset),
+ typeof(TimeSpan),
+ typeof(string),
+ typeof(Guid),
+ typeof(SqlBinary),
+ typeof(SqlBoolean),
+ typeof(SqlByte),
+ typeof(SqlBytes),
+ typeof(SqlChars),
+ typeof(SqlDateTime),
+ typeof(SqlDecimal),
+ typeof(SqlDouble),
+ typeof(SqlGuid),
+ typeof(SqlInt16),
+ typeof(SqlInt32),
+ typeof(SqlInt64),
+ typeof(SqlMoney),
+ typeof(SqlSingle),
+ typeof(SqlString),
+
+ /* non-primitives, but common */
+ typeof(object),
+ typeof(Type),
+ typeof(BigInteger),
+ typeof(Uri),
+
+ /* frequently used System.Drawing types */
+ typeof(Color),
+ typeof(Point),
+ typeof(PointF),
+ typeof(Rectangle),
+ typeof(RectangleF),
+ typeof(Size),
+ typeof(SizeF),
+ };
+
+ ///
+ /// Types which are allowed within the context of this scope.
+ ///
+ private HashSet m_allowedTypes;
+
+ ///
+ /// This thread's previous scope.
+ ///
+ private readonly Scope m_previousScope;
+
+ ///
+ /// The Serialization Guard token associated with this scope.
+ ///
+ private readonly DeserializationToken m_deserializationToken;
+
+ internal Scope(Scope previousScope, IEnumerable allowedTypes)
+ {
+ Debug.Assert(allowedTypes != null);
+
+ m_previousScope = previousScope;
+ m_allowedTypes = new HashSet(allowedTypes.Where(type => type != null));
+ m_deserializationToken = SerializationInfo.StartDeserialization();
+ }
+
+ public void Dispose()
+ {
+ if (this != s_activeScope)
+ {
+ // Stacks should never be popped out of order.
+ // We want to trap this condition in production.
+ Debug.Fail("Scope was popped out of order.");
+ throw new ObjectDisposedException(GetType().FullName);
+ }
+
+ m_deserializationToken.Dispose(); // it's a readonly struct, but Dispose still works properly
+ s_activeScope = m_previousScope; // could be null
+ }
+
+ public bool IsAllowedType(Type type)
+ {
+ Debug.Assert(type != null);
+
+ // Is the incoming type unconditionally allowed?
+
+ if (IsTypeUnconditionallyAllowed(type))
+ {
+ return true;
+ }
+
+ // The incoming type is allowed if the current scope or any nested inner
+ // scope allowed it.
+
+ for (Scope currentScope = this; currentScope != null; currentScope = currentScope.m_previousScope)
+ {
+ if (currentScope.m_allowedTypes.Contains(type))
+ {
+ return true;
+ }
+ }
+
+ // Did the application programmatically allow this type to be deserialized?
+
+ Type[] appDomainAllowedTypes = (Type[])AppDomain.CurrentDomain.GetData(AppDomainDataSetDefaultAllowedTypesKey);
+ if (appDomainAllowedTypes != null)
+ {
+ for (int i = 0; i < appDomainAllowedTypes.Length; i++)
+ {
+ if (type == appDomainAllowedTypes[i])
+ {
+ return true;
+ }
+ }
+ }
+
+ // All checks failed
+
+ return false;
+ }
+
+ private static bool IsTypeUnconditionallyAllowed(Type type)
+ {
+ TryAgain:
+ Debug.Assert(type != null);
+
+ // Check the list of unconditionally allowed types.
+
+ if (s_allowedTypes.Contains(type))
+ {
+ return true;
+ }
+
+ // Enums are also always allowed, as we optimistically assume the app
+ // developer didn't define a dangerous enum type.
+
+ if (type.IsEnum)
+ {
+ return true;
+ }
+
+ // Allow single-dimensional arrays of any unconditionally allowed type.
+
+ if (type.IsSZArray)
+ {
+ type = type.GetElementType();
+ goto TryAgain;
+ }
+
+ // Allow generic lists of any unconditionally allowed type.
+
+ if (type.IsGenericType && !type.IsGenericTypeDefinition && type.GetGenericTypeDefinition() == typeof(List<>))
+ {
+ type = type.GetGenericArguments()[0];
+ goto TryAgain;
+ }
+
+ // All checks failed.
+
+ return false;
+ }
+ }
+ }
+}
diff --git a/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj b/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj
index 4d0933db75b09a..9742209ccf9f3b 100644
--- a/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj
+++ b/src/libraries/System.Data.Common/tests/System.Data.Common.Tests.csproj
@@ -112,6 +112,7 @@
+
diff --git a/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs b/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs
new file mode 100644
index 00000000000000..3f76cebc8053df
--- /dev/null
+++ b/src/libraries/System.Data.Common/tests/System/Data/RestrictedTypeHandlingTests.cs
@@ -0,0 +1,446 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Data.SqlTypes;
+using System.Drawing;
+using System.IO;
+using System.Numerics;
+using System.Runtime.Serialization;
+using System.Text;
+using System.Xml;
+using System.Xml.Schema;
+using System.Xml.Serialization;
+using Xunit;
+using Xunit.Sdk;
+
+namespace System.Data.Tests
+{
+ // !! Important !!
+ // These tests manipulate global state, so they cannot be run in parallel with one another.
+ // We rely on xunit's default behavior of not parallelizing unit tests declared on the same
+ // test class: see https://xunit.net/docs/running-tests-in-parallel.html.
+ public class RestrictedTypeHandlingTests
+ {
+ private const string AppDomainDataSetDefaultAllowedTypesKey = "System.Data.DataSetDefaultAllowedTypes";
+
+ private static readonly Type[] _alwaysAllowedTypes = new Type[]
+ {
+ /* primitives */
+ typeof(bool),
+ typeof(char),
+ typeof(sbyte),
+ typeof(byte),
+ typeof(short),
+ typeof(ushort),
+ typeof(int),
+ typeof(uint),
+ typeof(long),
+ typeof(ulong),
+ typeof(float),
+ typeof(double),
+ typeof(decimal),
+ typeof(DateTime),
+ typeof(DateTimeOffset),
+ typeof(TimeSpan),
+ typeof(string),
+ typeof(Guid),
+ typeof(SqlBinary),
+ typeof(SqlBoolean),
+ typeof(SqlByte),
+ typeof(SqlBytes),
+ typeof(SqlChars),
+ typeof(SqlDateTime),
+ typeof(SqlDecimal),
+ typeof(SqlDouble),
+ typeof(SqlGuid),
+ typeof(SqlInt16),
+ typeof(SqlInt32),
+ typeof(SqlInt64),
+ typeof(SqlMoney),
+ typeof(SqlSingle),
+ typeof(SqlString),
+
+ /* non-primitives, but common */
+ typeof(object),
+ typeof(Type),
+ typeof(BigInteger),
+ typeof(Uri),
+
+ /* frequently used System.Drawing types */
+ typeof(Color),
+ typeof(Point),
+ typeof(PointF),
+ typeof(Rectangle),
+ typeof(RectangleF),
+ typeof(Size),
+ typeof(SizeF),
+
+ /* to test that enums are allowed */
+ typeof(StringComparison),
+ };
+
+ public static IEnumerable