Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix location collection in contract update validator #3117

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,48 +67,79 @@ func testContractUpdate(t *testing.T, oldCode string, newCode string) error {
return upgradeValidator.Validate()
}

func testContractUpdateWithImports(t *testing.T, oldCode, oldImport string, newCode, newImport string) error {
func testContractUpdateWithImports(
t *testing.T,
oldCode string,
newCode string,
newImports map[common.Location]string,
) error {
oldProgram, err := old_parser.ParseProgram(nil, []byte(oldCode), old_parser.Config{})
require.NoError(t, err)

newProgram, err := parser.ParseProgram(nil, []byte(newCode), parser.Config{})
require.NoError(t, err)

newImportedProgram, err := parser.ParseProgram(nil, []byte(newImport), parser.Config{})
require.NoError(t, err)
elaborations := map[common.Location]*sema.Elaboration{}

importedChecker, err := sema.NewChecker(
newImportedProgram,
utils.ImportedLocation,
nil,
&sema.Config{
AccessCheckMode: sema.AccessCheckModeStrict,
AttachmentsEnabled: true,
},
)
for location, code := range newImports {
newImportedProgram, err := parser.ParseProgram(nil, []byte(code), parser.Config{})
require.NoError(t, err)

require.NoError(t, err)
err = importedChecker.Check()
require.NoError(t, err)
importedChecker, err := sema.NewChecker(
newImportedProgram,
location,
nil,
&sema.Config{
AccessCheckMode: sema.AccessCheckModeStrict,
AttachmentsEnabled: true,
},
)

require.NoError(t, err)
err = importedChecker.Check()
require.NoError(t, err)

elaborations[location] = importedChecker.Elaboration
}

checker, err := sema.NewChecker(
newProgram,
utils.TestLocation,
nil,
&sema.Config{
AccessCheckMode: sema.AccessCheckModeStrict,
ImportHandler: func(_ *sema.Checker, _ common.Location, _ ast.Range) (sema.Import, error) {
ImportHandler: func(_ *sema.Checker, location common.Location, _ ast.Range) (sema.Import, error) {
importedElaboration := elaborations[location]
return sema.ElaborationImport{
Elaboration: importedChecker.Elaboration,
Elaboration: importedElaboration,
}, nil
},
LocationHandler: func(identifiers []ast.Identifier, location common.Location) (
locations []sema.ResolvedLocation, err error,
) {
if addressLocation, ok := location.(common.AddressLocation); ok && len(identifiers) == 1 {
location = common.AddressLocation{
Name: identifiers[0].Identifier,
Address: addressLocation.Address,
}
}

locations = append(locations, sema.ResolvedLocation{
Location: location,
Identifiers: identifiers,
})

return
},
AttachmentsEnabled: true,
})
require.NoError(t, err)

err = checker.Check()
require.NoError(t, err)

elaborations[utils.TestLocation] = checker.Elaboration

upgradeValidator := stdlib.NewCadenceV042ToV1ContractUpdateValidator(
utils.TestLocation,
"Test",
Expand All @@ -119,10 +150,8 @@ func testContractUpdateWithImports(t *testing.T, oldCode, oldImport string, newC
},
oldProgram,
newProgram,
map[common.Location]*sema.Elaboration{
utils.TestLocation: checker.Elaboration,
utils.ImportedLocation: importedChecker.Elaboration,
})
elaborations,
)
return upgradeValidator.Validate()
}

Expand Down Expand Up @@ -528,6 +557,66 @@ func TestContractUpgradeFieldType(t *testing.T) {

require.NoError(t, err)
})

t.Run("changing to a non-storable types", func(t *testing.T) {

t.Parallel()

const oldCode = `
access(all) contract Test {
access(all) struct Foo {
access(all) var a: Int
init() {
self.a = 0
}
}
}
`

const newCode = `
access(all) contract Test {
access(all) struct Foo {
access(all) var a: &Int?
init() {
self.a = nil
}
}
}
`

err := testContractUpdate(t, oldCode, newCode)
require.NoError(t, err)
})

t.Run("changing from a non-storable types", func(t *testing.T) {

t.Parallel()

const oldCode = `
access(all) contract Test {
access(all) struct Foo {
access(all) var a: &Int?
init() {
self.a = nil
}
}
}
`

const newCode = `
access(all) contract Test {
access(all) struct Foo {
access(all) var a: Int
init() {
self.a = 0
}
}
}
`

err := testContractUpdate(t, oldCode, newCode)
require.NoError(t, err)
})
}

