diff --git a/.gas-snapshot b/.gas-snapshot new file mode 100644 index 0000000..9b1c484 --- /dev/null +++ b/.gas-snapshot @@ -0,0 +1,9 @@ +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_100CallPolicies() (gas: 2467172) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_100TransferPolicies() (gas: 2209520) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 2336380) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50CallPolicies() (gas: 2259810) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50TransferPolicies() (gas: 2134968) +BenchmarkDefaultValidatorNewTest:test_execute_100CallPolicies() (gas: 2760504) +BenchmarkDefaultValidatorNewTest:test_execute_100TransferPolicies() (gas: 1317589) +BenchmarkDefaultValidatorNewTest:test_execute_50CallPolicies() (gas: 1471450) +BenchmarkDefaultValidatorNewTest:test_execute_50TransferPolicies() (gas: 743908) \ No newline at end of file diff --git a/foundry.toml b/foundry.toml index 61ba0f2..f3050a6 100644 --- a/foundry.toml +++ b/foundry.toml @@ -8,6 +8,10 @@ src = "src" out = "out" libs = ["lib"] +gas_reports = [ + "BenchmarkDefaultValidatorNew", +] + ignored_warnings_from = ["lib", "test"] [fmt] diff --git a/gasreport.txt b/gasreport.txt new file mode 100644 index 0000000..b04be77 --- /dev/null +++ b/gasreport.txt @@ -0,0 +1,26 @@ +No files changed, compilation skipped + +Ran 9 tests for test/BenchmarkDefaultValidatorNew.t.sol:BenchmarkDefaultValidatorNewTest +[PASS] test_createSessionKeyForSigner_100CallPolicies() (gas: 2470875) +[PASS] test_createSessionKeyForSigner_100TransferPolicies() (gas: 2213038) +[PASS] test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 2339958) +[PASS] test_createSessionKeyForSigner_50CallPolicies() (gas: 2263389) +[PASS] test_createSessionKeyForSigner_50TransferPolicies() (gas: 2138379) +[PASS] test_execute_100CallPolicies() (gas: 210125) +[PASS] test_execute_100TransferPolicies() (gas: 175656) +[PASS] test_execute_50CallPolicies() (gas: 210126) +[PASS] test_execute_50TransferPolicies() (gas: 175641) +Suite result: ok. 9 passed; 0 failed; 0 skipped; finished in 14.13ms (71.45ms CPU time) + + +Ran 1 test suite in 16.52ms (14.13ms CPU time): 9 tests passed, 0 failed, 0 skipped (9 total tests) +test_createSessionKeyForSigner_100CallPolicies() (gas: 3703 (0.150%)) +test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 3578 (0.153%)) +test_createSessionKeyForSigner_50CallPolicies() (gas: 3579 (0.158%)) +test_createSessionKeyForSigner_100TransferPolicies() (gas: 3518 (0.159%)) +test_createSessionKeyForSigner_50TransferPolicies() (gas: 3411 (0.160%)) +test_execute_50TransferPolicies() (gas: -568267 (-76.389%)) +test_execute_50CallPolicies() (gas: -1261324 (-85.720%)) +test_execute_100TransferPolicies() (gas: -1141933 (-86.668%)) +test_execute_100CallPolicies() (gas: -2550379 (-92.388%)) +Overall gas change: -5504114 (-31.094%) diff --git a/lib/solady b/lib/solady index 8200a70..ab7596b 160000 --- a/lib/solady +++ b/lib/solady @@ -1 +1 @@ -Subproject commit 8200a70e8dc2a77ecb074fc2e99a2a0d36547522 +Subproject commit ab7596b1a21f8a54bec722d54505e8daa40b5760 diff --git a/src/DefaultValidator.sol b/src/DefaultValidator.sol index 75f126e..a7e5f0c 100644 --- a/src/DefaultValidator.sol +++ b/src/DefaultValidator.sol @@ -1,4 +1,4 @@ -// SPDX-License-Identifier: MIT +// SPDX-License-Identifier: Apache-2.0 pragma solidity ^0.8.26; import "account-abstraction-v0.7/core/Helpers.sol"; diff --git a/src/DefaultValidatorNew.sol b/src/DefaultValidatorNew.sol new file mode 100644 index 0000000..c9ef252 --- /dev/null +++ b/src/DefaultValidatorNew.sol @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import "account-abstraction-v0.7/core/Helpers.sol"; +import {UserOperationLib} from "account-abstraction-v0.7/core/UserOperationLib.sol"; + +import { + IValidator, + MODULE_TYPE_VALIDATOR, + PackedUserOperation, + VALIDATION_FAILED, + VALIDATION_SUCCESS +} from "./interface/IERC7579Module.sol"; + +import {ModularAccount} from "./ModularAccount.sol"; + +import {Execution} from "./interface/IERC7579Account.sol"; +import {ExecutionLib} from "./lib/ExecutionLib.sol"; +import { + CALLTYPE_BATCH, + CALLTYPE_DELEGATECALL, + CALLTYPE_SINGLE, + CALLTYPE_STATIC, + CallType, + ModeCode, + ModeLib +} from "./lib/ModeLib.sol"; + +import {SessionLib} from "./lib/SessionLib.sol"; + +import {Execution, IERC7579Account} from "./interface/IERC7579Account.sol"; +import {ECDSA} from "solady/utils/ECDSA.sol"; +import {EIP712} from "solady/utils/EIP712.sol"; +import {EnumerableSetLib} from "solady/utils/EnumerableSetLib.sol"; +import {SignatureCheckerLib} from "solady/utils/SignatureCheckerLib.sol"; + +library DefaultValidatorNewStorage { + + /// @custom:storage-location erc7201:account.module.default.validator.new + bytes32 public constant DEFAULT_VALIDATOR_NEW_STORAGE_POSITION = + keccak256(abi.encode(uint256(keccak256("account.module.default.validator.new")) - 1)) & ~bytes32(uint256(0xff)); + + struct Data { + /// @dev Map from smart account address => is initialized + mapping(address => bool) isInitialized; + mapping(address => EnumerableSetLib.AddressSet) allSigners; + mapping(address account => mapping(address signer => bytes32 sessionUid)) sessionIds; + mapping(bytes32 sessionUid => SessionLib.CallSpec[]) callPolicies; + mapping(bytes32 sessionUid => SessionLib.TransferSpec[]) transferPolicies; + mapping(bytes32 sessionUid => SessionLib.SessionStorage sessionState) sessions; + mapping(bytes32 sessionUid => uint256 expiresAt) sessionExpiration; + mapping(bytes32 targetHash => SessionLib.CallSpec) targetCallPolicy; + mapping(bytes32 targetHash => SessionLib.TransferSpec) targetTransferPolicy; + mapping(bytes32 sessionUid => mapping(address target => bool)) anySelectorForTarget; + } + + function data() internal pure returns (Data storage $) { + bytes32 position = DEFAULT_VALIDATOR_NEW_STORAGE_POSITION; + assembly { + $.slot := position + } + } + +} + +contract DefaultValidatorNew is IValidator, EIP712 { + + using ECDSA for bytes32; + using ExecutionLib for bytes; + using ModeLib for ModeCode; + using EnumerableSetLib for EnumerableSetLib.AddressSet; + + using SessionLib for SessionLib.SessionStorage; + + error SessionZeroSigner(); + error SessionExpiresTooSoon(); + + event SessionCreated(address indexed account, bytes32 indexed sessionUid, SessionLib.SessionSpec sessionSpec); + + uint256 private constant _ADMIN_ROLE = 1 << 0; + + bytes4 constant ERC1271_MAGICVALUE = 0x1626ba7e; + bytes4 constant ERC1271_INVALID = 0xffffffff; + + bytes32 private constant MSG_TYPEHASH = keccak256("AccountMessage(bytes message)"); + + function onInstall(bytes calldata) external { + _defaultValidatorStorage().isInitialized[msg.sender] = true; + } + + function onUninstall(bytes calldata) external { + address account = msg.sender; + if (!_isInitialized(account)) { + revert NotInitialized(account); + } + + _defaultValidatorStorage().isInitialized[account] = false; + + address[] memory allSigners = _defaultValidatorStorage().allSigners[account].values(); + delete _defaultValidatorStorage().allSigners[msg.sender]; + } + + function isModuleType(uint256 moduleTypeId) external pure returns (bool) { + return moduleTypeId == MODULE_TYPE_VALIDATOR; + } + + function isInitialized(address smartAccount) external view returns (bool) { + return _isInitialized(smartAccount); + } + + function _isInitialized(address smartAccount) internal view returns (bool) { + return _defaultValidatorStorage().isInitialized[smartAccount]; + } + + function validateUserOp(PackedUserOperation calldata userOp, bytes32 userOpHash) external returns (uint256) { + address account = msg.sender; + + (address signer, bytes memory signature) = abi.decode(userOp.signature, (address, bytes)); + bytes32 hash = SignatureCheckerLib.toEthSignedMessageHash(userOpHash); + + bool isValidSig = SignatureCheckerLib.isValidERC6492SignatureNow(signer, hash, signature); + + if (!isValidSig) { + return VALIDATION_FAILED; + } + + (bool isValid, uint48 validAfter, uint48 validUntil) = _isValidSigner(account, signer, userOp); + + if (!isValid) { + return VALIDATION_FAILED; + } + + return _packValidationData(ValidationData(address(0), validAfter, validUntil)); + } + + function isValidSignatureWithSender(address sender, bytes32 hash, bytes calldata signature) + external + view + returns (bytes4) + { + address account = msg.sender; + + bytes32 targetDigest = getMessageHash(hash); + address signer = ECDSA.recover(targetDigest, signature); + + if ( + signer == ModularAccount(payable(account)).owner() + || ModularAccount(payable(account)).hasAnyRole(signer, _ADMIN_ROLE) + ) { + return ERC1271_MAGICVALUE; + } + + return ERC1271_INVALID; + } + + /** + * @notice Returns the hash of message that should be signed for EIP1271 verification. + * @param _hash The message hash to sign for the EIP-1271 origin verifying contract. + * @return messageHash The digest to sign for EIP-1271 verification. + */ + function getMessageHash(bytes32 _hash) public view returns (bytes32) { + bytes32 messageHash = keccak256(abi.encode(_hash)); + bytes32 typedDataHash = keccak256(abi.encode(MSG_TYPEHASH, messageHash)); + return keccak256(abi.encodePacked("\x19\x01", _domainSeparator(), typedDataHash)); + } + + function _domainNameAndVersion() internal pure override returns (string memory name, string memory version) { + name = "DefaultValidatorNew"; + version = "1"; + } + + // create session key for signer + function createSessionKey(SessionLib.SessionSpec calldata sessionSpec) external { + if (!_isInitialized(msg.sender)) { + revert NotInitialized(msg.sender); + } + if (sessionSpec.signer == address(0)) { + revert SessionZeroSigner(); + } + // Sessions should expire in no less than 60 seconds. + if (sessionSpec.expiresAt <= block.timestamp + 60) { + revert SessionExpiresTooSoon(); + } + + bytes32 sessionUid = keccak256(abi.encodePacked(sessionSpec.signer, msg.sender, block.timestamp)); + + _defaultValidatorStorage().sessionIds[msg.sender][sessionSpec.signer] = sessionUid; + _defaultValidatorStorage().allSigners[msg.sender].add(sessionSpec.signer); + _defaultValidatorStorage().sessionExpiration[sessionUid] = sessionSpec.expiresAt; + + for (uint256 i = 0; i < sessionSpec.callPolicies.length; i++) { + _defaultValidatorStorage().callPolicies[sessionUid].push(sessionSpec.callPolicies[i]); + bytes32 targetHash = keccak256( + abi.encodePacked(sessionSpec.callPolicies[i].target, sessionSpec.callPolicies[i].selector, sessionUid) + ); + _defaultValidatorStorage().targetCallPolicy[targetHash] = sessionSpec.callPolicies[i]; + + if (sessionSpec.callPolicies[i].selector == bytes4(0x00000000)) { + _defaultValidatorStorage().anySelectorForTarget[sessionUid][sessionSpec.callPolicies[i].target] = true; + } + } + + for (uint256 i = 0; i < sessionSpec.transferPolicies.length; i++) { + _defaultValidatorStorage().transferPolicies[sessionUid].push(sessionSpec.transferPolicies[i]); + bytes32 targetHash = keccak256(abi.encodePacked(sessionSpec.transferPolicies[i].target, sessionUid)); + _defaultValidatorStorage().targetTransferPolicy[targetHash] = sessionSpec.transferPolicies[i]; + } + + emit SessionCreated(msg.sender, sessionUid, sessionSpec); + } + + function getCallPoliciesForSigner(address account, address signer) + external + view + returns (SessionLib.CallSpec[] memory) + { + bytes32 sessionUid = _defaultValidatorStorage().sessionIds[account][signer]; + return _defaultValidatorStorage().callPolicies[sessionUid]; + } + + function getTransferPoliciesForSigner(address account, address signer) + external + view + returns (SessionLib.TransferSpec[] memory) + { + bytes32 sessionUid = _defaultValidatorStorage().sessionIds[account][signer]; + return _defaultValidatorStorage().transferPolicies[sessionUid]; + } + + function getSessionExpirationForSigner(address account, address signer) external view returns (uint256) { + bytes32 sessionUid = _defaultValidatorStorage().sessionIds[account][signer]; + return _defaultValidatorStorage().sessionExpiration[sessionUid]; + } + + function getSessionStateForSigner(address account, address signer) + external + view + returns (SessionLib.SessionState memory) + { + bytes32 sessionUid = _defaultValidatorStorage().sessionIds[account][signer]; + + return _defaultValidatorStorage().sessions[sessionUid].getState( + account, + _defaultValidatorStorage().callPolicies[sessionUid], + _defaultValidatorStorage().transferPolicies[sessionUid] + ); + } + + function _isValidSigner(address _account, address _signer, PackedUserOperation calldata _userOp) + internal + virtual + returns (bool isValid, uint48 validAfter, uint48 validUntil) + { + if ( + _signer == ModularAccount(payable(_account)).owner() + || ModularAccount(payable(_account)).hasAnyRole(_signer, _ADMIN_ROLE) + ) { + return (true, 0, 0); + } + + bytes32 sessionUid = _defaultValidatorStorage().sessionIds[_account][_signer]; + + bytes4 selector = _getFunctionSelector(_userOp.callData); + + if (selector == ModularAccount.execute.selector) { + (ModeCode mode,) = abi.decode(_userOp.callData[4:], (ModeCode, bytes)); + (CallType callType,,,) = mode.decode(); + + if (callType == CALLTYPE_SINGLE) { + uint256 dataOffset = abi.decode(_userOp.callData[36:], (uint256)); + + (address target, uint256 value, bytes calldata callData) = + ExecutionLib.decodeSingle(_userOp.callData[(36 + dataOffset):]); + + isValid = _validatePolicies(target, value, callData, sessionUid); + + return (isValid, validAfter, validUntil); + } else if (callType == CALLTYPE_BATCH) { + uint256 dataOffset = abi.decode(_userOp.callData[36:], (uint256)); + Execution[] memory executions = ExecutionLib.decodeBatch(_userOp.callData[36 + dataOffset:]); + uint256 length = executions.length; + + for (uint256 i = 0; i < length; i++) { + isValid = + _validatePolicies(executions[i].target, executions[i].value, executions[i].data, sessionUid); + + if (!isValid) { + return (isValid, validAfter, validUntil); + } + } + + isValid = true; + return (isValid, validAfter, validUntil); + } else if (callType == CALLTYPE_DELEGATECALL) { + // TODO + // return false for now + isValid = false; + return (isValid, validAfter, validUntil); + } else { + isValid = false; + return (isValid, validAfter, validUntil); + } + } else { + isValid = false; + return (isValid, validAfter, validUntil); + } + } + + function _validatePolicies(address target, uint256 value, bytes memory callData, bytes32 sessionUid) + internal + returns (bool) + { + uint256 sessionExpiresAt = _defaultValidatorStorage().sessionExpiration[sessionUid]; + SessionLib.CallSpec memory callPolicy; + SessionLib.TransferSpec memory transferPolicy; + + bytes4 targetSelector; + if (callData.length >= 4 && callData.length != 12) { + targetSelector = _defaultValidatorStorage().anySelectorForTarget[sessionUid][target] + ? bytes4(0x00000000) + : bytes4(callData[0]) | (bytes4(callData[1]) >> 8) | (bytes4(callData[2]) >> 16) + | (bytes4(callData[3]) >> 24); + bytes32 targetHash = keccak256(abi.encodePacked(target, targetSelector, sessionUid)); + callPolicy = _defaultValidatorStorage().targetCallPolicy[targetHash]; + } else { + bytes32 targetHash = keccak256(abi.encodePacked(target, sessionUid)); + transferPolicy = _defaultValidatorStorage().targetTransferPolicy[targetHash]; + } + return _defaultValidatorStorage().sessions[sessionUid].validate( + target, targetSelector, value, callData, sessionExpiresAt, callPolicy, transferPolicy + ); + } + + function _getFunctionSelector(bytes calldata data) internal pure returns (bytes4 functionSelector) { + require(data.length >= 4, "!Data"); + return bytes4(data[:4]); + } + + function _defaultValidatorStorage() internal pure returns (DefaultValidatorNewStorage.Data storage data) { + data = DefaultValidatorNewStorage.data(); + } + +} diff --git a/src/lib/SessionLib.sol b/src/lib/SessionLib.sol new file mode 100644 index 0000000..fd4b3b9 --- /dev/null +++ b/src/lib/SessionLib.sol @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import {LibBytes} from "solady/utils/LibBytes.sol"; + +library SessionLib { + + using SessionLib for SessionLib.Constraint; + using SessionLib for SessionLib.UsageLimit; + using LibBytes for bytes; + + enum LimitType { + Unlimited, + Lifetime, + Allowance + } + + enum Condition { + Unconstrained, + Equal, + Greater, + Less, + GreaterOrEqual, + LessOrEqual, + NotEqual + } + + struct UsageLimit { + LimitType limitType; + uint256 limit; // ignored if limitType == Unlimited + uint256 period; // ignored if limitType != Allowance + } + + struct Constraint { + Condition condition; + uint64 index; + bytes32 refValue; + UsageLimit limit; + } + + struct CallSpec { + address target; + bytes4 selector; + uint256 maxValuePerUse; + UsageLimit valueLimit; + Constraint[] constraints; + } + + struct TransferSpec { + address target; + uint256 maxValuePerUse; + UsageLimit valueLimit; + } + + struct SessionSpec { + address signer; + uint256 expiresAt; + CallSpec[] callPolicies; + TransferSpec[] transferPolicies; + } + + struct StoredSessionSpec { + uint256 expiresAt; + CallSpec[] callPolicies; + TransferSpec[] transferPolicies; + } + + struct LimitState { + // this might also be limited by a constraint or `maxValuePerUse`, + // which is not reflected here + uint256 remaining; + address target; + // ignored for transfer value + bytes4 selector; + // ignored for transfer and call value + uint256 index; + } + + // Info about remaining session limits and its status + struct SessionState { + LimitState[] transferValue; + LimitState[] callValue; + LimitState[] callParams; + } + + struct UsageTracker { + // Used for LimitType.Lifetime + mapping(address => uint256) lifetimeUsage; + // Used for LimitType.Allowance + // period => used that period + mapping(uint64 => mapping(address => uint256)) allowanceUsage; + } + + struct SessionStorage { + // (target) => transfer value tracker + mapping(address => UsageTracker) transferValue; + // (target, selector) => call value tracker + mapping(address => mapping(bytes4 => UsageTracker)) callValue; + // (target, selector, index) => call parameter tracker + // index is the constraint index in callPolicy, not the parameter index + mapping(address => mapping(bytes4 => mapping(uint256 => UsageTracker))) params; + } + + function checkAndUpdate(UsageLimit memory limit, UsageTracker storage tracker, uint256 value) + internal + returns (bool) + { + if (limit.limitType == LimitType.Lifetime) { + if (tracker.lifetimeUsage[msg.sender] + value > limit.limit) { + // revert SessionLifetimeUsageExceeded(tracker.lifetimeUsage[msg.sender], limit.limit); + return false; + } + tracker.lifetimeUsage[msg.sender] += value; + } else if (limit.limitType == LimitType.Allowance) { + uint64 period = uint64(block.timestamp / limit.period); + + if (tracker.allowanceUsage[period][msg.sender] + value > limit.limit) { + // revert SessionAllowanceExceeded(tracker.allowanceUsage[period][msg.sender], limit.limit, period); + return false; + } + + tracker.allowanceUsage[period][msg.sender] += value; + } + + return true; + } + + function checkAndUpdate(Constraint memory constraint, UsageTracker storage tracker, bytes memory data) + internal + returns (bool) + { + uint256 expectedLength = 4 + constraint.index * 32 + 32; + + if (data.length < expectedLength) { + // revert SessionInvalidDataLength(data.length, expectedLength); + return false; + } + + bytes32 param = data.load(4 + constraint.index * 32); + Condition condition = constraint.condition; + bytes32 refValue = constraint.refValue; + + if ( + (condition == Condition.Equal && param != refValue) || (condition == Condition.Greater && param <= refValue) + || (condition == Condition.Less && param >= refValue) + || (condition == Condition.GreaterOrEqual && param < refValue) + || (condition == Condition.LessOrEqual && param > refValue) + || (condition == Condition.NotEqual && param == refValue) + ) { + // revert SessionConditionFailed(param, refValue, uint8(condition)); + return false; + } + + bool check = constraint.limit.checkAndUpdate(tracker, uint256(param)); + + return check; + } + + function checkCallPolicy( + SessionStorage storage state, + bytes memory data, + address target, + bytes4 selector, + CallSpec memory callPolicy + ) private returns (bool found) { + if (target != callPolicy.target || selector != callPolicy.selector) { + return false; + } + + for (uint256 i = 0; i < callPolicy.constraints.length; i++) { + bool check = callPolicy.constraints[i].checkAndUpdate(state.params[target][selector][i], data); + if (!check) { + // return (false, callPolicy); + return false; + } + } + return true; + } + + function validate( + SessionStorage storage state, + address target, + bytes4 selector, + uint256 value, + bytes memory callData, + uint256 expiresAt, + CallSpec memory callPolicy, + TransferSpec memory transferPolicy + ) internal returns (bool) { + if (expiresAt < block.timestamp) { + return false; + } + + // TODO: fix this + if (callData.length >= 4 && callData.length != 12) { + bool callPolicyValid = checkCallPolicy(state, callData, target, selector, callPolicy); + + if (!callPolicyValid) { + return false; + } + + if (value > callPolicy.maxValuePerUse) { + // revert SessionMaxValueExceeded(value, callPolicy.maxValuePerUse); + return false; + } + + return callPolicy.valueLimit.checkAndUpdate(state.callValue[target][selector], value); + } else { + if (target != transferPolicy.target) { + return false; + } + + if (value > transferPolicy.maxValuePerUse) { + // revert SessionMaxValueExceeded(value, transferPolicy.maxValuePerUse); + return false; + } + return transferPolicy.valueLimit.checkAndUpdate(state.transferValue[target], value); + } + } + + function remainingLimit(UsageLimit memory limit, UsageTracker storage tracker, address account) + private + view + returns (uint256) + { + if (limit.limitType == LimitType.Unlimited) { + // this might be still limited by `maxValuePerUse` or a constraint + return type(uint256).max; + } + if (limit.limitType == LimitType.Lifetime) { + return limit.limit - tracker.lifetimeUsage[account]; + } + if (limit.limitType == LimitType.Allowance) { + // this is not used during validation, so it's fine to use block.timestamp + uint64 period = uint64(block.timestamp / limit.period); + return limit.limit - tracker.allowanceUsage[period][account]; + } + } + + function getState( + SessionStorage storage session, + address account, + CallSpec[] memory callPolicies, + TransferSpec[] memory transferPolicies + ) internal view returns (SessionState memory) { + LimitState[] memory transferValue; + LimitState[] memory callValue; + LimitState[] memory callParams; + + { + uint256 totalConstraints = 0; + for (uint256 i = 0; i < callPolicies.length; i++) { + totalConstraints += callPolicies[i].constraints.length; + } + + transferValue = new LimitState[](transferPolicies.length); + callValue = new LimitState[](callPolicies.length); + callParams = new LimitState[](totalConstraints); // there will be empty ones at the end + } + uint256 paramLimitIndex = 0; + + { + for (uint256 i = 0; i < transferValue.length; i++) { + TransferSpec memory transferSpec = transferPolicies[i]; + transferValue[i] = LimitState({ + remaining: remainingLimit(transferSpec.valueLimit, session.transferValue[transferSpec.target], account), + target: transferSpec.target, + selector: bytes4(0), + index: 0 + }); + } + } + + for (uint256 i = 0; i < callValue.length; i++) { + CallSpec memory callSpec = callPolicies[i]; + + { + callValue[i] = LimitState({ + remaining: remainingLimit( + callSpec.valueLimit, session.callValue[callSpec.target][callSpec.selector], account + ), + target: callSpec.target, + selector: callSpec.selector, + index: 0 + }); + } + + { + for (uint256 j = 0; j < callSpec.constraints.length; j++) { + if (callSpec.constraints[j].limit.limitType != LimitType.Unlimited) { + callParams[paramLimitIndex++] = LimitState({ + remaining: remainingLimit( + callSpec.constraints[j].limit, + session.params[callSpec.target][callSpec.selector][j], + account + ), + target: callSpec.target, + selector: callSpec.selector, + index: callSpec.constraints[j].index + }); + } + } + } + } + + // shrink array to actual size + assembly { + mstore(callParams, paramLimitIndex) + } + + return SessionState({transferValue: transferValue, callValue: callValue, callParams: callParams}); + } + +} diff --git a/test/BenchmarkDefaultValidatorNew.t.sol b/test/BenchmarkDefaultValidatorNew.t.sol new file mode 100644 index 0000000..e2019c4 --- /dev/null +++ b/test/BenchmarkDefaultValidatorNew.t.sol @@ -0,0 +1,588 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import {UserOperationLib} from "account-abstraction-v0.7/core/UserOperationLib.sol"; +import {IEntryPoint} from "account-abstraction-v0.7/interfaces/IEntryPoint.sol"; +import {PackedUserOperation} from "account-abstraction-v0.7/interfaces/PackedUserOperation.sol"; +import {Test} from "forge-std/Test.sol"; +import {ECDSA} from "solady/utils/ECDSA.sol"; +import {SignatureCheckerLib} from "solady/utils/SignatureCheckerLib.sol"; + +import {MockTarget} from "test/mock/MockTarget.sol"; +import {MockValidator} from "test/mock/MockValidator.sol"; +import {EntryPointLib} from "test/util/ERC4337Test.sol"; + +import "lib/forge-std/src/console.sol"; + +import { + CALLTYPE_SINGLE, + CALLTYPE_STATIC, + Execution, + ExecutionLib, + IERC7579Account, + MODULE_TYPE_FALLBACK, + MODULE_TYPE_VALIDATOR, + ModeLib, + ModularAccount +} from "src/ModularAccount.sol"; +import {ModularAccountFactory} from "src/ModularAccountFactory.sol"; + +import {InitializerInstallModule} from "src/interface/IModularAccount.sol"; + +import {DefaultValidatorNew} from "src/DefaultValidatorNew.sol"; + +import {SessionLib} from "src/lib/SessionLib.sol"; + +contract BenchmarkDefaultValidatorNewTest is Test { + + IEntryPoint public entrypoint; + ModularAccountFactory public factory; + ModularAccount public account; + + uint256 accountOwnerPKey = 0x2; + address public factoryOwner = vm.addr(0x1); + address public accountOwner = vm.addr(accountOwnerPKey); + + uint256 signerOnePKey = 1234; + address signerOne = vm.addr(signerOnePKey); + + address public validator; + MockTarget public target; + + function _setup(bytes memory accountSalt) internal { + vm.prank(factoryOwner); + entrypoint = IEntryPoint(EntryPointLib.deploy()); + validator = address(new DefaultValidatorNew()); + target = new MockTarget(); + ModularAccount accountImpl = new ModularAccount(address(entrypoint)); + factory = new ModularAccountFactory(address(entrypoint), factoryOwner, address(accountImpl)); + + InitializerInstallModule[] memory modules = _prepareDefaultValidatorNewInstallData(); + account = ModularAccount(payable(factory.createAccountWithModules(accountOwner, accountSalt, modules))); + } + + // === session key tests + + // create session key + function test_createSessionKeyForSigner_50CallPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x1))); + + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](50); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](0); + + for (uint256 i = 0; i < callPolicies.length; i++) { + callPolicies[i].target = vm.addr(i + 1); + callPolicies[i].selector = 0x12345678; + callPolicies[i].maxValuePerUse = i; + callPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[i].constraints = new SessionLib.Constraint[](1); + } + + spec.signer = vm.addr(1234); + spec.expiresAt = 100; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + + vm.deal(address(account), 1 ether); + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_createSessionKeyForSigner_100CallPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x2))); + + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](100); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](0); + + for (uint256 i = 0; i < callPolicies.length; i++) { + callPolicies[i].target = vm.addr(i + 1); + callPolicies[i].selector = 0x12345678; + callPolicies[i].maxValuePerUse = i; + callPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[i].constraints = new SessionLib.Constraint[](1); + } + + spec.signer = vm.addr(1234); + spec.expiresAt = 100; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + + vm.deal(address(account), 1 ether); + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_createSessionKeyForSigner_50TransferPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x3))); + + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](0); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](50); + + for (uint256 i = 0; i < transferPolicies.length; i++) { + transferPolicies[i].target = vm.addr(i + 1); + transferPolicies[i].maxValuePerUse = i; + transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + } + + spec.signer = vm.addr(1234); + spec.expiresAt = 100; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + + vm.deal(address(account), 1 ether); + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_createSessionKeyForSigner_100TransferPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x4))); + + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](0); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](100); + + for (uint256 i = 0; i < transferPolicies.length; i++) { + transferPolicies[i].target = vm.addr(i + 1); + transferPolicies[i].maxValuePerUse = i; + transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + } + + spec.signer = vm.addr(1234); + spec.expiresAt = 100; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + + vm.deal(address(account), 1 ether); + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_createSessionKeyForSigner_50Call50TransferPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x5))); + + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](50); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](50); + + for (uint256 i = 0; i < callPolicies.length; i++) { + callPolicies[i].target = vm.addr(i + 1); + callPolicies[i].selector = 0x12345678; + callPolicies[i].maxValuePerUse = i; + callPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[i].constraints = new SessionLib.Constraint[](1); + } + + for (uint256 i = 0; i < transferPolicies.length; i++) { + transferPolicies[i].target = vm.addr(i + 1); + transferPolicies[i].maxValuePerUse = i; + transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + } + + spec.signer = vm.addr(1234); + spec.expiresAt = 100; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + + vm.deal(address(account), 1 ether); + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + // execute + + function test_execute_50CallPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x6))); + + // 1. create session key + + { + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.TransferSpec[] memory transferPolicies; + + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](50); + for (uint256 i = 0; i < callPolicies.length - 1; i++) { + callPolicies[i].target = vm.addr(i + 1); + callPolicies[i].selector = 0x12345678; + callPolicies[i].maxValuePerUse = i; + callPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[i].constraints = new SessionLib.Constraint[](1); + } + + SessionLib.Constraint[] memory mockTargetConstraint = new SessionLib.Constraint[](1); + mockTargetConstraint[0].condition = SessionLib.Condition.LessOrEqual; + mockTargetConstraint[0].refValue = bytes32(uint256(100)); + mockTargetConstraint[0].index = 0; + mockTargetConstraint[0].limit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + callPolicies[callPolicies.length - 1].target = address(target); + callPolicies[callPolicies.length - 1].selector = MockTarget.set.selector; + callPolicies[callPolicies.length - 1].maxValuePerUse = 1000; + callPolicies[callPolicies.length - 1].valueLimit = + SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[callPolicies.length - 1].constraints = mockTargetConstraint; + + spec.signer = signerOne; + spec.expiresAt = 100; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + vm.deal(address(account), 2 ether); + entrypoint.handleOps(userOps, payable(address(account))); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(target), uint256(0), abi.encodeCall(MockTarget.set, 42)) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerOnePKey, msgHash); // sign with session key signer private key + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + address recoveredSigner = ECDSA.recover(msgHash, v, r, s); + assertEq(recoveredSigner, signerOne); + + userOp.signature = abi.encode(signerOne, userOpSignature); + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_execute_100CallPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x7))); + + // 1. create session key + + { + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.TransferSpec[] memory transferPolicies; + + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](100); + for (uint256 i = 0; i < callPolicies.length - 1; i++) { + callPolicies[i].target = vm.addr(i + 1); + callPolicies[i].selector = 0x12345678; + callPolicies[i].maxValuePerUse = i; + callPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[i].constraints = new SessionLib.Constraint[](1); + } + + SessionLib.Constraint[] memory mockTargetConstraint = new SessionLib.Constraint[](1); + mockTargetConstraint[0].condition = SessionLib.Condition.LessOrEqual; + mockTargetConstraint[0].refValue = bytes32(uint256(100)); + mockTargetConstraint[0].index = 0; + mockTargetConstraint[0].limit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + callPolicies[callPolicies.length - 1].target = address(target); + callPolicies[callPolicies.length - 1].selector = MockTarget.set.selector; + callPolicies[callPolicies.length - 1].maxValuePerUse = 1000; + callPolicies[callPolicies.length - 1].valueLimit = + SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[callPolicies.length - 1].constraints = mockTargetConstraint; + + spec.signer = signerOne; + spec.expiresAt = 100; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + vm.deal(address(account), 3 ether); + entrypoint.handleOps(userOps, payable(address(account))); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(target), uint256(0), abi.encodeCall(MockTarget.set, 42)) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerOnePKey, msgHash); // sign with session key signer private key + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + address recoveredSigner = ECDSA.recover(msgHash, v, r, s); + assertEq(recoveredSigner, signerOne); + + userOp.signature = abi.encode(signerOne, userOpSignature); + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_execute_50TransferPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x8))); + + // 1. create session key + + { + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies; + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](50); + for (uint256 i = 0; i < transferPolicies.length - 1; i++) { + transferPolicies[i].target = vm.addr(i + 1); + transferPolicies[i].maxValuePerUse = i; + transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + } + transferPolicies[transferPolicies.length - 1].target = address(target); + transferPolicies[transferPolicies.length - 1].maxValuePerUse = 1000; + transferPolicies[transferPolicies.length - 1].valueLimit = + SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.signer = signerOne; + spec.expiresAt = 100; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + vm.deal(address(account), 2 ether); + entrypoint.handleOps(userOps, payable(address(account))); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), 100, "")) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerOnePKey, msgHash); // sign with session key signer private key + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + address recoveredSigner = ECDSA.recover(msgHash, v, r, s); + assertEq(recoveredSigner, signerOne); + + userOp.signature = abi.encode(signerOne, userOpSignature); + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_execute_100TransferPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x9))); + + // 1. create session key + + { + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies; + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](100); + for (uint256 i = 0; i < transferPolicies.length - 1; i++) { + transferPolicies[i].target = vm.addr(i + 1); + transferPolicies[i].maxValuePerUse = i; + transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + } + transferPolicies[transferPolicies.length - 1].target = address(target); + transferPolicies[transferPolicies.length - 1].maxValuePerUse = 1000; + transferPolicies[transferPolicies.length - 1].valueLimit = + SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.signer = signerOne; + spec.expiresAt = 100; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + vm.deal(address(account), 2 ether); + entrypoint.handleOps(userOps, payable(address(account))); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), 100, "")) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerOnePKey, msgHash); // sign with session key signer private key + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + address recoveredSigner = ECDSA.recover(msgHash, v, r, s); + assertEq(recoveredSigner, signerOne); + + userOp.signature = abi.encode(signerOne, userOpSignature); + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + // test utils + + function getNonce(address _account, address _validator) internal returns (uint256 nonce) { + uint192 key = uint192(bytes24(bytes20(address(_validator)))); + nonce = (uint256(uint160(_validator)) << 96) | entrypoint.getNonce(_account, key); + } + + function getDefaultUserOp() internal returns (PackedUserOperation memory userOp) { + userOp = PackedUserOperation({ + sender: address(0), + nonce: 0, + initCode: "", + callData: "", + accountGasLimits: bytes32(abi.encodePacked(uint128(3e6), uint128(2e6))), + preVerificationGas: 2e6, + gasFees: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + paymasterAndData: bytes(""), + signature: abi.encodePacked(hex"41414141") + }); + } + + function _prepareDefaultValidatorNewInstallData() internal returns (InitializerInstallModule[] memory) { + InitializerInstallModule[] memory modules = new InitializerInstallModule[](1); + modules[0] = + InitializerInstallModule({moduleTypeId: MODULE_TYPE_VALIDATOR, module: validator, initData: bytes("")}); + return modules; + } + + function _getCreateSessionKeyOp(SessionLib.SessionSpec memory spec) + internal + returns (PackedUserOperation[] memory userOps) + { + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle( + address(validator), uint256(0), abi.encodeCall(DefaultValidatorNew.createSessionKey, (spec)) + ) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, msgHash); + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + userOp.signature = abi.encode(accountOwner, userOpSignature); + } + + // Create userOps array + userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + } + +} diff --git a/test/DefaultValidatorNew.t.sol b/test/DefaultValidatorNew.t.sol new file mode 100644 index 0000000..1192254 --- /dev/null +++ b/test/DefaultValidatorNew.t.sol @@ -0,0 +1,731 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import {UserOperationLib} from "account-abstraction-v0.7/core/UserOperationLib.sol"; +import {IEntryPoint} from "account-abstraction-v0.7/interfaces/IEntryPoint.sol"; +import {PackedUserOperation} from "account-abstraction-v0.7/interfaces/PackedUserOperation.sol"; +import {Test} from "forge-std/Test.sol"; + +import {ERC20} from "solady/tokens/ERC20.sol"; +import {ECDSA} from "solady/utils/ECDSA.sol"; +import {SignatureCheckerLib} from "solady/utils/SignatureCheckerLib.sol"; + +import {MockERC20} from "test/mock/MockERC20.sol"; +import {MockTarget} from "test/mock/MockTarget.sol"; +import {MockValidator} from "test/mock/MockValidator.sol"; +import {EntryPointLib} from "test/util/ERC4337Test.sol"; + +import "lib/forge-std/src/console.sol"; + +import { + CALLTYPE_SINGLE, + CALLTYPE_STATIC, + Execution, + ExecutionLib, + IERC7579Account, + MODULE_TYPE_FALLBACK, + MODULE_TYPE_VALIDATOR, + ModeLib, + ModularAccount +} from "src/ModularAccount.sol"; +import {ModularAccountFactory} from "src/ModularAccountFactory.sol"; + +import {InitializerInstallModule} from "src/interface/IModularAccount.sol"; + +import {DefaultValidatorNew} from "src/DefaultValidatorNew.sol"; + +import {SessionLib} from "src/lib/SessionLib.sol"; + +contract DefaultValidatorNewTest is Test { + + IEntryPoint public entrypoint; + ModularAccountFactory public factory; + ModularAccount public account; + MockTarget public target; + MockERC20 public erc20; + address public validator; + + bytes public accountSalt = abi.encodePacked(uint256(0xdeadbeef)); + + address public factoryOwner = vm.addr(0x1); + uint256 accountOwnerPKey = 0x2; + address public accountOwner = vm.addr(accountOwnerPKey); + + uint256 signerOnePKey = 0x3; + address signerOne = vm.addr(signerOnePKey); + + function setUp() public { + vm.prank(factoryOwner); + entrypoint = IEntryPoint(EntryPointLib.deploy()); + validator = address(new DefaultValidatorNew()); + target = new MockTarget(); + erc20 = new MockERC20(); + ModularAccount accountImpl = new ModularAccount(address(entrypoint)); + factory = new ModularAccountFactory(address(entrypoint), factoryOwner, address(accountImpl)); + + InitializerInstallModule[] memory modules = _prepareDefaultValidatorNewInstallData(); + account = ModularAccount(payable(factory.createAccountWithModules(accountOwner, accountSalt, modules))); + + vm.deal(address(account), 1 ether); + erc20.mint(address(account), 1 ether); + } + + function test_execute_owner() public { + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(target), uint256(0), abi.encodeCall(MockTarget.set, 42)) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory userOps = _getSignedUserOp(accountOwner, accountOwnerPKey, userOp); + + entrypoint.handleOps(userOps, payable(address(account))); + + assertTrue(target.number() == 42); + } + + function test_executeBatch_owner() public { + Execution[] memory executions = new Execution[](2); + executions[0] = Execution({target: address(target), value: 0, data: abi.encodeCall(MockTarget.set, 11)}); + executions[1] = Execution({target: address(target), value: 0, data: abi.encodeCall(MockTarget.setIf, (42, 11))}); + + bytes memory userOpCalldata = + abi.encodeCall(IERC7579Account.execute, (ModeLib.encodeSimpleBatch(), ExecutionLib.encodeBatch(executions))); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory userOps = _getSignedUserOp(accountOwner, accountOwnerPKey, userOp); + + entrypoint.handleOps(userOps, payable(address(account))); + + assertTrue(target.number() == 42); + } + + function test_signature() public { + // sign message + bytes32 message = keccak256("Hello World"); + bytes32 messageHash = DefaultValidatorNew(payable(validator)).getMessageHash(message); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, messageHash); + bytes memory signature = abi.encodePacked(r, s, v); + + // encode validator address + signature + bytes memory sigWithValidator = abi.encode(validator, signature); + + // verify signature + bytes4 returnValue = account.isValidSignature(message, sigWithValidator); + assertEq(returnValue, bytes4(0x1626ba7e)); // Magic value from ERC-1271 + } + + // === signer is admin (non owner) + + // === session key tests + + // create session key + function test_createSessionKey() public { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + _createSessionKey(spec); + + SessionLib.CallSpec[] memory fetchedCallPolicies = + DefaultValidatorNew(validator).getCallPoliciesForSigner(address(account), signerOne); + + SessionLib.TransferSpec[] memory fetchedTransferPolicies = + DefaultValidatorNew(validator).getTransferPoliciesForSigner(address(account), signerOne); + + uint256 expiresAt = DefaultValidatorNew(validator).getSessionExpirationForSigner(address(account), signerOne); + + assertEq(expiresAt, spec.expiresAt); + } + + function test_transferPolicy() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); + transferPolicies[0].target = address(target); + transferPolicies[0].maxValuePerUse = 500; + transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.transferPolicies = transferPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), 100, "")) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + + assertTrue(address(target).balance == 100); + } + + function test_revert_transferPolicy_crossLimit() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); + transferPolicies[0].target = address(target); + transferPolicies[0].maxValuePerUse = 500; + transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.transferPolicies = transferPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), 600, "")) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + + function test_revert_transferPolicy_wrongTarget() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); + transferPolicies[0].target = address(target); + transferPolicies[0].maxValuePerUse = 500; + transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.transferPolicies = transferPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(0x1234), 500, "")) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + + function test_callPolicy_erc20Transfer() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = ERC20.transfer.selector; + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + // set constraints + SessionLib.Constraint[] memory constraints = new SessionLib.Constraint[](2); + // constraint 0: can only transfer to target + constraints[0].index = 0; + constraints[0].refValue = bytes32(uint256(uint160(address(target)))); + constraints[0].condition = SessionLib.Condition.Equal; + + // constraint 1: can transfer 100 tokens per transaction, and 120 total + constraints[1].index = 1; + constraints[1].refValue = bytes32(uint256(100)); + constraints[1].condition = SessionLib.Condition.LessOrEqual; + constraints[1].limit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 120, 0); + + callPolicies[0].constraints = constraints; + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(erc20), 0, abi.encodeCall(ERC20.transfer, (address(target), 100))) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + + assertTrue(erc20.balanceOf(address(target)) == 100); + } + + function test_revert_callPolicy_erc20Transfer_wrongTarget() public { + address wrongTarget = vm.addr(0x789); + + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = ERC20.transfer.selector; + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + // set constraints + SessionLib.Constraint[] memory constraints = new SessionLib.Constraint[](2); + // constraint 0: can only transfer to defined target below + constraints[0].index = 0; + constraints[0].refValue = bytes32(uint256(uint160(address(target)))); + constraints[0].condition = SessionLib.Condition.Equal; + + // constraint 1: can transfer 100 tokens per transaction, and 120 total + constraints[1].index = 1; + constraints[1].refValue = bytes32(uint256(100)); + constraints[1].condition = SessionLib.Condition.LessOrEqual; + constraints[1].limit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 120, 0); + + callPolicies[0].constraints = constraints; + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + // === set target as wrongTarget here ========= + ExecutionLib.encodeSingle(address(erc20), 0, abi.encodeCall(ERC20.transfer, (wrongTarget, 100))) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + + function test_revert_callPolicy_erc20Transfer_wrongSelector() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = ERC20.transfer.selector; + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + // set constraints + SessionLib.Constraint[] memory constraints = new SessionLib.Constraint[](2); + // constraint 0: can only transfer to defined target below + constraints[0].index = 0; + constraints[0].refValue = bytes32(uint256(uint160(address(target)))); + constraints[0].condition = SessionLib.Condition.Equal; + + // constraint 1: can transfer 100 tokens per transaction, and 120 total + constraints[1].index = 1; + constraints[1].refValue = bytes32(uint256(100)); + constraints[1].condition = SessionLib.Condition.LessOrEqual; + constraints[1].limit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 120, 0); + + callPolicies[0].constraints = constraints; + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + // === set wrong selector here ========= + ExecutionLib.encodeSingle(address(erc20), 0, abi.encodeCall(MockERC20.mint, (address(target), 100))) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + + function test_revert_callPolicy_erc20Transfer_crossTxLimit() public { + uint256 wrongAmount = 101; + + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = ERC20.transfer.selector; + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + // set constraints + SessionLib.Constraint[] memory constraints = new SessionLib.Constraint[](2); + // constraint 0: can only transfer to defined target below + constraints[0].index = 0; + constraints[0].refValue = bytes32(uint256(uint160(address(target)))); + constraints[0].condition = SessionLib.Condition.Equal; + + // constraint 1: can transfer 100 tokens per transaction, and 120 total + constraints[1].index = 1; + constraints[1].refValue = bytes32(uint256(100)); + constraints[1].condition = SessionLib.Condition.LessOrEqual; + constraints[1].limit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 120, 0); + + callPolicies[0].constraints = constraints; + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + // === set amount as wrongAmount here ========= + ExecutionLib.encodeSingle( + address(erc20), 0, abi.encodeCall(ERC20.transfer, (address(target), wrongAmount)) + ) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + + function test_revert_callPolicy_erc20Transfer_crossTotalLimit() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = ERC20.transfer.selector; + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + // set constraints + SessionLib.Constraint[] memory constraints = new SessionLib.Constraint[](2); + // constraint 0: can only transfer to defined target below + constraints[0].index = 0; + constraints[0].refValue = bytes32(uint256(uint160(address(target)))); + constraints[0].condition = SessionLib.Condition.Equal; + + // constraint 1: can transfer 100 tokens per transaction, and 120 total + constraints[1].index = 1; + constraints[1].refValue = bytes32(uint256(100)); + constraints[1].condition = SessionLib.Condition.LessOrEqual; + constraints[1].limit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 120, 0); + + callPolicies[0].constraints = constraints; + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(erc20), 0, abi.encodeCall(ERC20.transfer, (address(target), 100))) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + + // resend the same tx, with 100 amount. + // should fail because remaining total limit is 120 - 100 = 20 + + userOp.nonce = getNonce(address(account), address(validator)); + ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + + SessionLib.SessionState memory state = + DefaultValidatorNew(validator).getSessionStateForSigner(address(account), signerOne); + assertEq(state.callParams[0].remaining, 20); + } + + function test_callPolicy_anySelector() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = bytes4(0x00000000); + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(erc20), 0, abi.encodeCall(ERC20.transfer, (address(target), 100))) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + + assertTrue(erc20.balanceOf(address(target)) == 100); + } + + function test_sessionPeriodAllowance() public { + uint256 period = 100; + + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); + transferPolicies[0].target = address(target); + transferPolicies[0].maxValuePerUse = 500; + transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 600, period); // <== set period, and limitType as Allowance + + spec.expiresAt = period * 5; + spec.transferPolicies = transferPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), 500, "")) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + + assertTrue(address(target).balance == 500); + + // revert case: send more value in the same period, because period limit is crossed + + userOp.nonce = getNonce(address(account), address(validator)); + ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + + // success when period resets (i.e. block timestamp begins a new period) + vm.warp(period + 1); + entrypoint.handleOps(ops, payable(address(account))); + assertTrue(address(target).balance == 1000); + } + + // test utils + + function getNonce(address _account, address _validator) internal returns (uint256 nonce) { + uint192 key = uint192(bytes24(bytes20(address(_validator)))); + nonce = (uint256(uint160(_validator)) << 96) | entrypoint.getNonce(_account, key); + } + + function getDefaultUserOp() internal returns (PackedUserOperation memory userOp) { + userOp = PackedUserOperation({ + sender: address(0), + nonce: 0, + initCode: "", + callData: "", + accountGasLimits: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + preVerificationGas: 2e6, + gasFees: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + paymasterAndData: bytes(""), + signature: abi.encodePacked(hex"41414141") + }); + } + + function getDefaultSessionSpec(address signer) internal returns (SessionLib.SessionSpec memory spec) { + spec.signer = signer; + spec.expiresAt = 100; + spec.callPolicies = new SessionLib.CallSpec[](0); + spec.transferPolicies = new SessionLib.TransferSpec[](0); + } + + function _prepareDefaultValidatorNewInstallData() internal returns (InitializerInstallModule[] memory) { + InitializerInstallModule[] memory modules = new InitializerInstallModule[](1); + modules[0] = + InitializerInstallModule({moduleTypeId: MODULE_TYPE_VALIDATOR, module: validator, initData: bytes("")}); + return modules; + } + + function _createSessionKey(SessionLib.SessionSpec memory spec) internal { + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle( + address(validator), uint256(0), abi.encodeCall(DefaultValidatorNew.createSessionKey, (spec)) + ) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(accountOwner, accountOwnerPKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + } + + function _getSignedUserOp(address signer, uint256 pkey, PackedUserOperation memory userOp) + internal + returns (PackedUserOperation[] memory) + { + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(pkey, msgHash); + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + userOp.signature = abi.encode(signer, userOpSignature); + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + return userOps; + } + +} diff --git a/test/mock/MockERC20.sol b/test/mock/MockERC20.sol new file mode 100644 index 0000000..28284d1 --- /dev/null +++ b/test/mock/MockERC20.sol @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import "solady/tokens/ERC20.sol"; + +contract MockERC20 is ERC20 { + + function name() public view override returns (string memory) { + return "Token"; + } + + function symbol() public view override returns (string memory) { + return "TKN"; + } + + function mint(address to, uint256 amount) external { + _mint(to, amount); + } + +}