Skip to content

Commit b9bf808

Browse files
authored
Merge pull request #1787 from onflow/supun/refactor-type-equality-checker
Separate type equality checker from contract update validator
2 parents 6d1ce5a + f8bb44a commit b9bf808

File tree

2 files changed

+272
-241
lines changed

2 files changed

+272
-241
lines changed

runtime/contract_update_validation.go

Lines changed: 4 additions & 241 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ import (
2323

2424
"github.com/onflow/cadence/runtime/ast"
2525
"github.com/onflow/cadence/runtime/common"
26-
"github.com/onflow/cadence/runtime/sema"
2726
)
2827

2928
type ContractUpdateValidator struct {
29+
TypeComparator
30+
3031
location Location
3132
contractName string
3233
oldProgram *ast.Program
3334
newProgram *ast.Program
34-
rootDecl ast.Declaration
3535
currentDecl ast.Declaration
3636
errors []error
3737
}
@@ -68,7 +68,8 @@ func (validator *ContractUpdateValidator) Validate() error {
6868
return validator.getContractUpdateError()
6969
}
7070

71-
validator.rootDecl = newRootDecl
71+
validator.TypeComparator.RootDeclIdentifier = newRootDecl.DeclarationIdentifier()
72+
7273
validator.checkDeclarationUpdatability(oldRootDecl, newRootDecl)
7374

