1
1
#nullable enable
2
2
3
+ using System . Diagnostics . CodeAnalysis ;
3
4
using HotChocolate . Configuration ;
4
5
using HotChocolate . Configuration . Validation ;
5
6
using HotChocolate . Language ;
@@ -37,25 +38,25 @@ public static Schema Create(
37
38
{
38
39
var typeInterceptors = new List < TypeInterceptor > ( ) ;
39
40
40
- if ( context . Options . StrictRuntimeTypeValidation &&
41
- ! builder . _typeInterceptors . Contains ( typeof ( TypeValidationTypeInterceptor ) ) )
41
+ if ( context . Options . StrictRuntimeTypeValidation
42
+ && ! builder . _typeInterceptors . Contains ( typeof ( TypeValidationTypeInterceptor ) ) )
42
43
{
43
44
builder . _typeInterceptors . Add ( typeof ( TypeValidationTypeInterceptor ) ) ;
44
45
}
45
46
46
- if ( context . Options . EnableFlagEnums &&
47
- ! builder . _typeInterceptors . Contains ( typeof ( FlagsEnumInterceptor ) ) )
47
+ if ( context . Options . EnableFlagEnums
48
+ && ! builder . _typeInterceptors . Contains ( typeof ( FlagsEnumInterceptor ) ) )
48
49
{
49
50
builder . _typeInterceptors . Add ( typeof ( FlagsEnumInterceptor ) ) ;
50
51
}
51
52
52
- if ( context . Options . RemoveUnusedTypeSystemDirectives &&
53
- ! builder . _typeInterceptors . Contains ( typeof ( DirectiveTypeInterceptor ) ) )
53
+ if ( context . Options . RemoveUnusedTypeSystemDirectives
54
+ && ! builder . _typeInterceptors . Contains ( typeof ( DirectiveTypeInterceptor ) ) )
54
55
{
55
56
builder . _typeInterceptors . Add ( typeof ( DirectiveTypeInterceptor ) ) ;
56
57
}
57
58
58
- if ( builder . _schemaFirstTypeInterceptor is not null )
59
+ if ( builder . _schemaFirstTypeInterceptor is not null )
59
60
{
60
61
typeInterceptors . Add ( builder . _schemaFirstTypeInterceptor ) ;
61
62
}
@@ -259,8 +260,7 @@ private static void InitializeInterceptors<T>(
259
260
List < T > interceptors )
260
261
where T : class
261
262
{
262
- if ( services is not EmptyServiceProvider &&
263
- services . GetService < IEnumerable < T > > ( ) is { } fromService )
263
+ if ( services is not EmptyServiceProvider && services . GetService < IEnumerable < T > > ( ) is { } fromService )
264
264
{
265
265
interceptors . AddRange ( fromService ) ;
266
266
}
@@ -293,28 +293,28 @@ private static RootTypeKind GetOperationKind(
293
293
if ( type is ObjectType objectType )
294
294
{
295
295
if ( IsOperationType (
296
- objectType ,
297
- OperationType . Query ,
298
- typeInspector ,
299
- operations ) )
296
+ objectType ,
297
+ OperationType . Query ,
298
+ typeInspector ,
299
+ operations ) )
300
300
{
301
301
return RootTypeKind . Query ;
302
302
}
303
303
304
304
if ( IsOperationType (
305
- objectType ,
306
- OperationType . Mutation ,
307
- typeInspector ,
308
- operations ) )
305
+ objectType ,
306
+ OperationType . Mutation ,
307
+ typeInspector ,
308
+ operations ) )
309
309
{
310
310
return RootTypeKind . Mutation ;
311
311
}
312
312
313
313
if ( IsOperationType (
314
- objectType ,
315
- OperationType . Subscription ,
316
- typeInspector ,
317
- operations ) )
314
+ objectType ,
315
+ OperationType . Subscription ,
316
+ typeInspector ,
317
+ operations ) )
318
318
{
319
319
return RootTypeKind . Subscription ;
320
320
}
@@ -338,8 +338,8 @@ private static bool IsOperationType(
338
338
339
339
if ( typeRef is ExtendedTypeReference cr )
340
340
{
341
- return cr . Type . Equals ( typeInspector . GetType ( objectType . GetType ( ) ) ) ||
342
- cr . Type . Equals ( typeInspector . GetType ( objectType . RuntimeType ) ) ;
341
+ return cr . Type . Equals ( typeInspector . GetType ( objectType . GetType ( ) ) )
342
+ || cr . Type . Equals ( typeInspector . GetType ( objectType . RuntimeType ) ) ;
343
343
}
344
344
345
345
if ( typeRef is SyntaxTypeReference str )
@@ -437,26 +437,29 @@ private static void ResolveOperations(
437
437
{
438
438
if ( operations . Count == 0 )
439
439
{
440
- schemaDef . QueryType = GetObjectType ( OperationTypeNames . Query ) ;
441
- schemaDef . MutationType = GetObjectType ( OperationTypeNames . Mutation ) ;
442
- schemaDef . SubscriptionType = GetObjectType ( OperationTypeNames . Subscription ) ;
440
+ schemaDef . QueryType = GetObjectType ( OperationTypeNames . Query , OperationType . Query ) ;
441
+ schemaDef . MutationType = GetObjectType ( OperationTypeNames . Mutation , OperationType . Mutation ) ;
442
+ schemaDef . SubscriptionType = GetObjectType ( OperationTypeNames . Subscription , OperationType . Subscription ) ;
443
443
}
444
444
else
445
445
{
446
446
schemaDef . QueryType = GetOperationType ( OperationType . Query ) ;
447
447
schemaDef . MutationType = GetOperationType ( OperationType . Mutation ) ;
448
448
schemaDef . SubscriptionType = GetOperationType ( OperationType . Subscription ) ;
449
449
}
450
-
451
450
return ;
452
451
453
- ObjectType ? GetObjectType ( string typeName )
452
+ ObjectType ? GetObjectType ( string typeName , OperationType expectedOperation )
454
453
{
455
454
foreach ( var registeredType in typeRegistry . Types )
456
455
{
457
- if ( registeredType . Type is ObjectType objectType &&
458
- objectType . Name . EqualsOrdinal ( typeName ) )
456
+ if ( registeredType . Type . Name . EqualsOrdinal ( typeName ) )
459
457
{
458
+ if ( registeredType . Type is not ObjectType objectType )
459
+ {
460
+ Throw ( ( INamedType ) registeredType . Type , expectedOperation ) ;
461
+ }
462
+
460
463
return objectType ;
461
464
}
462
465
}
@@ -466,30 +469,69 @@ private static void ResolveOperations(
466
469
467
470
ObjectType ? GetOperationType ( OperationType operation )
468
471
{
469
- if ( operations . TryGetValue ( operation , out var reference ) )
472
+ if ( ! operations . TryGetValue ( operation , out var reference ) )
470
473
{
471
- if ( reference is SchemaTypeReference sr )
474
+ return null ;
475
+ }
476
+
477
+ switch ( reference )
478
+ {
479
+ case SchemaTypeReference str :
472
480
{
473
- return ( ObjectType ) sr . Type ;
481
+ if ( str . Type is not ObjectType ot )
482
+ {
483
+ Throw ( ( INamedType ) str . Type , operation ) ;
484
+ }
485
+
486
+ return ot ;
474
487
}
475
488
476
- if ( reference is ExtendedTypeReference cr &&
477
- typeRegistry . TryGetType ( cr , out var registeredType ) )
489
+ case ExtendedTypeReference cr when typeRegistry . TryGetType ( cr , out var registeredType ) :
478
490
{
479
- return ( ObjectType ) registeredType . Type ;
491
+ if ( registeredType . Type is not ObjectType ot )
492
+ {
493
+ Throw ( ( INamedType ) registeredType . Type , operation ) ;
494
+ }
495
+
496
+ return ot ;
480
497
}
481
498
482
- if ( reference is SyntaxTypeReference str )
499
+ case SyntaxTypeReference str :
483
500
{
484
501
var namedType = str . Type . NamedType ( ) ;
485
- return typeRegistry . Types
502
+ var type = typeRegistry . Types
486
503
. Select ( t => t . Type )
487
- . OfType < ObjectType > ( )
488
504
. FirstOrDefault ( t => t . Name . EqualsOrdinal ( namedType . Name . Value ) ) ;
505
+
506
+ if ( type is null )
507
+ {
508
+ return null ;
509
+ }
510
+
511
+ if ( type is not ObjectType ot )
512
+ {
513
+ Throw ( ( INamedType ) type , operation ) ;
514
+ }
515
+
516
+ return ot ;
489
517
}
518
+
519
+ default :
520
+ return null ;
490
521
}
522
+ }
491
523
492
- return null ;
524
+ [ DoesNotReturn ]
525
+ static void Throw ( INamedType namedType , OperationType operation )
526
+ {
527
+ throw SchemaErrorBuilder . New ( )
528
+ . SetMessage (
529
+ "Cannot register `{0}` as {1} type as it is not an object type. `{0}` is of type `{2}`." ,
530
+ namedType . Name ,
531
+ operation ,
532
+ namedType . GetType ( ) . FullName )
533
+ . SetTypeSystemObject ( ( TypeSystemObjectBase ) namedType )
534
+ . BuildException ( ) ;
493
535
}
494
536
}
495
537
@@ -529,7 +571,7 @@ internal static void AddCoreSchemaServices(IServiceCollection services, LazySche
529
571
. ToArray ( ) ;
530
572
}
531
573
532
- if ( allSerializers is null || allSerializers . Length == 0 )
574
+ if ( allSerializers is null || allSerializers . Length == 0 )
533
575
{
534
576
allSerializers =
535
577
[
0 commit comments