Skip to content

Commit

Permalink
add tests and fix importing
Browse files Browse the repository at this point in the history
  • Loading branch information
dsainati1 committed Feb 7, 2024
1 parent 2567340 commit ffba5e1
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 16 deletions.
3 changes: 3 additions & 0 deletions runtime/contract_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,9 @@ func TestRuntimeLegacyContractUpdate(t *testing.T) {
OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) {
return accountCodes[location], nil
},
OnGetAccountContractNames: func(_ Address) ([]string, error) {
return []string{"Foo"}, nil
},
OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error {
accountCodes[location] = code
return nil
Expand Down
18 changes: 18 additions & 0 deletions runtime/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4607,6 +4607,24 @@ func (interpreter *Interpreter) getElaboration(location common.Location) *sema.E
return subInterpreter.Program.Elaboration
}

func (interpreter *Interpreter) AllElaborations() (elaborations map[common.Location]*sema.Elaboration) {

elaborations = map[common.Location]*sema.Elaboration{}

// Ensure the program for this location is loaded,
// so its checker is available

for location, _ := range interpreter.SharedState.allInterpreters {

Check failure on line 4617 in runtime/interpreter/interpreter.go

View workflow job for this annotation

GitHub Actions / Lint

File is not `gofmt`-ed with `-s` (gofmt)
subInterpreter := interpreter.EnsureLoaded(location)
if subInterpreter == nil || subInterpreter.Program == nil {
return nil
}

Check warning on line 4621 in runtime/interpreter/interpreter.go

View check run for this annotation

Codecov / codecov/patch

runtime/interpreter/interpreter.go#L4620-L4621

Added lines #L4620 - L4621 were not covered by tests
elaborations[location] = subInterpreter.Program.Elaboration
}

return
}