func TestContractUpgradeIntersectionAuthorization(t *testing.T) {
Expand Down Expand Up @@ -1166,16 +1255,8 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) {

t.Parallel()

const oldImport = `
pub contract TestImport {
pub resource interface I {
pub fun foo()
}
}
`

const oldCode = `
import TestImport from "imported"
import TestImport from 0x01

pub contract Test {
pub resource R:TestImport.I {
Expand All @@ -1199,7 +1280,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) {
`

const newCode = `
import TestImport from "imported"
import TestImport from 0x01

access(all) contract Test {
access(all) entitlement F
Expand All @@ -1215,7 +1296,17 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) {
}
`

err := testContractUpdateWithImports(t, oldCode, oldImport, newCode, newImport)
err := testContractUpdateWithImports(
t,
oldCode,
newCode,
map[common.Location]string{
common.AddressLocation{
Name: "TestImport",
Address: common.MustBytesToAddress([]byte{0x1}),
}: newImport,
},
)

require.NoError(t, err)
})
Expand Down Expand Up @@ -1270,16 +1361,8 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) {

t.Parallel()

const oldImport = `
pub contract TestImport {
pub resource interface I {
pub fun foo()
}
}
`

const oldCode = `
import TestImport from "imported"
import TestImport from 0x01

pub contract Test {
pub resource R:TestImport.I {
Expand All @@ -1303,7 +1386,7 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) {
`

const newCode = `
import TestImport from "imported"
import TestImport from 0x01

access(all) contract Test {
access(all) entitlement F
Expand All @@ -1319,7 +1402,17 @@ func TestContractUpgradeIntersectionFieldType(t *testing.T) {
}
`

err := testContractUpdateWithImports(t, oldCode, oldImport, newCode, newImport)
err := testContractUpdateWithImports(
t,
oldCode,
newCode,
map[common.Location]string{
common.AddressLocation{
Name: "TestImport",
Address: common.MustBytesToAddress([]byte{0x1}),
}: newImport,
},
)

cause := getSingleContractUpdateErrorCause(t, err, "Test")
assertFieldAuthorizationMismatchError(t, cause, "Test", "a", "E", "E, F")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) idAndLocationOfQualifie
rootIdentifier := validator.TypeComparator.RootDeclIdentifier.Identifier
location := validator.underlyingUpdateValidator.location

if typIdentifier != rootIdentifier &&
validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier] == nil {
foundLocations := validator.TypeComparator.foundIdentifierImportLocations

if typIdentifier != rootIdentifier && foundLocations[typIdentifier] == nil {
qualifiedString = fmt.Sprintf("%s.%s", rootIdentifier, qualifiedString)
return common.NewTypeIDFromQualifiedName(nil, location, qualifiedString), location

}

if loc := validator.TypeComparator.foundIdentifierImportLocations[typ.Identifier.Identifier]; loc != nil {
if loc := foundLocations[typIdentifier]; loc != nil {
SupunS marked this conversation as resolved.
Show resolved Hide resolved
location = loc
}

Expand All @@ -138,7 +138,11 @@ func (validator *CadenceV042ToV1ContractUpdateValidator) getEntitlementType(
entitlement *ast.NominalType,
) *sema.EntitlementType {
typeID, location := validator.idAndLocationOfQualifiedType(entitlement)
return validator.newElaborations[location].EntitlementType(typeID)
elaboration, ok := validator.newElaborations[location]
if !ok {
panic(errors.NewUnreachableError())
}
return elaboration.EntitlementType(typeID)
}

func (validator *CadenceV042ToV1ContractUpdateValidator) getEntitlementSetAccess(
Expand Down Expand Up @@ -346,10 +350,27 @@ typeSwitch:
}
}

// If the new/old type is non-storable,
// then changing the type of this field has no impact to the storage.
if isNonStorableType(oldType) || isNonStorableType(newType) {
return nil
}

return oldType.CheckEqual(newType, validator)

}

func isNonStorableType(typ ast.Type) bool {
switch typ := typ.(type) {
case *ast.ReferenceType, *ast.FunctionType:
return true
case *ast.OptionalType:
return isNonStorableType(typ.Type)
default:
turbolent marked this conversation as resolved.
Show resolved Hide resolved
return false
}
}

func (validator *CadenceV042ToV1ContractUpdateValidator) checkField(oldField *ast.FieldDeclaration, newField *ast.FieldDeclaration) {
oldType := oldField.TypeAnnotation.Type
newType := newField.TypeAnnotation.Type
Expand Down
23 changes: 18 additions & 5 deletions runtime/stdlib/contract_update_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,22 +131,35 @@ func collectImports(validator UpdateValidator, program *ast.Program) map[string]
for _, importDecl := range imports {
importLocation := importDecl.Location

addressLocation, ok := importLocation.(common.AddressLocation)
if !ok {
// e.g: Crypto
continue
}

// if there are no identifiers given, the import covers all of them
if addressLocation, isAddressLocation := importLocation.(common.AddressLocation); isAddressLocation && len(importDecl.Identifiers) == 0 {
if len(importDecl.Identifiers) == 0 {
allLocations, err := validator.getAccountContractNames(addressLocation.Address)
if err != nil {
validator.report(err)
}
for _, identifier := range allLocations {
// associate the location of an identifier's import with the location it's being imported from
// this assumes that two imports cannot have the same name, which should be prevented by the type checker
importLocations[identifier] = importLocation
importLocations[identifier] = common.AddressLocation{
Name: identifier,
Address: addressLocation.Address,
}
}
} else {
for _, identifier := range importDecl.Identifiers {
// associate the location of an identifier's import with the location it's being imported from
// this assumes that two imports cannot have the same name, which should be prevented by the type checker
importLocations[identifier.Identifier] = importLocation
name := identifier.Identifier
// associate the location of an identifier's import with the location it's being imported from.
// This assumes that two imports cannot have the same name, which should be prevented by the type checker
importLocations[name] = common.AddressLocation{
Name: name,
Address: addressLocation.Address,
}
}
}
}
Expand Down
Loading