diff --git a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go index e9effda9b..00ead33d8 100644 --- a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go +++ b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validation_test.go @@ -67,29 +67,40 @@ func testContractUpdate(t *testing.T, oldCode string, newCode string) error { return upgradeValidator.Validate() } -func testContractUpdateWithImports(t *testing.T, oldCode, oldImport string, newCode, newImport string) error { +func testContractUpdateWithImports( + t *testing.T, + oldCode string, + newCode string, + newImports map[common.Location]string, +) error { oldProgram, err := old_parser.ParseProgram(nil, []byte(oldCode), old_parser.Config{}) require.NoError(t, err) newProgram, err := parser.ParseProgram(nil, []byte(newCode), parser.Config{}) require.NoError(t, err) - newImportedProgram, err := parser.ParseProgram(nil, []byte(newImport), parser.Config{}) - require.NoError(t, err) + elaborations := map[common.Location]*sema.Elaboration{} - importedChecker, err := sema.NewChecker( - newImportedProgram, - utils.ImportedLocation, - nil, - &sema.Config{ - AccessCheckMode: sema.AccessCheckModeStrict, - AttachmentsEnabled: true, - }, - ) + for location, code := range newImports { + newImportedProgram, err := parser.ParseProgram(nil, []byte(code), parser.Config{}) + require.NoError(t, err) - require.NoError(t, err) - err = importedChecker.Check() - require.NoError(t, err) + importedChecker, err := sema.NewChecker( + newImportedProgram, + location, + nil, + &sema.Config{ + AccessCheckMode: sema.AccessCheckModeStrict, + AttachmentsEnabled: true, + }, + ) + + require.NoError(t, err) + err = importedChecker.Check() + require.NoError(t, err) + + elaborations[location] = importedChecker.Elaboration + } checker, err := sema.NewChecker( newProgram, @@ -97,11 +108,29 @@ func testContractUpdateWithImports(t *testing.T, oldCode, oldImport string, newC nil, &sema.Config{ AccessCheckMode: sema.AccessCheckModeStrict, - ImportHandler: func(_ *sema.Checker, _ common.Location, _ ast.Range) (sema.Import, error) { + ImportHandler: func(_ *sema.Checker, location common.Location, _ ast.Range) (sema.Import, error) { + importedElaboration := elaborations[location] return sema.ElaborationImport{ - Elaboration: importedChecker.Elaboration, + Elaboration: importedElaboration, }, nil }, + LocationHandler: func(identifiers []ast.Identifier, location common.Location) ( + locations []sema.ResolvedLocation, err error, + ) { + if addressLocation, ok := location.(common.AddressLocation); ok && len(identifiers) == 1 { + location = common.AddressLocation{ + Name: identifiers[0].Identifier, + Address: addressLocation.Address, + } + } + + locations = append(locations, sema.ResolvedLocation{ + Location: location, + Identifiers: identifiers, + }) + + return + }, AttachmentsEnabled: true, }) require.NoError(t, err) @@ -109,6 +138,8 @@ func testContractUpdateWithImports(t *testing.T, oldCode, oldImport string, newC err = checker.Check() require.NoError(t, err) + elaborations[utils.TestLocation] = checker.Elaboration + upgradeValidator := stdlib.NewCadenceV042ToV1ContractUpdateValidator( utils.TestLocation, "Test", @@ -119,10 +150,8 @@ func testContractUpdateWithImports(t *testing.T, oldCode, oldImport string, newC }, oldProgram, newProgram, - map[common.Location]*sema.Elaboration{ - utils.TestLocation: checker.Elaboration, - utils.ImportedLocation: importedChecker.Elaboration, - }) + elaborations, + ) return upgradeValidator.Validate() } @@ -528,6 +557,66 @@ func TestContractUpgradeFieldType(t *testing.T) { require.NoError(t, err) }) + + t.Run("changing to a non-storable types", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + access(all) contract Test { + access(all) struct Foo { + access(all) var a: Int + init() { + self.a = 0 + } + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) struct Foo { + access(all) var a: &Int? + init() { + self.a = nil + } + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + require.NoError(t, err) + }) + + t.Run("changing from a non-storable types", func(t *testing.T) { + + t.Parallel() + + const oldCode = ` + access(all) contract Test { + access(all) struct Foo { + access(all) var a: &Int? + init() { + self.a = nil + } + } + } + ` + + const newCode = ` + access(all) contract Test { + access(all) struct Foo { + access(all) var a: Int + init() { + self.a = 0 + } + } + } + ` + + err := testContractUpdate(t, oldCode, newCode) + require.NoError(t, err) + }) } func TestContractUpgradeIntersectionAuthorization(t *testing.T) { @@ -1166,16 +1255,8 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { t.Parallel() - const oldImport = ` - pub contract TestImport { - pub resource interface I { - pub fun foo() - } - } - ` - const oldCode = ` - import TestImport from "imported" + import TestImport from 0x01 pub contract Test { pub resource R:TestImport.I { @@ -1199,7 +1280,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { ` const newCode = ` - import TestImport from "imported" + import TestImport from 0x01 access(all) contract Test { access(all) entitlement F @@ -1215,7 +1296,17 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { } ` - err := testContractUpdateWithImports(t, oldCode, oldImport, newCode, newImport) + err := testContractUpdateWithImports( + t, + oldCode, + newCode, + map[common.Location]string{ + common.AddressLocation{ + Name: "TestImport", + Address: common.MustBytesToAddress([]byte{0x1}), + }: newImport, + }, + ) require.NoError(t, err) }) @@ -1270,16 +1361,8 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { t.Parallel() - const oldImport = ` - pub contract TestImport { - pub resource interface I { - pub fun foo() - } - } - ` - const oldCode = ` - import TestImport from "imported" + import TestImport from 0x01 pub contract Test { pub resource R:TestImport.I { @@ -1303,7 +1386,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { ` const newCode = ` - import TestImport from "imported" + import TestImport from 0x01 access(all) contract Test { access(all) entitlement F @@ -1319,7 +1402,17 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { } ` - err := testContractUpdateWithImports(t, oldCode, oldImport, newCode, newImport) + err := testContractUpdateWithImports( + t, + oldCode, + newCode, + map[common.Location]string{ + common.AddressLocation{ + Name: "TestImport", + Address: common.MustBytesToAddress([]byte{0x1}), + }: newImport, + }, + ) cause := getSingleContractUpdateErrorCause(t, err, "Test") assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "E", "E, F") diff --git a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go index d7ab8ae08..4d99ee7d7 100644 --- a/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go +++ b/runtime/stdlib/cadence_v0.42_to_v1_contract_upgrade_validator.go @@ -120,14 +120,14 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) idAndLocationOfQualifie rootIdentifier := validator.TypeComparator.RootDeclIdentifier.Identifier location := validator.underlyingUpdateValidator.location - if typIdentifier != rootIdentifier && - validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier] == nil { + foundLocations := validator.TypeComparator.foundIdentifierImportLocations + + if typIdentifier != rootIdentifier && foundLocations[typIdentifier] == nil { qualifiedString = fmt.Sprintf("%s.%s", rootIdentifier, qualifiedString) return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location - } - if loc := validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier]; loc != nil { + if loc := foundLocations[typIdentifier]; loc != nil { location = loc } @@ -138,7 +138,11 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) getEntitlementType( entitlement *ast.NominalType, ) *sema.EntitlementType { typeID, location := validator.idAndLocationOfQualifiedType(entitlement) - return validator.newElaborations[location].EntitlementType(typeID) + elaboration, ok := validator.newElaborations[location] + if !ok { + panic(errors.NewUnreachableError()) + } + return elaboration.EntitlementType(typeID) } func (validator *CadenceV042ToV1ContractUpdateValidator) getEntitlementSetAccess( @@ -346,10 +350,27 @@ typeSwitch: } } + // If the new/old type is non-storable, + // then changing the type of this field has no impact to the storage. + if isNonStorableType(oldType) || isNonStorableType(newType) { + return nil + } + return oldType.CheckEqual(newType, validator) } +func isNonStorableType(typ ast.Type) bool { + switch typ := typ.(type) { + case *ast.ReferenceType, *ast.FunctionType: + return true + case *ast.OptionalType: + return isNonStorableType(typ.Type) + default: + return false + } +} + func (validator *CadenceV042ToV1ContractUpdateValidator) checkField(oldField *ast.FieldDeclaration, newField *ast.FieldDeclaration) { oldType := oldField.TypeAnnotation.Type newType := newField.TypeAnnotation.Type diff --git a/runtime/stdlib/contract_update_validation.go b/runtime/stdlib/contract_update_validation.go index 484ec5710..a47b842b8 100644 --- a/runtime/stdlib/contract_update_validation.go +++ b/runtime/stdlib/contract_update_validation.go @@ -131,8 +131,14 @@ func collectImports(validator UpdateValidator, program *ast.Program) map[string] for _, importDecl := range imports { importLocation := importDecl.Location + addressLocation, ok := importLocation.(common.AddressLocation) + if !ok { + // e.g: Crypto + continue + } + // if there are no identifiers given, the import covers all of them - if addressLocation, isAddressLocation := importLocation.(common.AddressLocation); isAddressLocation && len(importDecl.Identifiers) == 0 { + if len(importDecl.Identifiers) == 0 { allLocations, err := validator.getAccountContractNames(addressLocation.Address) if err != nil { validator.report(err) @@ -140,13 +146,20 @@ func collectImports(validator UpdateValidator, program *ast.Program) map[string] for _, identifier := range allLocations { // associate the location of an identifier's import with the location it's being imported from // this assumes that two imports cannot have the same name, which should be prevented by the type checker - importLocations[identifier] = importLocation + importLocations[identifier] = common.AddressLocation{ + Name: identifier, + Address: addressLocation.Address, + } } } else { for _, identifier := range importDecl.Identifiers { - // associate the location of an identifier's import with the location it's being imported from - // this assumes that two imports cannot have the same name, which should be prevented by the type checker - importLocations[identifier.Identifier] = importLocation + name := identifier.Identifier + // associate the location of an identifier's import with the location it's being imported from. + // This assumes that two imports cannot have the same name, which should be prevented by the type checker + importLocations[name] = common.AddressLocation{ + Name: name, + Address: addressLocation.Address, + } } } }