// GetContractComposite gets the composite value of the contract at the address location.
func (interpreter *Interpreter) GetContractComposite(contractLocation common.AddressLocation) (*CompositeValue, error) {
contractGlobal := interpreter.Globals.Get(contractLocation.Name)
Expand Down
6 changes: 3 additions & 3 deletions runtime/stdlib/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,8 @@ func changeAccountContracts(

// Validate the contract update

inter := invocation.Interpreter

if isUpdate {
oldCode, err := handler.GetAccountContractCode(location)
handleContractUpdateError(err)
Expand Down Expand Up @@ -1624,7 +1626,7 @@ func changeAccountContracts(
handler,
oldProgram,
program.Program,
program.Elaboration,
inter.AllElaborations(),
)
} else {
validator = NewContractUpdateValidator(
Expand All @@ -1639,8 +1641,6 @@ func changeAccountContracts(
handleContractUpdateError(err)
}

inter := invocation.Interpreter

err = updateAccountContractCode(
handler,
location,
Expand Down
184 changes: 182 additions & 2 deletions runtime/stdlib/legacy_contract_upgrade_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ package stdlib_test
import (
"testing"

Check failure on line 23 in runtime/stdlib/legacy_contract_upgrade_validation_test.go

View workflow job for this annotation

GitHub Actions / Lint

File is not `goimports`-ed with -local github.com/onflow/cadence (goimports)
"github.com/onflow/cadence/runtime"
"github.com/onflow/cadence/runtime/ast"
"github.com/onflow/cadence/runtime/common"
"github.com/onflow/cadence/runtime/old_parser"
"github.com/onflow/cadence/runtime/parser"
"github.com/onflow/cadence/runtime/sema"
Expand Down Expand Up @@ -54,11 +57,71 @@ func testContractUpdate(t *testing.T, oldCode string, newCode string) error {
upgradeValidator := stdlib.NewLegacyContractUpdateValidator(
utils.TestLocation,
"Test",
// TODO: add contract name handling here once we have a way to test imported values
&runtime_utils.TestRuntimeInterface{},
oldProgram,
newProgram,
checker.Elaboration)
map[common.Location]*sema.Elaboration{
utils.TestLocation: checker.Elaboration,
})
return upgradeValidator.Validate()
}

func testContractUpdateWithImports(t *testing.T, oldCode, oldImport string, newCode, newImport string) error {
oldProgram, err := old_parser.ParseProgram(nil, []byte(oldCode), old_parser.Config{})
require.NoError(t, err)

newProgram, err := parser.ParseProgram(nil, []byte(newCode), parser.Config{})
require.NoError(t, err)

newImportedProgram, err := parser.ParseProgram(nil, []byte(newImport), parser.Config{})
require.NoError(t, err)

importedChecker, err := sema.NewChecker(
newImportedProgram,
utils.ImportedLocation,
nil,
&sema.Config{
AccessCheckMode: sema.AccessCheckModeStrict,
AttachmentsEnabled: true,
},
)

require.NoError(t, err)
err = importedChecker.Check()
require.NoError(t, err)

checker, err := sema.NewChecker(
newProgram,
utils.TestLocation,
nil,
&sema.Config{
AccessCheckMode: sema.AccessCheckModeStrict,
ImportHandler: func(_ *sema.Checker, _ common.Location, _ ast.Range) (sema.Import, error) {
return sema.ElaborationImport{
Elaboration: importedChecker.Elaboration,
}, nil
},
AttachmentsEnabled: true,
})
require.NoError(t, err)

err = checker.Check()
require.NoError(t, err)

upgradeValidator := stdlib.NewLegacyContractUpdateValidator(
utils.TestLocation,
"Test",
&runtime_utils.TestRuntimeInterface{
OnGetAccountContractNames: func(address runtime.Address) ([]string, error) {
return []string{"TestImport"}, nil
},
},
oldProgram,
newProgram,
map[common.Location]*sema.Elaboration{
utils.TestLocation: checker.Elaboration,
utils.ImportedLocation: importedChecker.Elaboration,
})
return upgradeValidator.Validate()
}

Expand Down Expand Up @@ -1012,6 +1075,64 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) {
require.NoError(t, err)
})

t.Run("change field type restricted entitled reference type with qualified types with imports", func(t *testing.T) {

t.Parallel()

const oldImport = `
pub contract TestImport {
pub resource interface I {
pub fun foo()
}
}
`

const oldCode = `
import TestImport from "imported"
pub contract Test {
pub resource R:TestImport.I {
pub fun foo()
}
pub var a: Capability<&R{TestImport.I}>?
init() {
self.a = nil
}
}
`

const newImport = `
access(all) contract TestImport {
access(all) entitlement E
access(all) resource interface I {
access(E) fun foo()
}
}
`

const newCode = `
import TestImport from "imported"
access(all) contract Test {
access(all) entitlement F
access(all) resource R: TestImport.I {
access(TestImport.E) fun foo() {}
access(Test.F) fun bar() {}
}
access(all) var a: Capability<auth(TestImport.E) &Test.R>?
init() {
self.a = nil
}
}
`

err := testContractUpdateWithImports(t, oldCode, oldImport, newCode, newImport)

require.NoError(t, err)
})

t.Run("change field type restricted entitled reference type with too many granted entitlements", func(t *testing.T) {

t.Parallel()
Expand Down Expand Up @@ -1057,4 +1178,63 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) {
cause := getSingleContractUpdateErrorCause(t, err, "Test")
assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "E", "E, F")
})

t.Run("change field type restricted entitled reference type with too many granted entitlements with imports", func(t *testing.T) {

t.Parallel()

const oldImport = `
pub contract TestImport {
pub resource interface I {
pub fun foo()
}
}
`

const oldCode = `
import TestImport from "imported"
pub contract Test {
pub resource R:TestImport.I {
pub fun foo()
}
pub var a: Capability<&R{TestImport.I}>?
init() {
self.a = nil
}
}
`

const newImport = `
access(all) contract TestImport {
access(all) entitlement E
access(all) resource interface I {
access(TestImport.E) fun foo()
}
}
`

const newCode = `
import TestImport from "imported"
access(all) contract Test {
access(all) entitlement F
access(all) resource R: TestImport.I {
access(TestImport.E) fun foo() {}
access(Test.F) fun bar() {}
}
access(all) var a: Capability<auth(TestImport.E, Test.F) &Test.R>?
init() {
self.a = nil
}
}
`

err := testContractUpdateWithImports(t, oldCode, oldImport, newCode, newImport)

cause := getSingleContractUpdateErrorCause(t, err, "Test")
assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "E", "E, F")
})
}
29 changes: 18 additions & 11 deletions runtime/stdlib/legacy_contract_upgrade_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
type LegacyContractUpdateValidator struct {
TypeComparator

newElaboration *sema.Elaboration
newElaborations map[common.Location]*sema.Elaboration
currentRestrictedTypeUpgradeRestrictions []*ast.NominalType

underlyingUpdateValidator *ContractUpdateValidator
Expand All @@ -45,14 +45,14 @@ func NewLegacyContractUpdateValidator(
provider AccountContractNamesProvider,
oldProgram *ast.Program,
newProgram *ast.Program,
newElaboration *sema.Elaboration,
newElaborations map[common.Location]*sema.Elaboration,
) *LegacyContractUpdateValidator {

underlyingValidator := NewContractUpdateValidator(location, contractName, provider, oldProgram, newProgram)

return &LegacyContractUpdateValidator{
underlyingUpdateValidator: underlyingValidator,
newElaboration: newElaboration,
newElaborations: newElaborations,
}
}

Expand Down Expand Up @@ -101,7 +101,7 @@ func (validator *LegacyContractUpdateValidator) report(err error) {
validator.underlyingUpdateValidator.report(err)
}

func (validator *LegacyContractUpdateValidator) idOfQualifiedType(typ *ast.NominalType) common.TypeID {
func (validator *LegacyContractUpdateValidator) idAndLocationOfQualifiedType(typ *ast.NominalType) (common.TypeID, common.Location) {

qualifiedString := typ.String()

Expand All @@ -115,18 +115,25 @@ func (validator *LegacyContractUpdateValidator) idOfQualifiedType(typ *ast.Nomin
// and in 1 and 2 we don't need to do anything
typIdentifier := typ.Identifier.Identifier
rootIdentifier := validator.TypeComparator.RootDeclIdentifier.Identifier
location := validator.underlyingUpdateValidator.location

if typIdentifier != rootIdentifier &&
validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier] == nil {
qualifiedString = fmt.Sprintf("%s.%s", rootIdentifier, qualifiedString)
return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location

}
return common.NewTypeIDFromQualifiedName(nil, validator.underlyingUpdateValidator.location, qualifiedString)

if loc := validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier]; loc != nil {
location = loc
}

return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location
}

func (validator *LegacyContractUpdateValidator) getEntitlementType(entitlement *ast.NominalType) *sema.EntitlementType {
typeID := validator.idOfQualifiedType(entitlement)
return validator.newElaboration.EntitlementType(typeID)
typeID, location := validator.idAndLocationOfQualifiedType(entitlement)
return validator.newElaborations[location].EntitlementType(typeID)
}

func (validator *LegacyContractUpdateValidator) getEntitlementSetAccess(entitlementSet ast.EntitlementSet) sema.EntitlementSetAccess {
Expand All @@ -145,13 +152,13 @@ func (validator *LegacyContractUpdateValidator) getEntitlementSetAccess(entitlem
}

func (validator *LegacyContractUpdateValidator) getCompositeType(composite *ast.NominalType) *sema.CompositeType {
typeID := validator.idOfQualifiedType(composite)
return validator.newElaboration.CompositeType(typeID)
typeID, location := validator.idAndLocationOfQualifiedType(composite)
return validator.newElaborations[location].CompositeType(typeID)
}

func (validator *LegacyContractUpdateValidator) getInterfaceType(intf *ast.NominalType) *sema.InterfaceType {
typeID := validator.idOfQualifiedType(intf)
return validator.newElaboration.InterfaceType(typeID)
typeID, location := validator.idAndLocationOfQualifiedType(intf)
return validator.newElaborations[location].InterfaceType(typeID)
}

func (validator *LegacyContractUpdateValidator) getIntersectedInterfaces(intersection []*ast.NominalType) (intfs []*sema.InterfaceType) {
Expand Down

0 comments on commit ffba5e1

Please sign in to comment.