Skip to content

Commit df6a783

Browse files
committed
Add reentry guard to router
1 parent f17ca47 commit df6a783

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

src/TrailsRouter.sol

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pragma solidity ^0.8.30;
33

44
import {SafeERC20} from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol";
55
import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
6+
import {ReentrancyGuard} from "@openzeppelin/contracts/utils/ReentrancyGuard.sol";
67
import {IDelegatedExtension} from "wallet-contracts-v3/modules/interfaces/IDelegatedExtension.sol";
78
import {Tstorish} from "tstorish/Tstorish.sol";
89
import {IMulticall3} from "./interfaces/IMulticall3.sol";
@@ -13,7 +14,7 @@ import {TrailsSentinelLib} from "./libraries/TrailsSentinelLib.sol";
1314
/// @author Miguel Mota, Shun Kakinoki
1415
/// @notice Consolidated router for Trails operations including multicall routing, balance injection, and token sweeping
1516
/// @dev Can be delegatecalled via the Sequence delegated extension module to access wallet storage/balances.
16-
contract TrailsRouter is IDelegatedExtension, ITrailsRouter, Tstorish {
17+
contract TrailsRouter is IDelegatedExtension, ITrailsRouter, ReentrancyGuard, Tstorish {
1718
// -------------------------------------------------------------------------
1819
// Libraries
1920
// -------------------------------------------------------------------------
@@ -55,7 +56,7 @@ contract TrailsRouter is IDelegatedExtension, ITrailsRouter, Tstorish {
5556
// -------------------------------------------------------------------------
5657

5758
/// @inheritdoc ITrailsRouter
58-
function execute(bytes calldata data) public payable returns (IMulticall3.Result[] memory returnResults) {
59+
function execute(bytes calldata data) public payable nonReentrant returns (IMulticall3.Result[] memory returnResults) {
5960
_validateRouterCall(data);
6061
(bool success, bytes memory returnData) = MULTICALL3.delegatecall(data);
6162
if (!success) revert TargetCallFailed(returnData);
@@ -66,6 +67,7 @@ contract TrailsRouter is IDelegatedExtension, ITrailsRouter, Tstorish {
6667
function pullAndExecute(address token, bytes calldata data)
6768
public
6869
payable
70+
nonReentrant
6971
returns (IMulticall3.Result[] memory returnResults)
7072
{
7173
uint256 amount;
@@ -77,13 +79,22 @@ contract TrailsRouter is IDelegatedExtension, ITrailsRouter, Tstorish {
7779
if (amount == 0) revert NoTokensToPull();
7880
}
7981

80-
return pullAmountAndExecute(token, amount, data);
82+
return _pullAmountAndExecute(token, amount, data);
8183
}
8284

8385
/// @inheritdoc ITrailsRouter
8486
function pullAmountAndExecute(address token, uint256 amount, bytes calldata data)
8587
public
8688
payable
89+
nonReentrant
90+
returns (IMulticall3.Result[] memory returnResults)
91+
{
92+
return _pullAmountAndExecute(token, amount, data);
93+
}
94+
95+
/// forge-lint: disable-next-line(mixed-case-function)
96+
function _pullAmountAndExecute(address token, uint256 amount, bytes memory data)
97+
internal
8798
returns (IMulticall3.Result[] memory returnResults)
8899
{
89100
_validateRouterCall(data);
@@ -109,7 +120,7 @@ contract TrailsRouter is IDelegatedExtension, ITrailsRouter, Tstorish {
109120
bytes calldata callData,
110121
uint256 amountOffset,
111122
bytes32 placeholder
112-
) external payable {
123+
) external payable nonReentrant {
113124
uint256 callerBalance;
114125

115126
if (token == address(0)) {
@@ -131,7 +142,7 @@ contract TrailsRouter is IDelegatedExtension, ITrailsRouter, Tstorish {
131142
bytes calldata callData,
132143
uint256 amountOffset,
133144
bytes32 placeholder
134-
) public payable {
145+
) public payable nonReentrant {
135146
uint256 callerBalance = _getSelfBalance(token);
136147
if (callerBalance == 0) {
137148
if (token == address(0)) {
@@ -149,7 +160,12 @@ contract TrailsRouter is IDelegatedExtension, ITrailsRouter, Tstorish {
149160
// -------------------------------------------------------------------------
150161

151162
/// @inheritdoc ITrailsRouter
152-
function sweep(address _token, address _recipient) public payable {
163+
function sweep(address _token, address _recipient) public payable nonReentrant {
164+
_sweep(_token, _recipient);
165+
}
166+
167+
/// forge-lint: disable-next-line(mixed-case-function)
168+
function _sweep(address _token, address _recipient) internal {
153169
uint256 amount = _getSelfBalance(_token);
154170
if (amount > 0) {
155171
if (_token == address(0)) {
@@ -165,6 +181,14 @@ contract TrailsRouter is IDelegatedExtension, ITrailsRouter, Tstorish {
165181
function refundAndSweep(address _token, address _refundRecipient, uint256 _refundAmount, address _sweepRecipient)
166182
public
167183
payable
184+
nonReentrant
185+
{
186+
_refundAndSweep(_token, _refundRecipient, _refundAmount, _sweepRecipient);
187+
}
188+
189+
/// forge-lint: disable-next-line(mixed-case-function)
190+
function _refundAndSweep(address _token, address _refundRecipient, uint256 _refundAmount, address _sweepRecipient)
191+
internal
168192
{
169193
uint256 current = _getSelfBalance(_token);
170194

@@ -194,12 +218,17 @@ contract TrailsRouter is IDelegatedExtension, ITrailsRouter, Tstorish {
194218
}
195219

196220
/// @inheritdoc ITrailsRouter
197-
function validateOpHashAndSweep(bytes32 opHash, address _token, address _recipient) public payable {
221+
function validateOpHashAndSweep(bytes32 opHash, address _token, address _recipient) public payable nonReentrant {
222+
_validateOpHashAndSweep(opHash, _token, _recipient);
223+
}
224+
225+
/// forge-lint: disable-next-line(mixed-case-function)
226+
function _validateOpHashAndSweep(bytes32 opHash, address _token, address _recipient) internal {
198227
uint256 slot = TrailsSentinelLib.successSlot(opHash);
199228
if (_getTstorish(slot) != TrailsSentinelLib.SUCCESS_VALUE) {
200229
revert SuccessSentinelNotSet();
201230
}
202-
sweep(_token, _recipient);
231+
_sweep(_token, _recipient);
203232
}
204233

205234
// -------------------------------------------------------------------------
@@ -234,20 +263,20 @@ contract TrailsRouter is IDelegatedExtension, ITrailsRouter, Tstorish {
234263
// Token Sweeper selectors
235264
if (selector == this.sweep.selector) {
236265
(address token, address recipient) = abi.decode(_data[4:], (address, address));
237-
sweep(token, recipient);
266+
_sweep(token, recipient);
238267
return;
239268
}
240269

241270
if (selector == this.refundAndSweep.selector) {
242271
(address token, address refundRecipient, uint256 refundAmount, address sweepRecipient) =
243272
abi.decode(_data[4:], (address, address, uint256, address));
244-
refundAndSweep(token, refundRecipient, refundAmount, sweepRecipient);
273+
_refundAndSweep(token, refundRecipient, refundAmount, sweepRecipient);
245274
return;
246275
}
247276

248277
if (selector == this.validateOpHashAndSweep.selector) {
249278
(, address token, address recipient) = abi.decode(_data[4:], (bytes32, address, address));
250-
validateOpHashAndSweep(_opHash, token, recipient);
279+
_validateOpHashAndSweep(_opHash, token, recipient);
251280
return;
252281
}
253282

0 commit comments

Comments
 (0)