Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/core/messaging/Gateway.sol
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,14 @@ contract Gateway is Auth, Recoverable, IGateway {
}

/// @inheritdoc IGateway
function withBatch(bytes memory data, uint256 value, address refund) public payable {
require(value <= msg.value, NotEnoughValueForCallback());
function withBatch(bytes memory data, uint256 callbackValue, address refund) public payable {
require(callbackValue <= msg.value, NotEnoughValueForCallback());

bool wasBatching = isBatching;
bool isNested = isBatching;
isBatching = true;

_batcher = msg.sender;
(bool success, bytes memory returnData) = msg.sender.call{value: value}(data);
(bool success, bytes memory returnData) = msg.sender.call{value: callbackValue}(data);
if (!success) {
uint256 length = returnData.length;
require(length != 0, CallFailedWithEmptyRevert());
Expand All @@ -247,16 +247,20 @@ contract Gateway is Auth, Recoverable, IGateway {
// Force the user to call lockCallback()
require(address(_batcher) == address(0), CallbackWasNotLocked());

if (!wasBatching) _endBatching(msg.value - value, refund);
if (isNested) {
_refund(refund, msg.value - callbackValue);
} else {
uint256 cost = _endBatching(msg.value - callbackValue, refund);
_refund(refund, msg.value - callbackValue - cost);
}
}

function _endBatching(uint256 fuel, address refund) internal {
function _endBatching(uint256 fuel, address refund) internal returns (uint256 cost) {
require(isBatching, NoBatched());
bytes32[] memory locators = TransientArrayLib.getBytes32(BATCH_LOCATORS_SLOT);

TransientArrayLib.clear(BATCH_LOCATORS_SLOT);

uint256 cost;
for (uint256 i; i < locators.length; i++) {
(uint16 centrifugeId, PoolId poolId) = _parseLocator(locators[i]);
bytes32 outboundBatchSlot = _outboundBatchSlot(centrifugeId, poolId);
Expand All @@ -270,8 +274,6 @@ contract Gateway is Auth, Recoverable, IGateway {
}

isBatching = false;

_refund(refund, fuel - cost);
}

/// @inheritdoc IGateway
Expand Down
16 changes: 8 additions & 8 deletions test/core/unit/Gateway.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,6 @@ contract GatewayTestEndBatching is GatewayTest {

gateway.endBatching{value: cost + 1234}(REFUND);

assertEq(REFUND.balance, 1234);
assertEq(gateway.batchGasLimit(REMOTE_CENT_ID, POOL_A), 0);
assertEq(gateway.outboundBatch(REMOTE_CENT_ID, POOL_A), new bytes(0));
assertEq(gateway.batchLocatorsLength(), 0);
Expand All @@ -630,7 +629,6 @@ contract GatewayTestEndBatching is GatewayTest {

gateway.endBatching{value: cost + 1234}(REFUND);

assertEq(REFUND.balance, 1234);
assertEq(gateway.batchGasLimit(REMOTE_CENT_ID, POOL_A), 0);
assertEq(gateway.outboundBatch(REMOTE_CENT_ID, POOL_A), new bytes(0));
assertEq(gateway.batchGasLimit(REMOTE_CENT_ID + 1, POOL_A), 0);
Expand All @@ -654,7 +652,6 @@ contract GatewayTestEndBatching is GatewayTest {

gateway.endBatching{value: cost * 2 + 1234}(REFUND);

assertEq(REFUND.balance, 1234);
assertEq(gateway.batchGasLimit(REMOTE_CENT_ID, POOL_A), 0);
assertEq(gateway.outboundBatch(REMOTE_CENT_ID, POOL_A), new bytes(0));
assertEq(gateway.batchGasLimit(REMOTE_CENT_ID, POOL_0), 0);
Expand Down Expand Up @@ -845,9 +842,10 @@ contract IntegrationMock is Test {
gateway = gateway_;
}

function _nested() external {
function _nested(address refund) external payable {
gateway.lockCallback();
gateway.withBatch(abi.encodeWithSelector(this._success.selector, false, 2), address(0));
assertEq(msg.value, 1234);
gateway.withBatch{value: msg.value}(abi.encodeWithSelector(this._success.selector, false, 2), refund);
}

function _emptyError() external {
Expand All @@ -872,8 +870,8 @@ contract IntegrationMock is Test {
gateway.lockCallback();
}

function callNested(address refund) external {
gateway.withBatch(abi.encodeWithSelector(this._nested.selector), refund);
function callNested(address refund) external payable {
gateway.withBatch{value: msg.value}(abi.encodeWithSelector(this._nested.selector, refund), msg.value, refund);
}

function callEmptyError(address refund) external {
Expand Down Expand Up @@ -957,9 +955,11 @@ contract GatewayTestWithBatch is GatewayTest {

function testWithCallbackNested() public {
vm.prank(ANY);
integration.callNested(REFUND);
vm.deal(ANY, 1234);
integration.callNested{value: 1234}(REFUND);

assertEq(integration.wasCalled(), true);
assertEq(REFUND.balance, 1234); // Refunded by the nested withBatch
}

function testWithCallbackPaid() public {
Expand Down
Loading