From d9b781ee9f5359e1b42e51e1bac5bbc235cebc38 Mon Sep 17 00:00:00 2001 From: Yash Date: Thu, 27 Feb 2025 22:03:00 +0530 Subject: [PATCH 01/15] New session key validator --- lib/solady | 2 +- src/DefaultValidator.sol | 2 +- src/DefaultValidatorNew.sol | 305 ++++++++++++++++++++++++++++++++ src/lib/SessionLib.sol | 344 ++++++++++++++++++++++++++++++++++++ 4 files changed, 651 insertions(+), 2 deletions(-) create mode 100644 src/DefaultValidatorNew.sol create mode 100644 src/lib/SessionLib.sol diff --git a/lib/solady b/lib/solady index 8200a70..ab7596b 160000 --- a/lib/solady +++ b/lib/solady @@ -1 +1 @@ -Subproject commit 8200a70e8dc2a77ecb074fc2e99a2a0d36547522 +Subproject commit ab7596b1a21f8a54bec722d54505e8daa40b5760 diff --git a/src/DefaultValidator.sol b/src/DefaultValidator.sol index 75f126e..a7e5f0c 100644 --- a/src/DefaultValidator.sol +++ b/src/DefaultValidator.sol @@ -1,4 +1,4 @@ -// SPDX-License-Identifier: MIT +// SPDX-License-Identifier: Apache-2.0 pragma solidity ^0.8.26; import "account-abstraction-v0.7/core/Helpers.sol"; diff --git a/src/DefaultValidatorNew.sol b/src/DefaultValidatorNew.sol new file mode 100644 index 0000000..8f98e20 --- /dev/null +++ b/src/DefaultValidatorNew.sol @@ -0,0 +1,305 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import "account-abstraction-v0.7/core/Helpers.sol"; +import {UserOperationLib} from "account-abstraction-v0.7/core/UserOperationLib.sol"; + +import { + IValidator, + MODULE_TYPE_VALIDATOR, + PackedUserOperation, + VALIDATION_FAILED, + VALIDATION_SUCCESS +} from "./interface/IERC7579Module.sol"; + +import {ModularAccount} from "./ModularAccount.sol"; + +import {Execution} from "./interface/IERC7579Account.sol"; +import {ExecutionLib} from "./lib/ExecutionLib.sol"; +import { + CALLTYPE_BATCH, + CALLTYPE_DELEGATECALL, + CALLTYPE_SINGLE, + CALLTYPE_STATIC, + CallType, + ModeCode, + ModeLib +} from "./lib/ModeLib.sol"; + +import {SessionLib} from "./lib/SessionLib.sol"; + +import {Execution, IERC7579Account} from "./interface/IERC7579Account.sol"; +import {ECDSA} from "solady/utils/ECDSA.sol"; +import {EIP712} from "solady/utils/EIP712.sol"; +import {EnumerableSetLib} from "solady/utils/EnumerableSetLib.sol"; +import {SignatureCheckerLib} from "solady/utils/SignatureCheckerLib.sol"; + +library DefaultValidatorNewStorage { + + /// @custom:storage-location erc7201:account.module.default.validator.new + bytes32 public constant DEFAULT_VALIDATOR_NEW_STORAGE_POSITION = + keccak256(abi.encode(uint256(keccak256("account.module.default.validator.new")) - 1)) & ~bytes32(uint256(0xff)); + + struct Data { + /// @dev Map from smart account address => is initialized + mapping(address => bool) isInitialized; + mapping(address => EnumerableSetLib.AddressSet) allSigners; + mapping(address account => uint256 openSessions) sessionCounter; + mapping(bytes32 sessionHash => SessionLib.SessionStorage sessionState) sessions; + mapping(address account => mapping(address signer => SessionLib.StoredSessionSpec)) sessionSpec; + } + + function data() internal pure returns (Data storage $) { + bytes32 position = DEFAULT_VALIDATOR_NEW_STORAGE_POSITION; + assembly { + $.slot := position + } + } + +} + +contract DefaultValidatorNew is IValidator, EIP712 { + + using ECDSA for bytes32; + using ExecutionLib for bytes; + using ModeLib for ModeCode; + using EnumerableSetLib for EnumerableSetLib.AddressSet; + + using SessionLib for SessionLib.SessionStorage; + + error SessionZeroSigner(); + error SessionUnlimitedFees(); + error SessionAlreadyExists(); + error SessionExpiresTooSoon(); + + event SessionCreated(address indexed account, bytes32 indexed sessionHash, SessionLib.SessionSpec sessionSpec); // new + event SessionRevoked(address indexed account, bytes32 indexed sessionHash); // new + + uint256 private constant _ADMIN_ROLE = 1 << 0; + + bytes4 constant ERC1271_MAGICVALUE = 0x1626ba7e; + bytes4 constant ERC1271_INVALID = 0xffffffff; + + bytes32 private constant MSG_TYPEHASH = keccak256("AccountMessage(bytes message)"); + + function onInstall(bytes calldata) external { + _defaultValidatorStorage().isInitialized[msg.sender] = true; + } + + function onUninstall(bytes calldata) external { + address account = msg.sender; + if (!_isInitialized(account)) { + revert NotInitialized(account); + } + + _defaultValidatorStorage().isInitialized[account] = false; + + address[] memory allSigners = _defaultValidatorStorage().allSigners[account].values(); + delete _defaultValidatorStorage().allSigners[msg.sender]; + + uint256 len = allSigners.length; + for (uint256 i = 0; i < len; i += 1) { + delete _defaultValidatorStorage().sessionSpec[account][allSigners[i]]; + } + } + + function isModuleType(uint256 moduleTypeId) external pure returns (bool) { + return moduleTypeId == MODULE_TYPE_VALIDATOR; + } + + function isInitialized(address smartAccount) external view returns (bool) { + return _isInitialized(smartAccount); + } + + function _isInitialized(address smartAccount) internal view returns (bool) { + return _defaultValidatorStorage().isInitialized[smartAccount]; + } + + function validateUserOp(PackedUserOperation calldata userOp, bytes32 userOpHash) external returns (uint256) { + address account = msg.sender; + + bytes32 hash = SignatureCheckerLib.toEthSignedMessageHash(userOpHash); + address signer = hash.recover(userOp.signature); + (bool isValid, uint48 validAfter, uint48 validUntil) = _isValidSigner(account, signer, userOp); + + if (!isValid) { + return VALIDATION_FAILED; + } + + return _packValidationData(ValidationData(address(0), validAfter, validUntil)); + } + + /// @notice This module should not be used to validate signatures (including EIP-1271), + /// as a signature by itself does not have enough information to validate it against a session. + function isValidSignatureWithSender(address sender, bytes32 hash, bytes calldata signature) + external + view + returns (bytes4) + { + address account = msg.sender; + bytes32 targetDigest = getMessageHash(hash); + address signer = ECDSA.recover(targetDigest, signature); + + if ( + signer == ModularAccount(payable(account)).owner() + || ModularAccount(payable(account)).hasAnyRole(signer, _ADMIN_ROLE) + ) { + return ERC1271_MAGICVALUE; + } + + return ERC1271_INVALID; + } + + /** + * @notice Returns the hash of message that should be signed for EIP1271 verification. + * @param _hash The message hash to sign for the EIP-1271 origin verifying contract. + * @return messageHash The digest to sign for EIP-1271 verification. + */ + function getMessageHash(bytes32 _hash) public view returns (bytes32) { + bytes32 messageHash = keccak256(abi.encode(_hash)); + bytes32 typedDataHash = keccak256(abi.encode(MSG_TYPEHASH, messageHash)); + return keccak256(abi.encodePacked("\x19\x01", _domainSeparator(), typedDataHash)); + } + + function _domainNameAndVersion() internal pure override returns (string memory name, string memory version) { + name = "DefaultValidatorNew"; + version = "1"; + } + + // create session key for signer + function createSessionKey(SessionLib.SessionSpec calldata sessionSpec) external { + SessionLib.StoredSessionSpec memory spec = SessionLib.StoredSessionSpec({ + expiresAt: sessionSpec.expiresAt, + feeLimit: sessionSpec.feeLimit, + callPolicies: sessionSpec.callPolicies, + transferPolicies: sessionSpec.transferPolicies + }); + bytes32 sessionHash = keccak256(abi.encode(sessionSpec.signer, spec)); + + if (!_isInitialized(msg.sender)) { + revert NotInitialized(msg.sender); + } + if (sessionSpec.signer == address(0)) { + revert SessionZeroSigner(); + } + if (sessionSpec.feeLimit.limitType == SessionLib.LimitType.Unlimited) { + revert SessionUnlimitedFees(); + } + if (_defaultValidatorStorage().sessions[sessionHash].status[msg.sender] != SessionLib.Status.NotInitialized) { + revert SessionAlreadyExists(); + } + // Sessions should expire in no less than 60 seconds. + if (sessionSpec.expiresAt <= block.timestamp + 60) { + revert SessionExpiresTooSoon(); + } + + _defaultValidatorStorage().sessionCounter[msg.sender]++; + _defaultValidatorStorage().sessions[sessionHash].status[msg.sender] = SessionLib.Status.Active; + _defaultValidatorStorage().allSigners[msg.sender].add(sessionSpec.signer); + + SessionLib.StoredSessionSpec storage newSessionSpec = _defaultValidatorStorage().sessionSpec[msg.sender][sessionSpec.signer]; + + newSessionSpec.expiresAt = sessionSpec.expiresAt; + newSessionSpec.feeLimit = sessionSpec.feeLimit; + + delete newSessionSpec.callPolicies; + for (uint256 i = 0; i < sessionSpec.callPolicies.length; i++) { + newSessionSpec.callPolicies.push(sessionSpec.callPolicies[i]); + } + + delete newSessionSpec.transferPolicies; + for (uint256 i = 0; i < sessionSpec.transferPolicies.length; i++) { + newSessionSpec.transferPolicies.push(sessionSpec.transferPolicies[i]); + } + + emit SessionCreated(msg.sender, sessionHash, sessionSpec); + } + + function getSessionKeyForSigner(address account, address signer) + external + view + returns (SessionLib.SessionState memory) + { + SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[account][signer]; + bytes32 sessionHash = keccak256(abi.encode(signer, sessionSpec)); + + return _defaultValidatorStorage().sessions[sessionHash].getState(account, sessionSpec); + } + + function _isValidSigner(address _account, address _signer, PackedUserOperation calldata _userOp) + internal + virtual + returns (bool isValid, uint48 validAfter, uint48 validUntil) + { + if ( + _signer == ModularAccount(payable(_account)).owner() + || ModularAccount(payable(_account)).hasAnyRole(_signer, _ADMIN_ROLE) + ) { + return (true, 0, 0); + } + + SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[_account][_signer]; + bytes32 sessionHash = keccak256(abi.encode(_signer, sessionSpec)); + + bytes4 selector = _getFunctionSelector(_userOp.callData); + + if (selector == ModularAccount.execute.selector) { + (ModeCode mode,) = abi.decode(_userOp.callData[4:], (ModeCode, bytes)); + (CallType callType,,,) = mode.decode(); + + if (callType == CALLTYPE_SINGLE) { + uint256 dataOffset = abi.decode(_userOp.callData[36:], (uint256)); + + (address target, uint256 value, bytes calldata callData) = + ExecutionLib.decodeSingle(_userOp.callData[(36 + dataOffset):]); + + isValid = + _defaultValidatorStorage().sessions[sessionHash].validate(target, value, callData, sessionSpec); + + (address paymaster,,) = UserOperationLib.unpackPaymasterStaticFields(_userOp.paymasterAndData); + if (paymaster == address(0)) { + // TODO: do we need to validate fee limit? + } + + return (isValid, validAfter, validUntil); + } else if (callType == CALLTYPE_BATCH) { + uint256 dataOffset = abi.decode(_userOp.callData[36:], (uint256)); + Execution[] memory executions = ExecutionLib.decodeBatch(_userOp.callData[36 + dataOffset:]); + uint256 length = executions.length; + + for (uint256 i = 0; i < length; i++) { + SessionLib.SessionStorage storage state = _defaultValidatorStorage().sessions[sessionHash]; + isValid = state.validate(executions[i].target, executions[i].value, executions[i].data, sessionSpec); + + if (!isValid) { + return (isValid, validAfter, validUntil); + } + } + + isValid = true; + return (isValid, validAfter, validUntil); + } else if (callType == CALLTYPE_DELEGATECALL) { + // TODO + // return false for now + isValid = false; + return (isValid, validAfter, validUntil); + } else { + isValid = false; + return (isValid, validAfter, validUntil); + } + } else { + isValid = false; + return (isValid, validAfter, validUntil); + } + } + + function _getFunctionSelector(bytes calldata data) internal pure returns (bytes4 functionSelector) { + require(data.length >= 4, "!Data"); + return bytes4(data[:4]); + } + + function _defaultValidatorStorage() internal pure returns (DefaultValidatorNewStorage.Data storage data) { + data = DefaultValidatorNewStorage.data(); + } + +} diff --git a/src/lib/SessionLib.sol b/src/lib/SessionLib.sol new file mode 100644 index 0000000..28293e4 --- /dev/null +++ b/src/lib/SessionLib.sol @@ -0,0 +1,344 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import {LibBytes} from "solady/utils/LibBytes.sol"; + +library SessionLib { + + using SessionLib for SessionLib.Constraint; + using SessionLib for SessionLib.UsageLimit; + using LibBytes for bytes; + + enum Status { + NotInitialized, + Active, + Closed + } + + enum LimitType { + Unlimited, + Lifetime, + Allowance + } + + enum Condition { + Unconstrained, + Equal, + Greater, + Less, + GreaterOrEqual, + LessOrEqual, + NotEqual + } + + struct UsageLimit { + LimitType limitType; + uint256 limit; // ignored if limitType == Unlimited + uint256 period; // ignored if limitType != Allowance + } + + struct Constraint { + Condition condition; + uint64 index; + bytes32 refValue; + UsageLimit limit; + } + + struct CallSpec { + address target; + bytes4 selector; + uint256 maxValuePerUse; + UsageLimit valueLimit; + Constraint[] constraints; + } + + struct TransferSpec { + address target; + uint256 maxValuePerUse; + UsageLimit valueLimit; + } + + struct SessionSpec { + address signer; + uint256 expiresAt; + UsageLimit feeLimit; + CallSpec[] callPolicies; + TransferSpec[] transferPolicies; + } + + struct StoredSessionSpec { + uint256 expiresAt; + UsageLimit feeLimit; + CallSpec[] callPolicies; + TransferSpec[] transferPolicies; + } + + struct LimitState { + // this might also be limited by a constraint or `maxValuePerUse`, + // which is not reflected here + uint256 remaining; + address target; + // ignored for transfer value + bytes4 selector; + // ignored for transfer and call value + uint256 index; + } + + // Info about remaining session limits and its status + struct SessionState { + Status status; + uint256 feesRemaining; + LimitState[] transferValue; + LimitState[] callValue; + LimitState[] callParams; + } + + struct UsageTracker { + // Used for LimitType.Lifetime + mapping(address => uint256) lifetimeUsage; + // Used for LimitType.Allowance + // period => used that period + mapping(uint64 => mapping(address => uint256)) allowanceUsage; + } + + struct SessionStorage { + mapping(address => Status) status; + UsageTracker fee; + // (target) => transfer value tracker + mapping(address => UsageTracker) transferValue; + // (target, selector) => call value tracker + mapping(address => mapping(bytes4 => UsageTracker)) callValue; + // (target, selector, index) => call parameter tracker + // index is the constraint index in callPolicy, not the parameter index + mapping(address => mapping(bytes4 => mapping(uint256 => UsageTracker))) params; + } + + function checkAndUpdate(UsageLimit memory limit, UsageTracker storage tracker, uint256 value) + internal + returns (bool) + { + if (limit.limitType == LimitType.Lifetime) { + if (tracker.lifetimeUsage[msg.sender] + value > limit.limit) { + // revert SessionLifetimeUsageExceeded(tracker.lifetimeUsage[msg.sender], limit.limit); + return false; + } + tracker.lifetimeUsage[msg.sender] += value; + } else if (limit.limitType == LimitType.Allowance) { + uint64 period = uint64(block.timestamp / limit.period); + + if (tracker.allowanceUsage[period][msg.sender] + value > limit.limit) { + // revert SessionAllowanceExceeded(tracker.allowanceUsage[period][msg.sender], limit.limit, period); + return false; + } + + tracker.allowanceUsage[period][msg.sender] += value; + } + + return true; + } + + function checkAndUpdate(Constraint memory constraint, UsageTracker storage tracker, bytes memory data) + internal + returns (bool) + { + uint256 expectedLength = 4 + constraint.index * 32 + 32; + + if (data.length < expectedLength) { + // revert SessionInvalidDataLength(data.length, expectedLength); + return false; + } + + bytes32 param = data.load(4 + constraint.index * 32); + Condition condition = constraint.condition; + bytes32 refValue = constraint.refValue; + + if ( + (condition == Condition.Equal && param != refValue) || (condition == Condition.Greater && param <= refValue) + || (condition == Condition.Less && param >= refValue) + || (condition == Condition.GreaterOrEqual && param < refValue) + || (condition == Condition.LessOrEqual && param > refValue) + || (condition == Condition.NotEqual && param == refValue) + ) { + // revert SessionConditionFailed(param, refValue, uint8(condition)); + return false; + } + + bool check = constraint.limit.checkAndUpdate(tracker, uint256(param)); + + return check; + } + + function checkCallPolicy( + SessionStorage storage state, + bytes memory data, + address target, + bytes4 selector, + CallSpec[] memory callPolicies + ) private returns (bool found, CallSpec memory) { + CallSpec memory callPolicy; + + for (uint256 i = 0; i < callPolicies.length; i++) { + if (callPolicies[i].target == target && callPolicies[i].selector == selector) { + callPolicy = callPolicies[i]; + found = true; + break; + } + } + + if (!found) { + return (found, callPolicy); + } + + for (uint256 i = 0; i < callPolicy.constraints.length; i++) { + bool check = callPolicy.constraints[i].checkAndUpdate(state.params[target][selector][i], data); + if (!check) { + return (false, callPolicy); + } + } + + return (true, callPolicy); + } + + // TODO: do we need this? + function validateFeeLimit() internal returns (bool) {} + + function validate( + SessionStorage storage state, + address target, + uint256 value, + bytes memory callData, + StoredSessionSpec memory spec + ) internal returns (bool) { + if (state.status[msg.sender] != Status.Active) { + // revert SessionNotActive(); + return false; + } + + // TODO: check timestamps + + if (callData.length >= 4) { + // bytes4 selector = bytes4(callData[:4]); + bytes4 selector = bytes4(callData[0]) | (bytes4(callData[1]) >> 8) | (bytes4(callData[2]) >> 16) + | (bytes4(callData[3]) >> 24); + (bool found, CallSpec memory callPolicy) = + checkCallPolicy(state, callData, target, selector, spec.callPolicies); + + if (!found) { + return false; + } + + if (value > callPolicy.maxValuePerUse) { + // revert SessionMaxValueExceeded(value, callPolicy.maxValuePerUse); + return false; + } + + callPolicy.valueLimit.checkAndUpdate(state.callValue[target][selector], value); + } else { + TransferSpec memory transferPolicy; + bool found = false; + + for (uint256 i = 0; i < spec.transferPolicies.length; i++) { + if (spec.transferPolicies[i].target == target) { + transferPolicy = spec.transferPolicies[i]; + found = true; + break; + } + } + + if (!found) { + // revert SessionTransferPolicyViolated(target); + return false; + } + + if (value > transferPolicy.maxValuePerUse) { + // revert SessionMaxValueExceeded(value, transferPolicy.maxValuePerUse); + return false; + } + transferPolicy.valueLimit.checkAndUpdate(state.transferValue[target], value); + } + } + + function remainingLimit(UsageLimit memory limit, UsageTracker storage tracker, address account) + private + view + returns (uint256) + { + if (limit.limitType == LimitType.Unlimited) { + // this might be still limited by `maxValuePerUse` or a constraint + return type(uint256).max; + } + if (limit.limitType == LimitType.Lifetime) { + return limit.limit - tracker.lifetimeUsage[account]; + } + if (limit.limitType == LimitType.Allowance) { + // this is not used during validation, so it's fine to use block.timestamp + uint64 period = uint64(block.timestamp / limit.period); + return limit.limit - tracker.allowanceUsage[period][account]; + } + } + + function getState(SessionStorage storage session, address account, StoredSessionSpec memory spec) + internal + view + returns (SessionState memory) + { + uint256 totalConstraints = 0; + for (uint256 i = 0; i < spec.callPolicies.length; i++) { + totalConstraints += spec.callPolicies[i].constraints.length; + } + + LimitState[] memory transferValue = new LimitState[](spec.transferPolicies.length); + LimitState[] memory callValue = new LimitState[](spec.callPolicies.length); + LimitState[] memory callParams = new LimitState[](totalConstraints); // there will be empty ones at the end + uint256 paramLimitIndex = 0; + + for (uint256 i = 0; i < transferValue.length; i++) { + TransferSpec memory transferSpec = spec.transferPolicies[i]; + transferValue[i] = LimitState({ + remaining: remainingLimit(transferSpec.valueLimit, session.transferValue[transferSpec.target], account), + target: transferSpec.target, + selector: bytes4(0), + index: 0 + }); + } + + for (uint256 i = 0; i < callValue.length; i++) { + CallSpec memory callSpec = spec.callPolicies[i]; + callValue[i] = LimitState({ + remaining: remainingLimit( + callSpec.valueLimit, session.callValue[callSpec.target][callSpec.selector], account + ), + target: callSpec.target, + selector: callSpec.selector, + index: 0 + }); + + for (uint256 j = 0; j < callSpec.constraints.length; j++) { + if (callSpec.constraints[j].limit.limitType != LimitType.Unlimited) { + callParams[paramLimitIndex++] = LimitState({ + remaining: remainingLimit( + callSpec.constraints[j].limit, session.params[callSpec.target][callSpec.selector][j], account + ), + target: callSpec.target, + selector: callSpec.selector, + index: callSpec.constraints[j].index + }); + } + } + } + + // shrink array to actual size + assembly { + mstore(callParams, paramLimitIndex) + } + + return SessionState({ + status: session.status[account], + feesRemaining: remainingLimit(spec.feeLimit, session.fee, account), + transferValue: transferValue, + callValue: callValue, + callParams: callParams + }); + } + +} From 9bdfc4bd75e07863067596ca8539a683c3709f3e Mon Sep 17 00:00:00 2001 From: Yash Date: Thu, 27 Feb 2025 22:53:02 +0530 Subject: [PATCH 02/15] validate expiration timestamp --- src/DefaultValidatorNew.sol | 3 ++- src/lib/SessionLib.sol | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/DefaultValidatorNew.sol b/src/DefaultValidatorNew.sol index 8f98e20..0c91993 100644 --- a/src/DefaultValidatorNew.sol +++ b/src/DefaultValidatorNew.sol @@ -197,7 +197,8 @@ contract DefaultValidatorNew is IValidator, EIP712 { _defaultValidatorStorage().sessions[sessionHash].status[msg.sender] = SessionLib.Status.Active; _defaultValidatorStorage().allSigners[msg.sender].add(sessionSpec.signer); - SessionLib.StoredSessionSpec storage newSessionSpec = _defaultValidatorStorage().sessionSpec[msg.sender][sessionSpec.signer]; + SessionLib.StoredSessionSpec storage newSessionSpec = + _defaultValidatorStorage().sessionSpec[msg.sender][sessionSpec.signer]; newSessionSpec.expiresAt = sessionSpec.expiresAt; newSessionSpec.feeLimit = sessionSpec.feeLimit; diff --git a/src/lib/SessionLib.sol b/src/lib/SessionLib.sol index 28293e4..27110c9 100644 --- a/src/lib/SessionLib.sol +++ b/src/lib/SessionLib.sol @@ -214,7 +214,9 @@ library SessionLib { return false; } - // TODO: check timestamps + if(spec.expiresAt < block.timestamp) { + return false; + } if (callData.length >= 4) { // bytes4 selector = bytes4(callData[:4]); From cbf8aaa7ec83985620fd8b1354eef3fd1da8264f Mon Sep 17 00:00:00 2001 From: Yash Date: Thu, 27 Feb 2025 23:32:19 +0530 Subject: [PATCH 03/15] wip tests --- src/DefaultValidatorNew.sol | 42 +++++- src/lib/SessionLib.sol | 2 +- test/DefaultValidatorNew.t.sol | 257 +++++++++++++++++++++++++++++++++ 3 files changed, 299 insertions(+), 2 deletions(-) create mode 100644 test/DefaultValidatorNew.t.sol diff --git a/src/DefaultValidatorNew.sol b/src/DefaultValidatorNew.sol index 0c91993..f7f204a 100644 --- a/src/DefaultValidatorNew.sol +++ b/src/DefaultValidatorNew.sol @@ -216,7 +216,47 @@ contract DefaultValidatorNew is IValidator, EIP712 { emit SessionCreated(msg.sender, sessionHash, sessionSpec); } - function getSessionKeyForSigner(address account, address signer) + function getCallPoliciesForSigner(address account, address signer) + external + view + returns (SessionLib.CallSpec[] memory) + { + SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[account][signer]; + + return sessionSpec.callPolicies; + } + + function getTransferPoliciesForSigner(address account, address signer) + external + view + returns (SessionLib.TransferSpec[] memory) + { + SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[account][signer]; + + return sessionSpec.transferPolicies; + } + + function getFeeLimitForSigner(address account, address signer) + external + view + returns (SessionLib.UsageLimit memory) + { + SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[account][signer]; + + return sessionSpec.feeLimit; + } + + function getSessionExpirationForSigner(address account, address signer) + external + view + returns (uint256) + { + SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[account][signer]; + + return sessionSpec.expiresAt; + } + + function getSessionStateForSigner(address account, address signer) external view returns (SessionLib.SessionState memory) diff --git a/src/lib/SessionLib.sol b/src/lib/SessionLib.sol index 27110c9..5962a9a 100644 --- a/src/lib/SessionLib.sol +++ b/src/lib/SessionLib.sol @@ -214,7 +214,7 @@ library SessionLib { return false; } - if(spec.expiresAt < block.timestamp) { + if (spec.expiresAt < block.timestamp) { return false; } diff --git a/test/DefaultValidatorNew.t.sol b/test/DefaultValidatorNew.t.sol new file mode 100644 index 0000000..19f3f62 --- /dev/null +++ b/test/DefaultValidatorNew.t.sol @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import {UserOperationLib} from "account-abstraction-v0.7/core/UserOperationLib.sol"; +import {IEntryPoint} from "account-abstraction-v0.7/interfaces/IEntryPoint.sol"; +import {PackedUserOperation} from "account-abstraction-v0.7/interfaces/PackedUserOperation.sol"; +import {Test} from "forge-std/Test.sol"; +import {ECDSA} from "solady/utils/ECDSA.sol"; +import {SignatureCheckerLib} from "solady/utils/SignatureCheckerLib.sol"; + +import {MockTarget} from "test/mock/MockTarget.sol"; +import {MockValidator} from "test/mock/MockValidator.sol"; +import {EntryPointLib} from "test/util/ERC4337Test.sol"; + +import "lib/forge-std/src/console.sol"; + +import { + CALLTYPE_SINGLE, + CALLTYPE_STATIC, + Execution, + ExecutionLib, + IERC7579Account, + MODULE_TYPE_FALLBACK, + MODULE_TYPE_VALIDATOR, + ModeLib, + ModularAccount +} from "src/ModularAccount.sol"; +import {ModularAccountFactory} from "src/ModularAccountFactory.sol"; + +import {InitializerInstallModule} from "src/interface/IModularAccount.sol"; + +import {DefaultValidatorNew} from "src/DefaultValidatorNew.sol"; + +import {SessionLib} from "src/lib/SessionLib.sol"; + +contract DefaultValidatorNewTest is Test { + + IEntryPoint public entrypoint; + ModularAccountFactory public factory; + ModularAccount public account; + + bytes public accountSalt = abi.encodePacked(uint256(0xdeadbeef)); + + uint256 accountOwnerPKey = 0x2; + address public factoryOwner = vm.addr(0x1); + address public accountOwner = vm.addr(accountOwnerPKey); + + address public validator; + MockTarget public target; + + function setUp() public { + vm.prank(factoryOwner); + entrypoint = IEntryPoint(EntryPointLib.deploy()); + validator = address(new DefaultValidatorNew()); + target = new MockTarget(); + ModularAccount accountImpl = new ModularAccount(address(entrypoint)); + factory = new ModularAccountFactory(address(entrypoint), factoryOwner, address(accountImpl)); + + InitializerInstallModule[] memory modules = _prepareDefaultValidatorNewInstallData(); + account = ModularAccount(payable(factory.createAccountWithModules(accountOwner, accountSalt, modules))); + } + + function test_execute_owner() public { + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(target), uint256(0), abi.encodeCall(MockTarget.set, 42)) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, msgHash); + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + address recoveredSigner = ECDSA.recover(msgHash, v, r, s); + assertEq(recoveredSigner, accountOwner); + + userOp.signature = userOpSignature; + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.deal(address(account), 1 ether); + entrypoint.handleOps(userOps, payable(address(account))); + + assertTrue(target.number() == 42); + } + + function test_executeBatch_owner() public { + Execution[] memory executions = new Execution[](2); + executions[0] = Execution({target: address(target), value: 0, data: abi.encodeCall(MockTarget.set, 11)}); + executions[1] = Execution({target: address(target), value: 0, data: abi.encodeCall(MockTarget.setIf, (42, 11))}); + + bytes memory userOpCalldata = + abi.encodeCall(IERC7579Account.execute, (ModeLib.encodeSimpleBatch(), ExecutionLib.encodeBatch(executions))); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, msgHash); + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + address recoveredSigner = ECDSA.recover(msgHash, v, r, s); + assertEq(recoveredSigner, accountOwner); + + userOp.signature = userOpSignature; + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.deal(address(account), 1 ether); + entrypoint.handleOps(userOps, payable(address(account))); + + assertTrue(target.number() == 42); + } + + function test_signature() public { + // sign message + bytes32 message = keccak256("Hello World"); + bytes32 messageHash = DefaultValidatorNew(payable(validator)).getMessageHash(message); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, messageHash); + bytes memory signature = abi.encodePacked(r, s, v); + + // encode validator address + signature + bytes memory sigWithValidator = abi.encode(validator, signature); + + // verify signature + bytes4 returnValue = account.isValidSignature(message, sigWithValidator); + assertEq(returnValue, bytes4(0x1626ba7e)); // Magic value from ERC-1271 + } + + // === signer is admin (non owner) + + // === session key tests + + // create session key + function test_createSessionKeyForSigner() public { + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](0); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](0); + + spec.signer = vm.addr(1234); + spec.expiresAt = 100; + spec.feeLimit = feeLimit; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + _createSessionKey(spec); + + SessionLib.CallSpec[] memory fetchedCallPolicies = + DefaultValidatorNew(validator).getCallPoliciesForSigner(address(account), vm.addr(1234)); + + SessionLib.TransferSpec[] memory fetchedTransferPolicies = + DefaultValidatorNew(validator).getTransferPoliciesForSigner(address(account), vm.addr(1234)); + + SessionLib.UsageLimit memory fetchedFeeLimit = + DefaultValidatorNew(validator).getFeeLimitForSigner(address(account), vm.addr(1234)); + + uint256 expiresAt = DefaultValidatorNew(validator).getSessionExpirationForSigner(address(account), vm.addr(1234)); + + assertEq(uint256(fetchedFeeLimit.limitType), uint256(SessionLib.LimitType.Lifetime)); + assertEq(expiresAt, 100); + } + + // test utils + + function getNonce(address _account, address _validator) internal returns (uint256 nonce) { + uint192 key = uint192(bytes24(bytes20(address(_validator)))); + nonce = (uint256(uint160(_validator)) << 96) | entrypoint.getNonce(_account, key); + } + + function getDefaultUserOp() internal returns (PackedUserOperation memory userOp) { + userOp = PackedUserOperation({ + sender: address(0), + nonce: 0, + initCode: "", + callData: "", + accountGasLimits: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + preVerificationGas: 2e6, + gasFees: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + paymasterAndData: bytes(""), + signature: abi.encodePacked(hex"41414141") + }); + } + + function _prepareDefaultValidatorNewInstallData() internal returns (InitializerInstallModule[] memory) { + InitializerInstallModule[] memory modules = new InitializerInstallModule[](1); + modules[0] = + InitializerInstallModule({moduleTypeId: MODULE_TYPE_VALIDATOR, module: validator, initData: bytes("")}); + return modules; + } + + function _createSessionKey(SessionLib.SessionSpec memory spec) internal { + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle( + address(validator), uint256(0), abi.encodeCall(DefaultValidatorNew.createSessionKey, (spec)) + ) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, msgHash); + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + userOp.signature = userOpSignature; + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.deal(address(account), 1 ether); + entrypoint.handleOps(userOps, payable(address(account))); + } + +} From 08c584889c8c715e77e5ec74cc56693885b73ec0 Mon Sep 17 00:00:00 2001 From: Yash Date: Fri, 28 Feb 2025 00:57:05 +0530 Subject: [PATCH 04/15] test --- src/DefaultValidatorNew.sol | 8 ++-- src/lib/SessionLib.sol | 6 +-- test/DefaultValidatorNew.t.sol | 67 ++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 7 deletions(-) diff --git a/src/DefaultValidatorNew.sol b/src/DefaultValidatorNew.sol index f7f204a..b4555e7 100644 --- a/src/DefaultValidatorNew.sol +++ b/src/DefaultValidatorNew.sol @@ -297,10 +297,10 @@ contract DefaultValidatorNew is IValidator, EIP712 { isValid = _defaultValidatorStorage().sessions[sessionHash].validate(target, value, callData, sessionSpec); - (address paymaster,,) = UserOperationLib.unpackPaymasterStaticFields(_userOp.paymasterAndData); - if (paymaster == address(0)) { - // TODO: do we need to validate fee limit? - } + // (address paymaster,,) = UserOperationLib.unpackPaymasterStaticFields(_userOp.paymasterAndData); + // if (paymaster == address(0)) { + // // TODO: do we need to validate fee limit? + // } return (isValid, validAfter, validUntil); } else if (callType == CALLTYPE_BATCH) { diff --git a/src/lib/SessionLib.sol b/src/lib/SessionLib.sol index 5962a9a..0aafc8e 100644 --- a/src/lib/SessionLib.sol +++ b/src/lib/SessionLib.sol @@ -218,7 +218,7 @@ library SessionLib { return false; } - if (callData.length >= 4) { + if (callData.length >= 4 && callData.length != 12 ) { // TODO: fix this // bytes4 selector = bytes4(callData[:4]); bytes4 selector = bytes4(callData[0]) | (bytes4(callData[1]) >> 8) | (bytes4(callData[2]) >> 16) | (bytes4(callData[3]) >> 24); @@ -234,7 +234,7 @@ library SessionLib { return false; } - callPolicy.valueLimit.checkAndUpdate(state.callValue[target][selector], value); + return callPolicy.valueLimit.checkAndUpdate(state.callValue[target][selector], value); } else { TransferSpec memory transferPolicy; bool found = false; @@ -256,7 +256,7 @@ library SessionLib { // revert SessionMaxValueExceeded(value, transferPolicy.maxValuePerUse); return false; } - transferPolicy.valueLimit.checkAndUpdate(state.transferValue[target], value); + return transferPolicy.valueLimit.checkAndUpdate(state.transferValue[target], value); } } diff --git a/test/DefaultValidatorNew.t.sol b/test/DefaultValidatorNew.t.sol index 19f3f62..c104ed2 100644 --- a/test/DefaultValidatorNew.t.sol +++ b/test/DefaultValidatorNew.t.sol @@ -45,6 +45,9 @@ contract DefaultValidatorNewTest is Test { address public factoryOwner = vm.addr(0x1); address public accountOwner = vm.addr(accountOwnerPKey); + uint256 signerOnePKey = 1234; + address signerOne = vm.addr(signerOnePKey); + address public validator; MockTarget public target; @@ -189,6 +192,70 @@ contract DefaultValidatorNewTest is Test { assertEq(expiresAt, 100); } + function test_execute_transferPolicy() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies; + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); + transferPolicies[0].target = address(target); + transferPolicies[0].maxValuePerUse = 1000; + transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.signer = signerOne; + spec.expiresAt = 100; + spec.feeLimit = feeLimit; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(target), 100, "") + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerOnePKey, msgHash); // sign with session key signer private key + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + address recoveredSigner = ECDSA.recover(msgHash, v, r, s); + assertEq(recoveredSigner, signerOne); + + userOp.signature = userOpSignature; + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.deal(address(account), 1 ether); + entrypoint.handleOps(userOps, payable(address(account))); + + assertTrue(address(target).balance == 100); + } + // test utils function getNonce(address _account, address _validator) internal returns (uint256 nonce) { From 42beb00fb83f70d17b59162b202cc1374a323599 Mon Sep 17 00:00:00 2001 From: Yash Date: Fri, 28 Feb 2025 00:57:46 +0530 Subject: [PATCH 05/15] fmt --- src/DefaultValidatorNew.sol | 6 +----- src/lib/SessionLib.sol | 3 ++- test/DefaultValidatorNew.t.sol | 9 +++------ 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/DefaultValidatorNew.sol b/src/DefaultValidatorNew.sol index b4555e7..323a8da 100644 --- a/src/DefaultValidatorNew.sol +++ b/src/DefaultValidatorNew.sol @@ -246,11 +246,7 @@ contract DefaultValidatorNew is IValidator, EIP712 { return sessionSpec.feeLimit; } - function getSessionExpirationForSigner(address account, address signer) - external - view - returns (uint256) - { + function getSessionExpirationForSigner(address account, address signer) external view returns (uint256) { SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[account][signer]; return sessionSpec.expiresAt; diff --git a/src/lib/SessionLib.sol b/src/lib/SessionLib.sol index 0aafc8e..9a60f43 100644 --- a/src/lib/SessionLib.sol +++ b/src/lib/SessionLib.sol @@ -218,7 +218,8 @@ library SessionLib { return false; } - if (callData.length >= 4 && callData.length != 12 ) { // TODO: fix this + if (callData.length >= 4 && callData.length != 12) { + // TODO: fix this // bytes4 selector = bytes4(callData[:4]); bytes4 selector = bytes4(callData[0]) | (bytes4(callData[1]) >> 8) | (bytes4(callData[2]) >> 16) | (bytes4(callData[3]) >> 24); diff --git a/test/DefaultValidatorNew.t.sol b/test/DefaultValidatorNew.t.sol index c104ed2..812ae05 100644 --- a/test/DefaultValidatorNew.t.sol +++ b/test/DefaultValidatorNew.t.sol @@ -186,7 +186,8 @@ contract DefaultValidatorNewTest is Test { SessionLib.UsageLimit memory fetchedFeeLimit = DefaultValidatorNew(validator).getFeeLimitForSigner(address(account), vm.addr(1234)); - uint256 expiresAt = DefaultValidatorNew(validator).getSessionExpirationForSigner(address(account), vm.addr(1234)); + uint256 expiresAt = + DefaultValidatorNew(validator).getSessionExpirationForSigner(address(account), vm.addr(1234)); assertEq(uint256(fetchedFeeLimit.limitType), uint256(SessionLib.LimitType.Lifetime)); assertEq(expiresAt, 100); @@ -218,11 +219,7 @@ contract DefaultValidatorNewTest is Test { // 2. prepare and send User Op bytes memory userOpCalldata = abi.encodeCall( - IERC7579Account.execute, - ( - ModeLib.encodeSimpleSingle(), - ExecutionLib.encodeSingle(address(target), 100, "") - ) + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), 100, "")) ); uint256 nonce = getNonce(address(account), address(validator)); From 8910e1652537cfdac3e0afafb5a7f2f31b219eff Mon Sep 17 00:00:00 2001 From: Yash Date: Fri, 28 Feb 2025 20:31:13 +0530 Subject: [PATCH 06/15] gas benchmarks --- .gas-snapshot | 5 + foundry.toml | 4 + gasreport.txt | 18 ++ src/DefaultValidatorNew.sol | 2 +- src/lib/SessionLib.sol | 2 +- test/BenchmarkDefaultValidatorNew.t.sol | 281 ++++++++++++++++++++++++ 6 files changed, 310 insertions(+), 2 deletions(-) create mode 100644 .gas-snapshot create mode 100644 gasreport.txt create mode 100644 test/BenchmarkDefaultValidatorNew.t.sol diff --git a/.gas-snapshot b/.gas-snapshot new file mode 100644 index 0000000..2fb0fc4 --- /dev/null +++ b/.gas-snapshot @@ -0,0 +1,5 @@ +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_100CallPolicies() (gas: 2467123) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_100TransferPolicies() (gas: 2209472) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 2336299) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50CallPolicies() (gas: 2259807) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50TransferPolicies() (gas: 2134908) \ No newline at end of file diff --git a/foundry.toml b/foundry.toml index 61ba0f2..f3050a6 100644 --- a/foundry.toml +++ b/foundry.toml @@ -8,6 +8,10 @@ src = "src" out = "out" libs = ["lib"] +gas_reports = [ + "BenchmarkDefaultValidatorNew", +] + ignored_warnings_from = ["lib", "test"] [fmt] diff --git a/gasreport.txt b/gasreport.txt new file mode 100644 index 0000000..904deeb --- /dev/null +++ b/gasreport.txt @@ -0,0 +1,18 @@ +No files changed, compilation skipped + +Ran 5 tests for test/BenchmarkDefaultValidatorNew.t.sol:BenchmarkDefaultValidatorNewTest +[PASS] test_createSessionKeyForSigner_100CallPolicies() (gas: 2467123) +[PASS] test_createSessionKeyForSigner_100TransferPolicies() (gas: 2209472) +[PASS] test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 2336299) +[PASS] test_createSessionKeyForSigner_50CallPolicies() (gas: 2259807) +[PASS] test_createSessionKeyForSigner_50TransferPolicies() (gas: 2134908) +Suite result: ok. 5 passed; 0 failed; 0 skipped; finished in 11.46ms (35.55ms CPU time) + + +Ran 1 test suite in 12.88ms (11.46ms CPU time): 5 tests passed, 0 failed, 0 skipped (5 total tests) +test_createSessionKeyForSigner_100CallPolicies() (gas: 0 (0.000%)) +test_createSessionKeyForSigner_100TransferPolicies() (gas: 0 (0.000%)) +test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 0 (0.000%)) +test_createSessionKeyForSigner_50CallPolicies() (gas: 0 (0.000%)) +test_createSessionKeyForSigner_50TransferPolicies() (gas: 0 (0.000%)) +Overall gas change: 0 (0.000%) diff --git a/src/DefaultValidatorNew.sol b/src/DefaultValidatorNew.sol index 323a8da..211aec0 100644 --- a/src/DefaultValidatorNew.sol +++ b/src/DefaultValidatorNew.sol @@ -119,7 +119,7 @@ contract DefaultValidatorNew is IValidator, EIP712 { address account = msg.sender; bytes32 hash = SignatureCheckerLib.toEthSignedMessageHash(userOpHash); - address signer = hash.recover(userOp.signature); + address signer = hash.recover(userOp.signature); // TODO; recover other sigs - 1271, 6492 (bool isValid, uint48 validAfter, uint48 validUntil) = _isValidSigner(account, signer, userOp); if (!isValid) { diff --git a/src/lib/SessionLib.sol b/src/lib/SessionLib.sol index 9a60f43..3f757c4 100644 --- a/src/lib/SessionLib.sol +++ b/src/lib/SessionLib.sol @@ -61,7 +61,7 @@ library SessionLib { struct SessionSpec { address signer; uint256 expiresAt; - UsageLimit feeLimit; + UsageLimit feeLimit; // TODO: remove CallSpec[] callPolicies; TransferSpec[] transferPolicies; } diff --git a/test/BenchmarkDefaultValidatorNew.t.sol b/test/BenchmarkDefaultValidatorNew.t.sol new file mode 100644 index 0000000..a64d523 --- /dev/null +++ b/test/BenchmarkDefaultValidatorNew.t.sol @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import {UserOperationLib} from "account-abstraction-v0.7/core/UserOperationLib.sol"; +import {IEntryPoint} from "account-abstraction-v0.7/interfaces/IEntryPoint.sol"; +import {PackedUserOperation} from "account-abstraction-v0.7/interfaces/PackedUserOperation.sol"; +import {Test} from "forge-std/Test.sol"; +import {ECDSA} from "solady/utils/ECDSA.sol"; +import {SignatureCheckerLib} from "solady/utils/SignatureCheckerLib.sol"; + +import {MockTarget} from "test/mock/MockTarget.sol"; +import {MockValidator} from "test/mock/MockValidator.sol"; +import {EntryPointLib} from "test/util/ERC4337Test.sol"; + +import "lib/forge-std/src/console.sol"; + +import { + CALLTYPE_SINGLE, + CALLTYPE_STATIC, + Execution, + ExecutionLib, + IERC7579Account, + MODULE_TYPE_FALLBACK, + MODULE_TYPE_VALIDATOR, + ModeLib, + ModularAccount +} from "src/ModularAccount.sol"; +import {ModularAccountFactory} from "src/ModularAccountFactory.sol"; + +import {InitializerInstallModule} from "src/interface/IModularAccount.sol"; + +import {DefaultValidatorNew} from "src/DefaultValidatorNew.sol"; + +import {SessionLib} from "src/lib/SessionLib.sol"; + +contract BenchmarkDefaultValidatorNewTest is Test { + + IEntryPoint public entrypoint; + ModularAccountFactory public factory; + ModularAccount public account; + + bytes public accountSalt = abi.encodePacked(uint256(0xdeadbeef)); + + uint256 accountOwnerPKey = 0x2; + address public factoryOwner = vm.addr(0x1); + address public accountOwner = vm.addr(accountOwnerPKey); + + uint256 signerOnePKey = 1234; + address signerOne = vm.addr(signerOnePKey); + + address public validator; + MockTarget public target; + + function setUp() public { + vm.pauseGasMetering(); + vm.prank(factoryOwner); + entrypoint = IEntryPoint(EntryPointLib.deploy()); + validator = address(new DefaultValidatorNew()); + target = new MockTarget(); + ModularAccount accountImpl = new ModularAccount(address(entrypoint)); + factory = new ModularAccountFactory(address(entrypoint), factoryOwner, address(accountImpl)); + + InitializerInstallModule[] memory modules = _prepareDefaultValidatorNewInstallData(); + account = ModularAccount(payable(factory.createAccountWithModules(accountOwner, accountSalt, modules))); + + vm.resumeGasMetering(); + } + + // === session key tests + + // create session key + function test_createSessionKeyForSigner_50CallPolicies() public { + vm.pauseGasMetering(); + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](50); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](0); + + for(uint256 i = 0; i < callPolicies.length; i++) { + callPolicies[i].target = vm.addr(i + 1); + callPolicies[i].selector = 0x12345678; + callPolicies[i].maxValuePerUse = i; + callPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[i].constraints = new SessionLib.Constraint[](1); + } + + spec.signer = vm.addr(1234); + spec.expiresAt = 100; + spec.feeLimit = feeLimit; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + + vm.deal(address(account), 1 ether); + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_createSessionKeyForSigner_100CallPolicies() public { + vm.pauseGasMetering(); + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](100); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](0); + + for(uint256 i = 0; i < callPolicies.length; i++) { + callPolicies[i].target = vm.addr(i + 1); + callPolicies[i].selector = 0x12345678; + callPolicies[i].maxValuePerUse = i; + callPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[i].constraints = new SessionLib.Constraint[](1); + } + + spec.signer = vm.addr(1234); + spec.expiresAt = 100; + spec.feeLimit = feeLimit; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + + vm.deal(address(account), 1 ether); + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_createSessionKeyForSigner_50TransferPolicies() public { + vm.pauseGasMetering(); + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](0); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](50); + + for(uint256 i = 0; i < transferPolicies.length; i++) { + transferPolicies[i].target = vm.addr(i + 1); + transferPolicies[i].maxValuePerUse = i; + transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + } + + spec.signer = vm.addr(1234); + spec.expiresAt = 100; + spec.feeLimit = feeLimit; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + + vm.deal(address(account), 1 ether); + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_createSessionKeyForSigner_100TransferPolicies() public { + vm.pauseGasMetering(); + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](0); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](100); + + for(uint256 i = 0; i < transferPolicies.length; i++) { + transferPolicies[i].target = vm.addr(i + 1); + transferPolicies[i].maxValuePerUse = i; + transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + } + + spec.signer = vm.addr(1234); + spec.expiresAt = 100; + spec.feeLimit = feeLimit; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + + vm.deal(address(account), 1 ether); + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_createSessionKeyForSigner_50Call50TransferPolicies() public { + vm.pauseGasMetering(); + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](50); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](50); + + for(uint256 i = 0; i < callPolicies.length; i++) { + callPolicies[i].target = vm.addr(i + 1); + callPolicies[i].selector = 0x12345678; + callPolicies[i].maxValuePerUse = i; + callPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[i].constraints = new SessionLib.Constraint[](1); + } + + for(uint256 i = 0; i < transferPolicies.length; i++) { + transferPolicies[i].target = vm.addr(i + 1); + transferPolicies[i].maxValuePerUse = i; + transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + } + + spec.signer = vm.addr(1234); + spec.expiresAt = 100; + spec.feeLimit = feeLimit; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + + vm.deal(address(account), 1 ether); + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + // test utils + + function getNonce(address _account, address _validator) internal returns (uint256 nonce) { + uint192 key = uint192(bytes24(bytes20(address(_validator)))); + nonce = (uint256(uint160(_validator)) << 96) | entrypoint.getNonce(_account, key); + } + + function getDefaultUserOp() internal returns (PackedUserOperation memory userOp) { + userOp = PackedUserOperation({ + sender: address(0), + nonce: 0, + initCode: "", + callData: "", + accountGasLimits: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + preVerificationGas: 2e6, + gasFees: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + paymasterAndData: bytes(""), + signature: abi.encodePacked(hex"41414141") + }); + } + + function _prepareDefaultValidatorNewInstallData() internal returns (InitializerInstallModule[] memory) { + InitializerInstallModule[] memory modules = new InitializerInstallModule[](1); + modules[0] = + InitializerInstallModule({moduleTypeId: MODULE_TYPE_VALIDATOR, module: validator, initData: bytes("")}); + return modules; + } + + function _getCreateSessionKeyOp(SessionLib.SessionSpec memory spec) internal returns (PackedUserOperation[] memory userOps) { + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle( + address(validator), uint256(0), abi.encodeCall(DefaultValidatorNew.createSessionKey, (spec)) + ) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, msgHash); + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + userOp.signature = userOpSignature; + } + + // Create userOps array + userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + } + +} From 385bf8cebecd18bae3f3f74d8c7f1bc70e630dc9 Mon Sep 17 00:00:00 2001 From: Yash Date: Fri, 28 Feb 2025 20:31:49 +0530 Subject: [PATCH 07/15] fmt --- test/BenchmarkDefaultValidatorNew.t.sol | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/test/BenchmarkDefaultValidatorNew.t.sol b/test/BenchmarkDefaultValidatorNew.t.sol index a64d523..d7026f4 100644 --- a/test/BenchmarkDefaultValidatorNew.t.sol +++ b/test/BenchmarkDefaultValidatorNew.t.sol @@ -77,7 +77,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](50); SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](0); - for(uint256 i = 0; i < callPolicies.length; i++) { + for (uint256 i = 0; i < callPolicies.length; i++) { callPolicies[i].target = vm.addr(i + 1); callPolicies[i].selector = 0x12345678; callPolicies[i].maxValuePerUse = i; @@ -106,7 +106,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](100); SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](0); - for(uint256 i = 0; i < callPolicies.length; i++) { + for (uint256 i = 0; i < callPolicies.length; i++) { callPolicies[i].target = vm.addr(i + 1); callPolicies[i].selector = 0x12345678; callPolicies[i].maxValuePerUse = i; @@ -135,7 +135,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](0); SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](50); - for(uint256 i = 0; i < transferPolicies.length; i++) { + for (uint256 i = 0; i < transferPolicies.length; i++) { transferPolicies[i].target = vm.addr(i + 1); transferPolicies[i].maxValuePerUse = i; transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); @@ -162,7 +162,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](0); SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](100); - for(uint256 i = 0; i < transferPolicies.length; i++) { + for (uint256 i = 0; i < transferPolicies.length; i++) { transferPolicies[i].target = vm.addr(i + 1); transferPolicies[i].maxValuePerUse = i; transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); @@ -189,7 +189,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](50); SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](50); - for(uint256 i = 0; i < callPolicies.length; i++) { + for (uint256 i = 0; i < callPolicies.length; i++) { callPolicies[i].target = vm.addr(i + 1); callPolicies[i].selector = 0x12345678; callPolicies[i].maxValuePerUse = i; @@ -197,7 +197,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { callPolicies[i].constraints = new SessionLib.Constraint[](1); } - for(uint256 i = 0; i < transferPolicies.length; i++) { + for (uint256 i = 0; i < transferPolicies.length; i++) { transferPolicies[i].target = vm.addr(i + 1); transferPolicies[i].maxValuePerUse = i; transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); @@ -244,7 +244,10 @@ contract BenchmarkDefaultValidatorNewTest is Test { return modules; } - function _getCreateSessionKeyOp(SessionLib.SessionSpec memory spec) internal returns (PackedUserOperation[] memory userOps) { + function _getCreateSessionKeyOp(SessionLib.SessionSpec memory spec) + internal + returns (PackedUserOperation[] memory userOps) + { bytes memory userOpCalldata = abi.encodeCall( IERC7579Account.execute, ( From 665b12fea7dd586ecf4acc9664017e05c4eca5a9 Mon Sep 17 00:00:00 2001 From: Yash Date: Fri, 28 Feb 2025 21:39:53 +0530 Subject: [PATCH 08/15] gas report --- .gas-snapshot | 14 +- gasreport.txt | 36 +-- test/BenchmarkDefaultValidatorNew.t.sol | 319 +++++++++++++++++++++++- 3 files changed, 343 insertions(+), 26 deletions(-) diff --git a/.gas-snapshot b/.gas-snapshot index 2fb0fc4..9b1c484 100644 --- a/.gas-snapshot +++ b/.gas-snapshot @@ -1,5 +1,9 @@ -BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_100CallPolicies() (gas: 2467123) -BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_100TransferPolicies() (gas: 2209472) -BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 2336299) -BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50CallPolicies() (gas: 2259807) -BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50TransferPolicies() (gas: 2134908) \ No newline at end of file +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_100CallPolicies() (gas: 2467172) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_100TransferPolicies() (gas: 2209520) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 2336380) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50CallPolicies() (gas: 2259810) +BenchmarkDefaultValidatorNewTest:test_createSessionKeyForSigner_50TransferPolicies() (gas: 2134968) +BenchmarkDefaultValidatorNewTest:test_execute_100CallPolicies() (gas: 2760504) +BenchmarkDefaultValidatorNewTest:test_execute_100TransferPolicies() (gas: 1317589) +BenchmarkDefaultValidatorNewTest:test_execute_50CallPolicies() (gas: 1471450) +BenchmarkDefaultValidatorNewTest:test_execute_50TransferPolicies() (gas: 743908) \ No newline at end of file diff --git a/gasreport.txt b/gasreport.txt index 904deeb..01bab83 100644 --- a/gasreport.txt +++ b/gasreport.txt @@ -1,18 +1,26 @@ No files changed, compilation skipped -Ran 5 tests for test/BenchmarkDefaultValidatorNew.t.sol:BenchmarkDefaultValidatorNewTest -[PASS] test_createSessionKeyForSigner_100CallPolicies() (gas: 2467123) -[PASS] test_createSessionKeyForSigner_100TransferPolicies() (gas: 2209472) -[PASS] test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 2336299) -[PASS] test_createSessionKeyForSigner_50CallPolicies() (gas: 2259807) -[PASS] test_createSessionKeyForSigner_50TransferPolicies() (gas: 2134908) -Suite result: ok. 5 passed; 0 failed; 0 skipped; finished in 11.46ms (35.55ms CPU time) +Ran 9 tests for test/BenchmarkDefaultValidatorNew.t.sol:BenchmarkDefaultValidatorNewTest +[PASS] test_createSessionKeyForSigner_100CallPolicies() (gas: 2467149) +[PASS] test_createSessionKeyForSigner_100TransferPolicies() (gas: 2209521) +[PASS] test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 2336358) +[PASS] test_createSessionKeyForSigner_50CallPolicies() (gas: 2259844) +[PASS] test_createSessionKeyForSigner_50TransferPolicies() (gas: 2134946) +[PASS] test_execute_100CallPolicies() (gas: 2760545) +[PASS] test_execute_100TransferPolicies() (gas: 1317589) +[PASS] test_execute_50CallPolicies() (gas: 1471450) +[PASS] test_execute_50TransferPolicies() (gas: 743850) +Suite result: ok. 9 passed; 0 failed; 0 skipped; finished in 12.74ms (69.57ms CPU time) -Ran 1 test suite in 12.88ms (11.46ms CPU time): 5 tests passed, 0 failed, 0 skipped (5 total tests) -test_createSessionKeyForSigner_100CallPolicies() (gas: 0 (0.000%)) -test_createSessionKeyForSigner_100TransferPolicies() (gas: 0 (0.000%)) -test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 0 (0.000%)) -test_createSessionKeyForSigner_50CallPolicies() (gas: 0 (0.000%)) -test_createSessionKeyForSigner_50TransferPolicies() (gas: 0 (0.000%)) -Overall gas change: 0 (0.000%) +Ran 1 test suite in 14.95ms (12.74ms CPU time): 9 tests passed, 0 failed, 0 skipped (9 total tests) +test_execute_100TransferPolicies() (gas: 0 (0.000%)) +test_execute_50CallPolicies() (gas: 0 (0.000%)) +test_createSessionKeyForSigner_100TransferPolicies() (gas: 1 (0.000%)) +test_createSessionKeyForSigner_100CallPolicies() (gas: -23 (-0.001%)) +test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: -22 (-0.001%)) +test_createSessionKeyForSigner_50TransferPolicies() (gas: -22 (-0.001%)) +test_execute_100CallPolicies() (gas: 41 (0.001%)) +test_createSessionKeyForSigner_50CallPolicies() (gas: 34 (0.002%)) +test_execute_50TransferPolicies() (gas: -58 (-0.008%)) +Overall gas change: -49 (-0.000%) diff --git a/test/BenchmarkDefaultValidatorNew.t.sol b/test/BenchmarkDefaultValidatorNew.t.sol index d7026f4..7f80c37 100644 --- a/test/BenchmarkDefaultValidatorNew.t.sol +++ b/test/BenchmarkDefaultValidatorNew.t.sol @@ -39,8 +39,6 @@ contract BenchmarkDefaultValidatorNewTest is Test { ModularAccountFactory public factory; ModularAccount public account; - bytes public accountSalt = abi.encodePacked(uint256(0xdeadbeef)); - uint256 accountOwnerPKey = 0x2; address public factoryOwner = vm.addr(0x1); address public accountOwner = vm.addr(accountOwnerPKey); @@ -51,8 +49,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { address public validator; MockTarget public target; - function setUp() public { - vm.pauseGasMetering(); + function _setup(bytes memory accountSalt) internal { vm.prank(factoryOwner); entrypoint = IEntryPoint(EntryPointLib.deploy()); validator = address(new DefaultValidatorNew()); @@ -62,8 +59,6 @@ contract BenchmarkDefaultValidatorNewTest is Test { InitializerInstallModule[] memory modules = _prepareDefaultValidatorNewInstallData(); account = ModularAccount(payable(factory.createAccountWithModules(accountOwner, accountSalt, modules))); - - vm.resumeGasMetering(); } // === session key tests @@ -71,6 +66,8 @@ contract BenchmarkDefaultValidatorNewTest is Test { // create session key function test_createSessionKeyForSigner_50CallPolicies() public { vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x1))); + SessionLib.SessionSpec memory spec; SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); @@ -100,6 +97,8 @@ contract BenchmarkDefaultValidatorNewTest is Test { function test_createSessionKeyForSigner_100CallPolicies() public { vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x2))); + SessionLib.SessionSpec memory spec; SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); @@ -129,6 +128,8 @@ contract BenchmarkDefaultValidatorNewTest is Test { function test_createSessionKeyForSigner_50TransferPolicies() public { vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x3))); + SessionLib.SessionSpec memory spec; SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); @@ -156,6 +157,8 @@ contract BenchmarkDefaultValidatorNewTest is Test { function test_createSessionKeyForSigner_100TransferPolicies() public { vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x4))); + SessionLib.SessionSpec memory spec; SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); @@ -183,6 +186,8 @@ contract BenchmarkDefaultValidatorNewTest is Test { function test_createSessionKeyForSigner_50Call50TransferPolicies() public { vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x5))); + SessionLib.SessionSpec memory spec; SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); @@ -216,6 +221,306 @@ contract BenchmarkDefaultValidatorNewTest is Test { entrypoint.handleOps(userOps, payable(address(account))); } + // execute + + function test_execute_50CallPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x6))); + + // 1. create session key + + { + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.TransferSpec[] memory transferPolicies; + + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](50); + for (uint256 i = 0; i < callPolicies.length - 1; i++) { + callPolicies[i].target = vm.addr(i + 1); + callPolicies[i].selector = 0x12345678; + callPolicies[i].maxValuePerUse = i; + callPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[i].constraints = new SessionLib.Constraint[](1); + } + + SessionLib.Constraint[] memory mockTargetConstraint = new SessionLib.Constraint[](1); + mockTargetConstraint[0].condition = SessionLib.Condition.LessOrEqual; + mockTargetConstraint[0].refValue = bytes32(uint256(100)); + mockTargetConstraint[0].index = 0; + mockTargetConstraint[0].limit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + callPolicies[callPolicies.length - 1].target = address(target); + callPolicies[callPolicies.length - 1].selector = MockTarget.set.selector; + callPolicies[callPolicies.length - 1].maxValuePerUse = 1000; + callPolicies[callPolicies.length - 1].valueLimit = + SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[callPolicies.length - 1].constraints = mockTargetConstraint; + + spec.signer = signerOne; + spec.expiresAt = 100; + spec.feeLimit = feeLimit; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + vm.deal(address(account), 2 ether); + entrypoint.handleOps(userOps, payable(address(account))); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), uint256(0), abi.encodeCall(MockTarget.set, 42))) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerOnePKey, msgHash); // sign with session key signer private key + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + address recoveredSigner = ECDSA.recover(msgHash, v, r, s); + assertEq(recoveredSigner, signerOne); + + userOp.signature = userOpSignature; + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_execute_100CallPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x7))); + + // 1. create session key + + { + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.TransferSpec[] memory transferPolicies; + + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](100); + for (uint256 i = 0; i < callPolicies.length - 1; i++) { + callPolicies[i].target = vm.addr(i + 1); + callPolicies[i].selector = 0x12345678; + callPolicies[i].maxValuePerUse = i; + callPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[i].constraints = new SessionLib.Constraint[](1); + } + + SessionLib.Constraint[] memory mockTargetConstraint = new SessionLib.Constraint[](1); + mockTargetConstraint[0].condition = SessionLib.Condition.LessOrEqual; + mockTargetConstraint[0].refValue = bytes32(uint256(100)); + mockTargetConstraint[0].index = 0; + mockTargetConstraint[0].limit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + callPolicies[callPolicies.length - 1].target = address(target); + callPolicies[callPolicies.length - 1].selector = MockTarget.set.selector; + callPolicies[callPolicies.length - 1].maxValuePerUse = 1000; + callPolicies[callPolicies.length - 1].valueLimit = + SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + callPolicies[callPolicies.length - 1].constraints = mockTargetConstraint; + + spec.signer = signerOne; + spec.expiresAt = 100; + spec.feeLimit = feeLimit; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + vm.deal(address(account), 3 ether); + entrypoint.handleOps(userOps, payable(address(account))); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), uint256(0), abi.encodeCall(MockTarget.set, 42))) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerOnePKey, msgHash); // sign with session key signer private key + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + address recoveredSigner = ECDSA.recover(msgHash, v, r, s); + assertEq(recoveredSigner, signerOne); + + userOp.signature = userOpSignature; + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_execute_50TransferPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x8))); + + // 1. create session key + + { + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies; + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](50); + for (uint256 i = 0; i < transferPolicies.length - 1; i++) { + transferPolicies[i].target = vm.addr(i + 1); + transferPolicies[i].maxValuePerUse = i; + transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + } + transferPolicies[transferPolicies.length - 1].target = address(target); + transferPolicies[transferPolicies.length - 1].maxValuePerUse = 1000; + transferPolicies[transferPolicies.length - 1].valueLimit = + SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.signer = signerOne; + spec.expiresAt = 100; + spec.feeLimit = feeLimit; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + vm.deal(address(account), 2 ether); + entrypoint.handleOps(userOps, payable(address(account))); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), 100, "")) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerOnePKey, msgHash); // sign with session key signer private key + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + address recoveredSigner = ECDSA.recover(msgHash, v, r, s); + assertEq(recoveredSigner, signerOne); + + userOp.signature = userOpSignature; + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + + function test_execute_100TransferPolicies() public { + vm.pauseGasMetering(); + _setup(abi.encodePacked(uint256(0x9))); + + // 1. create session key + + { + SessionLib.SessionSpec memory spec; + + SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + SessionLib.CallSpec[] memory callPolicies; + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](100); + for (uint256 i = 0; i < transferPolicies.length - 1; i++) { + transferPolicies[i].target = vm.addr(i + 1); + transferPolicies[i].maxValuePerUse = i; + transferPolicies[i].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + } + transferPolicies[transferPolicies.length - 1].target = address(target); + transferPolicies[transferPolicies.length - 1].maxValuePerUse = 1000; + transferPolicies[transferPolicies.length - 1].valueLimit = + SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.signer = signerOne; + spec.expiresAt = 100; + spec.feeLimit = feeLimit; + spec.callPolicies = callPolicies; + spec.transferPolicies = transferPolicies; + + PackedUserOperation[] memory userOps = _getCreateSessionKeyOp(spec); + vm.deal(address(account), 2 ether); + entrypoint.handleOps(userOps, payable(address(account))); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), 100, "")) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerOnePKey, msgHash); // sign with session key signer private key + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + address recoveredSigner = ECDSA.recover(msgHash, v, r, s); + assertEq(recoveredSigner, signerOne); + + userOp.signature = userOpSignature; + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + vm.resumeGasMetering(); + entrypoint.handleOps(userOps, payable(address(account))); + } + // test utils function getNonce(address _account, address _validator) internal returns (uint256 nonce) { @@ -229,7 +534,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { nonce: 0, initCode: "", callData: "", - accountGasLimits: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + accountGasLimits: bytes32(abi.encodePacked(uint128(3e6), uint128(2e6))), preVerificationGas: 2e6, gasFees: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), paymasterAndData: bytes(""), From 253a4abe5cce0d0320166fcd6acd1d3745803050 Mon Sep 17 00:00:00 2001 From: Yash Date: Tue, 4 Mar 2025 23:33:06 +0530 Subject: [PATCH 09/15] recover 1271 sig in validateUserOp --- src/DefaultValidatorNew.sol | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/DefaultValidatorNew.sol b/src/DefaultValidatorNew.sol index 211aec0..7a4094c 100644 --- a/src/DefaultValidatorNew.sol +++ b/src/DefaultValidatorNew.sol @@ -118,8 +118,15 @@ contract DefaultValidatorNew is IValidator, EIP712 { function validateUserOp(PackedUserOperation calldata userOp, bytes32 userOpHash) external returns (uint256) { address account = msg.sender; + (address signer, bytes memory signature) = abi.decode(userOp.signature, (address, bytes)); bytes32 hash = SignatureCheckerLib.toEthSignedMessageHash(userOpHash); - address signer = hash.recover(userOp.signature); // TODO; recover other sigs - 1271, 6492 + + bool isValidSig = SignatureCheckerLib.isValidSignatureNow(signer, hash, signature); + + if (!isValidSig) { + return VALIDATION_FAILED; + } + (bool isValid, uint48 validAfter, uint48 validUntil) = _isValidSigner(account, signer, userOp); if (!isValid) { @@ -137,6 +144,7 @@ contract DefaultValidatorNew is IValidator, EIP712 { returns (bytes4) { address account = msg.sender; + bytes32 targetDigest = getMessageHash(hash); address signer = ECDSA.recover(targetDigest, signature); From 599a00e45ce736462c13b716a0dddeb6a5d89fbb Mon Sep 17 00:00:00 2001 From: Yash Date: Wed, 5 Mar 2025 20:34:51 +0530 Subject: [PATCH 10/15] tests --- test/DefaultValidatorNew.t.sol | 458 ++++++++++++++++++++++++++------- test/mock/MockERC20.sol | 20 ++ 2 files changed, 390 insertions(+), 88 deletions(-) create mode 100644 test/mock/MockERC20.sol diff --git a/test/DefaultValidatorNew.t.sol b/test/DefaultValidatorNew.t.sol index 812ae05..200bf71 100644 --- a/test/DefaultValidatorNew.t.sol +++ b/test/DefaultValidatorNew.t.sol @@ -5,9 +5,12 @@ import {UserOperationLib} from "account-abstraction-v0.7/core/UserOperationLib.s import {IEntryPoint} from "account-abstraction-v0.7/interfaces/IEntryPoint.sol"; import {PackedUserOperation} from "account-abstraction-v0.7/interfaces/PackedUserOperation.sol"; import {Test} from "forge-std/Test.sol"; + +import {ERC20} from "solady/tokens/ERC20.sol"; import {ECDSA} from "solady/utils/ECDSA.sol"; import {SignatureCheckerLib} from "solady/utils/SignatureCheckerLib.sol"; +import {MockERC20} from "test/mock/MockERC20.sol"; import {MockTarget} from "test/mock/MockTarget.sol"; import {MockValidator} from "test/mock/MockValidator.sol"; import {EntryPointLib} from "test/util/ERC4337Test.sol"; @@ -38,29 +41,33 @@ contract DefaultValidatorNewTest is Test { IEntryPoint public entrypoint; ModularAccountFactory public factory; ModularAccount public account; + MockTarget public target; + MockERC20 public erc20; + address public validator; bytes public accountSalt = abi.encodePacked(uint256(0xdeadbeef)); - uint256 accountOwnerPKey = 0x2; address public factoryOwner = vm.addr(0x1); + uint256 accountOwnerPKey = 0x2; address public accountOwner = vm.addr(accountOwnerPKey); - uint256 signerOnePKey = 1234; + uint256 signerOnePKey = 0x3; address signerOne = vm.addr(signerOnePKey); - address public validator; - MockTarget public target; - function setUp() public { vm.prank(factoryOwner); entrypoint = IEntryPoint(EntryPointLib.deploy()); validator = address(new DefaultValidatorNew()); target = new MockTarget(); + erc20 = new MockERC20(); ModularAccount accountImpl = new ModularAccount(address(entrypoint)); factory = new ModularAccountFactory(address(entrypoint), factoryOwner, address(accountImpl)); InitializerInstallModule[] memory modules = _prepareDefaultValidatorNewInstallData(); account = ModularAccount(payable(factory.createAccountWithModules(accountOwner, accountSalt, modules))); + + vm.deal(address(account), 1 ether); + erc20.mint(address(account), 1 ether); } function test_execute_owner() public { @@ -79,25 +86,8 @@ contract DefaultValidatorNewTest is Test { userOp.nonce = nonce; userOp.callData = userOpCalldata; - // Sign UserOp - { - bytes32 opHash = entrypoint.getUserOpHash(userOp); - bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); - - (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, msgHash); - bytes memory userOpSignature = abi.encodePacked(r, s, v); - - address recoveredSigner = ECDSA.recover(msgHash, v, r, s); - assertEq(recoveredSigner, accountOwner); - - userOp.signature = userOpSignature; - } - - // Create userOps array - PackedUserOperation[] memory userOps = new PackedUserOperation[](1); - userOps[0] = userOp; + PackedUserOperation[] memory userOps = _getSignedUserOp(accountOwner, accountOwnerPKey, userOp); - vm.deal(address(account), 1 ether); entrypoint.handleOps(userOps, payable(address(account))); assertTrue(target.number() == 42); @@ -118,25 +108,8 @@ contract DefaultValidatorNewTest is Test { userOp.nonce = nonce; userOp.callData = userOpCalldata; - // Sign UserOp - { - bytes32 opHash = entrypoint.getUserOpHash(userOp); - bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); - - (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, msgHash); - bytes memory userOpSignature = abi.encodePacked(r, s, v); + PackedUserOperation[] memory userOps = _getSignedUserOp(accountOwner, accountOwnerPKey, userOp); - address recoveredSigner = ECDSA.recover(msgHash, v, r, s); - assertEq(recoveredSigner, accountOwner); - - userOp.signature = userOpSignature; - } - - // Create userOps array - PackedUserOperation[] memory userOps = new PackedUserOperation[](1); - userOps[0] = userOp; - - vm.deal(address(account), 1 ether); entrypoint.handleOps(userOps, payable(address(account))); assertTrue(target.number() == 42); @@ -162,55 +135,37 @@ contract DefaultValidatorNewTest is Test { // === session key tests // create session key - function test_createSessionKeyForSigner() public { - SessionLib.SessionSpec memory spec; - - SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); - SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](0); - SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](0); - - spec.signer = vm.addr(1234); - spec.expiresAt = 100; - spec.feeLimit = feeLimit; - spec.callPolicies = callPolicies; - spec.transferPolicies = transferPolicies; + function test_createSessionKey() public { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); _createSessionKey(spec); SessionLib.CallSpec[] memory fetchedCallPolicies = - DefaultValidatorNew(validator).getCallPoliciesForSigner(address(account), vm.addr(1234)); + DefaultValidatorNew(validator).getCallPoliciesForSigner(address(account), signerOne); SessionLib.TransferSpec[] memory fetchedTransferPolicies = - DefaultValidatorNew(validator).getTransferPoliciesForSigner(address(account), vm.addr(1234)); + DefaultValidatorNew(validator).getTransferPoliciesForSigner(address(account), signerOne); SessionLib.UsageLimit memory fetchedFeeLimit = - DefaultValidatorNew(validator).getFeeLimitForSigner(address(account), vm.addr(1234)); + DefaultValidatorNew(validator).getFeeLimitForSigner(address(account), signerOne); - uint256 expiresAt = - DefaultValidatorNew(validator).getSessionExpirationForSigner(address(account), vm.addr(1234)); + uint256 expiresAt = DefaultValidatorNew(validator).getSessionExpirationForSigner(address(account), signerOne); - assertEq(uint256(fetchedFeeLimit.limitType), uint256(SessionLib.LimitType.Lifetime)); - assertEq(expiresAt, 100); + assertEq(uint256(fetchedFeeLimit.limitType), uint256(spec.feeLimit.limitType)); + assertEq(expiresAt, spec.expiresAt); } - function test_execute_transferPolicy() public { + function test_transferPolicy() public { // 1. create session key { - SessionLib.SessionSpec memory spec; - - SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); - SessionLib.CallSpec[] memory callPolicies; + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); transferPolicies[0].target = address(target); - transferPolicies[0].maxValuePerUse = 1000; + transferPolicies[0].maxValuePerUse = 500; transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); - spec.signer = signerOne; - spec.expiresAt = 100; - spec.feeLimit = feeLimit; - spec.callPolicies = callPolicies; spec.transferPolicies = transferPolicies; _createSessionKey(spec); @@ -229,28 +184,339 @@ contract DefaultValidatorNewTest is Test { userOp.nonce = nonce; userOp.callData = userOpCalldata; - // Sign UserOp + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + + assertTrue(address(target).balance == 100); + } + + function test_revert_transferPolicy_crossLimit() public { + // 1. create session key + { - bytes32 opHash = entrypoint.getUserOpHash(userOp); - bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); - (uint8 v, bytes32 r, bytes32 s) = vm.sign(signerOnePKey, msgHash); // sign with session key signer private key - bytes memory userOpSignature = abi.encodePacked(r, s, v); + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); + transferPolicies[0].target = address(target); + transferPolicies[0].maxValuePerUse = 500; + transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); - address recoveredSigner = ECDSA.recover(msgHash, v, r, s); - assertEq(recoveredSigner, signerOne); + spec.transferPolicies = transferPolicies; - userOp.signature = userOpSignature; + _createSessionKey(spec); } - // Create userOps array - PackedUserOperation[] memory userOps = new PackedUserOperation[](1); - userOps[0] = userOp; + // 2. prepare and send User Op - vm.deal(address(account), 1 ether); - entrypoint.handleOps(userOps, payable(address(account))); + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), 600, "")) + ); - assertTrue(address(target).balance == 100); + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + + function test_callPolicy_erc20Transfer() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = ERC20.transfer.selector; + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + // set constraints + SessionLib.Constraint[] memory constraints = new SessionLib.Constraint[](2); + // constraint 0: can only transfer to target + constraints[0].index = 0; + constraints[0].refValue = bytes32(uint256(uint160(address(target)))); + constraints[0].condition = SessionLib.Condition.Equal; + + // constraint 1: can transfer 100 tokens per transaction, and 120 total + constraints[1].index = 1; + constraints[1].refValue = bytes32(uint256(100)); + constraints[1].condition = SessionLib.Condition.LessOrEqual; + constraints[1].limit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 120, 0); + + callPolicies[0].constraints = constraints; + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(erc20), 0, abi.encodeCall(ERC20.transfer, (address(target), 100))) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + + assertTrue(erc20.balanceOf(address(target)) == 100); + } + + function test_revert_callPolicy_erc20Transfer_wrongTarget() public { + address wrongTarget = vm.addr(0x789); + + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = ERC20.transfer.selector; + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + // set constraints + SessionLib.Constraint[] memory constraints = new SessionLib.Constraint[](2); + // constraint 0: can only transfer to defined target below + constraints[0].index = 0; + constraints[0].refValue = bytes32(uint256(uint160(address(target)))); + constraints[0].condition = SessionLib.Condition.Equal; + + // constraint 1: can transfer 100 tokens per transaction, and 120 total + constraints[1].index = 1; + constraints[1].refValue = bytes32(uint256(100)); + constraints[1].condition = SessionLib.Condition.LessOrEqual; + constraints[1].limit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 120, 0); + + callPolicies[0].constraints = constraints; + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + // === set target as wrongTarget here ========= + ExecutionLib.encodeSingle(address(erc20), 0, abi.encodeCall(ERC20.transfer, (wrongTarget, 100))) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + + function test_revert_callPolicy_erc20Transfer_crossTxLimit() public { + uint256 wrongAmount = 101; + + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = ERC20.transfer.selector; + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + // set constraints + SessionLib.Constraint[] memory constraints = new SessionLib.Constraint[](2); + // constraint 0: can only transfer to defined target below + constraints[0].index = 0; + constraints[0].refValue = bytes32(uint256(uint160(address(target)))); + constraints[0].condition = SessionLib.Condition.Equal; + + // constraint 1: can transfer 100 tokens per transaction, and 120 total + constraints[1].index = 1; + constraints[1].refValue = bytes32(uint256(100)); + constraints[1].condition = SessionLib.Condition.LessOrEqual; + constraints[1].limit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 120, 0); + + callPolicies[0].constraints = constraints; + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + // === set amount as wrongAmount here ========= + ExecutionLib.encodeSingle( + address(erc20), 0, abi.encodeCall(ERC20.transfer, (address(target), wrongAmount)) + ) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + + function test_revert_callPolicy_erc20Transfer_crossTotalLimit() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = ERC20.transfer.selector; + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + // set constraints + SessionLib.Constraint[] memory constraints = new SessionLib.Constraint[](2); + // constraint 0: can only transfer to defined target below + constraints[0].index = 0; + constraints[0].refValue = bytes32(uint256(uint160(address(target)))); + constraints[0].condition = SessionLib.Condition.Equal; + + // constraint 1: can transfer 100 tokens per transaction, and 120 total + constraints[1].index = 1; + constraints[1].refValue = bytes32(uint256(100)); + constraints[1].condition = SessionLib.Condition.LessOrEqual; + constraints[1].limit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 120, 0); + + callPolicies[0].constraints = constraints; + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(erc20), 0, abi.encodeCall(ERC20.transfer, (address(target), 100))) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + + // resend the same tx, with 100 amount. + // should fail because remaining total limit is 120 - 100 = 20 + + userOp.nonce = getNonce(address(account), address(validator)); + ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + + SessionLib.SessionState memory state = + DefaultValidatorNew(validator).getSessionStateForSigner(address(account), signerOne); + assertEq(state.callParams[0].remaining, 20); + } + + function test_sessionPeriodAllowance() public { + uint256 period = 100; + + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); + transferPolicies[0].target = address(target); + transferPolicies[0].maxValuePerUse = 500; + transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 600, period); // <== set period, and limitType as Allowance + + spec.expiresAt = period * 5; + spec.transferPolicies = transferPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), 500, "")) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + + assertTrue(address(target).balance == 500); + + // revert case: send more value in the same period, because period limit is crossed + + userOp.nonce = getNonce(address(account), address(validator)); + ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + + // success when period resets (i.e. block timestamp begins a new period) + vm.warp(period + 1); + entrypoint.handleOps(ops, payable(address(account))); + assertTrue(address(target).balance == 1000); } // test utils @@ -274,6 +540,14 @@ contract DefaultValidatorNewTest is Test { }); } + function getDefaultSessionSpec(address signer) internal returns (SessionLib.SessionSpec memory spec) { + spec.signer = signer; + spec.expiresAt = 100; + spec.feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + spec.callPolicies = new SessionLib.CallSpec[](0); + spec.transferPolicies = new SessionLib.TransferSpec[](0); + } + function _prepareDefaultValidatorNewInstallData() internal returns (InitializerInstallModule[] memory) { InitializerInstallModule[] memory modules = new InitializerInstallModule[](1); modules[0] = @@ -299,23 +573,31 @@ contract DefaultValidatorNewTest is Test { userOp.nonce = nonce; userOp.callData = userOpCalldata; + PackedUserOperation[] memory ops = _getSignedUserOp(accountOwner, accountOwnerPKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + } + + function _getSignedUserOp(address signer, uint256 pkey, PackedUserOperation memory userOp) + internal + returns (PackedUserOperation[] memory) + { // Sign UserOp { bytes32 opHash = entrypoint.getUserOpHash(userOp); bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); - (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, msgHash); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(pkey, msgHash); bytes memory userOpSignature = abi.encodePacked(r, s, v); - userOp.signature = userOpSignature; + userOp.signature = abi.encode(signer, userOpSignature); } // Create userOps array PackedUserOperation[] memory userOps = new PackedUserOperation[](1); userOps[0] = userOp; - vm.deal(address(account), 1 ether); - entrypoint.handleOps(userOps, payable(address(account))); + return userOps; } } diff --git a/test/mock/MockERC20.sol b/test/mock/MockERC20.sol new file mode 100644 index 0000000..28284d1 --- /dev/null +++ b/test/mock/MockERC20.sol @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import "solady/tokens/ERC20.sol"; + +contract MockERC20 is ERC20 { + + function name() public view override returns (string memory) { + return "Token"; + } + + function symbol() public view override returns (string memory) { + return "TKN"; + } + + function mint(address to, uint256 amount) external { + _mint(to, amount); + } + +} From 76e9c100f9b6684e4519e0cf4beba4a034c9db6a Mon Sep 17 00:00:00 2001 From: Yash Date: Wed, 5 Mar 2025 22:34:43 +0530 Subject: [PATCH 11/15] refactor - target hashes, sessionUid, new mappings etc --- src/DefaultValidatorNew.sol | 136 +++++++++++------------- src/lib/SessionLib.sol | 182 +++++++++++++-------------------- test/DefaultValidatorNew.t.sol | 5 - 3 files changed, 134 insertions(+), 189 deletions(-) diff --git a/src/DefaultValidatorNew.sol b/src/DefaultValidatorNew.sol index 7a4094c..1b334bf 100644 --- a/src/DefaultValidatorNew.sol +++ b/src/DefaultValidatorNew.sol @@ -44,9 +44,13 @@ library DefaultValidatorNewStorage { /// @dev Map from smart account address => is initialized mapping(address => bool) isInitialized; mapping(address => EnumerableSetLib.AddressSet) allSigners; - mapping(address account => uint256 openSessions) sessionCounter; - mapping(bytes32 sessionHash => SessionLib.SessionStorage sessionState) sessions; - mapping(address account => mapping(address signer => SessionLib.StoredSessionSpec)) sessionSpec; + mapping(address account => mapping(address signer => bytes32 sessionUid)) sessionIds; + mapping(bytes32 sessionUid => SessionLib.CallSpec[]) callPolicies; + mapping(bytes32 sessionUid => SessionLib.TransferSpec[]) transferPolicies; + mapping(bytes32 sessionUid => SessionLib.SessionStorage sessionState) sessions; + mapping(bytes32 sessionUid => uint256 expiresAt) sessionExpiration; + mapping(bytes32 targetHash => SessionLib.CallSpec) targetCallPolicy; + mapping(bytes32 targetHash => SessionLib.TransferSpec) targetTransferPolicy; } function data() internal pure returns (Data storage $) { @@ -68,12 +72,9 @@ contract DefaultValidatorNew is IValidator, EIP712 { using SessionLib for SessionLib.SessionStorage; error SessionZeroSigner(); - error SessionUnlimitedFees(); - error SessionAlreadyExists(); error SessionExpiresTooSoon(); - event SessionCreated(address indexed account, bytes32 indexed sessionHash, SessionLib.SessionSpec sessionSpec); // new - event SessionRevoked(address indexed account, bytes32 indexed sessionHash); // new + event SessionCreated(address indexed account, bytes32 indexed sessionUid, SessionLib.SessionSpec sessionSpec); uint256 private constant _ADMIN_ROLE = 1 << 0; @@ -96,11 +97,6 @@ contract DefaultValidatorNew is IValidator, EIP712 { address[] memory allSigners = _defaultValidatorStorage().allSigners[account].values(); delete _defaultValidatorStorage().allSigners[msg.sender]; - - uint256 len = allSigners.length; - for (uint256 i = 0; i < len; i += 1) { - delete _defaultValidatorStorage().sessionSpec[account][allSigners[i]]; - } } function isModuleType(uint256 moduleTypeId) external pure returns (bool) { @@ -126,7 +122,7 @@ contract DefaultValidatorNew is IValidator, EIP712 { if (!isValidSig) { return VALIDATION_FAILED; } - + (bool isValid, uint48 validAfter, uint48 validUntil) = _isValidSigner(account, signer, userOp); if (!isValid) { @@ -136,8 +132,6 @@ contract DefaultValidatorNew is IValidator, EIP712 { return _packValidationData(ValidationData(address(0), validAfter, validUntil)); } - /// @notice This module should not be used to validate signatures (including EIP-1271), - /// as a signature by itself does not have enough information to validate it against a session. function isValidSignatureWithSender(address sender, bytes32 hash, bytes calldata signature) external view @@ -176,52 +170,38 @@ contract DefaultValidatorNew is IValidator, EIP712 { // create session key for signer function createSessionKey(SessionLib.SessionSpec calldata sessionSpec) external { - SessionLib.StoredSessionSpec memory spec = SessionLib.StoredSessionSpec({ - expiresAt: sessionSpec.expiresAt, - feeLimit: sessionSpec.feeLimit, - callPolicies: sessionSpec.callPolicies, - transferPolicies: sessionSpec.transferPolicies - }); - bytes32 sessionHash = keccak256(abi.encode(sessionSpec.signer, spec)); - if (!_isInitialized(msg.sender)) { revert NotInitialized(msg.sender); } if (sessionSpec.signer == address(0)) { revert SessionZeroSigner(); } - if (sessionSpec.feeLimit.limitType == SessionLib.LimitType.Unlimited) { - revert SessionUnlimitedFees(); - } - if (_defaultValidatorStorage().sessions[sessionHash].status[msg.sender] != SessionLib.Status.NotInitialized) { - revert SessionAlreadyExists(); - } // Sessions should expire in no less than 60 seconds. if (sessionSpec.expiresAt <= block.timestamp + 60) { revert SessionExpiresTooSoon(); } - _defaultValidatorStorage().sessionCounter[msg.sender]++; - _defaultValidatorStorage().sessions[sessionHash].status[msg.sender] = SessionLib.Status.Active; - _defaultValidatorStorage().allSigners[msg.sender].add(sessionSpec.signer); + bytes32 sessionUid = keccak256(abi.encodePacked(sessionSpec.signer, msg.sender, block.timestamp)); - SessionLib.StoredSessionSpec storage newSessionSpec = - _defaultValidatorStorage().sessionSpec[msg.sender][sessionSpec.signer]; - - newSessionSpec.expiresAt = sessionSpec.expiresAt; - newSessionSpec.feeLimit = sessionSpec.feeLimit; + _defaultValidatorStorage().sessionIds[msg.sender][sessionSpec.signer] = sessionUid; + _defaultValidatorStorage().allSigners[msg.sender].add(sessionSpec.signer); + _defaultValidatorStorage().sessionExpiration[sessionUid] = sessionSpec.expiresAt; - delete newSessionSpec.callPolicies; for (uint256 i = 0; i < sessionSpec.callPolicies.length; i++) { - newSessionSpec.callPolicies.push(sessionSpec.callPolicies[i]); + _defaultValidatorStorage().callPolicies[sessionUid].push(sessionSpec.callPolicies[i]); + bytes32 targetHash = keccak256( + abi.encodePacked(sessionSpec.callPolicies[i].target, sessionSpec.callPolicies[i].selector, sessionUid) + ); + _defaultValidatorStorage().targetCallPolicy[targetHash] = sessionSpec.callPolicies[i]; } - delete newSessionSpec.transferPolicies; for (uint256 i = 0; i < sessionSpec.transferPolicies.length; i++) { - newSessionSpec.transferPolicies.push(sessionSpec.transferPolicies[i]); + _defaultValidatorStorage().transferPolicies[sessionUid].push(sessionSpec.transferPolicies[i]); + bytes32 targetHash = keccak256(abi.encodePacked(sessionSpec.transferPolicies[i].target, sessionUid)); + _defaultValidatorStorage().targetTransferPolicy[targetHash] = sessionSpec.transferPolicies[i]; } - emit SessionCreated(msg.sender, sessionHash, sessionSpec); + emit SessionCreated(msg.sender, sessionUid, sessionSpec); } function getCallPoliciesForSigner(address account, address signer) @@ -229,9 +209,8 @@ contract DefaultValidatorNew is IValidator, EIP712 { view returns (SessionLib.CallSpec[] memory) { - SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[account][signer]; - - return sessionSpec.callPolicies; + bytes32 sessionUid = _defaultValidatorStorage().sessionIds[account][signer]; + return _defaultValidatorStorage().callPolicies[sessionUid]; } function getTransferPoliciesForSigner(address account, address signer) @@ -239,25 +218,13 @@ contract DefaultValidatorNew is IValidator, EIP712 { view returns (SessionLib.TransferSpec[] memory) { - SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[account][signer]; - - return sessionSpec.transferPolicies; - } - - function getFeeLimitForSigner(address account, address signer) - external - view - returns (SessionLib.UsageLimit memory) - { - SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[account][signer]; - - return sessionSpec.feeLimit; + bytes32 sessionUid = _defaultValidatorStorage().sessionIds[account][signer]; + return _defaultValidatorStorage().transferPolicies[sessionUid]; } function getSessionExpirationForSigner(address account, address signer) external view returns (uint256) { - SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[account][signer]; - - return sessionSpec.expiresAt; + bytes32 sessionUid = _defaultValidatorStorage().sessionIds[account][signer]; + return _defaultValidatorStorage().sessionExpiration[sessionUid]; } function getSessionStateForSigner(address account, address signer) @@ -265,10 +232,13 @@ contract DefaultValidatorNew is IValidator, EIP712 { view returns (SessionLib.SessionState memory) { - SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[account][signer]; - bytes32 sessionHash = keccak256(abi.encode(signer, sessionSpec)); + bytes32 sessionUid = _defaultValidatorStorage().sessionIds[account][signer]; - return _defaultValidatorStorage().sessions[sessionHash].getState(account, sessionSpec); + return _defaultValidatorStorage().sessions[sessionUid].getState( + account, + _defaultValidatorStorage().callPolicies[sessionUid], + _defaultValidatorStorage().transferPolicies[sessionUid] + ); } function _isValidSigner(address _account, address _signer, PackedUserOperation calldata _userOp) @@ -283,8 +253,7 @@ contract DefaultValidatorNew is IValidator, EIP712 { return (true, 0, 0); } - SessionLib.StoredSessionSpec memory sessionSpec = _defaultValidatorStorage().sessionSpec[_account][_signer]; - bytes32 sessionHash = keccak256(abi.encode(_signer, sessionSpec)); + bytes32 sessionUid = _defaultValidatorStorage().sessionIds[_account][_signer]; bytes4 selector = _getFunctionSelector(_userOp.callData); @@ -298,13 +267,7 @@ contract DefaultValidatorNew is IValidator, EIP712 { (address target, uint256 value, bytes calldata callData) = ExecutionLib.decodeSingle(_userOp.callData[(36 + dataOffset):]); - isValid = - _defaultValidatorStorage().sessions[sessionHash].validate(target, value, callData, sessionSpec); - - // (address paymaster,,) = UserOperationLib.unpackPaymasterStaticFields(_userOp.paymasterAndData); - // if (paymaster == address(0)) { - // // TODO: do we need to validate fee limit? - // } + isValid = _validatePolicies(target, value, callData, sessionUid); return (isValid, validAfter, validUntil); } else if (callType == CALLTYPE_BATCH) { @@ -313,8 +276,8 @@ contract DefaultValidatorNew is IValidator, EIP712 { uint256 length = executions.length; for (uint256 i = 0; i < length; i++) { - SessionLib.SessionStorage storage state = _defaultValidatorStorage().sessions[sessionHash]; - isValid = state.validate(executions[i].target, executions[i].value, executions[i].data, sessionSpec); + isValid = + _validatePolicies(executions[i].target, executions[i].value, executions[i].data, sessionUid); if (!isValid) { return (isValid, validAfter, validUntil); @@ -338,6 +301,29 @@ contract DefaultValidatorNew is IValidator, EIP712 { } } + function _validatePolicies(address target, uint256 value, bytes memory callData, bytes32 sessionUid) + internal + returns (bool) + { + uint256 sessionExpiresAt = _defaultValidatorStorage().sessionExpiration[sessionUid]; + SessionLib.CallSpec memory callPolicy; + SessionLib.TransferSpec memory transferPolicy; + + bytes4 targetSelector; + if (callData.length >= 4 && callData.length != 12) { + targetSelector = bytes4(callData[0]) | (bytes4(callData[1]) >> 8) | (bytes4(callData[2]) >> 16) + | (bytes4(callData[3]) >> 24); + bytes32 targetHash = keccak256(abi.encodePacked(target, targetSelector, sessionUid)); + callPolicy = _defaultValidatorStorage().targetCallPolicy[targetHash]; + } else { + bytes32 targetHash = keccak256(abi.encodePacked(target, sessionUid)); + transferPolicy = _defaultValidatorStorage().targetTransferPolicy[targetHash]; + } + return _defaultValidatorStorage().sessions[sessionUid].validate( + target, targetSelector, value, callData, sessionExpiresAt, callPolicy, transferPolicy + ); + } + function _getFunctionSelector(bytes calldata data) internal pure returns (bytes4 functionSelector) { require(data.length >= 4, "!Data"); return bytes4(data[:4]); diff --git a/src/lib/SessionLib.sol b/src/lib/SessionLib.sol index 3f757c4..a3eab48 100644 --- a/src/lib/SessionLib.sol +++ b/src/lib/SessionLib.sol @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 pragma solidity ^0.8.26; +import "lib/forge-std/src/console.sol"; import {LibBytes} from "solady/utils/LibBytes.sol"; library SessionLib { @@ -9,12 +10,6 @@ library SessionLib { using SessionLib for SessionLib.UsageLimit; using LibBytes for bytes; - enum Status { - NotInitialized, - Active, - Closed - } - enum LimitType { Unlimited, Lifetime, @@ -61,14 +56,12 @@ library SessionLib { struct SessionSpec { address signer; uint256 expiresAt; - UsageLimit feeLimit; // TODO: remove CallSpec[] callPolicies; TransferSpec[] transferPolicies; } struct StoredSessionSpec { uint256 expiresAt; - UsageLimit feeLimit; CallSpec[] callPolicies; TransferSpec[] transferPolicies; } @@ -86,8 +79,6 @@ library SessionLib { // Info about remaining session limits and its status struct SessionState { - Status status; - uint256 feesRemaining; LimitState[] transferValue; LimitState[] callValue; LimitState[] callParams; @@ -102,8 +93,6 @@ library SessionLib { } struct SessionStorage { - mapping(address => Status) status; - UsageTracker fee; // (target) => transfer value tracker mapping(address => UsageTracker) transferValue; // (target, selector) => call value tracker @@ -173,60 +162,37 @@ library SessionLib { bytes memory data, address target, bytes4 selector, - CallSpec[] memory callPolicies - ) private returns (bool found, CallSpec memory) { - CallSpec memory callPolicy; - - for (uint256 i = 0; i < callPolicies.length; i++) { - if (callPolicies[i].target == target && callPolicies[i].selector == selector) { - callPolicy = callPolicies[i]; - found = true; - break; - } - } - - if (!found) { - return (found, callPolicy); - } - + CallSpec memory callPolicy + ) private returns (bool found) { for (uint256 i = 0; i < callPolicy.constraints.length; i++) { bool check = callPolicy.constraints[i].checkAndUpdate(state.params[target][selector][i], data); if (!check) { - return (false, callPolicy); + // return (false, callPolicy); + return false; } } - - return (true, callPolicy); + return true; } - // TODO: do we need this? - function validateFeeLimit() internal returns (bool) {} - function validate( SessionStorage storage state, address target, + bytes4 selector, uint256 value, bytes memory callData, - StoredSessionSpec memory spec + uint256 expiresAt, + CallSpec memory callPolicy, + TransferSpec memory transferPolicy ) internal returns (bool) { - if (state.status[msg.sender] != Status.Active) { - // revert SessionNotActive(); - return false; - } - - if (spec.expiresAt < block.timestamp) { + if (expiresAt < block.timestamp) { return false; } + // TODO: fix this if (callData.length >= 4 && callData.length != 12) { - // TODO: fix this - // bytes4 selector = bytes4(callData[:4]); - bytes4 selector = bytes4(callData[0]) | (bytes4(callData[1]) >> 8) | (bytes4(callData[2]) >> 16) - | (bytes4(callData[3]) >> 24); - (bool found, CallSpec memory callPolicy) = - checkCallPolicy(state, callData, target, selector, spec.callPolicies); - - if (!found) { + bool callPolicyValid = checkCallPolicy(state, callData, target, selector, callPolicy); + + if (!callPolicyValid) { return false; } @@ -237,19 +203,7 @@ library SessionLib { return callPolicy.valueLimit.checkAndUpdate(state.callValue[target][selector], value); } else { - TransferSpec memory transferPolicy; - bool found = false; - - for (uint256 i = 0; i < spec.transferPolicies.length; i++) { - if (spec.transferPolicies[i].target == target) { - transferPolicy = spec.transferPolicies[i]; - found = true; - break; - } - } - - if (!found) { - // revert SessionTransferPolicyViolated(target); + if (target != transferPolicy.target) { return false; } @@ -280,52 +234,68 @@ library SessionLib { } } - function getState(SessionStorage storage session, address account, StoredSessionSpec memory spec) - internal - view - returns (SessionState memory) - { - uint256 totalConstraints = 0; - for (uint256 i = 0; i < spec.callPolicies.length; i++) { - totalConstraints += spec.callPolicies[i].constraints.length; - } + function getState( + SessionStorage storage session, + address account, + CallSpec[] memory callPolicies, + TransferSpec[] memory transferPolicies + ) internal view returns (SessionState memory) { + LimitState[] memory transferValue; + LimitState[] memory callValue; + LimitState[] memory callParams; + + { + uint256 totalConstraints = 0; + for (uint256 i = 0; i < callPolicies.length; i++) { + totalConstraints += callPolicies[i].constraints.length; + } - LimitState[] memory transferValue = new LimitState[](spec.transferPolicies.length); - LimitState[] memory callValue = new LimitState[](spec.callPolicies.length); - LimitState[] memory callParams = new LimitState[](totalConstraints); // there will be empty ones at the end + transferValue = new LimitState[](transferPolicies.length); + callValue = new LimitState[](callPolicies.length); + callParams = new LimitState[](totalConstraints); // there will be empty ones at the end + } uint256 paramLimitIndex = 0; - for (uint256 i = 0; i < transferValue.length; i++) { - TransferSpec memory transferSpec = spec.transferPolicies[i]; - transferValue[i] = LimitState({ - remaining: remainingLimit(transferSpec.valueLimit, session.transferValue[transferSpec.target], account), - target: transferSpec.target, - selector: bytes4(0), - index: 0 - }); + { + for (uint256 i = 0; i < transferValue.length; i++) { + TransferSpec memory transferSpec = transferPolicies[i]; + transferValue[i] = LimitState({ + remaining: remainingLimit(transferSpec.valueLimit, session.transferValue[transferSpec.target], account), + target: transferSpec.target, + selector: bytes4(0), + index: 0 + }); + } } for (uint256 i = 0; i < callValue.length; i++) { - CallSpec memory callSpec = spec.callPolicies[i]; - callValue[i] = LimitState({ - remaining: remainingLimit( - callSpec.valueLimit, session.callValue[callSpec.target][callSpec.selector], account - ), - target: callSpec.target, - selector: callSpec.selector, - index: 0 - }); - - for (uint256 j = 0; j < callSpec.constraints.length; j++) { - if (callSpec.constraints[j].limit.limitType != LimitType.Unlimited) { - callParams[paramLimitIndex++] = LimitState({ - remaining: remainingLimit( - callSpec.constraints[j].limit, session.params[callSpec.target][callSpec.selector][j], account - ), - target: callSpec.target, - selector: callSpec.selector, - index: callSpec.constraints[j].index - }); + CallSpec memory callSpec = callPolicies[i]; + + { + callValue[i] = LimitState({ + remaining: remainingLimit( + callSpec.valueLimit, session.callValue[callSpec.target][callSpec.selector], account + ), + target: callSpec.target, + selector: callSpec.selector, + index: 0 + }); + } + + { + for (uint256 j = 0; j < callSpec.constraints.length; j++) { + if (callSpec.constraints[j].limit.limitType != LimitType.Unlimited) { + callParams[paramLimitIndex++] = LimitState({ + remaining: remainingLimit( + callSpec.constraints[j].limit, + session.params[callSpec.target][callSpec.selector][j], + account + ), + target: callSpec.target, + selector: callSpec.selector, + index: callSpec.constraints[j].index + }); + } } } } @@ -335,13 +305,7 @@ library SessionLib { mstore(callParams, paramLimitIndex) } - return SessionState({ - status: session.status[account], - feesRemaining: remainingLimit(spec.feeLimit, session.fee, account), - transferValue: transferValue, - callValue: callValue, - callParams: callParams - }); + return SessionState({transferValue: transferValue, callValue: callValue, callParams: callParams}); } } diff --git a/test/DefaultValidatorNew.t.sol b/test/DefaultValidatorNew.t.sol index 200bf71..3a6d87d 100644 --- a/test/DefaultValidatorNew.t.sol +++ b/test/DefaultValidatorNew.t.sol @@ -146,12 +146,8 @@ contract DefaultValidatorNewTest is Test { SessionLib.TransferSpec[] memory fetchedTransferPolicies = DefaultValidatorNew(validator).getTransferPoliciesForSigner(address(account), signerOne); - SessionLib.UsageLimit memory fetchedFeeLimit = - DefaultValidatorNew(validator).getFeeLimitForSigner(address(account), signerOne); - uint256 expiresAt = DefaultValidatorNew(validator).getSessionExpirationForSigner(address(account), signerOne); - assertEq(uint256(fetchedFeeLimit.limitType), uint256(spec.feeLimit.limitType)); assertEq(expiresAt, spec.expiresAt); } @@ -543,7 +539,6 @@ contract DefaultValidatorNewTest is Test { function getDefaultSessionSpec(address signer) internal returns (SessionLib.SessionSpec memory spec) { spec.signer = signer; spec.expiresAt = 100; - spec.feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); spec.callPolicies = new SessionLib.CallSpec[](0); spec.transferPolicies = new SessionLib.TransferSpec[](0); } From 0e11c0c455c0698d13d988facd495331f1a3fc9c Mon Sep 17 00:00:00 2001 From: Yash Date: Wed, 5 Mar 2025 22:39:05 +0530 Subject: [PATCH 12/15] fix tests --- test/BenchmarkDefaultValidatorNew.t.sol | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/test/BenchmarkDefaultValidatorNew.t.sol b/test/BenchmarkDefaultValidatorNew.t.sol index 7f80c37..b24d412 100644 --- a/test/BenchmarkDefaultValidatorNew.t.sol +++ b/test/BenchmarkDefaultValidatorNew.t.sol @@ -84,7 +84,6 @@ contract BenchmarkDefaultValidatorNewTest is Test { spec.signer = vm.addr(1234); spec.expiresAt = 100; - spec.feeLimit = feeLimit; spec.callPolicies = callPolicies; spec.transferPolicies = transferPolicies; @@ -115,7 +114,6 @@ contract BenchmarkDefaultValidatorNewTest is Test { spec.signer = vm.addr(1234); spec.expiresAt = 100; - spec.feeLimit = feeLimit; spec.callPolicies = callPolicies; spec.transferPolicies = transferPolicies; @@ -144,7 +142,6 @@ contract BenchmarkDefaultValidatorNewTest is Test { spec.signer = vm.addr(1234); spec.expiresAt = 100; - spec.feeLimit = feeLimit; spec.callPolicies = callPolicies; spec.transferPolicies = transferPolicies; @@ -173,7 +170,6 @@ contract BenchmarkDefaultValidatorNewTest is Test { spec.signer = vm.addr(1234); spec.expiresAt = 100; - spec.feeLimit = feeLimit; spec.callPolicies = callPolicies; spec.transferPolicies = transferPolicies; @@ -210,7 +206,6 @@ contract BenchmarkDefaultValidatorNewTest is Test { spec.signer = vm.addr(1234); spec.expiresAt = 100; - spec.feeLimit = feeLimit; spec.callPolicies = callPolicies; spec.transferPolicies = transferPolicies; @@ -259,7 +254,6 @@ contract BenchmarkDefaultValidatorNewTest is Test { spec.signer = signerOne; spec.expiresAt = 100; - spec.feeLimit = feeLimit; spec.callPolicies = callPolicies; spec.transferPolicies = transferPolicies; @@ -292,7 +286,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { address recoveredSigner = ECDSA.recover(msgHash, v, r, s); assertEq(recoveredSigner, signerOne); - userOp.signature = userOpSignature; + userOp.signature = abi.encode(signerOne, userOpSignature); } // Create userOps array @@ -339,7 +333,6 @@ contract BenchmarkDefaultValidatorNewTest is Test { spec.signer = signerOne; spec.expiresAt = 100; - spec.feeLimit = feeLimit; spec.callPolicies = callPolicies; spec.transferPolicies = transferPolicies; @@ -372,7 +365,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { address recoveredSigner = ECDSA.recover(msgHash, v, r, s); assertEq(recoveredSigner, signerOne); - userOp.signature = userOpSignature; + userOp.signature = abi.encode(signerOne, userOpSignature); } // Create userOps array @@ -408,7 +401,6 @@ contract BenchmarkDefaultValidatorNewTest is Test { spec.signer = signerOne; spec.expiresAt = 100; - spec.feeLimit = feeLimit; spec.callPolicies = callPolicies; spec.transferPolicies = transferPolicies; @@ -441,7 +433,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { address recoveredSigner = ECDSA.recover(msgHash, v, r, s); assertEq(recoveredSigner, signerOne); - userOp.signature = userOpSignature; + userOp.signature = abi.encode(signerOne, userOpSignature); } // Create userOps array @@ -477,7 +469,6 @@ contract BenchmarkDefaultValidatorNewTest is Test { spec.signer = signerOne; spec.expiresAt = 100; - spec.feeLimit = feeLimit; spec.callPolicies = callPolicies; spec.transferPolicies = transferPolicies; @@ -510,7 +501,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { address recoveredSigner = ECDSA.recover(msgHash, v, r, s); assertEq(recoveredSigner, signerOne); - userOp.signature = userOpSignature; + userOp.signature = abi.encode(signerOne, userOpSignature); } // Create userOps array @@ -578,7 +569,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, msgHash); bytes memory userOpSignature = abi.encodePacked(r, s, v); - userOp.signature = userOpSignature; + userOp.signature = abi.encode(accountOwner, userOpSignature); } // Create userOps array From 3563d0b544b9e176c1209bb718d9d0ed83fad82c Mon Sep 17 00:00:00 2001 From: Yash Date: Wed, 5 Mar 2025 22:40:48 +0530 Subject: [PATCH 13/15] gasreport after refactor --- gasreport.txt | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/gasreport.txt b/gasreport.txt index 01bab83..b04be77 100644 --- a/gasreport.txt +++ b/gasreport.txt @@ -1,26 +1,26 @@ No files changed, compilation skipped Ran 9 tests for test/BenchmarkDefaultValidatorNew.t.sol:BenchmarkDefaultValidatorNewTest -[PASS] test_createSessionKeyForSigner_100CallPolicies() (gas: 2467149) -[PASS] test_createSessionKeyForSigner_100TransferPolicies() (gas: 2209521) -[PASS] test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 2336358) -[PASS] test_createSessionKeyForSigner_50CallPolicies() (gas: 2259844) -[PASS] test_createSessionKeyForSigner_50TransferPolicies() (gas: 2134946) -[PASS] test_execute_100CallPolicies() (gas: 2760545) -[PASS] test_execute_100TransferPolicies() (gas: 1317589) -[PASS] test_execute_50CallPolicies() (gas: 1471450) -[PASS] test_execute_50TransferPolicies() (gas: 743850) -Suite result: ok. 9 passed; 0 failed; 0 skipped; finished in 12.74ms (69.57ms CPU time) +[PASS] test_createSessionKeyForSigner_100CallPolicies() (gas: 2470875) +[PASS] test_createSessionKeyForSigner_100TransferPolicies() (gas: 2213038) +[PASS] test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 2339958) +[PASS] test_createSessionKeyForSigner_50CallPolicies() (gas: 2263389) +[PASS] test_createSessionKeyForSigner_50TransferPolicies() (gas: 2138379) +[PASS] test_execute_100CallPolicies() (gas: 210125) +[PASS] test_execute_100TransferPolicies() (gas: 175656) +[PASS] test_execute_50CallPolicies() (gas: 210126) +[PASS] test_execute_50TransferPolicies() (gas: 175641) +Suite result: ok. 9 passed; 0 failed; 0 skipped; finished in 14.13ms (71.45ms CPU time) -Ran 1 test suite in 14.95ms (12.74ms CPU time): 9 tests passed, 0 failed, 0 skipped (9 total tests) -test_execute_100TransferPolicies() (gas: 0 (0.000%)) -test_execute_50CallPolicies() (gas: 0 (0.000%)) -test_createSessionKeyForSigner_100TransferPolicies() (gas: 1 (0.000%)) -test_createSessionKeyForSigner_100CallPolicies() (gas: -23 (-0.001%)) -test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: -22 (-0.001%)) -test_createSessionKeyForSigner_50TransferPolicies() (gas: -22 (-0.001%)) -test_execute_100CallPolicies() (gas: 41 (0.001%)) -test_createSessionKeyForSigner_50CallPolicies() (gas: 34 (0.002%)) -test_execute_50TransferPolicies() (gas: -58 (-0.008%)) -Overall gas change: -49 (-0.000%) +Ran 1 test suite in 16.52ms (14.13ms CPU time): 9 tests passed, 0 failed, 0 skipped (9 total tests) +test_createSessionKeyForSigner_100CallPolicies() (gas: 3703 (0.150%)) +test_createSessionKeyForSigner_50Call50TransferPolicies() (gas: 3578 (0.153%)) +test_createSessionKeyForSigner_50CallPolicies() (gas: 3579 (0.158%)) +test_createSessionKeyForSigner_100TransferPolicies() (gas: 3518 (0.159%)) +test_createSessionKeyForSigner_50TransferPolicies() (gas: 3411 (0.160%)) +test_execute_50TransferPolicies() (gas: -568267 (-76.389%)) +test_execute_50CallPolicies() (gas: -1261324 (-85.720%)) +test_execute_100TransferPolicies() (gas: -1141933 (-86.668%)) +test_execute_100CallPolicies() (gas: -2550379 (-92.388%)) +Overall gas change: -5504114 (-31.094%) From 04768921e0823e31b912026dde9f7a04b2bcf53c Mon Sep 17 00:00:00 2001 From: Yash Date: Wed, 5 Mar 2025 23:30:34 +0530 Subject: [PATCH 14/15] support any selector feature --- src/DefaultValidatorNew.sol | 11 +- src/lib/SessionLib.sol | 5 +- test/BenchmarkDefaultValidatorNew.t.sol | 16 ++- test/DefaultValidatorNew.t.sol | 133 ++++++++++++++++++++++++ 4 files changed, 158 insertions(+), 7 deletions(-) diff --git a/src/DefaultValidatorNew.sol b/src/DefaultValidatorNew.sol index 1b334bf..75686e0 100644 --- a/src/DefaultValidatorNew.sol +++ b/src/DefaultValidatorNew.sol @@ -51,6 +51,7 @@ library DefaultValidatorNewStorage { mapping(bytes32 sessionUid => uint256 expiresAt) sessionExpiration; mapping(bytes32 targetHash => SessionLib.CallSpec) targetCallPolicy; mapping(bytes32 targetHash => SessionLib.TransferSpec) targetTransferPolicy; + mapping(bytes32 sessionUid => mapping(address target => bool)) anySelectorForTarget; } function data() internal pure returns (Data storage $) { @@ -193,6 +194,10 @@ contract DefaultValidatorNew is IValidator, EIP712 { abi.encodePacked(sessionSpec.callPolicies[i].target, sessionSpec.callPolicies[i].selector, sessionUid) ); _defaultValidatorStorage().targetCallPolicy[targetHash] = sessionSpec.callPolicies[i]; + + if (sessionSpec.callPolicies[i].selector == bytes4(0x00000000)) { + _defaultValidatorStorage().anySelectorForTarget[sessionUid][sessionSpec.callPolicies[i].target] = true; + } } for (uint256 i = 0; i < sessionSpec.transferPolicies.length; i++) { @@ -311,8 +316,10 @@ contract DefaultValidatorNew is IValidator, EIP712 { bytes4 targetSelector; if (callData.length >= 4 && callData.length != 12) { - targetSelector = bytes4(callData[0]) | (bytes4(callData[1]) >> 8) | (bytes4(callData[2]) >> 16) - | (bytes4(callData[3]) >> 24); + targetSelector = _defaultValidatorStorage().anySelectorForTarget[sessionUid][target] + ? bytes4(0x00000000) + : bytes4(callData[0]) | (bytes4(callData[1]) >> 8) | (bytes4(callData[2]) >> 16) + | (bytes4(callData[3]) >> 24); bytes32 targetHash = keccak256(abi.encodePacked(target, targetSelector, sessionUid)); callPolicy = _defaultValidatorStorage().targetCallPolicy[targetHash]; } else { diff --git a/src/lib/SessionLib.sol b/src/lib/SessionLib.sol index a3eab48..fd4b3b9 100644 --- a/src/lib/SessionLib.sol +++ b/src/lib/SessionLib.sol @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 pragma solidity ^0.8.26; -import "lib/forge-std/src/console.sol"; import {LibBytes} from "solady/utils/LibBytes.sol"; library SessionLib { @@ -164,6 +163,10 @@ library SessionLib { bytes4 selector, CallSpec memory callPolicy ) private returns (bool found) { + if (target != callPolicy.target || selector != callPolicy.selector) { + return false; + } + for (uint256 i = 0; i < callPolicy.constraints.length; i++) { bool check = callPolicy.constraints[i].checkAndUpdate(state.params[target][selector][i], data); if (!check) { diff --git a/test/BenchmarkDefaultValidatorNew.t.sol b/test/BenchmarkDefaultValidatorNew.t.sol index b24d412..e2019c4 100644 --- a/test/BenchmarkDefaultValidatorNew.t.sol +++ b/test/BenchmarkDefaultValidatorNew.t.sol @@ -183,7 +183,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { function test_createSessionKeyForSigner_50Call50TransferPolicies() public { vm.pauseGasMetering(); _setup(abi.encodePacked(uint256(0x5))); - + SessionLib.SessionSpec memory spec; SessionLib.UsageLimit memory feeLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); @@ -265,7 +265,11 @@ contract BenchmarkDefaultValidatorNewTest is Test { // 2. prepare and send User Op bytes memory userOpCalldata = abi.encodeCall( - IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), uint256(0), abi.encodeCall(MockTarget.set, 42))) + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(target), uint256(0), abi.encodeCall(MockTarget.set, 42)) + ) ); uint256 nonce = getNonce(address(account), address(validator)); @@ -344,7 +348,11 @@ contract BenchmarkDefaultValidatorNewTest is Test { // 2. prepare and send User Op bytes memory userOpCalldata = abi.encodeCall( - IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(target), uint256(0), abi.encodeCall(MockTarget.set, 42))) + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(target), uint256(0), abi.encodeCall(MockTarget.set, 42)) + ) ); uint256 nonce = getNonce(address(account), address(validator)); @@ -447,7 +455,7 @@ contract BenchmarkDefaultValidatorNewTest is Test { function test_execute_100TransferPolicies() public { vm.pauseGasMetering(); _setup(abi.encodePacked(uint256(0x9))); - + // 1. create session key { diff --git a/test/DefaultValidatorNew.t.sol b/test/DefaultValidatorNew.t.sol index 3a6d87d..1192254 100644 --- a/test/DefaultValidatorNew.t.sol +++ b/test/DefaultValidatorNew.t.sol @@ -222,6 +222,41 @@ contract DefaultValidatorNewTest is Test { entrypoint.handleOps(ops, payable(address(account))); } + function test_revert_transferPolicy_wrongTarget() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); + transferPolicies[0].target = address(target); + transferPolicies[0].maxValuePerUse = 500; + transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.transferPolicies = transferPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, (ModeLib.encodeSimpleSingle(), ExecutionLib.encodeSingle(address(0x1234), 500, "")) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + function test_callPolicy_erc20Transfer() public { // 1. create session key @@ -336,6 +371,62 @@ contract DefaultValidatorNewTest is Test { entrypoint.handleOps(ops, payable(address(account))); } + function test_revert_callPolicy_erc20Transfer_wrongSelector() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = ERC20.transfer.selector; + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + // set constraints + SessionLib.Constraint[] memory constraints = new SessionLib.Constraint[](2); + // constraint 0: can only transfer to defined target below + constraints[0].index = 0; + constraints[0].refValue = bytes32(uint256(uint160(address(target)))); + constraints[0].condition = SessionLib.Condition.Equal; + + // constraint 1: can transfer 100 tokens per transaction, and 120 total + constraints[1].index = 1; + constraints[1].refValue = bytes32(uint256(100)); + constraints[1].condition = SessionLib.Condition.LessOrEqual; + constraints[1].limit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 120, 0); + + callPolicies[0].constraints = constraints; + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + // === set wrong selector here ========= + ExecutionLib.encodeSingle(address(erc20), 0, abi.encodeCall(MockERC20.mint, (address(target), 100))) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + function test_revert_callPolicy_erc20Transfer_crossTxLimit() public { uint256 wrongAmount = 101; @@ -463,6 +554,48 @@ contract DefaultValidatorNewTest is Test { assertEq(state.callParams[0].remaining, 20); } + function test_callPolicy_anySelector() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + // create call policy + SessionLib.CallSpec[] memory callPolicies = new SessionLib.CallSpec[](1); + callPolicies[0].target = address(erc20); + callPolicies[0].selector = bytes4(0x00000000); + callPolicies[0].maxValuePerUse = 0; + callPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Lifetime, 0, 0); + + spec.callPolicies = callPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall( + IERC7579Account.execute, + ( + ModeLib.encodeSimpleSingle(), + ExecutionLib.encodeSingle(address(erc20), 0, abi.encodeCall(ERC20.transfer, (address(target), 100))) + ) + ); + + uint256 nonce = getNonce(address(account), address(validator)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + + assertTrue(erc20.balanceOf(address(target)) == 100); + } + function test_sessionPeriodAllowance() public { uint256 period = 100; From 35e0c138ccf7527379116c93ba4152891a90639f Mon Sep 17 00:00:00 2001 From: Yash Date: Wed, 5 Mar 2025 23:39:16 +0530 Subject: [PATCH 15/15] 6492 sig --- src/DefaultValidatorNew.sol | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DefaultValidatorNew.sol b/src/DefaultValidatorNew.sol index 75686e0..c9ef252 100644 --- a/src/DefaultValidatorNew.sol +++ b/src/DefaultValidatorNew.sol @@ -118,7 +118,7 @@ contract DefaultValidatorNew is IValidator, EIP712 { (address signer, bytes memory signature) = abi.decode(userOp.signature, (address, bytes)); bytes32 hash = SignatureCheckerLib.toEthSignedMessageHash(userOpHash); - bool isValidSig = SignatureCheckerLib.isValidSignatureNow(signer, hash, signature); + bool isValidSig = SignatureCheckerLib.isValidERC6492SignatureNow(signer, hash, signature); if (!isValidSig) { return VALIDATION_FAILED;