Skip to content

Separate type equality checker from contract update validator #1787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 4 additions & 241 deletions runtime/contract_update_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ import (

"github.com/onflow/cadence/runtime/ast"
"github.com/onflow/cadence/runtime/common"
"github.com/onflow/cadence/runtime/sema"
)

type ContractUpdateValidator struct {
TypeComparator

location Location
contractName string
oldProgram *ast.Program
newProgram *ast.Program
rootDecl ast.Declaration
currentDecl ast.Declaration
errors []error
}
Expand Down Expand Up @@ -68,7 +68,8 @@ func (validator *ContractUpdateValidator) Validate() error {
return validator.getContractUpdateError()
}

validator.rootDecl = newRootDecl
validator.TypeComparator.RootDeclIdentifier = newRootDecl.DeclarationIdentifier()

validator.checkDeclarationUpdatability(oldRootDecl, newRootDecl)

if validator.hasErrors() {
Expand Down Expand Up @@ -104,7 +105,6 @@ func getRootDeclaration(program *ast.Program) (ast.Declaration, error) {
return nil, &ContractNotFoundError{
Range: ast.NewUnmeteredRangeFromPositioned(program),
}

}

func (validator *ContractUpdateValidator) hasErrors() bool {
Expand Down Expand Up @@ -310,203 +310,6 @@ func (validator *ContractUpdateValidator) checkEnumCases(oldDeclaration ast.Decl
}
}

func (validator *ContractUpdateValidator) CheckNominalTypeEquality(expected *ast.NominalType, found ast.Type) error {
foundNominalType, ok := found.(*ast.NominalType)
if !ok {
return getTypeMismatchError(expected, found)
}

// First check whether the names are equal.
ok = validator.checkNameEquality(expected, foundNominalType)
if !ok {
return getTypeMismatchError(expected, found)
}

return nil
}

func (validator *ContractUpdateValidator) CheckOptionalTypeEquality(expected *ast.OptionalType, found ast.Type) error {
foundOptionalType, ok := found.(*ast.OptionalType)
if !ok {
return getTypeMismatchError(expected, found)
}

return expected.Type.CheckEqual(foundOptionalType.Type, validator)
}

func (validator *ContractUpdateValidator) CheckVariableSizedTypeEquality(expected *ast.VariableSizedType, found ast.Type) error {
foundVarSizedType, ok := found.(*ast.VariableSizedType)
if !ok {
return getTypeMismatchError(expected, found)
}

return expected.Type.CheckEqual(foundVarSizedType.Type, validator)
}

func (validator *ContractUpdateValidator) CheckConstantSizedTypeEquality(expected *ast.ConstantSizedType, found ast.Type) error {
foundConstSizedType, ok := found.(*ast.ConstantSizedType)
if !ok {
return getTypeMismatchError(expected, found)
}

// Check size
if foundConstSizedType.Size.Value.Cmp(expected.Size.Value) != 0 ||
foundConstSizedType.Size.Base != expected.Size.Base {
return getTypeMismatchError(expected, found)
}

// Check type
return expected.Type.CheckEqual(foundConstSizedType.Type, validator)
}

func (validator *ContractUpdateValidator) CheckDictionaryTypeEquality(expected *ast.DictionaryType, found ast.Type) error {
foundDictionaryType, ok := found.(*ast.DictionaryType)
if !ok {
return getTypeMismatchError(expected, found)
}

err := expected.KeyType.CheckEqual(foundDictionaryType.KeyType, validator)
if err != nil {
return err
}

return expected.ValueType.CheckEqual(foundDictionaryType.ValueType, validator)
}

func (validator *ContractUpdateValidator) CheckRestrictedTypeEquality(expected *ast.RestrictedType, found ast.Type) error {
foundRestrictedType, ok := found.(*ast.RestrictedType)
if !ok {
return getTypeMismatchError(expected, found)
}

if expected.Type == nil {
if !isAnyStructOrAnyResourceType(foundRestrictedType.Type) {
return getTypeMismatchError(expected, found)
}
// else go on to check type restrictions
} else if foundRestrictedType.Type == nil {
if !isAnyStructOrAnyResourceType(expected.Type) {
return getTypeMismatchError(expected, found)
}
// else go on to check type restrictions
} else {
// both are not nil
err := expected.Type.CheckEqual(foundRestrictedType.Type, validator)
if err != nil {
return getTypeMismatchError(expected, found)
}
}

if len(expected.Restrictions) != len(foundRestrictedType.Restrictions) {
return getTypeMismatchError(expected, found)
}

for index, expectedRestriction := range expected.Restrictions {
foundRestriction := foundRestrictedType.Restrictions[index]
err := expectedRestriction.CheckEqual(foundRestriction, validator)
if err != nil {
return getTypeMismatchError(expected, found)
}
}

return nil
}

func (validator *ContractUpdateValidator) CheckInstantiationTypeEquality(expected *ast.InstantiationType, found ast.Type) error {
foundInstType, ok := found.(*ast.InstantiationType)
if !ok {
return getTypeMismatchError(expected, found)
}

err := expected.Type.CheckEqual(foundInstType.Type, validator)
if err != nil || len(expected.TypeArguments) != len(foundInstType.TypeArguments) {
return getTypeMismatchError(expected, found)
}

for index, typeArgs := range expected.TypeArguments {
otherTypeArgs := foundInstType.TypeArguments[index]
err := typeArgs.Type.CheckEqual(otherTypeArgs.Type, validator)
if err != nil {
return getTypeMismatchError(expected, found)
}
}

return nil
}

func (validator *ContractUpdateValidator) CheckFunctionTypeEquality(expected *ast.FunctionType, found ast.Type) error {
foundFuncType, ok := found.(*ast.FunctionType)
if !ok || len(expected.ParameterTypeAnnotations) != len(foundFuncType.ParameterTypeAnnotations) {
return getTypeMismatchError(expected, found)
}

for index, expectedParamType := range expected.ParameterTypeAnnotations {
foundParamType := foundFuncType.ParameterTypeAnnotations[index]
err := expectedParamType.Type.CheckEqual(foundParamType.Type, validator)
if err != nil {
return getTypeMismatchError(expected, found)
}
}

return expected.ReturnTypeAnnotation.Type.CheckEqual(foundFuncType.ReturnTypeAnnotation.Type, validator)
}

func (validator *ContractUpdateValidator) CheckReferenceTypeEquality(expected *ast.ReferenceType, found ast.Type) error {
refType, ok := found.(*ast.ReferenceType)
if !ok {
return getTypeMismatchError(expected, found)
}

return expected.Type.CheckEqual(refType.Type, validator)
}

func (validator *ContractUpdateValidator) checkNameEquality(expectedType *ast.NominalType, foundType *ast.NominalType) bool {
isExpectedQualifiedName := expectedType.IsQualifiedName()
isFoundQualifiedName := foundType.IsQualifiedName()

// A field with a composite type can be defined in two ways:
// - Using type name (var x @ResourceName)
// - Using qualified type name (var x @ContractName.ResourceName)

if isExpectedQualifiedName && !isFoundQualifiedName {
return validator.checkIdentifierEquality(expectedType, foundType)
}

if isFoundQualifiedName && !isExpectedQualifiedName {
return validator.checkIdentifierEquality(foundType, expectedType)
}

// At this point, either both are qualified names, or both are simple names.
// Thus, do a one-to-one match.
if expectedType.Identifier.Identifier != foundType.Identifier.Identifier {
return false
}

return identifiersEqual(expectedType.NestedIdentifiers, foundType.NestedIdentifiers)
}

func (validator *ContractUpdateValidator) checkIdentifierEquality(
qualifiedNominalType *ast.NominalType,
simpleNominalType *ast.NominalType,
) bool {

// Situation:
// qualifiedNominalType -> identifier: A, nestedIdentifiers: [foo, bar, ...]
// simpleNominalType -> identifier: foo, nestedIdentifiers: [bar, ...]

// If the first identifier (i.e: 'A') refers to a composite decl that is not the enclosing contract,
// then it must be referring to an imported contract. That means the two types are no longer the same.
if qualifiedNominalType.Identifier.Identifier != validator.rootDecl.DeclarationIdentifier().Identifier {
return false
}

if qualifiedNominalType.NestedIdentifiers[0].Identifier != simpleNominalType.Identifier.Identifier {
return false
}

return identifiersEqual(simpleNominalType.NestedIdentifiers, qualifiedNominalType.NestedIdentifiers[1:])
}

func (validator *ContractUpdateValidator) checkConformances(
oldDecl *ast.CompositeDeclaration,
newDecl *ast.CompositeDeclaration,
Expand Down Expand Up @@ -562,46 +365,6 @@ func (validator *ContractUpdateValidator) getContractUpdateError() error {
}
}

func getTypeMismatchError(expectedType ast.Type, foundType ast.Type) *TypeMismatchError {
return &TypeMismatchError{
ExpectedType: expectedType,
FoundType: foundType,
Range: ast.NewUnmeteredRangeFromPositioned(foundType),
}
}

func identifiersEqual(expected []ast.Identifier, found []ast.Identifier) bool {
if len(expected) != len(found) {
return false
}

for index, element := range found {
if expected[index].Identifier != element.Identifier {
return false
}
}
return true
}

func isAnyStructOrAnyResourceType(astType ast.Type) bool {
// If the restricted type is not stated, then it is either AnyStruct or AnyResource
if astType == nil {
return true
}

nominalType, ok := astType.(*ast.NominalType)
if !ok {
return false
}

switch nominalType.Identifier.Identifier {
case sema.AnyStructType.Name, sema.AnyResourceType.Name:
return true
default:
return false
}
}

func containsEnumsInProgram(program *ast.Program) bool {
declaration, err := getRootDeclaration(program)

Expand Down
Loading