Skip to content

Commit b99d3ae

Browse files
authored
Get navigation props working on concrete types where the name property is not present on the TPH Base (#1254)
1 parent 9e9abe6 commit b99d3ae

21 files changed

Lines changed: 657 additions & 13 deletions

src/GraphQL.EntityFramework/IncludeAppender.cs

Lines changed: 258 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using System.Reflection;
2+
13
class IncludeAppender(
24
IReadOnlyDictionary<Type, IReadOnlyDictionary<string, Navigation>> navigations,
35
IReadOnlyDictionary<Type, List<string>> keyNames,
@@ -64,29 +66,85 @@ static IQueryable<TItem> AddIncludesFromProjection<TItem>(
6466
FieldProjectionInfo projection)
6567
where TItem : class
6668
{
67-
if (projection.Navigations is not { Count: > 0 })
69+
var visitedTypes = new HashSet<Type> { typeof(TItem) };
70+
71+
if (projection.Navigations is { Count: > 0 })
6872
{
69-
return query;
73+
foreach (var (navName, navProjection) in projection.Navigations)
74+
{
75+
if (IsVisitedOrBaseType(navProjection.EntityType, visitedTypes))
76+
{
77+
continue;
78+
}
79+
80+
visitedTypes.Add(navProjection.EntityType);
81+
query = query.Include(navName);
82+
query = AddNestedIncludes(query, navName, navProjection.Projection, visitedTypes);
83+
visitedTypes.Remove(navProjection.EntityType);
84+
}
7085
}
7186

72-
var visitedTypes = new HashSet<Type> { typeof(TItem) };
87+
// Add derived-type navigation includes for TPH inline fragments
88+
// e.g. query.Include(e => ((GroupAccessRule)e).Group)
89+
if (projection.DerivedNavigations is { Count: > 0 })
90+
{
91+
query = AddDerivedTypeIncludes(query, projection.DerivedNavigations, visitedTypes);
92+
}
7393

74-
foreach (var (navName, navProjection) in projection.Navigations)
94+
return query;
95+
}
96+
97+
static IQueryable<TItem> AddDerivedTypeIncludes<TItem>(
98+
IQueryable<TItem> query,
99+
Dictionary<Type, Dictionary<string, NavigationProjectionInfo>> derivedNavigations,
100+
HashSet<Type> visitedTypes)
101+
where TItem : class
102+
{
103+
var itemType = typeof(TItem);
104+
var parameter = Expression.Parameter(itemType, "e");
105+
106+
foreach (var (derivedType, navDict) in derivedNavigations)
75107
{
76-
if (IsVisitedOrBaseType(navProjection.EntityType, visitedTypes))
108+
// Cast: (DerivedType)e
109+
var cast = Expression.Convert(parameter, derivedType);
110+
111+
foreach (var (navName, navProjection) in navDict)
77112
{
78-
continue;
79-
}
113+
if (IsVisitedOrBaseType(navProjection.EntityType, visitedTypes))
114+
{
115+
continue;
116+
}
80117

81-
visitedTypes.Add(navProjection.EntityType);
82-
query = query.Include(navName);
83-
query = AddNestedIncludes(query, navName, navProjection.Projection, visitedTypes);
84-
visitedTypes.Remove(navProjection.EntityType);
118+
// Property access: ((DerivedType)e).Navigation
119+
var property = derivedType.GetProperty(navName);
120+
if (property == null)
121+
{
122+
continue;
123+
}
124+
125+
var propertyAccess = Expression.Property(cast, property);
126+
127+
// Build lambda: e => ((DerivedType)e).Navigation
128+
var lambda = Expression.Lambda(propertyAccess, parameter);
129+
130+
// Call EntityFrameworkQueryableExtensions.Include(query, lambda)
131+
var includeMethod = GetIncludeMethod(itemType, property.PropertyType);
132+
query = (IQueryable<TItem>)includeMethod.Invoke(null, [query, lambda])!;
133+
}
85134
}
86135

87136
return query;
88137
}
89138

139+
static MethodInfo GetIncludeMethod(Type entityType, Type propertyType) =>
140+
typeof(EntityFrameworkQueryableExtensions)
141+
.GetMethods(BindingFlags.Static | BindingFlags.Public)
142+
.First(_ => _.Name == "Include" &&
143+
_.GetGenericArguments().Length == 2 &&
144+
_.GetParameters().Length == 2 &&
145+
_.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>))
146+
.MakeGenericMethod(entityType, propertyType);
147+
90148
static IQueryable<TItem> AddNestedIncludes<TItem>(
91149
IQueryable<TItem> query,
92150
string includePath,
@@ -180,7 +238,195 @@ FieldProjectionInfo GetProjectionInfo(
180238
}
181239
}
182240

