Skip to content

Commit 16095a7

Browse files
committed
Better error for invalid root type. (#8104)
1 parent c73b1ef commit 16095a7

File tree

1 file changed

+83
-41
lines changed

1 file changed

+83
-41
lines changed

src/HotChocolate/Core/src/Types/SchemaBuilder.Setup.cs

+83-41
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#nullable enable
22

3+
using System.Diagnostics.CodeAnalysis;
34
using HotChocolate.Configuration;
45
using HotChocolate.Configuration.Validation;
56
using HotChocolate.Language;
@@ -37,25 +38,25 @@ public static Schema Create(
3738
{
3839
var typeInterceptors = new List<TypeInterceptor>();
3940

40-
if (context.Options.StrictRuntimeTypeValidation &&
41-
!builder._typeInterceptors.Contains(typeof(TypeValidationTypeInterceptor)))
41+
if (context.Options.StrictRuntimeTypeValidation
42+
&& !builder._typeInterceptors.Contains(typeof(TypeValidationTypeInterceptor)))
4243
{
4344
builder._typeInterceptors.Add(typeof(TypeValidationTypeInterceptor));
4445
}
4546

46-
if (context.Options.EnableFlagEnums &&
47-
!builder._typeInterceptors.Contains(typeof(FlagsEnumInterceptor)))
47+
if (context.Options.EnableFlagEnums
48+
&& !builder._typeInterceptors.Contains(typeof(FlagsEnumInterceptor)))
4849
{
4950
builder._typeInterceptors.Add(typeof(FlagsEnumInterceptor));
5051
}
5152

52-
if (context.Options.RemoveUnusedTypeSystemDirectives &&
53-
!builder._typeInterceptors.Contains(typeof(DirectiveTypeInterceptor)))
53+
if (context.Options.RemoveUnusedTypeSystemDirectives
54+
&& !builder._typeInterceptors.Contains(typeof(DirectiveTypeInterceptor)))
5455
{
5556
builder._typeInterceptors.Add(typeof(DirectiveTypeInterceptor));
5657
}
5758

58-
if(builder._schemaFirstTypeInterceptor is not null)
59+
if (builder._schemaFirstTypeInterceptor is not null)
5960
{
6061
typeInterceptors.Add(builder._schemaFirstTypeInterceptor);
6162
}
@@ -259,8 +260,7 @@ private static void InitializeInterceptors<T>(
259260
List<T> interceptors)
260261
where T : class
261262
{
262-
if (services is not EmptyServiceProvider &&
263-
services.GetService<IEnumerable<T>>() is { } fromService)
263+
if (services is not EmptyServiceProvider && services.GetService<IEnumerable<T>>() is { } fromService)
264264
{
265265
interceptors.AddRange(fromService);
266266
}
@@ -293,28 +293,28 @@ private static RootTypeKind GetOperationKind(
293293
if (type is ObjectType objectType)
294294
{
295295
if (IsOperationType(
296-
objectType,
297-
OperationType.Query,
298-
typeInspector,
299-
operations))
296+
objectType,
297+
OperationType.Query,
298+
typeInspector,
299+
operations))
300300
{
301301
return RootTypeKind.Query;
302302
}
303303

304304
if (IsOperationType(
305-
objectType,
306-
OperationType.Mutation,
307-
typeInspector,
308-
operations))
305+
objectType,
306+
OperationType.Mutation,
307+
typeInspector,
308+
operations))
309309
{
310310
return RootTypeKind.Mutation;
311311
}
312312

313313
if (IsOperationType(
314-
objectType,
315-
OperationType.Subscription,
316-
typeInspector,
317-
operations))
314+
objectType,
315+
OperationType.Subscription,
316+
typeInspector,
317+
operations))
318318
{
319319
return RootTypeKind.Subscription;
320320
}
@@ -338,8 +338,8 @@ private static bool IsOperationType(
338338

339339
if (typeRef is ExtendedTypeReference cr)
340340
{
341-
return cr.Type.Equals(typeInspector.GetType(objectType.GetType())) ||
342-
cr.Type.Equals(typeInspector.GetType(objectType.RuntimeType));
341+
return cr.Type.Equals(typeInspector.GetType(objectType.GetType()))
342+
|| cr.Type.Equals(typeInspector.GetType(objectType.RuntimeType));
343343
}
344344

345345
if (typeRef is SyntaxTypeReference str)
@@ -437,26 +437,29 @@ private static void ResolveOperations(
437437
{
438438
if (operations.Count == 0)
439439
{
440-
schemaDef.QueryType = GetObjectType(OperationTypeNames.Query);
441-
schemaDef.MutationType = GetObjectType(OperationTypeNames.Mutation);
442-
schemaDef.SubscriptionType = GetObjectType(OperationTypeNames.Subscription);
440+
schemaDef.QueryType = GetObjectType(OperationTypeNames.Query, OperationType.Query);
441+
schemaDef.MutationType = GetObjectType(OperationTypeNames.Mutation, OperationType.Mutation);
442+
schemaDef.SubscriptionType = GetObjectType(OperationTypeNames.Subscription, OperationType.Subscription);
443443
}
444444
else
445445
{
446446
schemaDef.QueryType = GetOperationType(OperationType.Query);
447447
schemaDef.MutationType = GetOperationType(OperationType.Mutation);
448448
schemaDef.SubscriptionType = GetOperationType(OperationType.Subscription);
449449
}
450-
451450
return;
452451

453-
ObjectType? GetObjectType(string typeName)
452+
ObjectType? GetObjectType(string typeName, OperationType expectedOperation)
454453
{
455454
foreach (var registeredType in typeRegistry.Types)
456455
{
457-
if (registeredType.Type is ObjectType objectType &&
458-
objectType.Name.EqualsOrdinal(typeName))
456+
if (registeredType.Type.Name.EqualsOrdinal(typeName))
459457
{
458+
if (registeredType.Type is not ObjectType objectType)
459+
{
460+
Throw((INamedType)registeredType.Type, expectedOperation);
461+
}
462+
460463
return objectType;
461464
}
462465
}
@@ -466,30 +469,69 @@ private static void ResolveOperations(
466469

467470
ObjectType? GetOperationType(OperationType operation)
468471
{
469-
if (operations.TryGetValue(operation, out var reference))
472+
if (!operations.TryGetValue(operation, out var reference))
470473
{
471-
if (reference is SchemaTypeReference sr)
474+
return null;
475+
}
476+
477+
switch (reference)
478+
{
479+
case SchemaTypeReference str:
472480
{
473-
return (ObjectType)sr.Type;
481+
if (str.Type is not ObjectType ot)
482+
{
483+
Throw((INamedType)str.Type, operation);
484+
}
485+
486+
return ot;
474487
}
475488

476-
if (reference is ExtendedTypeReference cr &&
477-
typeRegistry.TryGetType(cr, out var registeredType))
489+
case ExtendedTypeReference cr when typeRegistry.TryGetType(cr, out var registeredType):
478490
{
479-
return (ObjectType)registeredType.Type;
491+
if (registeredType.Type is not ObjectType ot)
492+
{
493+
Throw((INamedType)registeredType.Type, operation);
494+
}
495+
496+
return ot;
480497
}
481498

482-
if (reference is SyntaxTypeReference str)
499+
case SyntaxTypeReference str:
483500
{
484501
var namedType = str.Type.NamedType();
485-
return typeRegistry.Types
502+
var type = typeRegistry.Types
486503
.Select(t => t.Type)
487-
.OfType<ObjectType>()
488504
.FirstOrDefault(t => t.Name.EqualsOrdinal(namedType.Name.Value));
505+
506+
if (type is null)
507+
{
508+
return null;
509+
}
510+
511+
if (type is not ObjectType ot)
512+
{
513+
Throw((INamedType)type, operation);
514+
}
515+
516+
return ot;
489517
}
518+
519+
default:
520+
return null;
490521
}
522+
}
491523

492-
return null;
524+
[DoesNotReturn]
525+
static void Throw(INamedType namedType, OperationType operation)
526+
{
527+
throw SchemaErrorBuilder.New()
528+
.SetMessage(
529+
"Cannot register `{0}` as {1} type as it is not an object type. `{0}` is of type `{2}`.",
530+
namedType.Name,
531+
operation,
532+
namedType.GetType().FullName)
533+
.SetTypeSystemObject((TypeSystemObjectBase)namedType)
534+
.BuildException();
493535
}
494536
}
495537

@@ -529,7 +571,7 @@ internal static void AddCoreSchemaServices(IServiceCollection services, LazySche
529571
.ToArray();
530572
}
531573

532-
if(allSerializers is null || allSerializers.Length == 0)
574+
if (allSerializers is null || allSerializers.Length == 0)
533575
{
534576
allSerializers =
535577
[

0 commit comments

Comments
 (0)