@@ -23,15 +23,15 @@ import (
23
23
24
24
"github.com/onflow/cadence/runtime/ast"
25
25
"github.com/onflow/cadence/runtime/common"
26
- "github.com/onflow/cadence/runtime/sema"
27
26
)
28
27
29
28
type ContractUpdateValidator struct {
29
+ TypeComparator
30
+
30
31
location Location
31
32
contractName string
32
33
oldProgram * ast.Program
33
34
newProgram * ast.Program
34
- rootDecl ast.Declaration
35
35
currentDecl ast.Declaration
36
36
errors []error
37
37
}
@@ -68,7 +68,8 @@ func (validator *ContractUpdateValidator) Validate() error {
68
68
return validator .getContractUpdateError ()
69
69
}
70
70
71
- validator .rootDecl = newRootDecl
71
+ validator .TypeComparator .RootDeclIdentifier = newRootDecl .DeclarationIdentifier ()
72
+
72
73
validator .checkDeclarationUpdatability (oldRootDecl , newRootDecl )
73
74
74
75
if validator .hasErrors () {
@@ -104,7 +105,6 @@ func getRootDeclaration(program *ast.Program) (ast.Declaration, error) {
104
105
return nil , & ContractNotFoundError {
105
106
Range : ast .NewUnmeteredRangeFromPositioned (program ),
106
107
}
107
-
108
108
}
109
109
110
110
func (validator * ContractUpdateValidator ) hasErrors () bool {
@@ -310,203 +310,6 @@ func (validator *ContractUpdateValidator) checkEnumCases(oldDeclaration ast.Decl
310
310
}
311
311
}
312
312
313
- func (validator * ContractUpdateValidator ) CheckNominalTypeEquality (expected * ast.NominalType , found ast.Type ) error {
314
- foundNominalType , ok := found .(* ast.NominalType )
315
- if ! ok {
316
- return getTypeMismatchError (expected , found )
317
- }
318
-
319
- // First check whether the names are equal.
320
- ok = validator .checkNameEquality (expected , foundNominalType )
321
- if ! ok {
322
- return getTypeMismatchError (expected , found )
323
- }
324
-
325
- return nil
326
- }
327
-
328
- func (validator * ContractUpdateValidator ) CheckOptionalTypeEquality (expected * ast.OptionalType , found ast.Type ) error {
329
- foundOptionalType , ok := found .(* ast.OptionalType )
330
- if ! ok {
331
- return getTypeMismatchError (expected , found )
332
- }
333
-
334
- return expected .Type .CheckEqual (foundOptionalType .Type , validator )
335
- }
336
-
337
- func (validator * ContractUpdateValidator ) CheckVariableSizedTypeEquality (expected * ast.VariableSizedType , found ast.Type ) error {
338
- foundVarSizedType , ok := found .(* ast.VariableSizedType )
339
- if ! ok {
340
- return getTypeMismatchError (expected , found )
341
- }
342
-
343
- return expected .Type .CheckEqual (foundVarSizedType .Type , validator )
344
- }
345
-
346
- func (validator * ContractUpdateValidator ) CheckConstantSizedTypeEquality (expected * ast.ConstantSizedType , found ast.Type ) error {
347
- foundConstSizedType , ok := found .(* ast.ConstantSizedType )
348
- if ! ok {
349
- return getTypeMismatchError (expected , found )
350
- }
351
-
352
- // Check size
353
- if foundConstSizedType .Size .Value .Cmp (expected .Size .Value ) != 0 ||
354
- foundConstSizedType .Size .Base != expected .Size .Base {
355
- return getTypeMismatchError (expected , found )
356
- }
357
-
358
- // Check type
359
- return expected .Type .CheckEqual (foundConstSizedType .Type , validator )
360
- }
361
-
362
- func (validator * ContractUpdateValidator ) CheckDictionaryTypeEquality (expected * ast.DictionaryType , found ast.Type ) error {
363
- foundDictionaryType , ok := found .(* ast.DictionaryType )
364
- if ! ok {
365
- return getTypeMismatchError (expected , found )
366
- }
367
-
368
- err := expected .KeyType .CheckEqual (foundDictionaryType .KeyType , validator )
369
- if err != nil {
370
- return err
371
- }
372
-
373
- return expected .ValueType .CheckEqual (foundDictionaryType .ValueType , validator )
374
- }
375
-
376
- func (validator * ContractUpdateValidator ) CheckRestrictedTypeEquality (expected * ast.RestrictedType , found ast.Type ) error {
377
- foundRestrictedType , ok := found .(* ast.RestrictedType )
378
- if ! ok {
379
- return getTypeMismatchError (expected , found )
380
- }
381
-
382
- if expected .Type == nil {
383
- if ! isAnyStructOrAnyResourceType (foundRestrictedType .Type ) {
384
- return getTypeMismatchError (expected , found )
385
- }
386
- // else go on to check type restrictions
387
- } else if foundRestrictedType .Type == nil {
388
- if ! isAnyStructOrAnyResourceType (expected .Type ) {
389
- return getTypeMismatchError (expected , found )
390
- }
391
- // else go on to check type restrictions
392
- } else {
393
- // both are not nil
394
- err := expected .Type .CheckEqual (foundRestrictedType .Type , validator )
395
- if err != nil {
396
- return getTypeMismatchError (expected , found )
397
- }
398
- }
399
-
400
- if len (expected .Restrictions ) != len (foundRestrictedType .Restrictions ) {
401
- return getTypeMismatchError (expected , found )
402
- }
403
-
404
- for index , expectedRestriction := range expected .Restrictions {
405
- foundRestriction := foundRestrictedType .Restrictions [index ]
406
- err := expectedRestriction .CheckEqual (foundRestriction , validator )
407
- if err != nil {
408
- return getTypeMismatchError (expected , found )
409
- }
410
- }
411
-
412
- return nil
413
- }
414
-
415
- func (validator * ContractUpdateValidator ) CheckInstantiationTypeEquality (expected * ast.InstantiationType , found ast.Type ) error {
416
- foundInstType , ok := found .(* ast.InstantiationType )
417
- if ! ok {
418
- return getTypeMismatchError (expected , found )
419
- }
420
-
421
- err := expected .Type .CheckEqual (foundInstType .Type , validator )
422
- if err != nil || len (expected .TypeArguments ) != len (foundInstType .TypeArguments ) {
423
- return getTypeMismatchError (expected , found )
424
- }
425
-
426
- for index , typeArgs := range expected .TypeArguments {
427
- otherTypeArgs := foundInstType .TypeArguments [index ]
428
- err := typeArgs .Type .CheckEqual (otherTypeArgs .Type , validator )
429
- if err != nil {
430
- return getTypeMismatchError (expected , found )
431
- }
432
- }
433
-
434
- return nil
435
- }
436
-
437
- func (validator * ContractUpdateValidator ) CheckFunctionTypeEquality (expected * ast.FunctionType , found ast.Type ) error {
438
- foundFuncType , ok := found .(* ast.FunctionType )
439
- if ! ok || len (expected .ParameterTypeAnnotations ) != len (foundFuncType .ParameterTypeAnnotations ) {
440
- return getTypeMismatchError (expected , found )
441
- }
442
-
443
- for index , expectedParamType := range expected .ParameterTypeAnnotations {
444
- foundParamType := foundFuncType .ParameterTypeAnnotations [index ]
445
- err := expectedParamType .Type .CheckEqual (foundParamType .Type , validator )
446
- if err != nil {
447
- return getTypeMismatchError (expected , found )
448
- }
449
- }
450
-
451
- return expected .ReturnTypeAnnotation .Type .CheckEqual (foundFuncType .ReturnTypeAnnotation .Type , validator )
452
- }
453
-
454
- func (validator * ContractUpdateValidator ) CheckReferenceTypeEquality (expected * ast.ReferenceType , found ast.Type ) error {
455
- refType , ok := found .(* ast.ReferenceType )
456
- if ! ok {
457
- return getTypeMismatchError (expected , found )
458
- }
459
-
460
- return expected .Type .CheckEqual (refType .Type , validator )
461
- }
462
-
463
- func (validator * ContractUpdateValidator ) checkNameEquality (expectedType * ast.NominalType , foundType * ast.NominalType ) bool {
464
- isExpectedQualifiedName := expectedType .IsQualifiedName ()
465
- isFoundQualifiedName := foundType .IsQualifiedName ()
466
-
467
- // A field with a composite type can be defined in two ways:
468
- // - Using type name (var x @ResourceName)
469
- // - Using qualified type name (var x @ContractName.ResourceName)
470
-
471
- if isExpectedQualifiedName && ! isFoundQualifiedName {
472
- return validator .checkIdentifierEquality (expectedType , foundType )
473
- }
474
-
475
- if isFoundQualifiedName && ! isExpectedQualifiedName {
476
- return validator .checkIdentifierEquality (foundType , expectedType )
477
- }
478
-
479
- // At this point, either both are qualified names, or both are simple names.
480
- // Thus, do a one-to-one match.
481
- if expectedType .Identifier .Identifier != foundType .Identifier .Identifier {
482
- return false
483
- }
484
-
485
- return identifiersEqual (expectedType .NestedIdentifiers , foundType .NestedIdentifiers )
486
- }
487
-
488
- func (validator * ContractUpdateValidator ) checkIdentifierEquality (
489
- qualifiedNominalType * ast.NominalType ,
490
- simpleNominalType * ast.NominalType ,
491
- ) bool {
492
-
493
- // Situation:
494
- // qualifiedNominalType -> identifier: A, nestedIdentifiers: [foo, bar, ...]
495
- // simpleNominalType -> identifier: foo, nestedIdentifiers: [bar, ...]
496
-
497
- // If the first identifier (i.e: 'A') refers to a composite decl that is not the enclosing contract,
498
- // then it must be referring to an imported contract. That means the two types are no longer the same.
499
- if qualifiedNominalType .Identifier .Identifier != validator .rootDecl .DeclarationIdentifier ().Identifier {
500
- return false
501
- }
502
-
503
- if qualifiedNominalType .NestedIdentifiers [0 ].Identifier != simpleNominalType .Identifier .Identifier {
504
- return false
505
- }
506
-
507
- return identifiersEqual (simpleNominalType .NestedIdentifiers , qualifiedNominalType .NestedIdentifiers [1 :])
508
- }
509
-
510
313
func (validator * ContractUpdateValidator ) checkConformances (
511
314
oldDecl * ast.CompositeDeclaration ,
512
315
newDecl * ast.CompositeDeclaration ,
@@ -562,46 +365,6 @@ func (validator *ContractUpdateValidator) getContractUpdateError() error {
562
365
}
563
366
}
564
367
565
- func getTypeMismatchError (expectedType ast.Type , foundType ast.Type ) * TypeMismatchError {
566
- return & TypeMismatchError {
567
- ExpectedType : expectedType ,
568
- FoundType : foundType ,
569
- Range : ast .NewUnmeteredRangeFromPositioned (foundType ),
570
- }
571
- }
572
-
573
- func identifiersEqual (expected []ast.Identifier , found []ast.Identifier ) bool {
574
- if len (expected ) != len (found ) {
575
- return false
576
- }
577
-
578
- for index , element := range found {
579
- if expected [index ].Identifier != element .Identifier {
580
- return false
581
- }
582
- }
583
- return true
584
- }
585
-
586
- func isAnyStructOrAnyResourceType (astType ast.Type ) bool {
587
- // If the restricted type is not stated, then it is either AnyStruct or AnyResource
588
- if astType == nil {
589
- return true
590
- }
591
-
592
- nominalType , ok := astType .(* ast.NominalType )
593
- if ! ok {
594
- return false
595
- }
596
-
597
- switch nominalType .Identifier .Identifier {
598
- case sema .AnyStructType .Name , sema .AnyResourceType .Name :
599
- return true
600
- default :
601
- return false
602
- }
603
- }
604
-
605
368
func containsEnumsInProgram (program * ast.Program ) bool {
606
369
declaration , err := getRootDeclaration (program )
607
370
0 commit comments