183-
return new(scalarFields, keys, foreignKeyNames, navProjections);
241+
// Scan for derived-type navigations from inline fragments (TPH support)
242+
var derivedNavigations = GetDerivedNavigationsFromFragments(context);
243+
244+
return new(scalarFields, keys, foreignKeyNames, navProjections, derivedNavigations);
245+
}
246+
247+
Dictionary<Type, Dictionary<string, NavigationProjectionInfo>>? GetDerivedNavigationsFromFragments(
248+
IResolveFieldContext context)
249+
{
250+
var selectionSet = GetLeafSelectionSet(context);
251+
if (selectionSet?.Selections is null)
252+
{
253+
return null;
254+
}
255+
256+
Dictionary<Type, Dictionary<string, NavigationProjectionInfo>>? result = null;
257+
258+
foreach (var selection in selectionSet.Selections)
259+
{
260+
GraphQLTypeCondition? typeCondition;
261+
GraphQLSelectionSet? fragmentSelectionSet;
262+
263+
switch (selection)
264+
{
265+
case GraphQLInlineFragment inlineFragment:
266+
typeCondition = inlineFragment.TypeCondition;
267+
fragmentSelectionSet = inlineFragment.SelectionSet;
268+
break;
269+
case GraphQLFragmentSpread fragmentSpread:
270+
{
271+
var name = fragmentSpread.FragmentName.Name;
272+
var fragmentDefinition = context.Document.Definitions
273+
.OfType<GraphQLFragmentDefinition>()
274+
.SingleOrDefault(_ => _.FragmentName.Name == name);
275+
if (fragmentDefinition is null)
276+
{
277+
continue;
278+
}
279+
280+
typeCondition = fragmentDefinition.TypeCondition;
281+
fragmentSelectionSet = fragmentDefinition.SelectionSet;
282+
break;
283+
}
284+
default:
285+
continue;
286+
}
287+
288+
if (typeCondition is null || fragmentSelectionSet?.Selections is null)
289+
{
290+
continue;
291+
}
292+
293+
var typeName = typeCondition.Type.Name.StringValue;
294+
295+
// Find the CLR type for this GraphQL type name using the schema
296+
if (!TryFindDerivedClrType(typeName, context.Schema, out var derivedType))
297+
{
298+
continue;
299+
}
300+
301+
// Get navigation properties for this derived type
302+
if (!navigations.TryGetValue(derivedType, out var derivedNavProps))
303+
{
304+
continue;
305+
}
306+
307+
// Process fields in this fragment against the derived type's navigation properties
308+
foreach (var field in fragmentSelectionSet.Selections.OfType<GraphQLField>())
309+
{
310+
var fieldName = field.Name.StringValue;
311+
if (!derivedNavProps.TryGetValue(fieldName, out var navigation))
312+
{
313+
continue;
314+
}
315+
316+
result ??= [];
317+
if (!result.TryGetValue(derivedType, out var derivedNavs))
318+
{
319+
derivedNavs = [];
320+
result[derivedType] = derivedNavs;
321+
}
322+
323+
if (derivedNavs.ContainsKey(navigation.Name))
324+
{
325+
continue;
326+
}
327+
328+
var navType = navigation.Type;
329+
navigations.TryGetValue(navType, out var nestedNavProps);
330+
keyNames.TryGetValue(navType, out var nestedKeys);
331+
foreignKeys.TryGetValue(navType, out var nestedFks);
332+
333+
derivedNavs[navigation.Name] = new(
334+
navType,
335+
navigation.IsCollection,
336+
GetNestedProjection(field.SelectionSet, nestedNavProps, nestedKeys, nestedFks, context));
337+
}
338+
}
339+
340+
return result;
341+
}
342+
343+
/// <summary>
344+
/// Navigate through connection wrapper fields (edges/items/node) to find the leaf selection set
345+
/// that contains the actual entity fields and inline fragments.
346+
/// </summary>
347+
static GraphQLSelectionSet? GetLeafSelectionSet(IResolveFieldContext context)
348+
{
349+
var selectionSet = context.FieldAst.SelectionSet;
350+
if (selectionSet?.Selections is null)
351+
{
352+
return null;
353+
}
354+
355+
// Drill through connection wrapper fields
356+
while (true)
357+
{
358+
var found = false;
359+
foreach (var selection in selectionSet.Selections)
360+
{
361+
if (selection is GraphQLField field && IsConnectionNodeName(field.Name.StringValue))
362+
{
363+
if (field.SelectionSet is not null)
364+
{
365+
selectionSet = field.SelectionSet;
366+
found = true;
367+
break;
368+
}
369+
}
370+
}
371+
372+
if (!found)
373+
{
374+
break;
375+
}
376+
}
377+
378+
return selectionSet;
379+
}
380+
381+
bool TryFindDerivedClrType(string graphQlTypeName, ISchema schema, [NotNullWhen(true)] out Type? clrType)
382+
{
383+
clrType = null;
384+
385+
// Use the schema's type lookup to resolve GraphQL type name → CLR type
386+
var graphType = schema.AllTypes.FirstOrDefault(_ => _.Name == graphQlTypeName);
387+
if (graphType is not null)
388+
{
389+
// Walk the type hierarchy to find the CLR type from the generic arguments
390+
var graphClrType = GetSourceType(graphType.GetType());
391+
if (graphClrType is not null && navigations.ContainsKey(graphClrType))
392+
{
393+
clrType = graphClrType;
394+
return true;
395+
}
396+
}
397+
398+
// Fallback: match CLR type name directly
399+
foreach (var type in navigations.Keys)
400+
{
401+
if (string.Equals(type.Name, graphQlTypeName, StringComparison.OrdinalIgnoreCase))
402+
{
403+
clrType = type;
404+
return true;
405+
}
406+
}
407+
408+
return false;
409+
}
410+
411+
static Type? GetSourceType(Type graphType)
412+
{
413+
var type = graphType;
414+
while (type is not null)
415+
{
416+
if (type.IsGenericType)
417+
{
418+
var genericDef = type.GetGenericTypeDefinition();
419+
if (genericDef == typeof(ObjectGraphType<>) ||
420+
genericDef == typeof(InterfaceGraphType<>))
421+
{
422+
return type.GenericTypeArguments[0];
423+
}
424+
}
425+
426+
type = type.BaseType;
427+
}
428+
429+
return null;
184430
}
185431

