|
| 1 | +using System.Reflection; |
| 2 | + |
1 | 3 | class IncludeAppender( |
2 | 4 | IReadOnlyDictionary<Type, IReadOnlyDictionary<string, Navigation>> navigations, |
3 | 5 | IReadOnlyDictionary<Type, List<string>> keyNames, |
@@ -64,29 +66,85 @@ static IQueryable<TItem> AddIncludesFromProjection<TItem>( |
64 | 66 | FieldProjectionInfo projection) |
65 | 67 | where TItem : class |
66 | 68 | { |
67 | | - if (projection.Navigations is not { Count: > 0 }) |
| 69 | + var visitedTypes = new HashSet<Type> { typeof(TItem) }; |
| 70 | + |
| 71 | + if (projection.Navigations is { Count: > 0 }) |
68 | 72 | { |
69 | | - return query; |
| 73 | + foreach (var (navName, navProjection) in projection.Navigations) |
| 74 | + { |
| 75 | + if (IsVisitedOrBaseType(navProjection.EntityType, visitedTypes)) |
| 76 | + { |
| 77 | + continue; |
| 78 | + } |
| 79 | + |
| 80 | + visitedTypes.Add(navProjection.EntityType); |
| 81 | + query = query.Include(navName); |
| 82 | + query = AddNestedIncludes(query, navName, navProjection.Projection, visitedTypes); |
| 83 | + visitedTypes.Remove(navProjection.EntityType); |
| 84 | + } |
70 | 85 | } |
71 | 86 |
|
72 | | - var visitedTypes = new HashSet<Type> { typeof(TItem) }; |
| 87 | + // Add derived-type navigation includes for TPH inline fragments |
| 88 | + // e.g. query.Include(e => ((GroupAccessRule)e).Group) |
| 89 | + if (projection.DerivedNavigations is { Count: > 0 }) |
| 90 | + { |
| 91 | + query = AddDerivedTypeIncludes(query, projection.DerivedNavigations, visitedTypes); |
| 92 | + } |
73 | 93 |
|
74 | | - foreach (var (navName, navProjection) in projection.Navigations) |
| 94 | + return query; |
| 95 | + } |
| 96 | + |
| 97 | + static IQueryable<TItem> AddDerivedTypeIncludes<TItem>( |
| 98 | + IQueryable<TItem> query, |
| 99 | + Dictionary<Type, Dictionary<string, NavigationProjectionInfo>> derivedNavigations, |
| 100 | + HashSet<Type> visitedTypes) |
| 101 | + where TItem : class |
| 102 | + { |
| 103 | + var itemType = typeof(TItem); |
| 104 | + var parameter = Expression.Parameter(itemType, "e"); |
| 105 | + |
| 106 | + foreach (var (derivedType, navDict) in derivedNavigations) |
75 | 107 | { |
76 | | - if (IsVisitedOrBaseType(navProjection.EntityType, visitedTypes)) |
| 108 | + // Cast: (DerivedType)e |
| 109 | + var cast = Expression.Convert(parameter, derivedType); |
| 110 | + |
| 111 | + foreach (var (navName, navProjection) in navDict) |
77 | 112 | { |
78 | | - continue; |
79 | | - } |
| 113 | + if (IsVisitedOrBaseType(navProjection.EntityType, visitedTypes)) |
| 114 | + { |
| 115 | + continue; |
| 116 | + } |
80 | 117 |
|
81 | | - visitedTypes.Add(navProjection.EntityType); |
82 | | - query = query.Include(navName); |
83 | | - query = AddNestedIncludes(query, navName, navProjection.Projection, visitedTypes); |
84 | | - visitedTypes.Remove(navProjection.EntityType); |
| 118 | + // Property access: ((DerivedType)e).Navigation |
| 119 | + var property = derivedType.GetProperty(navName); |
| 120 | + if (property == null) |
| 121 | + { |
| 122 | + continue; |
| 123 | + } |
| 124 | + |
| 125 | + var propertyAccess = Expression.Property(cast, property); |
| 126 | + |
| 127 | + // Build lambda: e => ((DerivedType)e).Navigation |
| 128 | + var lambda = Expression.Lambda(propertyAccess, parameter); |
| 129 | + |
| 130 | + // Call EntityFrameworkQueryableExtensions.Include(query, lambda) |
| 131 | + var includeMethod = GetIncludeMethod(itemType, property.PropertyType); |
| 132 | + query = (IQueryable<TItem>)includeMethod.Invoke(null, [query, lambda])!; |
| 133 | + } |
85 | 134 | } |
86 | 135 |
|
87 | 136 | return query; |
88 | 137 | } |
89 | 138 |
|
| 139 | + static MethodInfo GetIncludeMethod(Type entityType, Type propertyType) => |
| 140 | + typeof(EntityFrameworkQueryableExtensions) |
| 141 | + .GetMethods(BindingFlags.Static | BindingFlags.Public) |
| 142 | + .First(_ => _.Name == "Include" && |
| 143 | + _.GetGenericArguments().Length == 2 && |
| 144 | + _.GetParameters().Length == 2 && |
| 145 | + _.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>)) |
| 146 | + .MakeGenericMethod(entityType, propertyType); |
| 147 | + |
90 | 148 | static IQueryable<TItem> AddNestedIncludes<TItem>( |
91 | 149 | IQueryable<TItem> query, |
92 | 150 | string includePath, |
@@ -180,7 +238,195 @@ FieldProjectionInfo GetProjectionInfo( |
180 | 238 | } |
181 | 239 | } |
182 | 240 |
|
183 | | - return new(scalarFields, keys, foreignKeyNames, navProjections); |
| 241 | + // Scan for derived-type navigations from inline fragments (TPH support) |
| 242 | + var derivedNavigations = GetDerivedNavigationsFromFragments(context); |
| 243 | + |
| 244 | + return new(scalarFields, keys, foreignKeyNames, navProjections, derivedNavigations); |
| 245 | + } |
| 246 | + |
| 247 | + Dictionary<Type, Dictionary<string, NavigationProjectionInfo>>? GetDerivedNavigationsFromFragments( |
| 248 | + IResolveFieldContext context) |
| 249 | + { |
| 250 | + var selectionSet = GetLeafSelectionSet(context); |
| 251 | + if (selectionSet?.Selections is null) |
| 252 | + { |
| 253 | + return null; |
| 254 | + } |
| 255 | + |
| 256 | + Dictionary<Type, Dictionary<string, NavigationProjectionInfo>>? result = null; |
| 257 | + |
| 258 | + foreach (var selection in selectionSet.Selections) |
| 259 | + { |
| 260 | + GraphQLTypeCondition? typeCondition; |
| 261 | + GraphQLSelectionSet? fragmentSelectionSet; |
| 262 | + |
| 263 | + switch (selection) |
| 264 | + { |
| 265 | + case GraphQLInlineFragment inlineFragment: |
| 266 | + typeCondition = inlineFragment.TypeCondition; |
| 267 | + fragmentSelectionSet = inlineFragment.SelectionSet; |
| 268 | + break; |
| 269 | + case GraphQLFragmentSpread fragmentSpread: |
| 270 | + { |
| 271 | + var name = fragmentSpread.FragmentName.Name; |
| 272 | + var fragmentDefinition = context.Document.Definitions |
| 273 | + .OfType<GraphQLFragmentDefinition>() |
| 274 | + .SingleOrDefault(_ => _.FragmentName.Name == name); |
| 275 | + if (fragmentDefinition is null) |
| 276 | + { |
| 277 | + continue; |
| 278 | + } |
| 279 | + |
| 280 | + typeCondition = fragmentDefinition.TypeCondition; |
| 281 | + fragmentSelectionSet = fragmentDefinition.SelectionSet; |
| 282 | + break; |
| 283 | + } |
| 284 | + default: |
| 285 | + continue; |
| 286 | + } |
| 287 | + |
| 288 | + if (typeCondition is null || fragmentSelectionSet?.Selections is null) |
| 289 | + { |
| 290 | + continue; |
| 291 | + } |
| 292 | + |
| 293 | + var typeName = typeCondition.Type.Name.StringValue; |
| 294 | + |
| 295 | + // Find the CLR type for this GraphQL type name using the schema |
| 296 | + if (!TryFindDerivedClrType(typeName, context.Schema, out var derivedType)) |
| 297 | + { |
| 298 | + continue; |
| 299 | + } |
| 300 | + |
| 301 | + // Get navigation properties for this derived type |
| 302 | + if (!navigations.TryGetValue(derivedType, out var derivedNavProps)) |
| 303 | + { |
| 304 | + continue; |
| 305 | + } |
| 306 | + |
| 307 | + // Process fields in this fragment against the derived type's navigation properties |
| 308 | + foreach (var field in fragmentSelectionSet.Selections.OfType<GraphQLField>()) |
| 309 | + { |
| 310 | + var fieldName = field.Name.StringValue; |
| 311 | + if (!derivedNavProps.TryGetValue(fieldName, out var navigation)) |
| 312 | + { |
| 313 | + continue; |
| 314 | + } |
| 315 | + |
| 316 | + result ??= []; |
| 317 | + if (!result.TryGetValue(derivedType, out var derivedNavs)) |
| 318 | + { |
| 319 | + derivedNavs = []; |
| 320 | + result[derivedType] = derivedNavs; |
| 321 | + } |
| 322 | + |
| 323 | + if (derivedNavs.ContainsKey(navigation.Name)) |
| 324 | + { |
| 325 | + continue; |
| 326 | + } |
| 327 | + |
| 328 | + var navType = navigation.Type; |
| 329 | + navigations.TryGetValue(navType, out var nestedNavProps); |
| 330 | + keyNames.TryGetValue(navType, out var nestedKeys); |
| 331 | + foreignKeys.TryGetValue(navType, out var nestedFks); |
| 332 | + |
| 333 | + derivedNavs[navigation.Name] = new( |
| 334 | + navType, |
| 335 | + navigation.IsCollection, |
| 336 | + GetNestedProjection(field.SelectionSet, nestedNavProps, nestedKeys, nestedFks, context)); |
| 337 | + } |
| 338 | + } |
| 339 | + |
| 340 | + return result; |
| 341 | + } |
| 342 | + |
| 343 | + /// <summary> |
| 344 | + /// Navigate through connection wrapper fields (edges/items/node) to find the leaf selection set |
| 345 | + /// that contains the actual entity fields and inline fragments. |
| 346 | + /// </summary> |
| 347 | + static GraphQLSelectionSet? GetLeafSelectionSet(IResolveFieldContext context) |
| 348 | + { |
| 349 | + var selectionSet = context.FieldAst.SelectionSet; |
| 350 | + if (selectionSet?.Selections is null) |
| 351 | + { |
| 352 | + return null; |
| 353 | + } |
| 354 | + |
| 355 | + // Drill through connection wrapper fields |
| 356 | + while (true) |
| 357 | + { |
| 358 | + var found = false; |
| 359 | + foreach (var selection in selectionSet.Selections) |
| 360 | + { |
| 361 | + if (selection is GraphQLField field && IsConnectionNodeName(field.Name.StringValue)) |
| 362 | + { |
| 363 | + if (field.SelectionSet is not null) |
| 364 | + { |
| 365 | + selectionSet = field.SelectionSet; |
| 366 | + found = true; |
| 367 | + break; |
| 368 | + } |
| 369 | + } |
| 370 | + } |
| 371 | + |
| 372 | + if (!found) |
| 373 | + { |
| 374 | + break; |
| 375 | + } |
| 376 | + } |
| 377 | + |
| 378 | + return selectionSet; |
| 379 | + } |
| 380 | + |
| 381 | + bool TryFindDerivedClrType(string graphQlTypeName, ISchema schema, [NotNullWhen(true)] out Type? clrType) |
| 382 | + { |
| 383 | + clrType = null; |
| 384 | + |
| 385 | + // Use the schema's type lookup to resolve GraphQL type name → CLR type |
| 386 | + var graphType = schema.AllTypes.FirstOrDefault(_ => _.Name == graphQlTypeName); |
| 387 | + if (graphType is not null) |
| 388 | + { |
| 389 | + // Walk the type hierarchy to find the CLR type from the generic arguments |
| 390 | + var graphClrType = GetSourceType(graphType.GetType()); |
| 391 | + if (graphClrType is not null && navigations.ContainsKey(graphClrType)) |
| 392 | + { |
| 393 | + clrType = graphClrType; |
| 394 | + return true; |
| 395 | + } |
| 396 | + } |
| 397 | + |
| 398 | + // Fallback: match CLR type name directly |
| 399 | + foreach (var type in navigations.Keys) |
| 400 | + { |
| 401 | + if (string.Equals(type.Name, graphQlTypeName, StringComparison.OrdinalIgnoreCase)) |
| 402 | + { |
| 403 | + clrType = type; |
| 404 | + return true; |
| 405 | + } |
| 406 | + } |
| 407 | + |
| 408 | + return false; |
| 409 | + } |
| 410 | + |
| 411 | + static Type? GetSourceType(Type graphType) |
| 412 | + { |
| 413 | + var type = graphType; |
| 414 | + while (type is not null) |
| 415 | + { |
| 416 | + if (type.IsGenericType) |
| 417 | + { |
| 418 | + var genericDef = type.GetGenericTypeDefinition(); |
| 419 | + if (genericDef == typeof(ObjectGraphType<>) || |
| 420 | + genericDef == typeof(InterfaceGraphType<>)) |
| 421 | + { |
| 422 | + return type.GenericTypeArguments[0]; |
| 423 | + } |
| 424 | + } |
| 425 | + |
| 426 | + type = type.BaseType; |
| 427 | + } |
| 428 | + |
| 429 | + return null; |
184 | 430 | } |
185 | 431 |
|
186 | 432 | void ProcessConnectionNodeFields( |
|
0 commit comments