diff --git a/runtime/contract_update_test.go b/runtime/contract_update_test.go index 72a505427..cd11e2d43 100644 --- a/runtime/contract_update_test.go +++ b/runtime/contract_update_test.go @@ -716,6 +716,9 @@ func TestRuntimeLegacyContractUpdate(t *testing.T) { OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { return accountCodes[location], nil }, + OnGetAccountContractNames: func(_ Address) ([]string, error) { + return []string{"Foo"}, nil + }, OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { accountCodes[location] = code return nil diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index bf4da92c1..48ef579e1 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -4607,6 +4607,24 @@ func (interpreter *Interpreter) getElaboration(location common.Location) *sema.E return subInterpreter.Program.Elaboration } +func (interpreter *Interpreter) AllElaborations() (elaborations map[common.Location]*sema.Elaboration) { + + elaborations = map[common.Location]*sema.Elaboration{} + + // Ensure the program for this location is loaded, + // so its checker is available + + for location, _ := range interpreter.SharedState.allInterpreters { + subInterpreter := interpreter.EnsureLoaded(location) + if subInterpreter == nil || subInterpreter.Program == nil { + return nil + } + elaborations[location] = subInterpreter.Program.Elaboration + } + + return +} + // GetContractComposite gets the composite value of the contract at the address location. func (interpreter *Interpreter) GetContractComposite(contractLocation common.AddressLocation) (*CompositeValue, error) { contractGlobal := interpreter.Globals.Get(contractLocation.Name) diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index e3334f702..b68199a20 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -1586,6 +1586,8 @@ func changeAccountContracts( // Validate the contract update + inter := invocation.Interpreter + if isUpdate { oldCode, err := handler.GetAccountContractCode(location) handleContractUpdateError(err) @@ -1624,7 +1626,7 @@ func changeAccountContracts( handler, oldProgram, program.Program, - program.Elaboration, + inter.AllElaborations(), ) } else { validator = NewContractUpdateValidator( @@ -1639,8 +1641,6 @@ func changeAccountContracts( handleContractUpdateError(err) } - inter := invocation.Interpreter - err = updateAccountContractCode( handler, location, diff --git a/runtime/stdlib/legacy_contract_upgrade_validation_test.go b/runtime/stdlib/legacy_contract_upgrade_validation_test.go index 929c0bdc4..e78361b5b 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validation_test.go +++ b/runtime/stdlib/legacy_contract_upgrade_validation_test.go @@ -21,6 +21,9 @@ package stdlib_test import ( "testing" + "github.com/onflow/cadence/runtime" + "github.com/onflow/cadence/runtime/ast" + "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/old_parser" "github.com/onflow/cadence/runtime/parser" "github.com/onflow/cadence/runtime/sema" @@ -54,11 +57,71 @@ func testContractUpdate(t *testing.T, oldCode string, newCode string) error { upgradeValidator := stdlib.NewLegacyContractUpdateValidator( utils.TestLocation, "Test", - // TODO: add contract name handling here once we have a way to test imported values &runtime_utils.TestRuntimeInterface{}, oldProgram, newProgram, - checker.Elaboration) + map[common.Location]*sema.Elaboration{ + utils.TestLocation: checker.Elaboration, + }) + return upgradeValidator.Validate() +} + +func testContractUpdateWithImports(t *testing.T, oldCode, oldImport string, newCode, newImport string) error { + oldProgram, err := old_parser.ParseProgram(nil, []byte(oldCode), old_parser.Config{}) + require.NoError(t, err) + + newProgram, err := parser.ParseProgram(nil, []byte(newCode), parser.Config{}) + require.NoError(t, err) + + newImportedProgram, err := parser.ParseProgram(nil, []byte(newImport), parser.Config{}) + require.NoError(t, err) + + importedChecker, err := sema.NewChecker( + newImportedProgram, + utils.ImportedLocation, + nil, + &sema.Config{ + AccessCheckMode: sema.AccessCheckModeStrict, + AttachmentsEnabled: true, + }, + ) + + require.NoError(t, err) + err = importedChecker.Check() + require.NoError(t, err) + + checker, err := sema.NewChecker( + newProgram, + utils.TestLocation, + nil, + &sema.Config{ + AccessCheckMode: sema.AccessCheckModeStrict, + ImportHandler: func(_ *sema.Checker, _ common.Location, _ ast.Range) (sema.Import, error) { + return sema.ElaborationImport{ + Elaboration: importedChecker.Elaboration, + }, nil + }, + AttachmentsEnabled: true, + }) + require.NoError(t, err) + + err = checker.Check() + require.NoError(t, err) + + upgradeValidator := stdlib.NewLegacyContractUpdateValidator( + utils.TestLocation, + "Test", + &runtime_utils.TestRuntimeInterface{ + OnGetAccountContractNames: func(address runtime.Address) ([]string, error) { + return []string{"TestImport"}, nil + }, + }, + oldProgram, + newProgram, + map[common.Location]*sema.Elaboration{ + utils.TestLocation: checker.Elaboration, + utils.ImportedLocation: importedChecker.Elaboration, + }) return upgradeValidator.Validate() } @@ -1012,6 +1075,64 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { require.NoError(t, err) }) + t.Run("change field type restricted entitled reference type with qualified types with imports", func(t *testing.T) { + + t.Parallel() + + const oldImport = ` + pub contract TestImport { + pub resource interface I { + pub fun foo() + } + } + ` + + const oldCode = ` + import TestImport from "imported" + + pub contract Test { + pub resource R:TestImport.I { + pub fun foo() + } + + pub var a: Capability<&R{TestImport.I}>? + init() { + self.a = nil + } + } + ` + + const newImport = ` + access(all) contract TestImport { + access(all) entitlement E + access(all) resource interface I { + access(E) fun foo() + } + } + ` + + const newCode = ` + import TestImport from "imported" + + access(all) contract Test { + access(all) entitlement F + access(all) resource R: TestImport.I { + access(TestImport.E) fun foo() {} + access(Test.F) fun bar() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdateWithImports(t, oldCode, oldImport, newCode, newImport) + + require.NoError(t, err) + }) + t.Run("change field type restricted entitled reference type with too many granted entitlements", func(t *testing.T) { t.Parallel() @@ -1057,4 +1178,63 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) { cause := getSingleContractUpdateErrorCause(t, err, "Test") assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "E", "E, F") }) + + t.Run("change field type restricted entitled reference type with too many granted entitlements with imports", func(t *testing.T) { + + t.Parallel() + + const oldImport = ` + pub contract TestImport { + pub resource interface I { + pub fun foo() + } + } + ` + + const oldCode = ` + import TestImport from "imported" + + pub contract Test { + pub resource R:TestImport.I { + pub fun foo() + } + + pub var a: Capability<&R{TestImport.I}>? + init() { + self.a = nil + } + } + ` + + const newImport = ` + access(all) contract TestImport { + access(all) entitlement E + access(all) resource interface I { + access(TestImport.E) fun foo() + } + } + ` + + const newCode = ` + import TestImport from "imported" + + access(all) contract Test { + access(all) entitlement F + access(all) resource R: TestImport.I { + access(TestImport.E) fun foo() {} + access(Test.F) fun bar() {} + } + + access(all) var a: Capability? + init() { + self.a = nil + } + } + ` + + err := testContractUpdateWithImports(t, oldCode, oldImport, newCode, newImport) + + cause := getSingleContractUpdateErrorCause(t, err, "Test") + assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "E", "E, F") + }) } diff --git a/runtime/stdlib/legacy_contract_upgrade_validator.go b/runtime/stdlib/legacy_contract_upgrade_validator.go index 0df7d0930..05674acc7 100644 --- a/runtime/stdlib/legacy_contract_upgrade_validator.go +++ b/runtime/stdlib/legacy_contract_upgrade_validator.go @@ -31,7 +31,7 @@ import ( type LegacyContractUpdateValidator struct { TypeComparator - newElaboration *sema.Elaboration + newElaborations map[common.Location]*sema.Elaboration currentRestrictedTypeUpgradeRestrictions []*ast.NominalType underlyingUpdateValidator *ContractUpdateValidator @@ -45,14 +45,14 @@ func NewLegacyContractUpdateValidator( provider AccountContractNamesProvider, oldProgram *ast.Program, newProgram *ast.Program, - newElaboration *sema.Elaboration, + newElaborations map[common.Location]*sema.Elaboration, ) *LegacyContractUpdateValidator { underlyingValidator := NewContractUpdateValidator(location, contractName, provider, oldProgram, newProgram) return &LegacyContractUpdateValidator{ underlyingUpdateValidator: underlyingValidator, - newElaboration: newElaboration, + newElaborations: newElaborations, } } @@ -101,7 +101,7 @@ func (validator *LegacyContractUpdateValidator) report(err error) { validator.underlyingUpdateValidator.report(err) } -func (validator *LegacyContractUpdateValidator) idOfQualifiedType(typ *ast.NominalType) common.TypeID { +func (validator *LegacyContractUpdateValidator) idAndLocationOfQualifiedType(typ *ast.NominalType) (common.TypeID, common.Location) { qualifiedString := typ.String() @@ -115,18 +115,25 @@ func (validator *LegacyContractUpdateValidator) idOfQualifiedType(typ *ast.Nomin // and in 1 and 2 we don't need to do anything typIdentifier := typ.Identifier.Identifier rootIdentifier := validator.TypeComparator.RootDeclIdentifier.Identifier + location := validator.underlyingUpdateValidator.location if typIdentifier != rootIdentifier && validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier] == nil { qualifiedString = fmt.Sprintf("%s.%s", rootIdentifier, qualifiedString) + return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location } - return common.NewTypeIDFromQualifiedName(nil, validator.underlyingUpdateValidator.location, qualifiedString) + + if loc := validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier]; loc != nil { + location = loc + } + + return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location } func (validator *LegacyContractUpdateValidator) getEntitlementType(entitlement *ast.NominalType) *sema.EntitlementType { - typeID := validator.idOfQualifiedType(entitlement) - return validator.newElaboration.EntitlementType(typeID) + typeID, location := validator.idAndLocationOfQualifiedType(entitlement) + return validator.newElaborations[location].EntitlementType(typeID) } func (validator *LegacyContractUpdateValidator) getEntitlementSetAccess(entitlementSet ast.EntitlementSet) sema.EntitlementSetAccess { @@ -145,13 +152,13 @@ func (validator *LegacyContractUpdateValidator) getEntitlementSetAccess(entitlem } func (validator *LegacyContractUpdateValidator) getCompositeType(composite *ast.NominalType) *sema.CompositeType { - typeID := validator.idOfQualifiedType(composite) - return validator.newElaboration.CompositeType(typeID) + typeID, location := validator.idAndLocationOfQualifiedType(composite) + return validator.newElaborations[location].CompositeType(typeID) } func (validator *LegacyContractUpdateValidator) getInterfaceType(intf *ast.NominalType) *sema.InterfaceType { - typeID := validator.idOfQualifiedType(intf) - return validator.newElaboration.InterfaceType(typeID) + typeID, location := validator.idAndLocationOfQualifiedType(intf) + return validator.newElaborations[location].InterfaceType(typeID) } func (validator *LegacyContractUpdateValidator) getIntersectedInterfaces(intersection []*ast.NominalType) (intfs []*sema.InterfaceType) {