1
+ namespace Xunit
2
+ {
3
+ using System ;
4
+ using System . Collections ;
5
+ using System . Collections . Generic ;
6
+ using System . Linq ;
7
+ using System . Reflection ;
8
+
9
+ /// <summary>
10
+ /// Specifies which member should provide data for this parameter used for running the test method.
11
+ /// </summary>
12
+ [ AttributeUsage ( AttributeTargets . Parameter , AllowMultiple = true ) ]
13
+ public class CombinatorialMemberDataAttribute : Attribute
14
+ {
15
+ /// <summary>
16
+ /// Initializes a new instance of the <see cref="CombinatorialMemberDataAttribute"/> class.
17
+ /// </summary>
18
+ /// <param name="memberName">The name of the public static member on the test class that will provide the test data</param>
19
+ /// <param name="parameters">The parameters for the member (only supported for methods; ignored for everything else)</param>
20
+ public CombinatorialMemberDataAttribute ( string memberName , params object [ ] parameters )
21
+ {
22
+ this . MemberName = memberName ?? throw new ArgumentNullException ( nameof ( memberName ) ) ;
23
+ this . Parameters = parameters ;
24
+ }
25
+
26
+ /// <summary>
27
+ /// Gets the member name.
28
+ /// </summary>
29
+ public string MemberName { get ; }
30
+
31
+ /// <summary>
32
+ /// Gets or sets the type to retrieve the member from. If not set, then the property will be
33
+ /// retrieved from the unit test class.
34
+ /// </summary>
35
+ public Type MemberType { get ; set ; }
36
+
37
+ /// <summary>
38
+ /// Gets or sets the parameters passed to the member. Only supported for static methods.
39
+ /// </summary>
40
+ public object [ ] Parameters { get ; }
41
+
42
+ /// <summary>
43
+ /// Gets the values that should be passed to this parameter on the test method.
44
+ /// </summary>
45
+ /// <param name="parameterInfo">The parameter for which the data should be provided</param>
46
+ /// <returns>An array of values.</returns>
47
+ public object [ ] GetValues ( ParameterInfo parameterInfo )
48
+ {
49
+ var testMethod = parameterInfo . Member ;
50
+
51
+ var type = this . MemberType ?? testMethod ? . DeclaringType ;
52
+
53
+ if ( type == null )
54
+ {
55
+ return new object [ 0 ] ;
56
+ }
57
+
58
+ var accessor = this . GetPropertyAccessor ( type , parameterInfo ) ?? this . GetMethodAccessor ( type , parameterInfo ) ?? this . GetFieldAccessor ( type , parameterInfo ) ;
59
+ if ( accessor == null )
60
+ {
61
+ var parameterText = this . Parameters ? . Length > 0 ? $ " with parameter types: { string . Join ( ", " , this . Parameters . Select ( p => p ? . GetType ( ) . FullName ?? "(null)" ) ) } " : string . Empty ;
62
+ throw new ArgumentException ( $ "Could not find public static member (property, field, or method) named '{ this . MemberName } ' on { type . FullName } { parameterText } ") ;
63
+ }
64
+
65
+ var obj = ( IEnumerable ) accessor ( ) ;
66
+ return obj . Cast < object > ( ) . ToArray ( ) ;
67
+ }
68
+
69
+ private Func < object > GetPropertyAccessor ( Type type , ParameterInfo parameterInfo )
70
+ {
71
+ PropertyInfo propInfo = null ;
72
+ for ( var reflectionType = type ; reflectionType != null ; reflectionType = reflectionType . GetTypeInfo ( ) . BaseType )
73
+ {
74
+ propInfo = reflectionType . GetRuntimeProperty ( this . MemberName ) ;
75
+ if ( propInfo != null )
76
+ {
77
+ break ;
78
+ }
79
+ }
80
+
81
+ if ( propInfo ? . GetMethod == null || ! propInfo . GetMethod . IsStatic )
82
+ {
83
+ return null ;
84
+ }
85
+
86
+ this . EnsureValidMemberDataType ( propInfo . PropertyType , propInfo . DeclaringType , parameterInfo ) ;
87
+
88
+ return ( ) => propInfo . GetValue ( null , null ) ;
89
+ }
90
+
91
+ private Func < object > GetMethodAccessor ( Type type , ParameterInfo parameterInfo )
92
+ {
93
+ MethodInfo methodInfo = null ;
94
+ var parameterTypes = this . Parameters == null
95
+ ? new Type [ 0 ]
96
+ : this . Parameters . Select ( p => p . GetType ( ) ) . ToArray ( ) ;
97
+ for ( var reflectionType = type ; reflectionType != null ; reflectionType = reflectionType . GetTypeInfo ( ) . BaseType )
98
+ {
99
+ methodInfo = reflectionType . GetRuntimeMethods ( ) . FirstOrDefault ( m => m . Name == this . MemberName && this . ParameterTypesCompatible ( m . GetParameters ( ) , parameterTypes ) ) ;
100
+
101
+ if ( methodInfo != null )
102
+ {
103
+ break ;
104
+ }
105
+ }
106
+
107
+ if ( methodInfo == null || ! methodInfo . IsStatic )
108
+ {
109
+ return null ;
110
+ }
111
+
112
+ this . EnsureValidMemberDataType ( methodInfo . ReturnType , methodInfo . DeclaringType , parameterInfo ) ;
113
+
114
+ return ( ) => methodInfo . Invoke ( null , this . Parameters ) ;
115
+ }
116
+
117
+ private bool ParameterTypesCompatible ( ParameterInfo [ ] parameters , Type [ ] parameterTypes )
118
+ {
119
+ if ( parameters . Length != parameterTypes . Length )
120
+ {
121
+ return false ;
122
+ }
123
+
124
+ for ( var i = 0 ; i < parameters . Length ; i ++ )
125
+ {
126
+ if ( parameterTypes [ i ] != null && ! parameters [ i ] . ParameterType . GetTypeInfo ( )
127
+ . IsAssignableFrom ( parameterTypes [ i ] . GetTypeInfo ( ) ) )
128
+ {
129
+ return false ;
130
+ }
131
+ }
132
+
133
+ return true ;
134
+ }
135
+
136
+ private Func < object > GetFieldAccessor ( Type type , ParameterInfo parameterInfo )
137
+ {
138
+ FieldInfo fieldInfo = null ;
139
+ for ( var reflectionType = type ; reflectionType != null ; reflectionType = reflectionType . GetTypeInfo ( ) . BaseType )
140
+ {
141
+ fieldInfo = reflectionType . GetRuntimeField ( this . MemberName ) ;
142
+
143
+ if ( fieldInfo != null )
144
+ {
145
+ break ;
146
+ }
147
+ }
148
+
149
+ if ( fieldInfo == null || ! fieldInfo . IsStatic )
150
+ {
151
+ return null ;
152
+ }
153
+
154
+ this . EnsureValidMemberDataType ( fieldInfo . FieldType , fieldInfo . DeclaringType , parameterInfo ) ;
155
+
156
+ return ( ) => fieldInfo . GetValue ( null ) ;
157
+ }
158
+
159
+ private void EnsureValidMemberDataType ( Type type , Type declaringType , ParameterInfo parameterType )
160
+ {
161
+ var enumerableTypeInfo = typeof ( IEnumerable ) . GetTypeInfo ( ) ;
162
+
163
+ if ( ! enumerableTypeInfo . IsAssignableFrom ( type . GetTypeInfo ( ) ) )
164
+ {
165
+ throw new ArgumentException ( $ "Member { this . MemberName } on { type . FullName } did not return IEnumerable") ;
166
+ }
167
+
168
+ var enumerableGenericType = this . GetEnumerableType ( type ) ;
169
+ if ( enumerableTypeInfo . IsAssignableFrom ( enumerableGenericType ) )
170
+ {
171
+ throw new ArgumentException (
172
+ $ "Member { this . MemberName } on { declaringType . FullName } returned an IEnumerable<object[]>, which is not supported") ;
173
+ }
174
+
175
+ if ( ! enumerableGenericType . IsAssignableFrom ( parameterType . ParameterType . GetTypeInfo ( ) ) )
176
+ {
177
+ throw new ArgumentException (
178
+ $ "Parameter type { parameterType . ParameterType . FullName } is not compatible with returned member type { enumerableGenericType . FullName } ") ;
179
+ }
180
+ }
181
+
182
+ private TypeInfo GetEnumerableType ( Type enumerableType )
183
+ {
184
+ var enumerableGenericTypeDefinition = enumerableType . GetTypeInfo ( ) . GetGenericArguments ( ) ;
185
+ if ( enumerableGenericTypeDefinition != null )
186
+ {
187
+ return enumerableGenericTypeDefinition [ 0 ] . GetTypeInfo ( ) ;
188
+ }
189
+
190
+ foreach ( var implementedInterface in enumerableType . GetTypeInfo ( ) . ImplementedInterfaces )
191
+ {
192
+ var interfaceTypeInfo = implementedInterface . GetTypeInfo ( ) ;
193
+ if ( interfaceTypeInfo . IsGenericType && interfaceTypeInfo . GetGenericTypeDefinition ( ) == typeof ( IEnumerable < > ) )
194
+ {
195
+ return interfaceTypeInfo . GetGenericArguments ( ) [ 0 ] . GetTypeInfo ( ) ;
196
+ }
197
+ }
198
+
199
+ return null ;
200
+ }
201
+ }
202
+ }
0 commit comments