From 125cd59a16e99c9c0941519473cb17356c420495 Mon Sep 17 00:00:00 2001 From: Sammy Date: Tue, 24 Sep 2024 14:16:38 -0400 Subject: [PATCH] move lockable coin from types to keeper package --- x/axelarnet/keeper/message_route.go | 3 +- x/axelarnet/keeper/msg_server.go | 13 ++- x/axelarnet/message_handler.go | 40 ++++---- x/axelarnet/module.go | 5 +- x/axelarnet/types/evm_translator_test.go | 9 +- x/axelarnet/types/expected_keepers.go | 1 + x/axelarnet/types/mock/expected_keepers.go | 62 ++++++++++++ x/nexus/exported/types.go | 8 ++ .../coin.go => keeper/lockable_coin.go} | 96 +++++++++++-------- 9 files changed, 159 insertions(+), 78 deletions(-) rename x/nexus/{types/coin.go => keeper/lockable_coin.go} (57%) diff --git a/x/axelarnet/keeper/message_route.go b/x/axelarnet/keeper/message_route.go index e66a19bb1..29703975c 100644 --- a/x/axelarnet/keeper/message_route.go +++ b/x/axelarnet/keeper/message_route.go @@ -9,7 +9,6 @@ import ( "github.com/axelarnetwork/axelar-core/x/axelarnet/types" nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" - nexustypes "github.com/axelarnetwork/axelar-core/x/nexus/types" ) // for IBC execution @@ -78,7 +77,7 @@ func escrowAssetToMessageSender( return asset, bankK.SendCoins(ctx, sender, types.AxelarIBCAccount, sdk.NewCoins(asset)) case nexus.TypeGeneralMessageWithToken: - coin, err := nexustypes.NewCoin(ctx, nexusK, ibcK, bankK, *msg.Asset) + coin, err := nexusK.NewLockableCoin(ctx, ibcK, bankK, *msg.Asset) if err != nil { return sdk.Coin{}, err } diff --git a/x/axelarnet/keeper/msg_server.go b/x/axelarnet/keeper/msg_server.go index 50f891860..5574c6642 100644 --- a/x/axelarnet/keeper/msg_server.go +++ b/x/axelarnet/keeper/msg_server.go @@ -16,7 +16,6 @@ import ( "github.com/axelarnetwork/axelar-core/x/axelarnet/exported" "github.com/axelarnetwork/axelar-core/x/axelarnet/types" nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" - nexustypes "github.com/axelarnetwork/axelar-core/x/nexus/types" tss "github.com/axelarnetwork/axelar-core/x/tss/exported" "github.com/axelarnetwork/utils/funcs" ) @@ -80,7 +79,7 @@ func (s msgServer) CallContract(c context.Context, req *types.CallContractReques }) if req.Fee != nil { - normalizedCoin, err := nexustypes.NewCoin(ctx, s.nexus, s.ibcK, s.bank, req.Fee.Amount) + normalizedCoin, err := s.nexus.NewLockableCoin(ctx, s.ibcK, s.bank, req.Fee.Amount) if err != nil { return nil, sdkerrors.Wrap(err, "unrecognized fee denom") } @@ -94,7 +93,7 @@ func (s msgServer) CallContract(c context.Context, req *types.CallContractReques MessageID: msgID, Recipient: req.Fee.Recipient, Fee: req.Fee.Amount, - Asset: normalizedCoin.GetDenom(), + Asset: normalizedCoin.GetCoin().Denom, } if req.Fee.RefundRecipient != nil { feePaidEvent.RefundRecipient = req.Fee.RefundRecipient.String() @@ -186,7 +185,7 @@ func (s msgServer) ConfirmDeposit(c context.Context, req *types.ConfirmDepositRe return nil, fmt.Errorf("recipient chain '%s' is not activated", recipient.Chain.Name) } - normalizedCoin, err := nexustypes.NewCoin(ctx, s.nexus, s.ibcK, s.bank, coin) + normalizedCoin, err := s.nexus.NewLockableCoin(ctx, s.ibcK, s.bank, coin) if err != nil { return nil, err } @@ -195,7 +194,7 @@ func (s msgServer) ConfirmDeposit(c context.Context, req *types.ConfirmDepositRe return nil, err } - transferID, err := s.nexus.EnqueueForTransfer(ctx, depositAddr, normalizedCoin.Coin) + transferID, err := s.nexus.EnqueueForTransfer(ctx, depositAddr, normalizedCoin.GetCoin()) if err != nil { return nil, err } @@ -404,7 +403,7 @@ func (s msgServer) RouteIBCTransfers(c context.Context, _ *types.RouteIBCTransfe return nil, err } for _, p := range pendingTransfers { - coin, err := nexustypes.NewCoin(ctx, s.nexus, s.ibcK, s.bank, p.Asset) + coin, err := s.nexus.NewLockableCoin(ctx, s.ibcK, s.bank, p.Asset) if err != nil { s.Logger(ctx).Error(fmt.Sprintf("failed to route IBC transfer %s: %s", p.String(), err)) continue @@ -500,7 +499,7 @@ func (s msgServer) RouteMessage(c context.Context, req *types.RouteMessageReques } func transfer(ctx sdk.Context, k Keeper, n types.Nexus, b types.BankKeeper, ibc types.IBCKeeper, recipient sdk.AccAddress, coin sdk.Coin) error { - c, err := nexustypes.NewCoin(ctx, n, ibc, b, coin) + c, err := n.NewLockableCoin(ctx, ibc, b, coin) if err != nil { return err } diff --git a/x/axelarnet/message_handler.go b/x/axelarnet/message_handler.go index 5211c1ad4..45c5f194b 100644 --- a/x/axelarnet/message_handler.go +++ b/x/axelarnet/message_handler.go @@ -19,7 +19,6 @@ import ( "github.com/axelarnetwork/axelar-core/x/axelarnet/keeper" "github.com/axelarnetwork/axelar-core/x/axelarnet/types" nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" - nexustypes "github.com/axelarnetwork/axelar-core/x/nexus/types" tss "github.com/axelarnetwork/axelar-core/x/tss/exported" "github.com/axelarnetwork/utils/funcs" ) @@ -185,7 +184,7 @@ func OnRecvMessage(ctx sdk.Context, k keeper.Keeper, ibcK keeper.IBCKeeper, n ty return ack } -func validateMessage(ctx sdk.Context, ibcK keeper.IBCKeeper, n types.Nexus, ibcPath string, msg Message, token nexustypes.Coin) error { +func validateMessage(ctx sdk.Context, ibcK keeper.IBCKeeper, n types.Nexus, ibcPath string, msg Message, token nexus.LockableCoin) error { // validate source chain srcChainName, srcChainFound := ibcK.GetChainNameByIBCPath(ctx, ibcPath) if !srcChainFound { @@ -198,7 +197,7 @@ func validateMessage(ctx sdk.Context, ibcK keeper.IBCKeeper, n types.Nexus, ibcP } if msg.Fee != nil { - err := validateFee(ctx, n, token.Coin, nexus.MessageType(msg.Type), srcChain, *msg.Fee) + err := validateFee(ctx, n, token.GetCoin(), nexus.MessageType(msg.Type), srcChain, *msg.Fee) if err != nil { return err } @@ -231,12 +230,12 @@ func validateMessage(ctx sdk.Context, ibcK keeper.IBCKeeper, n types.Nexus, ibcP return fmt.Errorf("unrecognized destination chain %s", destChainName) } - if !n.IsAssetRegistered(ctx, srcChain, token.GetDenom()) { - return fmt.Errorf("asset %s is not registered on chain %s", token.GetDenom(), srcChain.Name) + if !n.IsAssetRegistered(ctx, srcChain, token.GetCoin().Denom) { + return fmt.Errorf("asset %s is not registered on chain %s", token.GetCoin().Denom, srcChain.Name) } - if !n.IsAssetRegistered(ctx, destChain, token.GetDenom()) { - return fmt.Errorf("asset %s is not registered on chain %s", token.GetDenom(), destChain.Name) + if !n.IsAssetRegistered(ctx, destChain, token.GetCoin().Denom) { + return fmt.Errorf("asset %s is not registered on chain %s", token.GetCoin().Denom, destChain.Name) } return nil default: @@ -244,7 +243,7 @@ func validateMessage(ctx sdk.Context, ibcK keeper.IBCKeeper, n types.Nexus, ibcP } } -func handleMessage(ctx sdk.Context, n types.Nexus, b types.BankKeeper, sourceAddress nexus.CrossChainAddress, msg Message, token nexustypes.Coin) error { +func handleMessage(ctx sdk.Context, n types.Nexus, b types.BankKeeper, sourceAddress nexus.CrossChainAddress, msg Message, token nexus.LockableCoin) error { id, txID, nonce := n.GenerateMessageID(ctx) // ignore token for call contract @@ -285,7 +284,7 @@ func handleMessage(ctx sdk.Context, n types.Nexus, b types.BankKeeper, sourceAdd return n.SetNewMessage(ctx, m) } -func handleMessageWithToken(ctx sdk.Context, n types.Nexus, b types.BankKeeper, sourceAddress nexus.CrossChainAddress, msg Message, token nexustypes.Coin) error { +func handleMessageWithToken(ctx sdk.Context, n types.Nexus, b types.BankKeeper, sourceAddress nexus.CrossChainAddress, msg Message, token nexus.LockableCoin) error { id, txID, nonce := n.GenerateMessageID(ctx) token, err := deductFee(ctx, b, msg.Fee, token, id) @@ -299,6 +298,7 @@ func handleMessageWithToken(ctx sdk.Context, n types.Nexus, b types.BankKeeper, destChain := funcs.MustOk(n.GetChain(ctx, nexus.ChainName(msg.DestinationChain))) recipient := nexus.CrossChainAddress{Chain: destChain, Address: msg.DestinationAddress} + coin := token.GetCoin() m := nexus.NewGeneralMessage( id, sourceAddress, @@ -306,7 +306,7 @@ func handleMessageWithToken(ctx sdk.Context, n types.Nexus, b types.BankKeeper, crypto.Keccak256Hash(msg.Payload).Bytes(), txID, nonce, - &token.Coin, + &coin, ) events.Emit(ctx, &types.ContractCallWithTokenSubmitted{ @@ -317,13 +317,13 @@ func handleMessageWithToken(ctx sdk.Context, n types.Nexus, b types.BankKeeper, ContractAddress: m.GetDestinationAddress(), PayloadHash: m.PayloadHash, Payload: msg.Payload, - Asset: token.Coin, + Asset: token.GetCoin(), }) return n.SetNewMessage(ctx, m) } -func handleTokenSent(ctx sdk.Context, n types.Nexus, sourceAddress nexus.CrossChainAddress, msg Message, token nexustypes.Coin) error { +func handleTokenSent(ctx sdk.Context, n types.Nexus, sourceAddress nexus.CrossChainAddress, msg Message, token nexus.LockableCoin) error { destChain := funcs.MustOk(n.GetChain(ctx, nexus.ChainName(msg.DestinationChain))) crossChainAddr := nexus.CrossChainAddress{Chain: destChain, Address: msg.DestinationAddress} @@ -331,7 +331,7 @@ func handleTokenSent(ctx sdk.Context, n types.Nexus, sourceAddress nexus.CrossCh return err } - transferID, err := n.EnqueueTransfer(ctx, sourceAddress.Chain, crossChainAddr, token.Coin) + transferID, err := n.EnqueueTransfer(ctx, sourceAddress.Chain, crossChainAddr, token.GetCoin()) if err != nil { return err } @@ -342,7 +342,7 @@ func handleTokenSent(ctx sdk.Context, n types.Nexus, sourceAddress nexus.CrossCh SourceChain: sourceAddress.Chain.Name, DestinationAddress: crossChainAddr.Address, DestinationChain: crossChainAddr.Chain.Name, - Asset: token.Coin, + Asset: token.GetCoin(), }) return nil @@ -351,7 +351,7 @@ func handleTokenSent(ctx sdk.Context, n types.Nexus, sourceAddress nexus.CrossCh // extractTokenFromPacketData get normalized token from ICS20 packet // panic if unable to unmarshal packet data -func extractTokenFromPacketData(ctx sdk.Context, ibcK keeper.IBCKeeper, n types.Nexus, b types.BankKeeper, packet ibcexported.PacketI) (nexustypes.Coin, error) { +func extractTokenFromPacketData(ctx sdk.Context, ibcK keeper.IBCKeeper, n types.Nexus, b types.BankKeeper, packet ibcexported.PacketI) (nexus.LockableCoin, error) { data := funcs.Must(types.ToICS20Packet(packet)) // parse the transfer amount @@ -389,11 +389,11 @@ func extractTokenFromPacketData(ctx sdk.Context, ibcK keeper.IBCKeeper, n types. denom = denomTrace.IBCDenom() } - return nexustypes.NewCoin(ctx, n, ibcK, b, sdk.NewCoin(denom, amount)) + return n.NewLockableCoin(ctx, ibcK, b, sdk.NewCoin(denom, amount)) } // deductFee pays the fee and returns the updated transfer amount with the fee deducted -func deductFee(ctx sdk.Context, b types.BankKeeper, fee *Fee, token nexustypes.Coin, msgID string) (nexustypes.Coin, error) { +func deductFee(ctx sdk.Context, b types.BankKeeper, fee *Fee, token nexus.LockableCoin, msgID string) (nexus.LockableCoin, error) { if fee == nil { return token, nil } @@ -406,7 +406,7 @@ func deductFee(ctx sdk.Context, b types.BankKeeper, fee *Fee, token nexustypes.C MessageID: msgID, Recipient: recipient, Fee: coin, - Asset: token.GetDenom(), + Asset: token.GetCoin().Denom, } if fee.RefundRecipient != nil { feePaidEvent.RefundRecipient = *fee.RefundRecipient @@ -414,9 +414,7 @@ func deductFee(ctx sdk.Context, b types.BankKeeper, fee *Fee, token nexustypes.C events.Emit(ctx, &feePaidEvent) // subtract fee from transfer value - token.Amount = token.Amount.Sub(feeAmount) - - return token, b.SendCoins(ctx, types.AxelarIBCAccount, recipient, sdk.NewCoins(coin)) + return token.Sub(sdk.NewCoin(token.GetCoin().Denom, feeAmount)), b.SendCoins(ctx, types.AxelarIBCAccount, recipient, sdk.NewCoins(coin)) } // validateReceiver rejects uppercase GMP account address diff --git a/x/axelarnet/module.go b/x/axelarnet/module.go index d84162ea2..be94430ed 100644 --- a/x/axelarnet/module.go +++ b/x/axelarnet/module.go @@ -29,7 +29,6 @@ import ( "github.com/axelarnetwork/axelar-core/x/axelarnet/keeper" "github.com/axelarnetwork/axelar-core/x/axelarnet/types" nexus "github.com/axelarnetwork/axelar-core/x/nexus/exported" - nexustypes "github.com/axelarnetwork/axelar-core/x/nexus/types" "github.com/axelarnetwork/utils/funcs" ) @@ -343,7 +342,7 @@ func (m AxelarnetIBCModule) setRoutedPacketFailed(ctx sdk.Context, packet channe // check if the packet is Axelar routed cross chain transfer transferID, ok := getSeqIDMapping(ctx, m.keeper, port, channel, sequence) if ok { - coin, err := nexustypes.NewCoin(ctx, m.nexus, m.ibcK, m.bank, funcs.MustOk(m.keeper.GetTransfer(ctx, transferID)).Token) + coin, err := m.nexus.NewLockableCoin(ctx, m.ibcK, m.bank, funcs.MustOk(m.keeper.GetTransfer(ctx, transferID)).Token) if err != nil { return err } @@ -369,7 +368,7 @@ func (m AxelarnetIBCModule) setRoutedPacketFailed(ctx sdk.Context, packet channe // check if the packet is Axelar routed general message messageID, ok := getSeqMessageIDMapping(ctx, m.keeper, port, channel, sequence) if ok { - coin, err := nexustypes.NewCoin(ctx, m.nexus, m.ibcK, m.bank, extractTokenFromAckOrTimeoutPacket(packet)) + coin, err := m.nexus.NewLockableCoin(ctx, m.ibcK, m.bank, extractTokenFromAckOrTimeoutPacket(packet)) if err != nil { return err } diff --git a/x/axelarnet/types/evm_translator_test.go b/x/axelarnet/types/evm_translator_test.go index 1547d5984..82a96d150 100644 --- a/x/axelarnet/types/evm_translator_test.go +++ b/x/axelarnet/types/evm_translator_test.go @@ -32,7 +32,6 @@ var ( addressType = funcs.Must(abi.NewType("address", "address", nil)) stringType = funcs.Must(abi.NewType("string", "string", nil)) bytesType = funcs.Must(abi.NewType("bytes", "bytes", nil)) - bytes32Type = funcs.Must(abi.NewType("bytes32", "bytes32", nil)) uint8Type = funcs.Must(abi.NewType("uint8", "uint8", nil)) uint64Type = funcs.Must(abi.NewType("uint64", "uint64", nil)) uint64ArrayType = funcs.Must(abi.NewType("uint64[]", "uint64[]", nil)) @@ -173,15 +172,15 @@ func TestConstructWasmMessageV1Large(t *testing.T) { assert.True(t, ok) assert.Equal(t, boolTrue, actualBool) - arrayInterface, ok := executeMsg["string_array"].([]interface{}) + arrayInterface, _ := executeMsg["string_array"].([]interface{}) actualStringArray := slices.Map(arrayInterface, func(t interface{}) string { return t.(string) }) assert.Equal(t, stringArray, actualStringArray) - arrayInterface, ok = executeMsg["uint64_array"].([]interface{}) + arrayInterface, _ = executeMsg["uint64_array"].([]interface{}) uint64StrArray := slices.Map(arrayInterface, func(t interface{}) string { return t.(json.Number).String() }) assert.Equal(t, uint64Array, slices.Map(uint64StrArray, func(t string) uint64 { return funcs.Must(strconv.ParseUint(t, 10, 64)) })) - arrayInterface, ok = executeMsg["uint64_array_nested"].([]interface{}) + arrayInterface, _ = executeMsg["uint64_array_nested"].([]interface{}) actualUint64NestedArray := slices.Map(arrayInterface, func(inner interface{}) []uint64 { return slices.Map(inner.([]interface{}), func(t interface{}) uint64 { return funcs.Must(strconv.ParseUint(t.(json.Number).String(), 10, 64)) @@ -189,7 +188,7 @@ func TestConstructWasmMessageV1Large(t *testing.T) { }) assert.Equal(t, uint64NestedArray, actualUint64NestedArray) - arrayInterface, ok = executeMsg["string_array_nested"].([]interface{}) + arrayInterface, _ = executeMsg["string_array_nested"].([]interface{}) actualStringNestedArray := slices.Map(arrayInterface, func(inner interface{}) []string { return slices.Map(inner.([]interface{}), func(t interface{}) string { return t.(string) diff --git a/x/axelarnet/types/expected_keepers.go b/x/axelarnet/types/expected_keepers.go index c13a9cdf7..b79b28568 100644 --- a/x/axelarnet/types/expected_keepers.go +++ b/x/axelarnet/types/expected_keepers.go @@ -87,6 +87,7 @@ type Nexus interface { SetMessageFailed(ctx sdk.Context, id string) error GenerateMessageID(ctx sdk.Context) (string, []byte, uint64) ValidateAddress(ctx sdk.Context, address nexus.CrossChainAddress) error + NewLockableCoin(ctx sdk.Context, ibc nexustypes.IBCKeeper, bank nexustypes.BankKeeper, coin sdk.Coin) (nexus.LockableCoin, error) } // BankKeeper defines the expected interface contract the vesting module requires diff --git a/x/axelarnet/types/mock/expected_keepers.go b/x/axelarnet/types/mock/expected_keepers.go index 760e6cecc..af78ed715 100644 --- a/x/axelarnet/types/mock/expected_keepers.go +++ b/x/axelarnet/types/mock/expected_keepers.go @@ -751,6 +751,9 @@ var _ axelarnettypes.Nexus = &NexusMock{} // LoggerFunc: func(ctx cosmossdktypes.Context) log.Logger { // panic("mock out the Logger method") // }, +// NewLockableCoinFunc: func(ctx cosmossdktypes.Context, ibc nexustypes.IBCKeeper, bank nexustypes.BankKeeper, coin cosmossdktypes.Coin) (github_com_axelarnetwork_axelar_core_x_nexus_exported.LockableCoin, error) { +// panic("mock out the NewLockableCoin method") +// }, // RateLimitTransferFunc: func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.ChainName, asset cosmossdktypes.Coin, direction github_com_axelarnetwork_axelar_core_x_nexus_exported.TransferDirection) error { // panic("mock out the RateLimitTransfer method") // }, @@ -884,6 +887,9 @@ type NexusMock struct { // LoggerFunc mocks the Logger method. LoggerFunc func(ctx cosmossdktypes.Context) log.Logger + // NewLockableCoinFunc mocks the NewLockableCoin method. + NewLockableCoinFunc func(ctx cosmossdktypes.Context, ibc nexustypes.IBCKeeper, bank nexustypes.BankKeeper, coin cosmossdktypes.Coin) (github_com_axelarnetwork_axelar_core_x_nexus_exported.LockableCoin, error) + // RateLimitTransferFunc mocks the RateLimitTransfer method. RateLimitTransferFunc func(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.ChainName, asset cosmossdktypes.Coin, direction github_com_axelarnetwork_axelar_core_x_nexus_exported.TransferDirection) error @@ -1128,6 +1134,17 @@ type NexusMock struct { // Ctx is the ctx argument value. Ctx cosmossdktypes.Context } + // NewLockableCoin holds details about calls to the NewLockableCoin method. + NewLockableCoin []struct { + // Ctx is the ctx argument value. + Ctx cosmossdktypes.Context + // Ibc is the ibc argument value. + Ibc nexustypes.IBCKeeper + // Bank is the bank argument value. + Bank nexustypes.BankKeeper + // Coin is the coin argument value. + Coin cosmossdktypes.Coin + } // RateLimitTransfer holds details about calls to the RateLimitTransfer method. RateLimitTransfer []struct { // Ctx is the ctx argument value. @@ -1269,6 +1286,7 @@ type NexusMock struct { lockIsWasmConnectionActivated sync.RWMutex lockLinkAddresses sync.RWMutex lockLogger sync.RWMutex + lockNewLockableCoin sync.RWMutex lockRateLimitTransfer sync.RWMutex lockRegisterAsset sync.RWMutex lockRegisterFee sync.RWMutex @@ -2328,6 +2346,50 @@ func (mock *NexusMock) LoggerCalls() []struct { return calls } +// NewLockableCoin calls NewLockableCoinFunc. +func (mock *NexusMock) NewLockableCoin(ctx cosmossdktypes.Context, ibc nexustypes.IBCKeeper, bank nexustypes.BankKeeper, coin cosmossdktypes.Coin) (github_com_axelarnetwork_axelar_core_x_nexus_exported.LockableCoin, error) { + if mock.NewLockableCoinFunc == nil { + panic("NexusMock.NewLockableCoinFunc: method is nil but Nexus.NewLockableCoin was just called") + } + callInfo := struct { + Ctx cosmossdktypes.Context + Ibc nexustypes.IBCKeeper + Bank nexustypes.BankKeeper + Coin cosmossdktypes.Coin + }{ + Ctx: ctx, + Ibc: ibc, + Bank: bank, + Coin: coin, + } + mock.lockNewLockableCoin.Lock() + mock.calls.NewLockableCoin = append(mock.calls.NewLockableCoin, callInfo) + mock.lockNewLockableCoin.Unlock() + return mock.NewLockableCoinFunc(ctx, ibc, bank, coin) +} + +// NewLockableCoinCalls gets all the calls that were made to NewLockableCoin. +// Check the length with: +// +// len(mockedNexus.NewLockableCoinCalls()) +func (mock *NexusMock) NewLockableCoinCalls() []struct { + Ctx cosmossdktypes.Context + Ibc nexustypes.IBCKeeper + Bank nexustypes.BankKeeper + Coin cosmossdktypes.Coin +} { + var calls []struct { + Ctx cosmossdktypes.Context + Ibc nexustypes.IBCKeeper + Bank nexustypes.BankKeeper + Coin cosmossdktypes.Coin + } + mock.lockNewLockableCoin.RLock() + calls = mock.calls.NewLockableCoin + mock.lockNewLockableCoin.RUnlock() + return calls +} + // RateLimitTransfer calls RateLimitTransferFunc. func (mock *NexusMock) RateLimitTransfer(ctx cosmossdktypes.Context, chain github_com_axelarnetwork_axelar_core_x_nexus_exported.ChainName, asset cosmossdktypes.Coin, direction github_com_axelarnetwork_axelar_core_x_nexus_exported.TransferDirection) error { if mock.RateLimitTransferFunc == nil { diff --git a/x/nexus/exported/types.go b/x/nexus/exported/types.go index 594bcfd58..f52a4d1bf 100644 --- a/x/nexus/exported/types.go +++ b/x/nexus/exported/types.go @@ -21,6 +21,14 @@ import ( //go:generate moq -out ./mock/types.go -pkg mock . MaintainerState +type LockableCoin interface { + GetCoin() sdk.Coin + GetOriginalCoin(ctx sdk.Context) sdk.Coin + Sub(sdk.Coin) LockableCoin + Lock(ctx sdk.Context, fromAddr sdk.AccAddress) error + Unlock(ctx sdk.Context, toAddr sdk.AccAddress) error +} + // AddressValidator defines a function that implements address verification upon a request to link addresses type AddressValidator func(ctx sdk.Context, address CrossChainAddress) error diff --git a/x/nexus/types/coin.go b/x/nexus/keeper/lockable_coin.go similarity index 57% rename from x/nexus/types/coin.go rename to x/nexus/keeper/lockable_coin.go index 63d8bdbfb..3b84a6828 100644 --- a/x/nexus/types/coin.go +++ b/x/nexus/keeper/lockable_coin.go @@ -1,4 +1,4 @@ -package types +package keeper import ( "fmt" @@ -9,37 +9,43 @@ import ( axelarnet "github.com/axelarnetwork/axelar-core/x/axelarnet/exported" "github.com/axelarnetwork/axelar-core/x/nexus/exported" + "github.com/axelarnetwork/axelar-core/x/nexus/types" "github.com/axelarnetwork/utils/funcs" ) -// Coin provides functionality to lock and release coins -type Coin struct { +// NewLockableCoin creates a new lockable coin +func (k Keeper) NewLockableCoin(ctx sdk.Context, ibc types.IBCKeeper, bank types.BankKeeper, coin sdk.Coin) (exported.LockableCoin, error) { + return newLockableCoin(ctx, k, ibc, bank, coin) +} + +// lockableCoin provides functionality to lock and release coins +type lockableCoin struct { sdk.Coin - coinType CoinType - nexus Nexus - ibc IBCKeeper - bank BankKeeper + coinType types.CoinType + nexus types.Nexus + ibc types.IBCKeeper + bank types.BankKeeper } -// NewCoin creates a coin struct, assign a coin type and normalize the denom if it's a ICS20 token -func NewCoin(ctx sdk.Context, nexus Nexus, ibc IBCKeeper, bank BankKeeper, coin sdk.Coin) (Coin, error) { +// newLockableCoin creates a coin struct, assign a coin type and normalize the denom if it's a ICS20 token +func newLockableCoin(ctx sdk.Context, nexus types.Nexus, ibc types.IBCKeeper, bank types.BankKeeper, coin sdk.Coin) (lockableCoin, error) { coinType, err := getCoinType(ctx, nexus, coin.GetDenom()) if err != nil { - return Coin{}, err + return lockableCoin{}, err } // If coin type is ICS20, we need to normalize it to convert from 'ibc/{hash}' // to native asset denom so that nexus could recognize it - if coinType == ICS20 { + if coinType == types.ICS20 { denomTrace, err := ibc.ParseIBCDenom(ctx, coin.GetDenom()) if err != nil { - return Coin{}, err + return lockableCoin{}, err } coin = sdk.NewCoin(denomTrace.GetBaseDenom(), coin.Amount) } - c := Coin{ + c := lockableCoin{ Coin: coin, coinType: coinType, nexus: nexus, @@ -47,26 +53,36 @@ func NewCoin(ctx sdk.Context, nexus Nexus, ibc IBCKeeper, bank BankKeeper, coin bank: bank, } if _, err := c.getOriginalCoin(ctx); err != nil { - return Coin{}, err + return lockableCoin{}, err } return c, nil } +func (c lockableCoin) GetCoin() sdk.Coin { + return c.Coin +} + // GetOriginalCoin returns the original coin -func (c Coin) GetOriginalCoin(ctx sdk.Context) sdk.Coin { +func (c lockableCoin) GetOriginalCoin(ctx sdk.Context) sdk.Coin { // NOTE: must not fail since it's already checked in NewCoin return funcs.Must(c.getOriginalCoin(ctx)) } +func (c lockableCoin) Sub(coin sdk.Coin) exported.LockableCoin { + c.Coin = c.Coin.Sub(coin) + + return c +} + // Lock locks the given coin from the given address -func (c Coin) Lock(ctx sdk.Context, fromAddr sdk.AccAddress) error { +func (c lockableCoin) Lock(ctx sdk.Context, fromAddr sdk.AccAddress) error { coin := c.GetOriginalCoin(ctx) switch c.coinType { - case ICS20, Native: + case types.ICS20, types.Native: return lock(ctx, c.bank, fromAddr, coin) - case External: + case types.External: return burn(ctx, c.bank, fromAddr, coin) default: return fmt.Errorf("unrecognized coin type %d", c.coinType) @@ -74,32 +90,32 @@ func (c Coin) Lock(ctx sdk.Context, fromAddr sdk.AccAddress) error { } // Unlock unlocks the given coin to the given address -func (c Coin) Unlock(ctx sdk.Context, toAddr sdk.AccAddress) error { +func (c lockableCoin) Unlock(ctx sdk.Context, toAddr sdk.AccAddress) error { coin := c.GetOriginalCoin(ctx) switch c.coinType { - case ICS20, Native: + case types.ICS20, types.Native: return unlock(ctx, c.bank, toAddr, coin) - case External: + case types.External: return mint(ctx, c.bank, toAddr, coin) default: return fmt.Errorf("unrecognized coin type %d", c.coinType) } } -func (c Coin) getOriginalCoin(ctx sdk.Context) (sdk.Coin, error) { +func (c lockableCoin) getOriginalCoin(ctx sdk.Context) (sdk.Coin, error) { switch c.coinType { - case ICS20: + case types.ICS20: return c.toICS20(ctx) - case Native, External: + case types.Native, types.External: return c.Coin, nil default: return sdk.Coin{}, fmt.Errorf("unrecognized coin type %d", c.coinType) } } -func (c Coin) toICS20(ctx sdk.Context) (sdk.Coin, error) { - if c.coinType != ICS20 { +func (c lockableCoin) toICS20(ctx sdk.Context) (sdk.Coin, error) { + if c.coinType != types.ICS20 { return sdk.Coin{}, fmt.Errorf("%s is not ICS20 token", c.GetDenom()) } @@ -122,53 +138,53 @@ func (c Coin) toICS20(ctx sdk.Context) (sdk.Coin, error) { return sdk.NewCoin(trace.IBCDenom(), c.Amount), nil } -func lock(ctx sdk.Context, bank BankKeeper, fromAddr sdk.AccAddress, coin sdk.Coin) error { +func lock(ctx sdk.Context, bank types.BankKeeper, fromAddr sdk.AccAddress, coin sdk.Coin) error { return bank.SendCoins(ctx, fromAddr, exported.GetEscrowAddress(coin.GetDenom()), sdk.NewCoins(coin)) } -func unlock(ctx sdk.Context, bank BankKeeper, toAddr sdk.AccAddress, coin sdk.Coin) error { +func unlock(ctx sdk.Context, bank types.BankKeeper, toAddr sdk.AccAddress, coin sdk.Coin) error { return bank.SendCoins(ctx, exported.GetEscrowAddress(coin.GetDenom()), toAddr, sdk.NewCoins(coin)) } -func burn(ctx sdk.Context, bank BankKeeper, fromAddr sdk.AccAddress, coin sdk.Coin) error { +func burn(ctx sdk.Context, bank types.BankKeeper, fromAddr sdk.AccAddress, coin sdk.Coin) error { coins := sdk.NewCoins(coin) - if err := bank.SendCoinsFromAccountToModule(ctx, fromAddr, ModuleName, coins); err != nil { + if err := bank.SendCoinsFromAccountToModule(ctx, fromAddr, types.ModuleName, coins); err != nil { return err } // NOTE: should never fail since the coin is just transfered to the module // account before the burn - funcs.MustNoErr(bank.BurnCoins(ctx, ModuleName, coins)) + funcs.MustNoErr(bank.BurnCoins(ctx, types.ModuleName, coins)) return nil } -func mint(ctx sdk.Context, bank BankKeeper, toAddr sdk.AccAddress, coin sdk.Coin) error { +func mint(ctx sdk.Context, bank types.BankKeeper, toAddr sdk.AccAddress, coin sdk.Coin) error { coins := sdk.NewCoins(coin) - if err := bank.MintCoins(ctx, ModuleName, coins); err != nil { + if err := bank.MintCoins(ctx, types.ModuleName, coins); err != nil { return err } // NOTE: should never fail since the coin is just minted to the module // account before the transfer - funcs.MustNoErr(bank.SendCoinsFromModuleToAccount(ctx, ModuleName, toAddr, coins)) + funcs.MustNoErr(bank.SendCoinsFromModuleToAccount(ctx, types.ModuleName, toAddr, coins)) return nil } -func getCoinType(ctx sdk.Context, nexus Nexus, denom string) (CoinType, error) { +func getCoinType(ctx sdk.Context, nexus types.Nexus, denom string) (types.CoinType, error) { switch { // check if the format of token denomination is 'ibc/{hash}' case isIBCDenom(denom): - return ICS20, nil + return types.ICS20, nil case isNativeAssetOnAxelarnet(ctx, nexus, denom): - return Native, nil + return types.Native, nil case nexus.IsAssetRegistered(ctx, axelarnet.Axelarnet, denom): - return External, nil + return types.External, nil default: - return Unrecognized, fmt.Errorf("unrecognized coin %s", denom) + return types.Unrecognized, fmt.Errorf("unrecognized coin %s", denom) } } @@ -190,7 +206,7 @@ func isIBCDenom(denom string) bool { return true } -func isNativeAssetOnAxelarnet(ctx sdk.Context, nexus Nexus, denom string) bool { +func isNativeAssetOnAxelarnet(ctx sdk.Context, nexus types.Nexus, denom string) bool { chain, ok := nexus.GetChainByNativeAsset(ctx, denom) return ok && chain.Name.Equals(axelarnet.Axelarnet.Name)