7475
if validator.hasErrors() {
@@ -104,7 +105,6 @@ func getRootDeclaration(program *ast.Program) (ast.Declaration, error) {
104105
return nil, &ContractNotFoundError{
105106
Range: ast.NewUnmeteredRangeFromPositioned(program),
106107
}
107-
108108
}
109109

110110
func (validator *ContractUpdateValidator) hasErrors() bool {
@@ -310,203 +310,6 @@ func (validator *ContractUpdateValidator) checkEnumCases(oldDeclaration ast.Decl
310310
}
311311
}
312312

313-
func (validator *ContractUpdateValidator) CheckNominalTypeEquality(expected *ast.NominalType, found ast.Type) error {
314-
foundNominalType, ok := found.(*ast.NominalType)
315-
if !ok {
316-
return getTypeMismatchError(expected, found)
317-
}
318-
319-
// First check whether the names are equal.
320-
ok = validator.checkNameEquality(expected, foundNominalType)
321-
if !ok {
322-
return getTypeMismatchError(expected, found)
323-
}
324-
325-
return nil
326-
}
327-
328-
func (validator *ContractUpdateValidator) CheckOptionalTypeEquality(expected *ast.OptionalType, found ast.Type) error {
329-
foundOptionalType, ok := found.(*ast.OptionalType)
330-
if !ok {
331-
return getTypeMismatchError(expected, found)
332-
}
333-
334-
return expected.Type.CheckEqual(foundOptionalType.Type, validator)
335-
}
336-
337-
func (validator *ContractUpdateValidator) CheckVariableSizedTypeEquality(expected *ast.VariableSizedType, found ast.Type) error {
338-
foundVarSizedType, ok := found.(*ast.VariableSizedType)
339-
if !ok {
340-
return getTypeMismatchError(expected, found)
341-
}
342-
343-
return expected.Type.CheckEqual(foundVarSizedType.Type, validator)
344-
}
345-
346-
func (validator *ContractUpdateValidator) CheckConstantSizedTypeEquality(expected *ast.ConstantSizedType, found ast.Type) error {
347-
foundConstSizedType, ok := found.(*ast.ConstantSizedType)
348-
if !ok {
349-
return getTypeMismatchError(expected, found)
350-
}
351-
352-
// Check size
353-
if foundConstSizedType.Size.Value.Cmp(expected.Size.Value) != 0 ||
354-
foundConstSizedType.Size.Base != expected.Size.Base {
355-
return getTypeMismatchError(expected, found)
356-
}
357-
358-
// Check type
359-
return expected.Type.CheckEqual(foundConstSizedType.Type, validator)
360-
}
361-
362-
func (validator *ContractUpdateValidator) CheckDictionaryTypeEquality(expected *ast.DictionaryType, found ast.Type) error {
363-
foundDictionaryType, ok := found.(*ast.DictionaryType)
364-
if !ok {
365-
return getTypeMismatchError(expected, found)
366-
}
367-
368-
err := expected.KeyType.CheckEqual(foundDictionaryType.KeyType, validator)
369-
if err != nil {
370-
return err
371-
}
372-
373-
return expected.ValueType.CheckEqual(foundDictionaryType.ValueType, validator)
374-
}
375-
376-
func (validator *ContractUpdateValidator) CheckRestrictedTypeEquality(expected *ast.RestrictedType, found ast.Type) error {
377-
foundRestrictedType, ok := found.(*ast.RestrictedType)
378-
if !ok {
379-
return getTypeMismatchError(expected, found)
380-
}
381-
382-
if expected.Type == nil {
383-
if !isAnyStructOrAnyResourceType(foundRestrictedType.Type) {
384-
return getTypeMismatchError(expected, found)
385-
}
386-
// else go on to check type restrictions
387-
} else if foundRestrictedType.Type == nil {
388-
if !isAnyStructOrAnyResourceType(expected.Type) {
389-
return getTypeMismatchError(expected, found)
390-
}
391-
// else go on to check type restrictions
392-
} else {
393-
// both are not nil
394-
err := expected.Type.CheckEqual(foundRestrictedType.Type, validator)
395-
if err != nil {
396-
return getTypeMismatchError(expected, found)
397-
}
398-
}
399-
400-
if len(expected.Restrictions) != len(foundRestrictedType.Restrictions) {
401-
return getTypeMismatchError(expected, found)
402-
}
403-
404-
for index, expectedRestriction := range expected.Restrictions {
405-
foundRestriction := foundRestrictedType.Restrictions[index]
406-
err := expectedRestriction.CheckEqual(foundRestriction, validator)
407-
if err != nil {
408-
return getTypeMismatchError(expected, found)
409-
}
410-
}
411-
412-
return nil
413-
}
414-
415-
func (validator *ContractUpdateValidator) CheckInstantiationTypeEquality(expected *ast.InstantiationType, found ast.Type) error {
416-
foundInstType, ok := found.(*ast.InstantiationType)
417-
if !ok {
418-
return getTypeMismatchError(expected, found)
419-
}
420-
421-
err := expected.Type.CheckEqual(foundInstType.Type, validator)
422-
if err != nil || len(expected.TypeArguments) != len(foundInstType.TypeArguments) {
423-
return getTypeMismatchError(expected, found)
424-
}
425-
426-
for index, typeArgs := range expected.TypeArguments {
427-
otherTypeArgs := foundInstType.TypeArguments[index]
428-
err := typeArgs.Type.CheckEqual(otherTypeArgs.Type, validator)
429-
if err != nil {
430-
return getTypeMismatchError(expected, found)
431-
}
432-
}
433-
434-
return nil
435-
}
436-
437-
func (validator *ContractUpdateValidator) CheckFunctionTypeEquality(expected *ast.FunctionType, found ast.Type) error {
438-
foundFuncType, ok := found.(*ast.FunctionType)
439-
if !ok || len(expected.ParameterTypeAnnotations) != len(foundFuncType.ParameterTypeAnnotations) {
440-
return getTypeMismatchError(expected, found)
441-
}
442-
443-
for index, expectedParamType := range expected.ParameterTypeAnnotations {
444-
foundParamType := foundFuncType.ParameterTypeAnnotations[index]
445-
err := expectedParamType.Type.CheckEqual(foundParamType.Type, validator)
446-
if err != nil {
447-
return getTypeMismatchError(expected, found)
448-
}
449-
}
450-
451-
return expected.ReturnTypeAnnotation.Type.CheckEqual(foundFuncType.ReturnTypeAnnotation.Type, validator)
452-
}
453-
454-
func (validator *ContractUpdateValidator) CheckReferenceTypeEquality(expected *ast.ReferenceType, found ast.Type) error {
455-
refType, ok := found.(*ast.ReferenceType)
456-
if !ok {
457-
return getTypeMismatchError(expected, found)
458-
}
459-
460-
return expected.Type.CheckEqual(refType.Type, validator)
461-
}
462-
463-
func (validator *ContractUpdateValidator) checkNameEquality(expectedType *ast.NominalType, foundType *ast.NominalType) bool {
464-
isExpectedQualifiedName := expectedType.IsQualifiedName()
465-
isFoundQualifiedName := foundType.IsQualifiedName()
466-
467-
// A field with a composite type can be defined in two ways:
468-
// - Using type name (var x @ResourceName)
469-
// - Using qualified type name (var x @ContractName.ResourceName)
470-
471-
if isExpectedQualifiedName && !isFoundQualifiedName {
472-
return validator.checkIdentifierEquality(expectedType, foundType)
473-
}
474-
475-
if isFoundQualifiedName && !isExpectedQualifiedName {
476-
return validator.checkIdentifierEquality(foundType, expectedType)
477-
}
478-
479-
// At this point, either both are qualified names, or both are simple names.
480-
// Thus, do a one-to-one match.
481-
if expectedType.Identifier.Identifier != foundType.Identifier.Identifier {
482-
return false
483-
}
484-
485-
return identifiersEqual(expectedType.NestedIdentifiers, foundType.NestedIdentifiers)
486-
}
487-
488-
func (validator *ContractUpdateValidator) checkIdentifierEquality(
489-
qualifiedNominalType *ast.NominalType,
490-
simpleNominalType *ast.NominalType,
491-
) bool {
492-
493-
// Situation:
494-
// qualifiedNominalType -> identifier: A, nestedIdentifiers: [foo, bar, ...]
495-
// simpleNominalType -> identifier: foo, nestedIdentifiers: [bar, ...]
496-
497-
// If the first identifier (i.e: 'A') refers to a composite decl that is not the enclosing contract,
498-
// then it must be referring to an imported contract. That means the two types are no longer the same.
499-
if qualifiedNominalType.Identifier.Identifier != validator.rootDecl.DeclarationIdentifier().Identifier {
500-
return false
501-
}
502-
503-
if qualifiedNominalType.NestedIdentifiers[0].Identifier != simpleNominalType.Identifier.Identifier {
504-
return false
505-
}
506-
507-
return identifiersEqual(simpleNominalType.NestedIdentifiers, qualifiedNominalType.NestedIdentifiers[1:])
508-
}
509-
510313
func (validator *ContractUpdateValidator) checkConformances(
511314
oldDecl *ast.CompositeDeclaration,
512315
newDecl *ast.CompositeDeclaration,
@@ -562,46 +365,6 @@ func (validator *ContractUpdateValidator) getContractUpdateError() error {
562365
}
563366
}
564367

565-
func getTypeMismatchError(expectedType ast.Type, foundType ast.Type) *TypeMismatchError {
566-
return &TypeMismatchError{
567-
ExpectedType: expectedType,
568-
FoundType: foundType,
569-
Range: ast.NewUnmeteredRangeFromPositioned(foundType),
570-
}
571-
}
572-
573-
func identifiersEqual(expected []ast.Identifier, found []ast.Identifier) bool {
574-
if len(expected) != len(found) {
575-
return false
576-
}
577-
578-
for index, element := range found {
579-
if expected[index].Identifier != element.Identifier {
580-
return false
581-
}
582-
}
583-
return true
584-
}
585-
586-
func isAnyStructOrAnyResourceType(astType ast.Type) bool {
587-
// If the restricted type is not stated, then it is either AnyStruct or AnyResource
588-
if astType == nil {
589-
return true
590-
}
591-
592-
nominalType, ok := astType.(*ast.NominalType)
593-
if !ok {
594-
return false
595-
}
596-
597-
switch nominalType.Identifier.Identifier {
598-
case sema.AnyStructType.Name, sema.AnyResourceType.Name:
599-
return true
600-
default:
601-
return false
602-
}
603-
}
604-
605368
func containsEnumsInProgram(program *ast.Program) bool {
606369
declaration, err := getRootDeclaration(program)
607370

0 commit comments

Comments
 (0)