Skip to content

Commit

Permalink
feat: reentrancy protection (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuelJet authored Oct 8, 2024
1 parent 98bd7f6 commit 3aa434b
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 180 deletions.
36 changes: 19 additions & 17 deletions src/components/MultiSigTimelock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ abstract contract MultiSigTimelock is User, IMultiSigTimelock {
ActionType actionType,
address target,
uint256 value
) public validSigner isExecutable validActionType(actionType) {
) public validSigner isExecutable validActionType(actionType) nonReentrant {
if (_isPendingAction) revert PendingActionState(_isPendingAction);

if (actionType == ActionType.ADD_SIGNER && target == address(0)) {
Expand Down Expand Up @@ -208,7 +208,7 @@ abstract contract MultiSigTimelock is User, IMultiSigTimelock {
*/
function executeAction(
uint256 actionId
) public validExecutor validAction(actionId) pendingAction(actionId) {
) public validExecutor validAction(actionId) pendingAction(actionId) nonReentrant {
Action storage action = _actions[actionId];
uint256 executionTimestamp = block.timestamp;
if (!_isMultiSigTimelockElapsed(action.timestamp)) {
Expand All @@ -221,7 +221,9 @@ abstract contract MultiSigTimelock is User, IMultiSigTimelock {
if (!_isSignatoryThresholdMet(action.approvals.current())) {
revert InsufficientSignerApprovals(signatoryThreshold, action.approvals.current());
}
} else {
}

if (_isExecutor(_msgSender()) && action.approvals.current() < signatoryThreshold) {
action.isOverride = true;
}

Expand Down Expand Up @@ -332,6 +334,20 @@ abstract contract MultiSigTimelock is User, IMultiSigTimelock {
return action.signatures.values();
}

/**
* @notice Updates the signatory threshold for the vault.
* Emits the `ThresholdUpdated` event.
* @param newThreshold The new threshold value for signatory approval.
* @dev Only callable by the contract executors.
*/
function _updateSignatoryThreshold(
uint256 newThreshold
) internal validExecutor {
uint256 oldThreshold = signatoryThreshold;
signatoryThreshold = newThreshold;
emit ThresholdUpdated(oldThreshold, newThreshold);
}

/**
* @notice Verifies if the signatory threshold has been met for an action.
* @param approvalCount The current number of approvals for the action.
Expand All @@ -354,20 +370,6 @@ abstract contract MultiSigTimelock is User, IMultiSigTimelock {
return block.timestamp >= initiatedAt + multiSigTimelock;
}

/**
* @notice Updates the signatory threshold for the vault.
* Emits the `ThresholdUpdated` event.
* @param newThreshold The new threshold value for signatory approval.
* @dev Only callable by the contract executors.
*/
function _updateSignatoryThreshold(
uint256 newThreshold
) internal validExecutor {
uint256 oldThreshold = signatoryThreshold;
signatoryThreshold = newThreshold;
emit ThresholdUpdated(oldThreshold, newThreshold);
}

/**
* @notice Updates the multi-sig vault timelock.
* Emits the `MultiSigTimelockUpdated` event.
Expand Down
16 changes: 7 additions & 9 deletions src/components/MultiSigTransaction.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import {IERC20} from '@openzeppelin/contracts/token/ERC20/IERC20.sol';
import {SafeERC20} from '@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol';
import {EnumerableSet} from '@openzeppelin/contracts/utils/structs/EnumerableSet.sol';
import {IMultiSigTransaction} from '../interfaces/IMultiSigTransaction.sol';
import {ERC20Validator} from '../libraries/ERC20Validator.sol';
import {Transaction} from '../utilities/VaultStructs.sol';
import {MultiSigTimelock} from './MultiSigTimelock.sol';
import {SafeMath} from '../libraries/SafeMath.sol';
Expand Down Expand Up @@ -84,8 +83,7 @@ abstract contract MultiSigTransaction is MultiSigTimelock, IMultiSigTransaction
/**
* @inheritdoc IMultiSigTransaction
*/
function depositToken(address token, uint256 amount) external payable {
ERC20Validator.requireValidERC20Token(token);
function depositToken(address token, uint256 amount) external payable nonReentrant {
uint256 allowance = IERC20(token).allowance(_msgSender(), address(this));

if (allowance < amount) {
Expand All @@ -109,14 +107,13 @@ abstract contract MultiSigTransaction is MultiSigTimelock, IMultiSigTransaction
address token,
uint256 value,
bytes memory data
) public validSigner isExecutable {
) public validSigner isExecutable nonReentrant {
AddressUtils.requireValidTransactionReceiver(to);
if (_isPendingTransaction) revert PendingTransactionState(_isPendingTransaction);

if (token == address(0)) {
if (value > getBalance()) revert InsufficientTokenBalance(getBalance(), value);
} else {
ERC20Validator.requireValidERC20Token(token);
if (value > getTokenBalance(token)) revert InsufficientTokenBalance(getTokenBalance(token), value);
}

Expand Down Expand Up @@ -171,7 +168,7 @@ abstract contract MultiSigTransaction is MultiSigTimelock, IMultiSigTransaction
*/
function executeTransaction(
uint256 txId
) public validExecutor validTransaction(txId) pendingTransaction(txId) {
) public validExecutor validTransaction(txId) pendingTransaction(txId) nonReentrant {
uint256 executionTimestamp = block.timestamp;
Transaction storage txn = _transactions[txId];
if (!_isMultiSigTimelockElapsed(txn.timestamp)) {
Expand All @@ -184,7 +181,9 @@ abstract contract MultiSigTransaction is MultiSigTimelock, IMultiSigTransaction
if (!_isSignatoryThresholdMet(txn.approvals.current())) {
revert InsufficientSignerApprovals(signatoryThreshold, txn.approvals.current());
}
} else {
}

if (_isExecutor(_msgSender()) && txn.approvals.current() < signatoryThreshold) {
txn.isOverride = true;
}

Expand Down Expand Up @@ -233,8 +232,7 @@ abstract contract MultiSigTransaction is MultiSigTimelock, IMultiSigTransaction
*/
function getTokenBalance(
address token
) public returns (uint256) {
ERC20Validator.requireValidERC20Token(token);
) public view returns (uint256) {
return IERC20(token).balanceOf(address(this));
}

Expand Down
85 changes: 43 additions & 42 deletions src/components/user/User.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ import {SafeMath} from '../../libraries/SafeMath.sol';
import {RoleType} from '../../utilities/VaultEnums.sol';
import {UserProfile} from '../../utilities/VaultStructs.sol';
import {AddressUtils} from '../../libraries/AddressUtils.sol';
import '@openzeppelin/contracts/utils/ReentrancyGuard.sol';

/**
* @title User Contract
* @author Emmanuel Joseph (JET)
* @dev Manages user profiles and integrates with the Owner role for user administration within the MultiSigVault system.
*/
abstract contract User is OwnerRole, ExecutorRole, SignerRole, IUser {
abstract contract User is ReentrancyGuard, OwnerRole, ExecutorRole, SignerRole, IUser {
using Counters for Counters.Counter;
using AddressUtils for address;

Expand Down Expand Up @@ -75,38 +76,14 @@ abstract contract User is OwnerRole, ExecutorRole, SignerRole, IUser {
_;
}

/**
* @notice Returns the total number of users.
* @return uint256 The total number of users.
* @dev Only callable by the owner.
*/
function totalUsers() public view onlyOwner returns (uint256) {
return _userCount.current();
}

/**
* @notice Returns the user profile object of a given address.
*
* @param user The address of the user whose profile is requested.
* @return UserProfile The user profile object associated with the provided address.
* @dev Requirements:
* - Limited to the owner account.
* - `user` cannot be the zero address.
*/
function getUserProfile(
address user
) public view validUser(user) onlyOwner returns (UserProfile memory) {
return _users[user];
}

/**
* @notice Adds a new executor.
* @param newExecutor The address of the new executor.
* @dev Only callable by the owner.
*/
function addExecutor(
address newExecutor
) public onlyOwner {
) public onlyOwner nonReentrant {
_addExecutor(newExecutor);
_addUser(newExecutor, RoleType.EXECUTOR);
}
Expand All @@ -118,7 +95,7 @@ abstract contract User is OwnerRole, ExecutorRole, SignerRole, IUser {
*/
function updateExecutor(
address newExecutor
) public onlyOwner {
) public onlyOwner nonReentrant {
address oldExecutor = executor();
oldExecutor.requireValidUserAddress();
newExecutor.requireValidUserAddress();
Expand All @@ -135,7 +112,7 @@ abstract contract User is OwnerRole, ExecutorRole, SignerRole, IUser {
* NOTE: Removing executor will leave the contract without an executor,
* thereby disabling any functionality that is only available to the executor.
*/
function removeExecutor() public onlyOwner {
function removeExecutor() public onlyOwner nonReentrant {
address oldExecutor = executor();
oldExecutor.requireValidUserAddress();

Expand All @@ -146,7 +123,7 @@ abstract contract User is OwnerRole, ExecutorRole, SignerRole, IUser {
/**
* @notice Allows an executor to approve the owner override after the timelock has elapsed.
*/
function approveOwnerOverride() public onlyExecutor {
function approveOwnerOverride() public onlyExecutor nonReentrant {
address currentOwner = owner();
address currentExecutor = _msgSender();
uint256 currentTimestamp = block.timestamp;
Expand All @@ -159,14 +136,38 @@ abstract contract User is OwnerRole, ExecutorRole, SignerRole, IUser {
_addUser(currentExecutor, RoleType.OWNER);
}

/**
* @notice Returns the total number of users.
* @return uint256 The total number of users.
* @dev Only callable by the owner.
*/
function totalUsers() public view onlyOwner returns (uint256) {
return _userCount.current();
}

/**
* @notice Returns the user profile object of a given address.
*
* @param user The address of the user whose profile is requested.
* @return UserProfile The user profile object associated with the provided address.
* @dev Requirements:
* - Limited to the owner account.
* - `user` cannot be the zero address.
*/
function getUserProfile(
address user
) public view validUser(user) onlyOwner returns (UserProfile memory) {
return _users[user];
}

/**
* @notice Adds a new signer user.
* @param newSigner The address of the new signer.
* @dev Only callable by the owner.
*/
function _addSigner(
address newSigner
) internal override validExecutor {
) internal override validExecutor nonReentrant {
super._addSigner(newSigner);
_addUser(newSigner, RoleType.SIGNER);
}
Expand All @@ -178,23 +179,12 @@ abstract contract User is OwnerRole, ExecutorRole, SignerRole, IUser {
*/
function _removeSigner(
address signer
) internal override validExecutor {
) internal override validExecutor nonReentrant {
signer.requireValidUserAddress();
super._removeSigner(signer);
_removeUser(signer);
}

/**
* @notice Checks if an address is a user.
* @param user The address to check.
* @return status True if the address is a user, otherwise false.
*/
function _isUser(
address user
) private view returns (bool status) {
status = _users[user].user != address(0);
}

/**
* @notice Returns the total number of valid signers.
* @return uint256 The total number of valid signers.
Expand Down Expand Up @@ -226,4 +216,15 @@ abstract contract User is OwnerRole, ExecutorRole, SignerRole, IUser {
delete _users[user];
_userCount.decrement();
}

/**
* @notice Checks if an address is a user.
* @param user The address to check.
* @return status True if the address is a user, otherwise false.
*/
function _isUser(
address user
) private view returns (bool status) {
status = _users[user].user != address(0);
}
}
16 changes: 8 additions & 8 deletions src/components/user/roles/ExecutorRole.sol
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ abstract contract ExecutorRole is AccessControl, IExecutorRole {
_;
}

/**
* @notice Returns the address of the current executor.
* @return The address of the executor.
*/
function executor() public view returns (address) {
return _executor;
}

/**
* @notice Initiates the owner override process with a timelock.
* @dev Only callable by the executor. The override process will start and only be executable after the timelock period has passed.
Expand All @@ -58,6 +50,14 @@ abstract contract ExecutorRole is AccessControl, IExecutorRole {
emit OwnerOverrideInitiated(_msgSender(), overrideInitiatedAt);
}

/**
* @notice Returns the address of the current executor.
* @return The address of the executor.
*/
function executor() public view returns (address) {
return _executor;
}

/**
* @notice Adds a new executor.
* @param newExecutor The address of the new executor.
Expand Down
30 changes: 15 additions & 15 deletions src/components/user/roles/OwnerRole.sol
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,6 @@ abstract contract OwnerRole is AccessControl, IOwnerRole {
_;
}

/**
* @notice Returns the address of the current owner.
* @return The address of the owner.
*/
function owner() public view returns (address) {
return _owner;
}

/**
* @notice Increases the owner override timelock to a new limit.
* Emits the `OwnerOverrideTimelockIncreased` event.
Expand Down Expand Up @@ -96,14 +88,11 @@ abstract contract OwnerRole is AccessControl, IOwnerRole {
}

/**
* @notice Checks if an address is the contract owner.
* @param account The address to check.
* @return status True if the address is the contract owner, otherwise false.
* @notice Returns the address of the current owner.
* @return The address of the owner.
*/
function _isOwner(
address account
) internal view returns (bool status) {
status = account == _owner && hasRole(OWNER_ROLE, account);
function owner() public view returns (address) {
return _owner;
}

/**
Expand All @@ -119,4 +108,15 @@ abstract contract OwnerRole is AccessControl, IOwnerRole {
_owner = newOwner;
emit OwnerChanged(oldOwner, newOwner);
}

/**
* @notice Checks if an address is the contract owner.
* @param account The address to check.
* @return status True if the address is the contract owner, otherwise false.
*/
function _isOwner(
address account
) internal view returns (bool status) {
status = account == _owner && hasRole(OWNER_ROLE, account);
}
}
4 changes: 2 additions & 2 deletions src/libraries/AddressUtils.sol
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ library AddressUtils {
*/
function requireValidTransactionReceiver(
address payable receiver
) internal pure {
if (receiver == address(0)) {
) internal view {
if (receiver == address(0) || receiver == address(this)) {
revert InvalidTransactionReceiver(receiver);
}
}
Expand Down
Loading

0 comments on commit 3aa434b

Please sign in to comment.