From c626c370e3ac84c85d7cf178aedaeb2d42c14ae6 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 30 Jan 2024 13:11:34 -0500 Subject: [PATCH 01/17] refactor update validator to allow multiple implementations of field checking --- runtime/stdlib/contract_update_validation.go | 105 ++++++++---------- .../legacy_contract_upgrade_validator.go | 77 +++++++++++++ 2 files changed, 126 insertions(+), 56 deletions(-) create mode 100644 runtime/stdlib/legacy_contract_upgrade_validator.go diff --git a/runtime/stdlib/contract_update_validation.go b/runtime/stdlib/contract_update_validation.go index b62f157ef..38a6d3b53 100644 --- a/runtime/stdlib/contract_update_validation.go +++ b/runtime/stdlib/contract_update_validation.go @@ -28,7 +28,15 @@ import ( ) type UpdateValidator interface { + ast.TypeEqualityChecker + Validate() error + report(error) + + getCurrentDeclaration() ast.Declaration + setCurrentDeclaration(ast.Declaration) + + checkField(oldField *ast.FieldDeclaration, newField *ast.FieldDeclaration) } type ContractUpdateValidator struct { @@ -63,21 +71,29 @@ func NewContractUpdateValidator( } } +func (validator *ContractUpdateValidator) getCurrentDeclaration() ast.Declaration { + return validator.currentDecl +} + +func (validator *ContractUpdateValidator) setCurrentDeclaration(decl ast.Declaration) { + validator.currentDecl = decl +} + // Validate validates the contract update, and returns an error if it is an invalid update. func (validator *ContractUpdateValidator) Validate() error { - oldRootDecl := validator.getRootDeclaration(validator.oldProgram) + oldRootDecl := getRootDeclaration(validator, validator.oldProgram) if validator.hasErrors() { return validator.getContractUpdateError() } - newRootDecl := validator.getRootDeclaration(validator.newProgram) + newRootDecl := getRootDeclaration(validator, validator.newProgram) if validator.hasErrors() { return validator.getContractUpdateError() } validator.TypeComparator.RootDeclIdentifier = newRootDecl.DeclarationIdentifier() - validator.checkDeclarationUpdatability(oldRootDecl, newRootDecl) + checkDeclarationUpdatability(validator, oldRootDecl, newRootDecl) if validator.hasErrors() { return validator.getContractUpdateError() @@ -86,8 +102,8 @@ func (validator *ContractUpdateValidator) Validate() error { return nil } -func (validator *ContractUpdateValidator) getRootDeclaration(program *ast.Program) ast.Declaration { - decl, err := getRootDeclaration(program) +func getRootDeclaration(validator UpdateValidator, program *ast.Program) ast.Declaration { + decl, err := getRootDeclarationOfProgram(program) if err != nil { validator.report(&ContractNotFoundError{ @@ -98,7 +114,7 @@ func (validator *ContractUpdateValidator) getRootDeclaration(program *ast.Progra return decl } -func getRootDeclaration(program *ast.Program) (ast.Declaration, error) { +func getRootDeclarationOfProgram(program *ast.Program) (ast.Declaration, error) { compositeDecl := program.SoleContractDeclaration() if compositeDecl != nil { return compositeDecl, nil @@ -118,7 +134,8 @@ func (validator *ContractUpdateValidator) hasErrors() bool { return len(validator.errors) > 0 } -func (validator *ContractUpdateValidator) checkDeclarationUpdatability( +func checkDeclarationUpdatability( + validator UpdateValidator, oldDeclaration ast.Declaration, newDeclaration ast.Declaration, ) { @@ -137,24 +154,28 @@ func (validator *ContractUpdateValidator) checkDeclarationUpdatability( return } - parentDecl := validator.currentDecl - validator.currentDecl = newDeclaration + parentDecl := validator.getCurrentDeclaration() + validator.setCurrentDeclaration(newDeclaration) defer func() { - validator.currentDecl = parentDecl + validator.setCurrentDeclaration(parentDecl) }() - validator.checkFields(oldDeclaration, newDeclaration) + checkFields(validator, oldDeclaration, newDeclaration) - validator.checkNestedDeclarations(oldDeclaration, newDeclaration) + checkNestedDeclarations(validator, oldDeclaration, newDeclaration) if newDecl, ok := newDeclaration.(*ast.CompositeDeclaration); ok { if oldDecl, ok := oldDeclaration.(*ast.CompositeDeclaration); ok { - validator.checkConformances(oldDecl, newDecl) + checkConformances(validator, oldDecl, newDecl) } } } -func (validator *ContractUpdateValidator) checkFields(oldDeclaration ast.Declaration, newDeclaration ast.Declaration) { +func checkFields( + validator UpdateValidator, + oldDeclaration ast.Declaration, + newDeclaration ast.Declaration, +) { oldFields := oldDeclaration.DeclarationMembers().FieldsByIdentifier() newFields := newDeclaration.DeclarationMembers().Fields() @@ -192,7 +213,8 @@ func (validator *ContractUpdateValidator) checkField(oldField *ast.FieldDeclarat } } -func (validator *ContractUpdateValidator) checkNestedDeclarations( +func checkNestedDeclarations( + validator UpdateValidator, oldDeclaration ast.Declaration, newDeclaration ast.Declaration, ) { @@ -208,7 +230,7 @@ func (validator *ContractUpdateValidator) checkNestedDeclarations( continue } - validator.checkDeclarationUpdatability(oldNestedDecl, newNestedDecl) + checkDeclarationUpdatability(validator, oldNestedDecl, newNestedDecl) // If there's a matching new decl, then remove the old one from the map. delete(oldNominalTypeDecls, newNestedDecl.Identifier.Identifier) @@ -223,7 +245,7 @@ func (validator *ContractUpdateValidator) checkNestedDeclarations( continue } - validator.checkDeclarationUpdatability(oldNestedDecl, newNestedDecl) + checkDeclarationUpdatability(validator, oldNestedDecl, newNestedDecl) // If there's a matching new decl, then remove the old one from the map. delete(oldNominalTypeDecls, newNestedDecl.Identifier.Identifier) @@ -238,7 +260,7 @@ func (validator *ContractUpdateValidator) checkNestedDeclarations( continue } - validator.checkDeclarationUpdatability(oldNestedDecl, newNestedDecl) + checkDeclarationUpdatability(validator, oldNestedDecl, newNestedDecl) // If there's a matching new decl, then remove the old one from the map. delete(oldNominalTypeDecls, newNestedDecl.Identifier.Identifier) @@ -270,7 +292,7 @@ func (validator *ContractUpdateValidator) checkNestedDeclarations( } // Check enum-cases, if there are any. - validator.checkEnumCases(oldDeclaration, newDeclaration) + checkEnumCases(validator, oldDeclaration, newDeclaration) } func getNestedNominalTypeDecls(declaration ast.Declaration) map[string]ast.Declaration { @@ -297,7 +319,11 @@ func getNestedNominalTypeDecls(declaration ast.Declaration) map[string]ast.Decla // checkEnumCases validates updating enum cases. Updated enum must: // - Have at-least the same number of enum-cases as the old enum (Adding is allowed, but no removals). // - Preserve the order of the old enum-cases (Adding to top/middle is not allowed, swapping is not allowed). -func (validator *ContractUpdateValidator) checkEnumCases(oldDeclaration ast.Declaration, newDeclaration ast.Declaration) { +func checkEnumCases( + validator UpdateValidator, + oldDeclaration ast.Declaration, + newDeclaration ast.Declaration, +) { newEnumCases := newDeclaration.DeclarationMembers().EnumCases() oldEnumCases := oldDeclaration.DeclarationMembers().EnumCases() @@ -337,7 +363,8 @@ func (validator *ContractUpdateValidator) checkEnumCases(oldDeclaration ast.Decl } } -func (validator *ContractUpdateValidator) checkConformances( +func checkConformances( + validator UpdateValidator, oldDecl *ast.CompositeDeclaration, newDecl *ast.CompositeDeclaration, ) { @@ -393,7 +420,7 @@ func (validator *ContractUpdateValidator) getContractUpdateError() error { } func containsEnumsInProgram(program *ast.Program) bool { - declaration, err := getRootDeclaration(program) + declaration, err := getRootDeclarationOfProgram(program) if err != nil { return false @@ -617,37 +644,3 @@ func (e *MissingDeclarationError) Error() string { e.Name, ) } - -type LegacyContractUpdateValidator struct { - TypeComparator - - location common.Location - contractName string - oldProgram *ast.Program - newProgram *ast.Program -} - -// NewContractUpdateValidator initializes and returns a validator, without performing any validation. -// Invoke the `Validate()` method of the validator returned, to start validating the contract. -func NewLegacyContractUpdateValidator( - location common.Location, - contractName string, - oldProgram *ast.Program, - newProgram *ast.Program, -) *LegacyContractUpdateValidator { - - return &LegacyContractUpdateValidator{ - location: location, - oldProgram: oldProgram, - newProgram: newProgram, - contractName: contractName, - } -} - -var _ UpdateValidator = &LegacyContractUpdateValidator{} - -// Validate validates the contract update, and returns an error if it is an invalid update. -// TODO: for now this is empty until we determine what validation is necessary for a Cadence 1.0 upgrade -func (validator *LegacyContractUpdateValidator) Validate() error { - return nil -} diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/legacy_contract_upgrade_validator.go new file mode 100644 index 000000000..51261ce95 --- /dev/null +++ b/runtime/stdlib/legacy_contract_upgrade_validator.go @@ -0,0 +1,77 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stdlib + +import ( + "github.com/onflow/cadence/runtime/ast" + "github.com/onflow/cadence/runtime/common" +) + +type LegacyContractUpdateValidator struct { + TypeComparator + + underlyingUpdateValidator *ContractUpdateValidator +} + +// NewContractUpdateValidator initializes and returns a validator, without performing any validation. +// Invoke the `Validate()` method of the validator returned, to start validating the contract. +func NewLegacyContractUpdateValidator( + location common.Location, + contractName string, + oldProgram *ast.Program, + newProgram *ast.Program, +) *LegacyContractUpdateValidator { + + underlyingValidator := NewContractUpdateValidator(location, contractName, oldProgram, newProgram) + + return &LegacyContractUpdateValidator{ + underlyingUpdateValidator: underlyingValidator, + } +} + +var _ UpdateValidator = &LegacyContractUpdateValidator{} + +func (validator *LegacyContractUpdateValidator) getCurrentDeclaration() ast.Declaration { + return validator.underlyingUpdateValidator.getCurrentDeclaration() +} + +func (validator *LegacyContractUpdateValidator) setCurrentDeclaration(decl ast.Declaration) { + validator.underlyingUpdateValidator.setCurrentDeclaration(decl) +} + +// Validate validates the contract update, and returns an error if it is an invalid update. +func (validator *LegacyContractUpdateValidator) Validate() error { + return validator.underlyingUpdateValidator.Validate() +} + +func (validator *LegacyContractUpdateValidator) report(err error) { + validator.underlyingUpdateValidator.report(err) +} + +func (validator *LegacyContractUpdateValidator) checkField(oldField *ast.FieldDeclaration, newField *ast.FieldDeclaration) { + err := oldField.TypeAnnotation.Type.CheckEqual(newField.TypeAnnotation.Type, validator) + if err != nil { + validator.report(&FieldMismatchError{ + DeclName: validator.getCurrentDeclaration().DeclarationIdentifier().Identifier, + FieldName: newField.Identifier.Identifier, + Err: err, + Range: ast.NewUnmeteredRangeFromPositioned(newField.TypeAnnotation), + }) + } +} From c2fcb6ad4d5298a8620000df59d79a5fa61221b8 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 30 Jan 2024 13:37:39 -0500 Subject: [PATCH 02/17] add some basic tests for the upgrade validator --- ...legacy_contract_upgrade_validation_test.go | 393 ++++++++++++++++++ 1 file changed, 393 insertions(+) create mode 100644 runtime/stdlib/legacy_contract_upgrade_validation_test.go diff --git a/runtime/stdlib/legacy_contract_upgrade_validation_test.go b/runtime/stdlib/legacy_contract_upgrade_validation_test.go new file mode 100644 index 000000000..e3a4f320d --- /dev/null +++ b/runtime/stdlib/legacy_contract_upgrade_validation_test.go @@ -0,0 +1,393 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stdlib_test + +import ( + "testing" + + "github.com/onflow/cadence/runtime/old_parser" + "github.com/onflow/cadence/runtime/parser" + "github.com/onflow/cadence/runtime/stdlib" + "github.com/onflow/cadence/runtime/tests/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testContractUpdate(t *testing.T, oldCode string, newCode string) error { + oldProgram, err := old_parser.ParseProgram(nil, []byte(oldCode), old_parser.Config{}) + require.NoError(t, err) + + newProgram, err := parser.ParseProgram(nil, []byte(newCode), parser.Config{}) + require.NoError(t, err) + + upgradeValidator := stdlib.NewLegacyContractUpdateValidator(utils.TestLocation, "Test", oldProgram, newProgram) + return upgradeValidator.Validate() +} + +func getSingleContractUpdateErrorCause(t *testing.T, err error, contractName string) error { + updateErr := getContractUpdateError(t, err, contractName) + + require.Len(t, updateErr.Errors, 1) + return updateErr.Errors[0] +} + +func getContractUpdateError(t *testing.T, err error, contractName string) *stdlib.ContractUpdateError { + require.Error(t, err) + + var contractUpdateErr *stdlib.ContractUpdateError + require.ErrorAs(t, err, &contractUpdateErr) + + assert.Equal(t, contractName, contractUpdateErr.ContractName) + + return contractUpdateErr +} + +func assertFieldTypeMismatchError( + t *testing.T, + err error, + erroneousDeclName string, + fieldName string, + expectedType string, + foundType string, +) { + var fieldMismatchError *stdlib.FieldMismatchError + require.ErrorAs(t, err, &fieldMismatchError) + + assert.Equal(t, fieldName, fieldMismatchError.FieldName) + assert.Equal(t, erroneousDeclName, fieldMismatchError.DeclName) + + var typeMismatchError *stdlib.TypeMismatchError + assert.ErrorAs(t, fieldMismatchError.Err, &typeMismatchError) + + assert.Equal(t, expectedType, typeMismatchError.ExpectedType.String()) + assert.Equal(t, foundType, typeMismatchError.FoundType.String()) +} + +func TestContractUpgradeFieldAccess(t *testing.T) { + + t.Parallel() + + t.Run("change field access to entitlement", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + access(all) contract Test { + access(all) var a: Int + init() { + self.a = 0 + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(E) var a: Int + init() { + self.a = 0 + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) + + t.Run("change field access to all", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub var a: Int + init() { + self.a = 0 + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) var a: Int + init() { + self.a = 0 + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) +} + +func TestContractUpgradeFieldType(t *testing.T) { + + t.Parallel() + + t.Run("change field types illegally", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + access(all) contract Test { + access(all) var a: Int + init() { + self.a = 0 + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) var a: String + init() { + self.a = "hello" + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldTypeMismatchError(t, cause, "Test", "a", "Int", "String") + + }) + + t.Run("change field type reference auth", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub var a: &Int + init() { + self.a = "hello" + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(all) var a: auth(E) &Int + init() { + self.a = 0 + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) + + t.Run("change field type capability reference auth", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub var a: Capability<&Int> + init() { + self.a = "hello" + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(all) var a: Capability + init() { + self.a = 0 + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + // TODO: this should not be allowed, the migration is not going to change the underlying referenced type of `a` + require.NoError(t, err) + }) + + t.Run("change field type restricted type", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub resource interface I {} + pub resource R:I {} + + pub var a: @R{I} + init() { + self.a <- create R() + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) resource interface I {} + access(all) resource R:I {} + + access(all) var a: @{I} + init() { + self.a <- create R() + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) +} + +func TestContractUpgradeIntersectionFieldType(t *testing.T) { + + t.Parallel() + + t.Run("change field type restricted reference type", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub resource interface I {} + pub resource R:I {} + + pub var a: &R{I}? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) resource interface I {} + access(all) resource R:I {} + + access(all) var a: &R? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + // TODO: this should be allowed, as the migration will convert `&R{I}` to `&R` + require.NoError(t, err) + }) + + t.Run("change field type restricted entitled reference type", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub resource interface I { + pub fun foo() + } + pub resource R:I { + pub fun foo() + } + + pub var a: &R{I}? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(all) resource interface I { + access(E) fun foo() {} + } + access(all) resource R:I { + access(E) fun foo() {} + } + + access(all) var a: auth(E) &R? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + // TODO: this should be allowed, as the migration will convert `&R{I}` to `auth(E) &R` + require.NoError(t, err) + }) + + t.Run("change field type restricted entitled reference type with too many granted entitlements", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub resource interface I { + pub fun foo() + } + pub resource R:I { + pub fun foo() + pub fun bar() + } + + pub var a: &R{I}? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(all) entitlement F + access(all) resource interface I { + access(E) fun foo() {} + } + access(all) resource R:I { + access(E) fun foo() {} + access(F) fun bar() {} + } + + access(all) var a: auth(E, F) &R? + init() { + self.a = nil + } + } + ` + + // TODO: this should not be allowed, as the migration will convert `&R{I}` to `auth(E) &R` + err := testContractUpdate(t, oldCode, newCode) + + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldTypeMismatchError(t, cause, "Test", "a", "&R{I}?", "auth(E, F) &R?") + }) +} From c222cf604a1487627f02655fe8a41dbdeeccf5f4 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 30 Jan 2024 14:15:47 -0500 Subject: [PATCH 03/17] upgrade rules for intersection types --- ...legacy_contract_upgrade_validation_test.go | 22 +++--- .../legacy_contract_upgrade_validator.go | 68 ++++++++++++++++--- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/runtime/stdlib/legacy_contract_upgrade_validation_test.go b/runtime/stdlib/legacy_contract_upgrade_validation_test.go index e3a4f320d..5831f3b67 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validation_test.go +++ b/runtime/stdlib/legacy_contract_upgrade_validation_test.go @@ -223,12 +223,19 @@ func TestContractUpgradeFieldType(t *testing.T) { } ` + // TODO: this should not be allowed, the migration is not going to change the underlying referenced type of `a` err := testContractUpdate(t, oldCode, newCode) - // TODO: this should not be allowed, the migration is not going to change the underlying referenced type of `a` - require.NoError(t, err) + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldTypeMismatchError(t, cause, "Test", "a", "Capability<&Int>", "Capability") }) +} + +func TestContractUpgradeIntersectionFieldType(t *testing.T) { + + t.Parallel() + t.Run("change field type restricted type", func(t *testing.T) { t.Parallel() @@ -259,13 +266,10 @@ func TestContractUpgradeFieldType(t *testing.T) { err := testContractUpdate(t, oldCode, newCode) - require.NoError(t, err) + // This is not allowed because `@R{I}` is converted to `@R`, not `@{I}` + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldTypeMismatchError(t, cause, "Test", "a", "R", "{I}") }) -} - -func TestContractUpgradeIntersectionFieldType(t *testing.T) { - - t.Parallel() t.Run("change field type restricted reference type", func(t *testing.T) { @@ -297,7 +301,6 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { err := testContractUpdate(t, oldCode, newCode) - // TODO: this should be allowed, as the migration will convert `&R{I}` to `&R` require.NoError(t, err) }) @@ -340,7 +343,6 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { err := testContractUpdate(t, oldCode, newCode) - // TODO: this should be allowed, as the migration will convert `&R{I}` to `auth(E) &R` require.NoError(t, err) }) diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/legacy_contract_upgrade_validator.go index 51261ce95..772836fb4 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validator.go +++ b/runtime/stdlib/legacy_contract_upgrade_validator.go @@ -57,21 +57,71 @@ func (validator *LegacyContractUpdateValidator) setCurrentDeclaration(decl ast.D // Validate validates the contract update, and returns an error if it is an invalid update. func (validator *LegacyContractUpdateValidator) Validate() error { - return validator.underlyingUpdateValidator.Validate() + underlyingValidator := validator.underlyingUpdateValidator + + oldRootDecl := getRootDeclaration(validator, underlyingValidator.oldProgram) + if underlyingValidator.hasErrors() { + return underlyingValidator.getContractUpdateError() + } + + newRootDecl := getRootDeclaration(validator, underlyingValidator.newProgram) + if underlyingValidator.hasErrors() { + return underlyingValidator.getContractUpdateError() + } + + validator.TypeComparator.RootDeclIdentifier = newRootDecl.DeclarationIdentifier() + + checkDeclarationUpdatability(validator, oldRootDecl, newRootDecl) + + if underlyingValidator.hasErrors() { + return underlyingValidator.getContractUpdateError() + } + + return nil } func (validator *LegacyContractUpdateValidator) report(err error) { validator.underlyingUpdateValidator.report(err) } +func (validator *LegacyContractUpdateValidator) checkTypeUpgradability(oldType ast.Type, newType ast.Type) error { + + switch oldType := oldType.(type) { + case *ast.OptionalType: + if newOptional, isOptional := newType.(*ast.OptionalType); isOptional { + return validator.checkTypeUpgradability(oldType.Type, newOptional.Type) + } + case *ast.ReferenceType: + if newReference, isReference := newType.(*ast.ReferenceType); isReference { + return validator.checkTypeUpgradability(oldType.Type, newReference.Type) + } + case *ast.IntersectionType: + // intersection types cannot be upgraded unless they have a legacy restricted type, + // in which case they must be upgraded according to the migration rules: i.e. R{I} -> R + if oldType.LegacyRestrictedType == nil { + break + } + return validator.checkTypeUpgradability(oldType.LegacyRestrictedType, newType) + } + + return oldType.CheckEqual(newType, validator) + +} + func (validator *LegacyContractUpdateValidator) checkField(oldField *ast.FieldDeclaration, newField *ast.FieldDeclaration) { - err := oldField.TypeAnnotation.Type.CheckEqual(newField.TypeAnnotation.Type, validator) - if err != nil { - validator.report(&FieldMismatchError{ - DeclName: validator.getCurrentDeclaration().DeclarationIdentifier().Identifier, - FieldName: newField.Identifier.Identifier, - Err: err, - Range: ast.NewUnmeteredRangeFromPositioned(newField.TypeAnnotation), - }) + oldType := oldField.TypeAnnotation.Type + newType := newField.TypeAnnotation.Type + + err := validator.checkTypeUpgradability(oldType, newType) + if err == nil { + return } + + validator.report(&FieldMismatchError{ + DeclName: validator.getCurrentDeclaration().DeclarationIdentifier().Identifier, + FieldName: newField.Identifier.Identifier, + Err: err, + Range: ast.NewUnmeteredRangeFromPositioned(newField.TypeAnnotation), + }) + } From 3782d878fc1ab3ed7b0415a6508b4b19a03ba778 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Tue, 30 Jan 2024 16:35:54 -0500 Subject: [PATCH 04/17] properly plumb through arrays, dicts, and capabilities --- runtime/stdlib/account.go | 1 + ...legacy_contract_upgrade_validation_test.go | 206 ++++++++++++++---- .../legacy_contract_upgrade_validator.go | 46 ++++ 3 files changed, 211 insertions(+), 42 deletions(-) diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index 6a0b9b219..8ca0a88b5 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -1619,6 +1619,7 @@ func changeAccountContracts( contractName, oldProgram, program.Program, + program.Elaboration, ) } else { validator = NewContractUpdateValidator( diff --git a/runtime/stdlib/legacy_contract_upgrade_validation_test.go b/runtime/stdlib/legacy_contract_upgrade_validation_test.go index 5831f3b67..6e0e9414d 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validation_test.go +++ b/runtime/stdlib/legacy_contract_upgrade_validation_test.go @@ -23,6 +23,7 @@ import ( "github.com/onflow/cadence/runtime/old_parser" "github.com/onflow/cadence/runtime/parser" + "github.com/onflow/cadence/runtime/sema" "github.com/onflow/cadence/runtime/stdlib" "github.com/onflow/cadence/runtime/tests/utils" "github.com/stretchr/testify/assert" @@ -36,7 +37,20 @@ func testContractUpdate(t *testing.T, oldCode string, newCode string) error { newProgram, err := parser.ParseProgram(nil, []byte(newCode), parser.Config{}) require.NoError(t, err) - upgradeValidator := stdlib.NewLegacyContractUpdateValidator(utils.TestLocation, "Test", oldProgram, newProgram) + checker, err := sema.NewChecker( + newProgram, + utils.TestLocation, + nil, + &sema.Config{ + AccessCheckMode: sema.AccessCheckModeStrict, + AttachmentsEnabled: true, + }) + require.NoError(t, err) + + err = checker.Check() + require.NoError(t, err) + + upgradeValidator := stdlib.NewLegacyContractUpdateValidator(utils.TestLocation, "Test", oldProgram, newProgram, checker.Elaboration) return upgradeValidator.Validate() } @@ -89,20 +103,24 @@ func TestContractUpgradeFieldAccess(t *testing.T) { const oldCode = ` access(all) contract Test { - access(all) var a: Int - init() { - self.a = 0 - } + access(all) resource R { + access(all) var a: Int + init() { + self.a = 0 + } + } } ` const newCode = ` access(all) contract Test { access(all) entitlement E - access(E) var a: Int - init() { - self.a = 0 - } + access(all) resource R { + access(E) var a: Int + init() { + self.a = 0 + } + } } ` @@ -172,15 +190,15 @@ func TestContractUpgradeFieldType(t *testing.T) { }) - t.Run("change field type reference auth", func(t *testing.T) { + t.Run("change field type capability reference auth", func(t *testing.T) { t.Parallel() const oldCode = ` pub contract Test { - pub var a: &Int + pub var a: Capability<&Int>? init() { - self.a = "hello" + self.a = nil } } ` @@ -188,55 +206,94 @@ func TestContractUpgradeFieldType(t *testing.T) { const newCode = ` access(all) contract Test { access(all) entitlement E - access(all) var a: auth(E) &Int + access(all) var a: Capability? init() { - self.a = 0 + self.a = nil } } ` + // TODO: this should not be allowed, the migration is not going to change the underlying referenced type of `a` err := testContractUpdate(t, oldCode, newCode) - require.NoError(t, err) + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldTypeMismatchError(t, cause, "Test", "a", "Capability<&Int>", "Capability") }) +} - t.Run("change field type capability reference auth", func(t *testing.T) { +func TestContractUpgradeIntersectionFieldType(t *testing.T) { + + t.Parallel() + + t.Run("change field type restricted type", func(t *testing.T) { t.Parallel() const oldCode = ` pub contract Test { - pub var a: Capability<&Int> + pub resource interface I {} + pub resource R:I {} + + pub var a: @R{I} init() { - self.a = "hello" + self.a <- create R() } } ` const newCode = ` access(all) contract Test { - access(all) entitlement E - access(all) var a: Capability + access(all) resource interface I {} + access(all) resource R:I {} + + access(all) var a: @{I} init() { - self.a = 0 + self.a <- create R() } } ` - // TODO: this should not be allowed, the migration is not going to change the underlying referenced type of `a` err := testContractUpdate(t, oldCode, newCode) + // This is not allowed because `@R{I}` is converted to `@R`, not `@{I}` cause := getSingleContractUpdateErrorCause(t, err, "Test") - assertFieldTypeMismatchError(t, cause, "Test", "a", "Capability<&Int>", "Capability") + assertFieldTypeMismatchError(t, cause, "Test", "a", "R", "{I}") }) -} + t.Run("change field type restricted type variable sized", func(t *testing.T) { -func TestContractUpgradeIntersectionFieldType(t *testing.T) { + t.Parallel() - t.Parallel() + const oldCode = ` + pub contract Test { + pub resource interface I {} + pub resource R:I {} - t.Run("change field type restricted type", func(t *testing.T) { + pub var a: @[R{I}] + init() { + self.a <- [<- create R()] + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) resource interface I {} + access(all) resource R:I {} + + access(all) var a: @[R] + init() { + self.a <- [<- create R()] + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) + + t.Run("change field type restricted type constant sized", func(t *testing.T) { t.Parallel() @@ -245,9 +302,9 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { pub resource interface I {} pub resource R:I {} - pub var a: @R{I} + pub var a: @[R{I}; 1] init() { - self.a <- create R() + self.a <- [<- create R()] } } ` @@ -257,18 +314,83 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { access(all) resource interface I {} access(all) resource R:I {} - access(all) var a: @{I} + access(all) var a: @[R; 1] init() { - self.a <- create R() + self.a <- [<- create R()] + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) + + t.Run("change field type restricted type constant sized with size change", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub resource interface I {} + pub resource R:I {} + + pub var a: @[R{I}; 1] + init() { + self.a <- [<- create R()] + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) resource interface I {} + access(all) resource R:I {} + + access(all) var a: @[R; 2] + init() { + self.a <- [<- create R(), <- create R()] } } ` err := testContractUpdate(t, oldCode, newCode) - // This is not allowed because `@R{I}` is converted to `@R`, not `@{I}` cause := getSingleContractUpdateErrorCause(t, err, "Test") - assertFieldTypeMismatchError(t, cause, "Test", "a", "R", "{I}") + assertFieldTypeMismatchError(t, cause, "Test", "a", "[{I}; 1]", "[R; 2]") + }) + + t.Run("change field type restricted type dict", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub resource interface I {} + pub resource R:I {} + + pub var a: @{Int: R{I}} + init() { + self.a <- {0: <- create R()} + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) resource interface I {} + access(all) resource R:I {} + + access(all) var a: @{Int: R} + init() { + self.a <- {0: <- create R()} + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) }) t.Run("change field type restricted reference type", func(t *testing.T) { @@ -280,7 +402,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { pub resource interface I {} pub resource R:I {} - pub var a: &R{I}? + pub var a: Capability<&R{I}>? init() { self.a = nil } @@ -292,7 +414,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { access(all) resource interface I {} access(all) resource R:I {} - access(all) var a: &R? + access(all) var a: Capability<&R>? init() { self.a = nil } @@ -317,7 +439,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { pub fun foo() } - pub var a: &R{I}? + pub var a: Capability<&R{I}>? init() { self.a = nil } @@ -328,13 +450,13 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { access(all) contract Test { access(all) entitlement E access(all) resource interface I { - access(E) fun foo() {} + access(E) fun foo() } access(all) resource R:I { access(E) fun foo() {} } - access(all) var a: auth(E) &R? + access(all) var a: Capability? init() { self.a = nil } @@ -360,7 +482,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { pub fun bar() } - pub var a: &R{I}? + pub var a: Capability<&R{I}>? init() { self.a = nil } @@ -372,14 +494,14 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { access(all) entitlement E access(all) entitlement F access(all) resource interface I { - access(E) fun foo() {} + access(E) fun foo() } access(all) resource R:I { access(E) fun foo() {} access(F) fun bar() {} } - access(all) var a: auth(E, F) &R? + access(all) var a: Capability? init() { self.a = nil } @@ -390,6 +512,6 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { err := testContractUpdate(t, oldCode, newCode) cause := getSingleContractUpdateErrorCause(t, err, "Test") - assertFieldTypeMismatchError(t, cause, "Test", "a", "&R{I}?", "auth(E, F) &R?") + assertFieldTypeMismatchError(t, cause, "Test", "a", "Capability", "Capability") }) } diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/legacy_contract_upgrade_validator.go index 772836fb4..c8dba50e0 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validator.go +++ b/runtime/stdlib/legacy_contract_upgrade_validator.go @@ -21,11 +21,14 @@ package stdlib import ( "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/sema" ) type LegacyContractUpdateValidator struct { TypeComparator + newElaboration *sema.Elaboration + underlyingUpdateValidator *ContractUpdateValidator } @@ -36,12 +39,14 @@ func NewLegacyContractUpdateValidator( contractName string, oldProgram *ast.Program, newProgram *ast.Program, + newElaboration *sema.Elaboration, ) *LegacyContractUpdateValidator { underlyingValidator := NewContractUpdateValidator(location, contractName, oldProgram, newProgram) return &LegacyContractUpdateValidator{ underlyingUpdateValidator: underlyingValidator, + newElaboration: newElaboration, } } @@ -102,6 +107,47 @@ func (validator *LegacyContractUpdateValidator) checkTypeUpgradability(oldType a break } return validator.checkTypeUpgradability(oldType.LegacyRestrictedType, newType) + case *ast.VariableSizedType: + if newVariableSizedType, isVariableSizedType := newType.(*ast.VariableSizedType); isVariableSizedType { + return validator.checkTypeUpgradability(oldType.Type, newVariableSizedType.Type) + } + case *ast.ConstantSizedType: + if newConstantSizedType, isConstantSizedType := newType.(*ast.ConstantSizedType); isConstantSizedType { + if oldType.Size.Value.Cmp(newConstantSizedType.Size.Value) != 0 || + oldType.Size.Base != newConstantSizedType.Size.Base { + return newTypeMismatchError(oldType, newConstantSizedType) + } + return validator.checkTypeUpgradability(oldType.Type, newConstantSizedType.Type) + } + case *ast.DictionaryType: + if newDictionaryType, isDictionaryType := newType.(*ast.DictionaryType); isDictionaryType { + err := validator.checkTypeUpgradability(oldType.KeyType, newDictionaryType.KeyType) + if err != nil { + return err + } + return validator.checkTypeUpgradability(oldType.ValueType, newDictionaryType.ValueType) + } + case *ast.InstantiationType: + // if the type is a Capability, allow the borrow type to change according to the normal upgrade rules + if oldNominalType, isNominal := oldType.Type.(*ast.NominalType); isNominal && + oldNominalType.Identifier.Identifier == "Capability" { + + if instantiationType, isInstantiation := newType.(*ast.InstantiationType); isInstantiation { + if newNominalType, isNominal := oldType.Type.(*ast.NominalType); isNominal && + newNominalType.Identifier.Identifier == "Capability" { + + // Capability insantiation types must have exactly 1 type argument + if len(oldType.TypeArguments) != 1 || len(instantiationType.TypeArguments) != 1 { + break + } + + oldTypeArg := oldType.TypeArguments[0] + newTypeArg := instantiationType.TypeArguments[0] + + return validator.checkTypeUpgradability(oldTypeArg.Type, newTypeArg.Type) + } + } + } } return oldType.CheckEqual(newType, validator) From 01c2da154ac413a0fdd5a82312cfa0cbc398e474 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 31 Jan 2024 10:52:16 -0500 Subject: [PATCH 05/17] in progress entitlements check --- runtime/stdlib/legacy_contract_upgrade_validator.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/legacy_contract_upgrade_validator.go index c8dba50e0..ac84651e1 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validator.go +++ b/runtime/stdlib/legacy_contract_upgrade_validator.go @@ -89,6 +89,10 @@ func (validator *LegacyContractUpdateValidator) report(err error) { validator.underlyingUpdateValidator.report(err) } +func (validator *LegacyContractUpdateValidator) checkEntitlementsUpgrade(oldType *ast.ReferenceType, newType *ast.ReferenceType) error { + newAuthorization := newType.Authorization +} + func (validator *LegacyContractUpdateValidator) checkTypeUpgradability(oldType ast.Type, newType ast.Type) error { switch oldType := oldType.(type) { @@ -98,6 +102,12 @@ func (validator *LegacyContractUpdateValidator) checkTypeUpgradability(oldType a } case *ast.ReferenceType: if newReference, isReference := newType.(*ast.ReferenceType); isReference { + if newReference.Authorization != nil { + err := validator.checkEntitlementsUpgrade(oldType, newReference) + if err != nil { + return err + } + } return validator.checkTypeUpgradability(oldType.Type, newReference.Type) } case *ast.IntersectionType: From f66371dd2ca063d44fa78c00c77aea89cf754cf6 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 31 Jan 2024 16:39:47 -0500 Subject: [PATCH 06/17] verification for adding entitlements to references to composites --- ...legacy_contract_upgrade_validation_test.go | 217 +++++++++++++++++- .../legacy_contract_upgrade_validator.go | 155 ++++++++++++- 2 files changed, 369 insertions(+), 3 deletions(-) diff --git a/runtime/stdlib/legacy_contract_upgrade_validation_test.go b/runtime/stdlib/legacy_contract_upgrade_validation_test.go index 6e0e9414d..7f6c88b3b 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validation_test.go +++ b/runtime/stdlib/legacy_contract_upgrade_validation_test.go @@ -93,6 +93,27 @@ func assertFieldTypeMismatchError( assert.Equal(t, foundType, typeMismatchError.FoundType.String()) } +func assertFieldAuthorizationMismatchError( + t *testing.T, + err error, + erroneousDeclName string, + fieldName string, + expectedType string, + foundType string, +) { + var fieldMismatchError *stdlib.FieldMismatchError + require.ErrorAs(t, err, &fieldMismatchError) + + assert.Equal(t, fieldName, fieldMismatchError.FieldName) + assert.Equal(t, erroneousDeclName, fieldMismatchError.DeclName) + + var authorizationMismatchError *stdlib.AuthorizationMismatchError + assert.ErrorAs(t, fieldMismatchError.Err, &authorizationMismatchError) + + assert.Equal(t, expectedType, authorizationMismatchError.ExpectedAuthorization.String()) + assert.Equal(t, foundType, authorizationMismatchError.FoundAuthorization.String()) +} + func TestContractUpgradeFieldAccess(t *testing.T) { t.Parallel() @@ -190,6 +211,43 @@ func TestContractUpgradeFieldType(t *testing.T) { }) + t.Run("change field intersection types illegally", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + access(all) contract Test { + access(all) struct interface I {} + access(all) struct interface J {} + access(all) struct S: I, J {} + + access(all) var a: {I} + init() { + self.a = S() + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) struct interface I {} + access(all) struct interface J {} + access(all) struct S: I, J {} + + access(all) var a: {I, J} + init() { + self.a = S() + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldTypeMismatchError(t, cause, "Test", "a", "{I}", "{I, J}") + + }) + t.Run("change field type capability reference auth", func(t *testing.T) { t.Parallel() @@ -213,11 +271,166 @@ func TestContractUpgradeFieldType(t *testing.T) { } ` - // TODO: this should not be allowed, the migration is not going to change the underlying referenced type of `a` err := testContractUpdate(t, oldCode, newCode) cause := getSingleContractUpdateErrorCause(t, err, "Test") - assertFieldTypeMismatchError(t, cause, "Test", "a", "Capability<&Int>", "Capability") + assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "all", "E") + }) + + t.Run("change field type capability reference auth allowed composite", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub struct S { + pub fun foo() {} + } + + pub var a: Capability<&S>? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + + access(all) struct S { + access(E) fun foo() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) + + t.Run("change field type capability reference auth allowed too many entitlements", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub struct S { + pub fun foo() {} + } + + pub var a: Capability<&S>? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(all) entitlement F + + access(all) struct S { + access(E) fun foo() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "E", "E, F") + }) + + t.Run("change field type capability reference auth fewer entitlements", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub struct S { + pub fun foo() {} + pub fun bar() {} + } + + pub var a: Capability<&S>? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(all) entitlement F + + access(all) struct S { + access(E) fun foo() {} + access(F) fun bar() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) + + t.Run("change field type capability reference auth disjunctive entitlements", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub struct S { + pub fun foo() {} + pub fun bar() {} + } + + pub var a: Capability<&S>? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(all) entitlement F + + access(all) struct S { + access(E) fun foo() {} + access(F) fun bar() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) }) } diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/legacy_contract_upgrade_validator.go index ac84651e1..2975c3de4 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validator.go +++ b/runtime/stdlib/legacy_contract_upgrade_validator.go @@ -19,8 +19,12 @@ package stdlib import ( + "fmt" + "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/common/orderedmap" + "github.com/onflow/cadence/runtime/errors" "github.com/onflow/cadence/runtime/sema" ) @@ -89,8 +93,125 @@ func (validator *LegacyContractUpdateValidator) report(err error) { validator.underlyingUpdateValidator.report(err) } -func (validator *LegacyContractUpdateValidator) checkEntitlementsUpgrade(oldType *ast.ReferenceType, newType *ast.ReferenceType) error { +func (validator *LegacyContractUpdateValidator) idOfQualifiedType(typ *ast.NominalType) common.TypeID { + + qualifiedString := typ.String() + + // working under the assumption that any new program we are validating already typechecks, + // any nominal type must fall into one of three cases: + // 1) type qualified by an import (e.g `C.R` where `C` is an imported type) + // 2) type qualified by the root declaration (e.g `C.R` where `C` is the root contract or contract interface of the new contract) + // 3) unqualified type (e.g. `R`, but declared inside `C`) + // + // in case 3, we prepend the root declaration identifer with a `.` to the type's string to get its qualified name, + // and in 1 and 2 we don't need to do anything + typIdentifier := typ.Identifier.Identifier + rootIdentifier := validator.TypeComparator.RootDeclIdentifier.Identifier + + if typIdentifier != rootIdentifier { // && + // && validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier] == nil + qualifiedString = fmt.Sprintf("%s.%s", rootIdentifier, qualifiedString) + + } + return common.NewTypeIDFromQualifiedName(nil, validator.underlyingUpdateValidator.location, qualifiedString) +} + +func (validator *LegacyContractUpdateValidator) getEntitlementType(entitlement *ast.NominalType) *sema.EntitlementType { + typeID := validator.idOfQualifiedType(entitlement) + return validator.newElaboration.EntitlementType(typeID) +} + +func (validator *LegacyContractUpdateValidator) getEntitlementSetAccess(entitlementSet ast.EntitlementSet) sema.EntitlementSetAccess { + var entitlements []*sema.EntitlementType + + for _, entitlement := range entitlementSet.Entitlements() { + entitlements = append(entitlements, validator.getEntitlementType(entitlement)) + } + + entitlementSetKind := sema.Conjunction + if entitlementSet.Separator() == ast.Disjunction { + entitlementSetKind = sema.Disjunction + } + + return sema.NewEntitlementSetAccess(entitlements, entitlementSetKind) +} + +func (validator *LegacyContractUpdateValidator) getCompositeType(composite *ast.NominalType) *sema.CompositeType { + typeID := validator.idOfQualifiedType(composite) + return validator.newElaboration.CompositeType(typeID) +} + +func (validator *LegacyContractUpdateValidator) getInterfaceType(intf *ast.NominalType) *sema.InterfaceType { + typeID := validator.idOfQualifiedType(intf) + return validator.newElaboration.InterfaceType(typeID) +} + +func (validator *LegacyContractUpdateValidator) getIntersectedInterfaces(intersections *ast.IntersectionType) (intfs []*sema.InterfaceType) { + for _, intf := range intersections.Types { + intfs = append(intfs, validator.getInterfaceType(intf)) + } + return +} + +func (validator *LegacyContractUpdateValidator) requirePermitsAccess( + expected sema.Access, + found sema.EntitlementSetAccess, + foundType ast.Type, +) error { + if !found.PermitsAccess(expected) { + return &AuthorizationMismatchError{ + FoundAuthorization: found, + ExpectedAuthorization: expected, + Range: ast.NewUnmeteredRangeFromPositioned(foundType), + } + } + return nil +} + +func (validator *LegacyContractUpdateValidator) checkEntitlementsUpgrade( + oldType *ast.ReferenceType, + newType *ast.ReferenceType, +) error { newAuthorization := newType.Authorization + newEntitlementSet, isEntitlementsSet := newAuthorization.(ast.EntitlementSet) + foundEntitlementSet := validator.getEntitlementSetAccess(newEntitlementSet) + + // if the new authorization is not an entitlements set, there's nothing to check here + if !isEntitlementsSet { + return nil + } + + switch newReferencedType := newType.Type.(type) { + // a lone nominal type must be a composite + case *ast.NominalType: + compositeType := validator.getCompositeType(newReferencedType) + + expectedAccess := sema.UnauthorizedAccess + + if compositeType != nil { + supportedEntitlements := compositeType.SupportedEntitlements() + expectedAccess = sema.NewAccessFromEntitlementSet(supportedEntitlements, sema.Conjunction) + } + + return validator.requirePermitsAccess(expectedAccess, foundEntitlementSet, newReferencedType) + + // a reference to an intersection (or restricted) type is granted entitlements based on the intersected interfaces, + // ignoring the legacy restricted type + case *ast.IntersectionType: + interfaces := validator.getIntersectedInterfaces(newReferencedType) + + supportedEntitlements := orderedmap.New[sema.EntitlementOrderedSet](0) + + for _, interfaceType := range interfaces { + supportedEntitlements.SetAll(interfaceType.SupportedEntitlements()) + } + + expectedEntitlementSet := sema.NewAccessFromEntitlementSet(supportedEntitlements, sema.Conjunction) + + return validator.requirePermitsAccess(expectedEntitlementSet, foundEntitlementSet, newReferencedType) + } + + return nil } func (validator *LegacyContractUpdateValidator) checkTypeUpgradability(oldType ast.Type, newType ast.Type) error { @@ -179,5 +300,37 @@ func (validator *LegacyContractUpdateValidator) checkField(oldField *ast.FieldDe Err: err, Range: ast.NewUnmeteredRangeFromPositioned(newField.TypeAnnotation), }) +} + +// AuthorizationMismatchError is reported during a contract upgrade, +// when a field value is given authorization that is more powerful +// than that which the migration would grant it +type AuthorizationMismatchError struct { + ExpectedAuthorization sema.Access + FoundAuthorization sema.Access + ast.Range +} + +var _ errors.UserError = &AuthorizationMismatchError{} +var _ errors.SecondaryError = &AuthorizationMismatchError{} + +func (*AuthorizationMismatchError) IsUserError() {} + +func (e *AuthorizationMismatchError) Error() string { + return "mismatching authorization" +} + +func (e *AuthorizationMismatchError) SecondaryError() string { + if e.ExpectedAuthorization == sema.PrimitiveAccess(ast.AccessAll) { + return fmt.Sprintf( + "The entitlements migration would not grant this value any entitlements, but the annotation present is `%s`", + e.FoundAuthorization.QualifiedString(), + ) + } + return fmt.Sprintf( + "The entitlements migration would only grant this value `%s`, but the annotation present is `%s`", + e.ExpectedAuthorization.QualifiedString(), + e.FoundAuthorization.QualifiedString(), + ) } From f0960d4aaef9c6b5ff5b5f55b54334e22646676e Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 31 Jan 2024 16:48:33 -0500 Subject: [PATCH 07/17] verification for adding entitlements on intersection -> intersection upgrades --- ...legacy_contract_upgrade_validation_test.go | 248 ++++++++++++++++++ 1 file changed, 248 insertions(+) diff --git a/runtime/stdlib/legacy_contract_upgrade_validation_test.go b/runtime/stdlib/legacy_contract_upgrade_validation_test.go index 7f6c88b3b..ea5d828d6 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validation_test.go +++ b/runtime/stdlib/legacy_contract_upgrade_validation_test.go @@ -434,6 +434,254 @@ func TestContractUpgradeFieldType(t *testing.T) { }) } +func TestContractUpgradeIntersectionAuthorization(t *testing.T) { + + t.Parallel() + + t.Run("change field type capability reference auth allowed intersection", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub struct interface I { + pub fun foo() + } + pub struct S:I { + pub fun foo() {} + } + + pub var a: Capability<&{I}>? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + + access(all) struct interface I { + access(E) fun foo() + } + + access(all) struct S:I { + access(E) fun foo() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) + + t.Run("change field type capability reference auth allowed too many entitlements", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub struct interface I {} + pub struct S:I { + pub fun foo() {} + } + + pub var a: Capability<&{I}>? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + + access(all) struct interface I {} + + access(all) struct S:I { + access(E) fun foo() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "all", "E") + }) + + t.Run("change field type capability reference auth allowed multiple intersected", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub struct interface I { + pub fun bar() + } + pub struct interface J { + pub fun foo() + } + pub struct S:I, J { + pub fun foo() {} + pub fun bar() {} + } + + pub var a: Capability<&{I, J}>? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(all) entitlement F + + access(all) struct interface I { + access(E) fun foo() + } + access(all) struct interface J { + access(F) fun bar() + } + + access(all) struct S:I, J { + access(E) fun foo() {} + access(F) fun bar() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) + + t.Run("change field type capability reference auth allowed multiple intersected fewer entitlements", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub struct interface I { + pub fun bar() + } + pub struct interface J { + pub fun foo() + } + pub struct S:I, J { + pub fun foo() {} + pub fun bar() {} + } + + pub var a: Capability<&{I, J}>? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(all) entitlement F + + access(all) struct interface I { + access(E) fun foo() + } + access(all) struct interface J { + access(F) fun bar() + } + + access(all) struct S:I, J { + access(E) fun foo() {} + access(F) fun bar() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) + + t.Run("change field type capability reference auth multiple intersected with too many entitlements", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub struct interface I { + pub fun bar() + } + pub struct interface J { + pub fun foo() + } + pub struct S:I, J { + pub fun foo() {} + pub fun bar() {} + } + + pub var a: Capability<&{I, J}>? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(all) entitlement F + + access(all) struct interface I { + access(E) fun foo() + } + access(all) struct interface J {} + + access(all) struct S:I, J { + access(E) fun foo() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "E", "E, F") + }) + +} + func TestContractUpgradeIntersectionFieldType(t *testing.T) { t.Parallel() From b0e98ce8fb85326446eb3347e5e6e3a298d95d49 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 31 Jan 2024 17:13:04 -0500 Subject: [PATCH 08/17] add support for entitlements of restricted -> composite upgrade --- ...legacy_contract_upgrade_validation_test.go | 3 +- .../legacy_contract_upgrade_validator.go | 82 ++++++++++++------- 2 files changed, 53 insertions(+), 32 deletions(-) diff --git a/runtime/stdlib/legacy_contract_upgrade_validation_test.go b/runtime/stdlib/legacy_contract_upgrade_validation_test.go index ea5d828d6..173b5fe1c 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validation_test.go +++ b/runtime/stdlib/legacy_contract_upgrade_validation_test.go @@ -969,10 +969,9 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { } ` - // TODO: this should not be allowed, as the migration will convert `&R{I}` to `auth(E) &R` err := testContractUpdate(t, oldCode, newCode) cause := getSingleContractUpdateErrorCause(t, err, "Test") - assertFieldTypeMismatchError(t, cause, "Test", "a", "Capability", "Capability") + assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "E", "E, F") }) } diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/legacy_contract_upgrade_validator.go index 2975c3de4..f32110208 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validator.go +++ b/runtime/stdlib/legacy_contract_upgrade_validator.go @@ -31,7 +31,8 @@ import ( type LegacyContractUpdateValidator struct { TypeComparator - newElaboration *sema.Elaboration + newElaboration *sema.Elaboration + currentRestrictedTypeUpgradeRestrictions []*ast.NominalType underlyingUpdateValidator *ContractUpdateValidator } @@ -146,8 +147,8 @@ func (validator *LegacyContractUpdateValidator) getInterfaceType(intf *ast.Nomin return validator.newElaboration.InterfaceType(typeID) } -func (validator *LegacyContractUpdateValidator) getIntersectedInterfaces(intersections *ast.IntersectionType) (intfs []*sema.InterfaceType) { - for _, intf := range intersections.Types { +func (validator *LegacyContractUpdateValidator) getIntersectedInterfaces(intersection []*ast.NominalType) (intfs []*sema.InterfaceType) { + for _, intf := range intersection { intfs = append(intfs, validator.getInterfaceType(intf)) } return @@ -168,6 +169,40 @@ func (validator *LegacyContractUpdateValidator) requirePermitsAccess( return nil } +func (validator *LegacyContractUpdateValidator) expectedAuthorizationOfComposite(composite *ast.NominalType) sema.Access { + // if this field is set, we are currently upgrading a formerly legacy restricted type into a reference to a composite + // in this case, the expected entitlements are based not on the underlying composite type, + // but instead the types previously in the restriction set + if validator.currentRestrictedTypeUpgradeRestrictions != nil { + return validator.expectedAuthorizationOfIntersection(validator.currentRestrictedTypeUpgradeRestrictions) + } + + compositeType := validator.getCompositeType(composite) + + if compositeType == nil { + return sema.UnauthorizedAccess + } + + supportedEntitlements := compositeType.SupportedEntitlements() + return sema.NewAccessFromEntitlementSet(supportedEntitlements, sema.Conjunction) +} + +func (validator *LegacyContractUpdateValidator) expectedAuthorizationOfIntersection(intersectionTypes []*ast.NominalType) sema.Access { + + // a reference to an intersection (or restricted) type is granted entitlements based on the intersected interfaces, + // ignoring the legacy restricted type, as an intersection type appearing in the new contract means it must have originally + // been a restricted type with no legacy type + interfaces := validator.getIntersectedInterfaces(intersectionTypes) + + supportedEntitlements := orderedmap.New[sema.EntitlementOrderedSet](0) + + for _, interfaceType := range interfaces { + supportedEntitlements.SetAll(interfaceType.SupportedEntitlements()) + } + + return sema.NewAccessFromEntitlementSet(supportedEntitlements, sema.Conjunction) +} + func (validator *LegacyContractUpdateValidator) checkEntitlementsUpgrade( oldType *ast.ReferenceType, newType *ast.ReferenceType, @@ -184,31 +219,12 @@ func (validator *LegacyContractUpdateValidator) checkEntitlementsUpgrade( switch newReferencedType := newType.Type.(type) { // a lone nominal type must be a composite case *ast.NominalType: - compositeType := validator.getCompositeType(newReferencedType) - - expectedAccess := sema.UnauthorizedAccess - - if compositeType != nil { - supportedEntitlements := compositeType.SupportedEntitlements() - expectedAccess = sema.NewAccessFromEntitlementSet(supportedEntitlements, sema.Conjunction) - } - + expectedAccess := validator.expectedAuthorizationOfComposite(newReferencedType) return validator.requirePermitsAccess(expectedAccess, foundEntitlementSet, newReferencedType) - // a reference to an intersection (or restricted) type is granted entitlements based on the intersected interfaces, - // ignoring the legacy restricted type case *ast.IntersectionType: - interfaces := validator.getIntersectedInterfaces(newReferencedType) - - supportedEntitlements := orderedmap.New[sema.EntitlementOrderedSet](0) - - for _, interfaceType := range interfaces { - supportedEntitlements.SetAll(interfaceType.SupportedEntitlements()) - } - - expectedEntitlementSet := sema.NewAccessFromEntitlementSet(supportedEntitlements, sema.Conjunction) - - return validator.requirePermitsAccess(expectedEntitlementSet, foundEntitlementSet, newReferencedType) + expectedAccess := validator.expectedAuthorizationOfIntersection(newReferencedType.Types) + return validator.requirePermitsAccess(expectedAccess, foundEntitlementSet, newReferencedType) } return nil @@ -223,13 +239,16 @@ func (validator *LegacyContractUpdateValidator) checkTypeUpgradability(oldType a } case *ast.ReferenceType: if newReference, isReference := newType.(*ast.ReferenceType); isReference { + err := validator.checkTypeUpgradability(oldType.Type, newReference.Type) + if err != nil { + return err + } + if newReference.Authorization != nil { - err := validator.checkEntitlementsUpgrade(oldType, newReference) - if err != nil { - return err - } + return validator.checkEntitlementsUpgrade(oldType, newReference) + } - return validator.checkTypeUpgradability(oldType.Type, newReference.Type) + return nil } case *ast.IntersectionType: // intersection types cannot be upgraded unless they have a legacy restricted type, @@ -237,7 +256,9 @@ func (validator *LegacyContractUpdateValidator) checkTypeUpgradability(oldType a if oldType.LegacyRestrictedType == nil { break } + validator.currentRestrictedTypeUpgradeRestrictions = oldType.Types return validator.checkTypeUpgradability(oldType.LegacyRestrictedType, newType) + case *ast.VariableSizedType: if newVariableSizedType, isVariableSizedType := newType.(*ast.VariableSizedType); isVariableSizedType { return validator.checkTypeUpgradability(oldType.Type, newVariableSizedType.Type) @@ -289,6 +310,7 @@ func (validator *LegacyContractUpdateValidator) checkField(oldField *ast.FieldDe oldType := oldField.TypeAnnotation.Type newType := newField.TypeAnnotation.Type + validator.currentRestrictedTypeUpgradeRestrictions = nil err := validator.checkTypeUpgradability(oldType, newType) if err == nil { return From 06b3c310d479f34aa986b409a576f9b360a36d74 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 31 Jan 2024 17:14:39 -0500 Subject: [PATCH 09/17] add todo comment --- runtime/stdlib/legacy_contract_upgrade_validator.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/legacy_contract_upgrade_validator.go index f32110208..744c4c026 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validator.go +++ b/runtime/stdlib/legacy_contract_upgrade_validator.go @@ -109,7 +109,8 @@ func (validator *LegacyContractUpdateValidator) idOfQualifiedType(typ *ast.Nomin typIdentifier := typ.Identifier.Identifier rootIdentifier := validator.TypeComparator.RootDeclIdentifier.Identifier - if typIdentifier != rootIdentifier { // && + if typIdentifier != rootIdentifier { + // TODO: add and test support for qualifying types imported from other contracts // && validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier] == nil qualifiedString = fmt.Sprintf("%s.%s", rootIdentifier, qualifiedString) From 532565d2098e048a81f7727d6d2d74e1aea9d3bb Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Fri, 2 Feb 2024 09:43:53 -0500 Subject: [PATCH 10/17] enable removal of conformances in contract update --- runtime/contract_update_validation_test.go | 6 +-- runtime/stdlib/contract_update_validation.go | 44 ++++++++------------ 2 files changed, 18 insertions(+), 32 deletions(-) diff --git a/runtime/contract_update_validation_test.go b/runtime/contract_update_validation_test.go index dcf8c5259..401d34625 100644 --- a/runtime/contract_update_validation_test.go +++ b/runtime/contract_update_validation_test.go @@ -2044,11 +2044,7 @@ func TestRuntimeContractUpdateConformanceChanges(t *testing.T) { ` err := testDeployAndUpdate(t, "Test", oldCode, newCode) - RequireError(t, err) - - cause := getSingleContractUpdateErrorCause(t, err, "Test") - - assertConformanceMismatchError(t, cause, "Foo") + require.NoError(t, err) }) t.Run("Change conformance order", func(t *testing.T) { diff --git a/runtime/stdlib/contract_update_validation.go b/runtime/stdlib/contract_update_validation.go index 38a6d3b53..24c2d6136 100644 --- a/runtime/stdlib/contract_update_validation.go +++ b/runtime/stdlib/contract_update_validation.go @@ -166,7 +166,7 @@ func checkDeclarationUpdatability( if newDecl, ok := newDeclaration.(*ast.CompositeDeclaration); ok { if oldDecl, ok := oldDeclaration.(*ast.CompositeDeclaration); ok { - checkConformances(validator, oldDecl, newDecl) + checkConformance(validator, oldDecl, newDecl) } } } @@ -363,45 +363,35 @@ func checkEnumCases( } } -func checkConformances( +func checkConformance( validator UpdateValidator, oldDecl *ast.CompositeDeclaration, newDecl *ast.CompositeDeclaration, ) { + // at this point the declaration kinds are known to be the same + if oldDecl.DeclarationKind() != common.DeclarationKindEnum { + return + } + // Here it is assumed enums will always have one and only one conformance. // This is enforced by the checker. // Therefore, below check for multiple conformances is only applicable // for non-enum type composite declarations. i.e: structs, resources, etc. - oldConformances := oldDecl.Conformances - newConformances := newDecl.Conformances - - // All the existing conformances must have a match. Order is not important. - // Having extra new conformance is OK. See: https://github.com/onflow/cadence/issues/1394 - for _, oldConformance := range oldConformances { - found := false - for index, newConformance := range newConformances { - err := oldConformance.CheckEqual(newConformance, validator) - if err == nil { - found = true - - // Remove the matched conformance, so we don't have to check it again. - // i.e: optimization - newConformances = append(newConformances[:index], newConformances[index+1:]...) - break - } - } + oldConformance := oldDecl.Conformances[0] + newConformance := newDecl.Conformances[0] - if !found { - validator.report(&ConformanceMismatchError{ - DeclName: newDecl.Identifier.Identifier, - Range: ast.NewUnmeteredRangeFromPositioned(newDecl.Identifier), - }) + err := oldConformance.CheckEqual(newConformance, validator) - return - } + if err == nil { + return } + + validator.report(&ConformanceMismatchError{ + DeclName: newDecl.Identifier.Identifier, + Range: ast.NewUnmeteredRangeFromPositioned(newDecl.Identifier), + }) } func (validator *ContractUpdateValidator) report(err error) { From 08bfe2d9f294b5d2cafffbb05531996674b76934 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 7 Feb 2024 12:15:00 -0500 Subject: [PATCH 11/17] collect imports in upgrade validator --- runtime/stdlib/legacy_contract_upgrade_validator.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/legacy_contract_upgrade_validator.go index 744c4c026..f14116ad1 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validator.go +++ b/runtime/stdlib/legacy_contract_upgrade_validator.go @@ -80,6 +80,8 @@ func (validator *LegacyContractUpdateValidator) Validate() error { } validator.TypeComparator.RootDeclIdentifier = newRootDecl.DeclarationIdentifier() + validator.TypeComparator.expectedIdentifierImportLocations = collectImports(underlyingValidator.oldProgram) + validator.TypeComparator.foundIdentifierImportLocations = collectImports(underlyingValidator.newProgram) checkDeclarationUpdatability(validator, oldRootDecl, newRootDecl) @@ -109,9 +111,8 @@ func (validator *LegacyContractUpdateValidator) idOfQualifiedType(typ *ast.Nomin typIdentifier := typ.Identifier.Identifier rootIdentifier := validator.TypeComparator.RootDeclIdentifier.Identifier - if typIdentifier != rootIdentifier { - // TODO: add and test support for qualifying types imported from other contracts - // && validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier] == nil + if typIdentifier != rootIdentifier && + validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier] == nil { qualifiedString = fmt.Sprintf("%s.%s", rootIdentifier, qualifiedString) } From 9ec4da0a7c6b600842a3a49fe2f6666e3390ce06 Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 7 Feb 2024 12:24:57 -0500 Subject: [PATCH 12/17] fix compile --- runtime/stdlib/account.go | 1 + runtime/stdlib/contract_update_validation.go | 13 +++++++++---- .../legacy_contract_upgrade_validation_test.go | 10 +++++++++- runtime/stdlib/legacy_contract_upgrade_validator.go | 11 ++++++++--- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index edb25dafc..e3334f702 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -1621,6 +1621,7 @@ func changeAccountContracts( validator = NewLegacyContractUpdateValidator( location, contractName, + handler, oldProgram, program.Program, program.Elaboration, diff --git a/runtime/stdlib/contract_update_validation.go b/runtime/stdlib/contract_update_validation.go index b95cf5b69..f90a27ba5 100644 --- a/runtime/stdlib/contract_update_validation.go +++ b/runtime/stdlib/contract_update_validation.go @@ -37,6 +37,7 @@ type UpdateValidator interface { setCurrentDeclaration(ast.Declaration) checkField(oldField *ast.FieldDeclaration, newField *ast.FieldDeclaration) + getAccountContractNames(address common.Address) ([]string, error) } type ContractUpdateValidator struct { @@ -84,6 +85,10 @@ func (validator *ContractUpdateValidator) setCurrentDeclaration(decl ast.Declara validator.currentDecl = decl } +func (validator *ContractUpdateValidator) getAccountContractNames(address common.Address) ([]string, error) { + return validator.accountContractNamesProvider.GetAccountContractNames(address) +} + // Validate validates the contract update, and returns an error if it is an invalid update. func (validator *ContractUpdateValidator) Validate() error { oldRootDecl := getRootDeclaration(validator, validator.oldProgram) @@ -97,8 +102,8 @@ func (validator *ContractUpdateValidator) Validate() error { } validator.TypeComparator.RootDeclIdentifier = newRootDecl.DeclarationIdentifier() - validator.TypeComparator.expectedIdentifierImportLocations = validator.collectImports(validator.oldProgram) - validator.TypeComparator.foundIdentifierImportLocations = validator.collectImports(validator.newProgram) + validator.TypeComparator.expectedIdentifierImportLocations = collectImports(validator, validator.oldProgram) + validator.TypeComparator.foundIdentifierImportLocations = collectImports(validator, validator.newProgram) if validator.hasErrors() { return validator.getContractUpdateError() @@ -113,7 +118,7 @@ func (validator *ContractUpdateValidator) Validate() error { return nil } -func (validator *ContractUpdateValidator) collectImports(program *ast.Program) map[string]common.Location { +func collectImports(validator UpdateValidator, program *ast.Program) map[string]common.Location { importLocations := map[string]common.Location{} imports := program.ImportDeclarations() @@ -123,7 +128,7 @@ func (validator *ContractUpdateValidator) collectImports(program *ast.Program) m // if there are no identifiers given, the import covers all of them if addressLocation, isAddressLocation := importLocation.(common.AddressLocation); isAddressLocation && len(importDecl.Identifiers) == 0 { - allLocations, err := validator.accountContractNamesProvider.GetAccountContractNames(addressLocation.Address) + allLocations, err := validator.getAccountContractNames(addressLocation.Address) if err != nil { validator.report(err) } diff --git a/runtime/stdlib/legacy_contract_upgrade_validation_test.go b/runtime/stdlib/legacy_contract_upgrade_validation_test.go index 173b5fe1c..152e7e68a 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validation_test.go +++ b/runtime/stdlib/legacy_contract_upgrade_validation_test.go @@ -25,6 +25,7 @@ import ( "github.com/onflow/cadence/runtime/parser" "github.com/onflow/cadence/runtime/sema" "github.com/onflow/cadence/runtime/stdlib" + "github.com/onflow/cadence/runtime/tests/runtime_utils" "github.com/onflow/cadence/runtime/tests/utils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -50,7 +51,14 @@ func testContractUpdate(t *testing.T, oldCode string, newCode string) error { err = checker.Check() require.NoError(t, err) - upgradeValidator := stdlib.NewLegacyContractUpdateValidator(utils.TestLocation, "Test", oldProgram, newProgram, checker.Elaboration) + upgradeValidator := stdlib.NewLegacyContractUpdateValidator( + utils.TestLocation, + "Test", + // TODO: add contract name handling here once we have a way to test imported values + &runtime_utils.TestRuntimeInterface{}, + oldProgram, + newProgram, + checker.Elaboration) return upgradeValidator.Validate() } diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/legacy_contract_upgrade_validator.go index f14116ad1..0df7d0930 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validator.go +++ b/runtime/stdlib/legacy_contract_upgrade_validator.go @@ -42,12 +42,13 @@ type LegacyContractUpdateValidator struct { func NewLegacyContractUpdateValidator( location common.Location, contractName string, + provider AccountContractNamesProvider, oldProgram *ast.Program, newProgram *ast.Program, newElaboration *sema.Elaboration, ) *LegacyContractUpdateValidator { - underlyingValidator := NewContractUpdateValidator(location, contractName, oldProgram, newProgram) + underlyingValidator := NewContractUpdateValidator(location, contractName, provider, oldProgram, newProgram) return &LegacyContractUpdateValidator{ underlyingUpdateValidator: underlyingValidator, @@ -65,6 +66,10 @@ func (validator *LegacyContractUpdateValidator) setCurrentDeclaration(decl ast.D validator.underlyingUpdateValidator.setCurrentDeclaration(decl) } +func (validator *LegacyContractUpdateValidator) getAccountContractNames(address common.Address) ([]string, error) { + return validator.underlyingUpdateValidator.accountContractNamesProvider.GetAccountContractNames(address) +} + // Validate validates the contract update, and returns an error if it is an invalid update. func (validator *LegacyContractUpdateValidator) Validate() error { underlyingValidator := validator.underlyingUpdateValidator @@ -80,8 +85,8 @@ func (validator *LegacyContractUpdateValidator) Validate() error { } validator.TypeComparator.RootDeclIdentifier = newRootDecl.DeclarationIdentifier() - validator.TypeComparator.expectedIdentifierImportLocations = collectImports(underlyingValidator.oldProgram) - validator.TypeComparator.foundIdentifierImportLocations = collectImports(underlyingValidator.newProgram) + validator.TypeComparator.expectedIdentifierImportLocations = collectImports(validator, underlyingValidator.oldProgram) + validator.TypeComparator.foundIdentifierImportLocations = collectImports(validator, underlyingValidator.newProgram) checkDeclarationUpdatability(validator, oldRootDecl, newRootDecl) From 2567340f836f7c708c41906e9e1b2bf43ccf872e Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 7 Feb 2024 12:32:25 -0500 Subject: [PATCH 13/17] add tests for local qualified names --- ...legacy_contract_upgrade_validation_test.go | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/runtime/stdlib/legacy_contract_upgrade_validation_test.go b/runtime/stdlib/legacy_contract_upgrade_validation_test.go index 152e7e68a..929c0bdc4 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validation_test.go +++ b/runtime/stdlib/legacy_contract_upgrade_validation_test.go @@ -862,6 +862,39 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { require.NoError(t, err) }) + t.Run("change field type restricted type dict with qualified names", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub resource interface I {} + pub resource R:I {} + + pub var a: @{Int: R{I}} + init() { + self.a <- {0: <- create R()} + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) resource interface I {} + access(all) resource R:I {} + + access(all) var a: @{Int: Test.R} + init() { + self.a <- {0: <- create Test.R()} + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) + t.Run("change field type restricted reference type", func(t *testing.T) { t.Parallel() @@ -937,6 +970,48 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { require.NoError(t, err) }) + t.Run("change field type restricted entitled reference type with qualified types", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub resource interface I { + pub fun foo() + } + pub resource R:I { + pub fun foo() + } + + pub var a: Capability<&R{I}>? + init() { + self.a = nil + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) entitlement E + access(all) resource interface I { + access(Test.E) fun foo() + } + access(all) resource R:I { + access(Test.E) fun foo() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + + require.NoError(t, err) + }) + t.Run("change field type restricted entitled reference type with too many granted entitlements", func(t *testing.T) { t.Parallel() From ffba5e1f8f782c60fd4c3e950239d84dd0beb92c Mon Sep 17 00:00:00 2001 From: Daniel Sainati Date: Wed, 7 Feb 2024 13:17:59 -0500 Subject: [PATCH 14/17] add tests and fix importing --- runtime/contract_update_test.go | 3 + runtime/interpreter/interpreter.go | 18 ++ runtime/stdlib/account.go | 6 +- ...legacy_contract_upgrade_validation_test.go | 184 +++++++++++++++++- .../legacy_contract_upgrade_validator.go | 29 +-- 5 files changed, 224 insertions(+), 16 deletions(-) diff --git a/runtime/contract_update_test.go b/runtime/contract_update_test.go index 72a505427..cd11e2d43 100644 --- a/runtime/contract_update_test.go +++ b/runtime/contract_update_test.go @@ -716,6 +716,9 @@ func TestRuntimeLegacyContractUpdate(t *testing.T) { OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { return accountCodes[location], nil }, + OnGetAccountContractNames: func(_ Address) ([]string, error) { + return []string{"Foo"}, nil + }, OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { accountCodes[location] = code return nil diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index bf4da92c1..48ef579e1 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -4607,6 +4607,24 @@ func (interpreter *Interpreter) getElaboration(location common.Location) *sema.E return subInterpreter.Program.Elaboration } +func (interpreter *Interpreter) AllElaborations() (elaborations map[common.Location]*sema.Elaboration) { + + elaborations = map[common.Location]*sema.Elaboration{} + + // Ensure the program for this location is loaded, + // so its checker is available + + for location, _ := range interpreter.SharedState.allInterpreters { + subInterpreter := interpreter.EnsureLoaded(location) + if subInterpreter == nil || subInterpreter.Program == nil { + return nil + } + elaborations[location] = subInterpreter.Program.Elaboration + } + + return +} + // GetContractComposite gets the composite value of the contract at the address location. func (interpreter *Interpreter) GetContractComposite(contractLocation common.AddressLocation) (*CompositeValue, error) { contractGlobal := interpreter.Globals.Get(contractLocation.Name) diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index e3334f702..b68199a20 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -1586,6 +1586,8 @@ func changeAccountContracts( // Validate the contract update + inter := invocation.Interpreter + if isUpdate { oldCode, err := handler.GetAccountContractCode(location) handleContractUpdateError(err) @@ -1624,7 +1626,7 @@ func changeAccountContracts( handler, oldProgram, program.Program, - program.Elaboration, + inter.AllElaborations(), ) } else { validator = NewContractUpdateValidator( @@ -1639,8 +1641,6 @@ func changeAccountContracts( handleContractUpdateError(err) } - inter := invocation.Interpreter - err = updateAccountContractCode( handler, location, diff --git a/runtime/stdlib/legacy_contract_upgrade_validation_test.go b/runtime/stdlib/legacy_contract_upgrade_validation_test.go index 929c0bdc4..e78361b5b 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validation_test.go +++ b/runtime/stdlib/legacy_contract_upgrade_validation_test.go @@ -21,6 +21,9 @@ package stdlib_test import ( "testing" + "github.com/onflow/cadence/runtime" + "github.com/onflow/cadence/runtime/ast" + "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/old_parser" "github.com/onflow/cadence/runtime/parser" "github.com/onflow/cadence/runtime/sema" @@ -54,11 +57,71 @@ func testContractUpdate(t *testing.T, oldCode string, newCode string) error { upgradeValidator := stdlib.NewLegacyContractUpdateValidator( utils.TestLocation, "Test", - // TODO: add contract name handling here once we have a way to test imported values &runtime_utils.TestRuntimeInterface{}, oldProgram, newProgram, - checker.Elaboration) + map[common.Location]*sema.Elaboration{ + utils.TestLocation: checker.Elaboration, + }) + return upgradeValidator.Validate() +} + +func testContractUpdateWithImports(t *testing.T, oldCode, oldImport string, newCode, newImport string) error { + oldProgram, err := old_parser.ParseProgram(nil, []byte(oldCode), old_parser.Config{}) + require.NoError(t, err) + + newProgram, err := parser.ParseProgram(nil, []byte(newCode), parser.Config{}) + require.NoError(t, err) + + newImportedProgram, err := parser.ParseProgram(nil, []byte(newImport), parser.Config{}) + require.NoError(t, err) + + importedChecker, err := sema.NewChecker( + newImportedProgram, + utils.ImportedLocation, + nil, + &sema.Config{ + AccessCheckMode: sema.AccessCheckModeStrict, + AttachmentsEnabled: true, + }, + ) + + require.NoError(t, err) + err = importedChecker.Check() + require.NoError(t, err) + + checker, err := sema.NewChecker( + newProgram, + utils.TestLocation, + nil, + &sema.Config{ + AccessCheckMode: sema.AccessCheckModeStrict, + ImportHandler: func(_ *sema.Checker, _ common.Location, _ ast.Range) (sema.Import, error) { + return sema.ElaborationImport{ + Elaboration: importedChecker.Elaboration, + }, nil + }, + AttachmentsEnabled: true, + }) + require.NoError(t, err) + + err = checker.Check() + require.NoError(t, err) + + upgradeValidator := stdlib.NewLegacyContractUpdateValidator( + utils.TestLocation, + "Test", + &runtime_utils.TestRuntimeInterface{ + OnGetAccountContractNames: func(address runtime.Address) ([]string, error) { + return []string{"TestImport"}, nil + }, + }, + oldProgram, + newProgram, + map[common.Location]*sema.Elaboration{ + utils.TestLocation: checker.Elaboration, + utils.ImportedLocation: importedChecker.Elaboration, + }) return upgradeValidator.Validate() } @@ -1012,6 +1075,64 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { require.NoError(t, err) }) + t.Run("change field type restricted entitled reference type with qualified types with imports", func(t *testing.T) { + + t.Parallel() + + const oldImport = ` + pub contract TestImport { + pub resource interface I { + pub fun foo() + } + } + ` + + const oldCode = ` + import TestImport from "imported" + + pub contract Test { + pub resource R:TestImport.I { + pub fun foo() + } + + pub var a: Capability<&R{TestImport.I}>? + init() { + self.a = nil + } + } + ` + + const newImport = ` + access(all) contract TestImport { + access(all) entitlement E + access(all) resource interface I { + access(E) fun foo() + } + } + ` + + const newCode = ` + import TestImport from "imported" + + access(all) contract Test { + access(all) entitlement F + access(all) resource R: TestImport.I { + access(TestImport.E) fun foo() {} + access(Test.F) fun bar() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdateWithImports(t, oldCode, oldImport, newCode, newImport) + + require.NoError(t, err) + }) + t.Run("change field type restricted entitled reference type with too many granted entitlements", func(t *testing.T) { t.Parallel() @@ -1057,4 +1178,63 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { cause := getSingleContractUpdateErrorCause(t, err, "Test") assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "E", "E, F") }) + + t.Run("change field type restricted entitled reference type with too many granted entitlements with imports", func(t *testing.T) { + + t.Parallel() + + const oldImport = ` + pub contract TestImport { + pub resource interface I { + pub fun foo() + } + } + ` + + const oldCode = ` + import TestImport from "imported" + + pub contract Test { + pub resource R:TestImport.I { + pub fun foo() + } + + pub var a: Capability<&R{TestImport.I}>? + init() { + self.a = nil + } + } + ` + + const newImport = ` + access(all) contract TestImport { + access(all) entitlement E + access(all) resource interface I { + access(TestImport.E) fun foo() + } + } + ` + + const newCode = ` + import TestImport from "imported" + + access(all) contract Test { + access(all) entitlement F + access(all) resource R: TestImport.I { + access(TestImport.E) fun foo() {} + access(Test.F) fun bar() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdateWithImports(t, oldCode, oldImport, newCode, newImport) + + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "E", "E, F") + }) } diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/legacy_contract_upgrade_validator.go index 0df7d0930..05674acc7 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validator.go +++ b/runtime/stdlib/legacy_contract_upgrade_validator.go @@ -31,7 +31,7 @@ import ( type LegacyContractUpdateValidator struct { TypeComparator - newElaboration *sema.Elaboration + newElaborations map[common.Location]*sema.Elaboration currentRestrictedTypeUpgradeRestrictions []*ast.NominalType underlyingUpdateValidator *ContractUpdateValidator @@ -45,14 +45,14 @@ func NewLegacyContractUpdateValidator( provider AccountContractNamesProvider, oldProgram *ast.Program, newProgram *ast.Program, - newElaboration *sema.Elaboration, + newElaborations map[common.Location]*sema.Elaboration, ) *LegacyContractUpdateValidator { underlyingValidator := NewContractUpdateValidator(location, contractName, provider, oldProgram, newProgram) return &LegacyContractUpdateValidator{ underlyingUpdateValidator: underlyingValidator, - newElaboration: newElaboration, + newElaborations: newElaborations, } } @@ -101,7 +101,7 @@ func (validator *LegacyContractUpdateValidator) report(err error) { validator.underlyingUpdateValidator.report(err) } -func (validator *LegacyContractUpdateValidator) idOfQualifiedType(typ *ast.NominalType) common.TypeID { +func (validator *LegacyContractUpdateValidator) idAndLocationOfQualifiedType(typ *ast.NominalType) (common.TypeID, common.Location) { qualifiedString := typ.String() @@ -115,18 +115,25 @@ func (validator *LegacyContractUpdateValidator) idOfQualifiedType(typ *ast.Nomin // and in 1 and 2 we don't need to do anything typIdentifier := typ.Identifier.Identifier rootIdentifier := validator.TypeComparator.RootDeclIdentifier.Identifier + location := validator.underlyingUpdateValidator.location if typIdentifier != rootIdentifier && validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier] == nil { qualifiedString = fmt.Sprintf("%s.%s", rootIdentifier, qualifiedString) + return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location } - return common.NewTypeIDFromQualifiedName(nil, validator.underlyingUpdateValidator.location, qualifiedString) + + if loc := validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier]; loc != nil { + location = loc + } + + return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location } func (validator *LegacyContractUpdateValidator) getEntitlementType(entitlement *ast.NominalType) *sema.EntitlementType { - typeID := validator.idOfQualifiedType(entitlement) - return validator.newElaboration.EntitlementType(typeID) + typeID, location := validator.idAndLocationOfQualifiedType(entitlement) + return validator.newElaborations[location].EntitlementType(typeID) } func (validator *LegacyContractUpdateValidator) getEntitlementSetAccess(entitlementSet ast.EntitlementSet) sema.EntitlementSetAccess { @@ -145,13 +152,13 @@ func (validator *LegacyContractUpdateValidator) getEntitlementSetAccess(entitlem } func (validator *LegacyContractUpdateValidator) getCompositeType(composite *ast.NominalType) *sema.CompositeType { - typeID := validator.idOfQualifiedType(composite) - return validator.newElaboration.CompositeType(typeID) + typeID, location := validator.idAndLocationOfQualifiedType(composite) + return validator.newElaborations[location].CompositeType(typeID) } func (validator *LegacyContractUpdateValidator) getInterfaceType(intf *ast.NominalType) *sema.InterfaceType { - typeID := validator.idOfQualifiedType(intf) - return validator.newElaboration.InterfaceType(typeID) + typeID, location := validator.idAndLocationOfQualifiedType(intf) + return validator.newElaborations[location].InterfaceType(typeID) } func (validator *LegacyContractUpdateValidator) getIntersectedInterfaces(intersection []*ast.NominalType) (intfs []*sema.InterfaceType) { From ecb5e3c6ccb0c40fc32381a92fe039fc94608b25 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Tue, 13 Feb 2024 20:40:37 +0530 Subject: [PATCH 15/17] Refactor --- runtime/interpreter/interpreter.go | 2 +- runtime/stdlib/account.go | 2 +- ...to_v1_contract_upgrade_validation_test.go} | 35 ++++++++++++-- ...v0.42_to_v1_contract_upgrade_validator.go} | 48 +++++++++---------- 4 files changed, 56 insertions(+), 31 deletions(-) rename runtime/stdlib/{legacy_contract_upgrade_validation_test.go => cadence_v0.42_to_v1_contract_upgrade_validation_test.go} (97%) rename runtime/stdlib/{legacy_contract_upgrade_validator.go => cadence_v0.42_to_v1_contract_upgrade_validator.go} (83%) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 48ef579e1..5eb26ff72 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -4614,7 +4614,7 @@ func (interpreter *Interpreter) AllElaborations() (elaborations map[common.Locat // Ensure the program for this location is loaded, // so its checker is available - for location, _ := range interpreter.SharedState.allInterpreters { + for location := range interpreter.SharedState.allInterpreters { //nolint:maprange subInterpreter := interpreter.EnsureLoaded(location) if subInterpreter == nil || subInterpreter.Program == nil { return nil diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index b68199a20..5e9f829a8 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -1620,7 +1620,7 @@ func changeAccountContracts( var validator UpdateValidator if legacyContractUpgrade { - validator = NewLegacyContractUpdateValidator( + validator = NewCadenceV042ToV1ContractUpdateValidator( location, contractName, handler, diff --git a/runtime/stdlib/legacy_contract_upgrade_validation_test.go b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go similarity index 97% rename from runtime/stdlib/legacy_contract_upgrade_validation_test.go rename to runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go index e78361b5b..e4968d9b8 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validation_test.go +++ b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go @@ -21,6 +21,9 @@ package stdlib_test import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/onflow/cadence/runtime" "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" @@ -30,8 +33,6 @@ import ( "github.com/onflow/cadence/runtime/stdlib" "github.com/onflow/cadence/runtime/tests/runtime_utils" "github.com/onflow/cadence/runtime/tests/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func testContractUpdate(t *testing.T, oldCode string, newCode string) error { @@ -54,7 +55,7 @@ func testContractUpdate(t *testing.T, oldCode string, newCode string) error { err = checker.Check() require.NoError(t, err) - upgradeValidator := stdlib.NewLegacyContractUpdateValidator( + upgradeValidator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( utils.TestLocation, "Test", &runtime_utils.TestRuntimeInterface{}, @@ -108,7 +109,7 @@ func testContractUpdateWithImports(t *testing.T, oldCode, oldImport string, newC err = checker.Check() require.NoError(t, err) - upgradeValidator := stdlib.NewLegacyContractUpdateValidator( + upgradeValidator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( utils.TestLocation, "Test", &runtime_utils.TestRuntimeInterface{ @@ -217,7 +218,6 @@ func TestContractUpgradeFieldAccess(t *testing.T) { ` err := testContractUpdate(t, oldCode, newCode) - require.NoError(t, err) }) @@ -244,7 +244,32 @@ func TestContractUpgradeFieldAccess(t *testing.T) { ` err := testContractUpdate(t, oldCode, newCode) + require.NoError(t, err) + }) + t.Run("change field access to self", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + pub contract Test { + pub var a: Int + init() { + self.a = 0 + } + } + ` + + const newCode = ` + access(all) contract Test { + access(self) var a: Int + init() { + self.a = 0 + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) require.NoError(t, err) }) } diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go similarity index 83% rename from runtime/stdlib/legacy_contract_upgrade_validator.go rename to runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go index 05674acc7..b1beda402 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validator.go +++ b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go @@ -28,7 +28,7 @@ import ( "github.com/onflow/cadence/runtime/sema" ) -type LegacyContractUpdateValidator struct { +type CadenceV042ToV1ContractUpdateValidator struct { TypeComparator newElaborations map[common.Location]*sema.Elaboration @@ -37,41 +37,41 @@ type LegacyContractUpdateValidator struct { underlyingUpdateValidator *ContractUpdateValidator } -// NewContractUpdateValidator initializes and returns a validator, without performing any validation. +// NewCadenceV042ToV1ContractUpdateValidator initializes and returns a validator, without performing any validation. // Invoke the `Validate()` method of the validator returned, to start validating the contract. -func NewLegacyContractUpdateValidator( +func NewCadenceV042ToV1ContractUpdateValidator( location common.Location, contractName string, provider AccountContractNamesProvider, oldProgram *ast.Program, newProgram *ast.Program, newElaborations map[common.Location]*sema.Elaboration, -) *LegacyContractUpdateValidator { +) *CadenceV042ToV1ContractUpdateValidator { underlyingValidator := NewContractUpdateValidator(location, contractName, provider, oldProgram, newProgram) - return &LegacyContractUpdateValidator{ + return &CadenceV042ToV1ContractUpdateValidator{ underlyingUpdateValidator: underlyingValidator, newElaborations: newElaborations, } } -var _ UpdateValidator = &LegacyContractUpdateValidator{} +var _ UpdateValidator = &CadenceV042ToV1ContractUpdateValidator{} -func (validator *LegacyContractUpdateValidator) getCurrentDeclaration() ast.Declaration { +func (validator *CadenceV042ToV1ContractUpdateValidator) getCurrentDeclaration() ast.Declaration { return validator.underlyingUpdateValidator.getCurrentDeclaration() } -func (validator *LegacyContractUpdateValidator) setCurrentDeclaration(decl ast.Declaration) { +func (validator *CadenceV042ToV1ContractUpdateValidator) setCurrentDeclaration(decl ast.Declaration) { validator.underlyingUpdateValidator.setCurrentDeclaration(decl) } -func (validator *LegacyContractUpdateValidator) getAccountContractNames(address common.Address) ([]string, error) { +func (validator *CadenceV042ToV1ContractUpdateValidator) getAccountContractNames(address common.Address) ([]string, error) { return validator.underlyingUpdateValidator.accountContractNamesProvider.GetAccountContractNames(address) } // Validate validates the contract update, and returns an error if it is an invalid update. -func (validator *LegacyContractUpdateValidator) Validate() error { +func (validator *CadenceV042ToV1ContractUpdateValidator) Validate() error { underlyingValidator := validator.underlyingUpdateValidator oldRootDecl := getRootDeclaration(validator, underlyingValidator.oldProgram) @@ -97,11 +97,11 @@ func (validator *LegacyContractUpdateValidator) Validate() error { return nil } -func (validator *LegacyContractUpdateValidator) report(err error) { +func (validator *CadenceV042ToV1ContractUpdateValidator) report(err error) { validator.underlyingUpdateValidator.report(err) } -func (validator *LegacyContractUpdateValidator) idAndLocationOfQualifiedType(typ *ast.NominalType) (common.TypeID, common.Location) { +func (validator *CadenceV042ToV1ContractUpdateValidator) idAndLocationOfQualifiedType(typ *ast.NominalType) (common.TypeID, common.Location) { qualifiedString := typ.String() @@ -111,7 +111,7 @@ func (validator *LegacyContractUpdateValidator) idAndLocationOfQualifiedType(typ // 2) type qualified by the root declaration (e.g `C.R` where `C` is the root contract or contract interface of the new contract) // 3) unqualified type (e.g. `R`, but declared inside `C`) // - // in case 3, we prepend the root declaration identifer with a `.` to the type's string to get its qualified name, + // in case 3, we prepend the root declaration identifier with a `.` to the type's string to get its qualified name, // and in 1 and 2 we don't need to do anything typIdentifier := typ.Identifier.Identifier rootIdentifier := validator.TypeComparator.RootDeclIdentifier.Identifier @@ -131,12 +131,12 @@ func (validator *LegacyContractUpdateValidator) idAndLocationOfQualifiedType(typ return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location } -func (validator *LegacyContractUpdateValidator) getEntitlementType(entitlement *ast.NominalType) *sema.EntitlementType { +func (validator *CadenceV042ToV1ContractUpdateValidator) getEntitlementType(entitlement *ast.NominalType) *sema.EntitlementType { typeID, location := validator.idAndLocationOfQualifiedType(entitlement) return validator.newElaborations[location].EntitlementType(typeID) } -func (validator *LegacyContractUpdateValidator) getEntitlementSetAccess(entitlementSet ast.EntitlementSet) sema.EntitlementSetAccess { +func (validator *CadenceV042ToV1ContractUpdateValidator) getEntitlementSetAccess(entitlementSet ast.EntitlementSet) sema.EntitlementSetAccess { var entitlements []*sema.EntitlementType for _, entitlement := range entitlementSet.Entitlements() { @@ -151,24 +151,24 @@ func (validator *LegacyContractUpdateValidator) getEntitlementSetAccess(entitlem return sema.NewEntitlementSetAccess(entitlements, entitlementSetKind) } -func (validator *LegacyContractUpdateValidator) getCompositeType(composite *ast.NominalType) *sema.CompositeType { +func (validator *CadenceV042ToV1ContractUpdateValidator) getCompositeType(composite *ast.NominalType) *sema.CompositeType { typeID, location := validator.idAndLocationOfQualifiedType(composite) return validator.newElaborations[location].CompositeType(typeID) } -func (validator *LegacyContractUpdateValidator) getInterfaceType(intf *ast.NominalType) *sema.InterfaceType { +func (validator *CadenceV042ToV1ContractUpdateValidator) getInterfaceType(intf *ast.NominalType) *sema.InterfaceType { typeID, location := validator.idAndLocationOfQualifiedType(intf) return validator.newElaborations[location].InterfaceType(typeID) } -func (validator *LegacyContractUpdateValidator) getIntersectedInterfaces(intersection []*ast.NominalType) (intfs []*sema.InterfaceType) { +func (validator *CadenceV042ToV1ContractUpdateValidator) getIntersectedInterfaces(intersection []*ast.NominalType) (intfs []*sema.InterfaceType) { for _, intf := range intersection { intfs = append(intfs, validator.getInterfaceType(intf)) } return } -func (validator *LegacyContractUpdateValidator) requirePermitsAccess( +func (validator *CadenceV042ToV1ContractUpdateValidator) requirePermitsAccess( expected sema.Access, found sema.EntitlementSetAccess, foundType ast.Type, @@ -183,7 +183,7 @@ func (validator *LegacyContractUpdateValidator) requirePermitsAccess( return nil } -func (validator *LegacyContractUpdateValidator) expectedAuthorizationOfComposite(composite *ast.NominalType) sema.Access { +func (validator *CadenceV042ToV1ContractUpdateValidator) expectedAuthorizationOfComposite(composite *ast.NominalType) sema.Access { // if this field is set, we are currently upgrading a formerly legacy restricted type into a reference to a composite // in this case, the expected entitlements are based not on the underlying composite type, // but instead the types previously in the restriction set @@ -201,7 +201,7 @@ func (validator *LegacyContractUpdateValidator) expectedAuthorizationOfComposite return sema.NewAccessFromEntitlementSet(supportedEntitlements, sema.Conjunction) } -func (validator *LegacyContractUpdateValidator) expectedAuthorizationOfIntersection(intersectionTypes []*ast.NominalType) sema.Access { +func (validator *CadenceV042ToV1ContractUpdateValidator) expectedAuthorizationOfIntersection(intersectionTypes []*ast.NominalType) sema.Access { // a reference to an intersection (or restricted) type is granted entitlements based on the intersected interfaces, // ignoring the legacy restricted type, as an intersection type appearing in the new contract means it must have originally @@ -217,7 +217,7 @@ func (validator *LegacyContractUpdateValidator) expectedAuthorizationOfIntersect return sema.NewAccessFromEntitlementSet(supportedEntitlements, sema.Conjunction) } -func (validator *LegacyContractUpdateValidator) checkEntitlementsUpgrade( +func (validator *CadenceV042ToV1ContractUpdateValidator) checkEntitlementsUpgrade( oldType *ast.ReferenceType, newType *ast.ReferenceType, ) error { @@ -244,7 +244,7 @@ func (validator *LegacyContractUpdateValidator) checkEntitlementsUpgrade( return nil } -func (validator *LegacyContractUpdateValidator) checkTypeUpgradability(oldType ast.Type, newType ast.Type) error { +func (validator *CadenceV042ToV1ContractUpdateValidator) checkTypeUpgradability(oldType ast.Type, newType ast.Type) error { switch oldType := oldType.(type) { case *ast.OptionalType: @@ -320,7 +320,7 @@ func (validator *LegacyContractUpdateValidator) checkTypeUpgradability(oldType a } -func (validator *LegacyContractUpdateValidator) checkField(oldField *ast.FieldDeclaration, newField *ast.FieldDeclaration) { +func (validator *CadenceV042ToV1ContractUpdateValidator) checkField(oldField *ast.FieldDeclaration, newField *ast.FieldDeclaration) { oldType := oldField.TypeAnnotation.Type newType := newField.TypeAnnotation.Type From 67411cb7fe9f604f6199f0c9124359be87a72990 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Wed, 14 Feb 2024 20:00:54 +0530 Subject: [PATCH 16/17] Revert allowing to remove interface conformance --- runtime/contract_update_validation_test.go | 5 ++- runtime/stdlib/contract_update_validation.go | 47 +++++++++++++------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/runtime/contract_update_validation_test.go b/runtime/contract_update_validation_test.go index d66b07f89..5834a8d76 100644 --- a/runtime/contract_update_validation_test.go +++ b/runtime/contract_update_validation_test.go @@ -2558,7 +2558,10 @@ func TestRuntimeContractUpdateConformanceChanges(t *testing.T) { ` err := testDeployAndUpdate(t, "Test", oldCode, newCode) - require.NoError(t, err) + RequireError(t, err) + + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertConformanceMismatchError(t, cause, "Foo") }) t.Run("Change conformance order", func(t *testing.T) { diff --git a/runtime/stdlib/contract_update_validation.go b/runtime/stdlib/contract_update_validation.go index f90a27ba5..4eb0e0333 100644 --- a/runtime/stdlib/contract_update_validation.go +++ b/runtime/stdlib/contract_update_validation.go @@ -416,29 +416,46 @@ func checkConformance( newDecl *ast.CompositeDeclaration, ) { - // at this point the declaration kinds are known to be the same - if oldDecl.DeclarationKind() != common.DeclarationKindEnum { - return - } - // Here it is assumed enums will always have one and only one conformance. // This is enforced by the checker. // Therefore, below check for multiple conformances is only applicable // for non-enum type composite declarations. i.e: structs, resources, etc. - oldConformance := oldDecl.Conformances[0] - newConformance := newDecl.Conformances[0] + oldConformances := oldDecl.Conformances + newConformances := newDecl.Conformances + + // All the existing conformances must have a match. Order is not important. + // Having extra new conformance is OK. See: https://github.com/onflow/cadence/issues/1394 + + // Note: Removing a conformance is NOT OK. That could lead to type-safety issues. + // e.g: + // - Someone stores an array of type `[{I}]` with `T:I` objects inside. + // - Later T’s conformance to `I` is removed. + // - Now `[{I}]` contains objects if `T` that does not conform to `I`. + + for _, oldConformance := range oldConformances { + found := false + for index, newConformance := range newConformances { + err := oldConformance.CheckEqual(newConformance, validator) + if err == nil { + found = true + + // Remove the matched conformance, so we don't have to check it again. + // i.e: optimization + newConformances = append(newConformances[:index], newConformances[index+1:]...) + break + } + } - err := oldConformance.CheckEqual(newConformance, validator) + if !found { + validator.report(&ConformanceMismatchError{ + DeclName: newDecl.Identifier.Identifier, + Range: ast.NewUnmeteredRangeFromPositioned(newDecl.Identifier), + }) - if err == nil { - return + return + } } - - validator.report(&ConformanceMismatchError{ - DeclName: newDecl.Identifier.Identifier, - Range: ast.NewUnmeteredRangeFromPositioned(newDecl.Identifier), - }) } func (validator *ContractUpdateValidator) report(err error) { From 65faf291ca63f273f63b6332601071e8430faefa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Thu, 22 Feb 2024 08:31:45 -0800 Subject: [PATCH 17/17] use getElaboration, sort locations, panic when elaboration is not available --- runtime/interpreter/interpreter.go | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 7a78de97e..0f973a34b 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -24,6 +24,7 @@ import ( "fmt" "math" "math/big" + "sort" "strconv" "time" @@ -4581,15 +4582,26 @@ func (interpreter *Interpreter) AllElaborations() (elaborations map[common.Locat elaborations = map[common.Location]*sema.Elaboration{} - // Ensure the program for this location is loaded, - // so its checker is available + allInterpreters := interpreter.SharedState.allInterpreters + + locations := make([]common.Location, 0, len(allInterpreters)) + + for location := range allInterpreters { //nolint:maprange + locations = append(locations, location) + } + + sort.Slice(locations, func(i, j int) bool { + a := locations[i] + b := locations[j] + return a.ID() < b.ID() + }) - for location := range interpreter.SharedState.allInterpreters { //nolint:maprange - subInterpreter := interpreter.EnsureLoaded(location) - if subInterpreter == nil || subInterpreter.Program == nil { - return nil + for _, location := range locations { + elaboration := interpreter.getElaboration(location) + if elaboration == nil { + panic(errors.NewUnexpectedError("missing elaboration for location %s", location)) } - elaborations[location] = subInterpreter.Program.Elaboration + elaborations[location] = elaboration } return