186432
void ProcessConnectionNodeFields(

src/GraphQL.EntityFramework/SelectProjection/FieldProjectionInfo.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ record FieldProjectionInfo(
22
HashSet<string> ScalarFields,
33
List<string>? KeyNames,
44
IReadOnlySet<string>? ForeignKeyNames,
5-
Dictionary<string, NavigationProjectionInfo>? Navigations);
5+
Dictionary<string, NavigationProjectionInfo>? Navigations,
6+
Dictionary<Type, Dictionary<string, NavigationProjectionInfo>>? DerivedNavigations = null);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
public class CategoryEntity
2+
{
3+
public Guid Id { get; set; } = Guid.NewGuid();
4+
public string? Name { get; set; }
5+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
public class CategoryGraphType :
2+
EfObjectGraphType<IntegrationDbContext, CategoryEntity>
3+
{
4+
public CategoryGraphType(IEfGraphQLService<IntegrationDbContext> graphQlService) :
5+
base(graphQlService) =>
6+
AutoMap();
7+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
public class RegionEntity
2+
{
3+
public Guid Id { get; set; } = Guid.NewGuid();
4+
public string? Name { get; set; }
5+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
public class RegionGraphType :
2+
EfObjectGraphType<IntegrationDbContext, RegionEntity>
3+
{
4+
public RegionGraphType(IEfGraphQLService<IntegrationDbContext> graphQlService) :
5+
base(graphQlService) =>
6+
AutoMap();
7+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
public abstract class TphDerivedNavBaseEntity
2+
{
3+
public Guid Id { get; set; } = Guid.NewGuid();
4+
public string? Property { get; set; }
5+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
public class TphDerivedNavBaseGraphType(IEfGraphQLService<IntegrationDbContext> graphQlService) :
2+
EfInterfaceGraphType<IntegrationDbContext, TphDerivedNavBaseEntity>(graphQlService);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
public class TphDerivedNavCategoryEntity : TphDerivedNavBaseEntity
2+
{
3+
public Guid? CategoryId { get; set; }
4+
public CategoryEntity? Category { get; set; }
5+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
public class TphDerivedNavCategoryGraphType :
2+
EfObjectGraphType<IntegrationDbContext, TphDerivedNavCategoryEntity>
3+
{
4+
public TphDerivedNavCategoryGraphType(IEfGraphQLService<IntegrationDbContext> graphQlService) :
5+
base(graphQlService)
6+
{
7+
AutoMap();
8+
Interface<TphDerivedNavBaseGraphType>();
9+
IsTypeOf = _ => _ is TphDerivedNavCategoryEntity;
10+
}
11+
}

0 commit comments

Comments
 (0)