diff --git a/src/4337/SessionLib.sol b/src/4337/SessionLib.sol new file mode 100644 index 0000000..ddd1078 --- /dev/null +++ b/src/4337/SessionLib.sol @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import {LibBytes} from "solady/utils/LibBytes.sol"; + +library SessionLib { + + using SessionLib for SessionLib.Constraint; + using SessionLib for SessionLib.UsageLimit; + using LibBytes for bytes; + + enum LimitType { + Unlimited, + Lifetime, + Allowance + } + + enum Condition { + Unconstrained, + Equal, + Greater, + Less, + GreaterOrEqual, + LessOrEqual, + NotEqual + } + + struct UsageLimit { + LimitType limitType; + uint256 limit; // ignored if limitType == Unlimited + uint256 period; // ignored if limitType != Allowance + } + + struct Constraint { + Condition condition; + uint64 index; + bytes32 refValue; + UsageLimit limit; + } + + struct CallSpec { + address target; + bytes4 selector; + uint256 maxValuePerUse; + UsageLimit valueLimit; + Constraint[] constraints; + } + + struct TransferSpec { + address target; + uint256 maxValuePerUse; + UsageLimit valueLimit; + } + + struct SessionSpec { + address signer; + uint256 expiresAt; + CallSpec[] callPolicies; + TransferSpec[] transferPolicies; + } + + struct StoredSessionSpec { + uint256 expiresAt; + CallSpec[] callPolicies; + TransferSpec[] transferPolicies; + } + + struct LimitState { + // this might also be limited by a constraint or `maxValuePerUse`, + // which is not reflected here + uint256 remaining; + address target; + // ignored for transfer value + bytes4 selector; + // ignored for transfer and call value + uint256 index; + } + + // Info about remaining session limits and its status + struct SessionState { + LimitState[] transferValue; + LimitState[] callValue; + LimitState[] callParams; + } + + struct UsageTracker { + // Used for LimitType.Lifetime + uint256 lifetimeUsage; + // Used for LimitType.Allowance + // period => used that period + mapping(uint64 => uint256) allowanceUsage; + } + + struct SessionStorage { + // (target) => transfer value tracker + mapping(address => UsageTracker) transferValue; + // (target, selector) => call value tracker + mapping(address => mapping(bytes4 => UsageTracker)) callValue; + // (target, selector, index) => call parameter tracker + // index is the constraint index in callPolicy, not the parameter index + mapping(address => mapping(bytes4 => mapping(uint256 => UsageTracker))) params; + } + + function checkAndUpdate(UsageLimit memory limit, UsageTracker storage tracker, uint256 value) + internal + returns (bool) + { + if (limit.limitType == LimitType.Lifetime) { + if (tracker.lifetimeUsage + value > limit.limit) { + // revert SessionLifetimeUsageExceeded(tracker.lifetimeUsage[msg.sender], limit.limit); + return false; + } + tracker.lifetimeUsage += value; + } else if (limit.limitType == LimitType.Allowance) { + uint64 period = uint64(block.timestamp / limit.period); + + if (tracker.allowanceUsage[period] + value > limit.limit) { + // revert SessionAllowanceExceeded(tracker.allowanceUsage[period][msg.sender], limit.limit, period); + return false; + } + + tracker.allowanceUsage[period] += value; + } + + return true; + } + + function checkAndUpdate(Constraint memory constraint, UsageTracker storage tracker, bytes memory data) + internal + returns (bool) + { + uint256 expectedLength = 4 + constraint.index * 32 + 32; + + if (data.length < expectedLength) { + // revert SessionInvalidDataLength(data.length, expectedLength); + return false; + } + + bytes32 param = data.load(4 + constraint.index * 32); + Condition condition = constraint.condition; + bytes32 refValue = constraint.refValue; + + if ( + (condition == Condition.Equal && param != refValue) || (condition == Condition.Greater && param <= refValue) + || (condition == Condition.Less && param >= refValue) + || (condition == Condition.GreaterOrEqual && param < refValue) + || (condition == Condition.LessOrEqual && param > refValue) + || (condition == Condition.NotEqual && param == refValue) + ) { + // revert SessionConditionFailed(param, refValue, uint8(condition)); + return false; + } + + bool check = constraint.limit.checkAndUpdate(tracker, uint256(param)); + + return check; + } + + function checkCallPolicy( + SessionStorage storage state, + bytes memory data, + address target, + bytes4 selector, + CallSpec memory callPolicy + ) private returns (bool found) { + if (target != callPolicy.target || selector != callPolicy.selector) { + return false; + } + + for (uint256 i = 0; i < callPolicy.constraints.length; i++) { + bool check = callPolicy.constraints[i].checkAndUpdate(state.params[target][selector][i], data); + if (!check) { + // return (false, callPolicy); + return false; + } + } + return true; + } + + function validate( + SessionStorage storage state, + address target, + bytes4 selector, + uint256 value, + bytes memory callData, + uint256 expiresAt, + CallSpec memory callPolicy, + TransferSpec memory transferPolicy + ) internal returns (bool) { + if (expiresAt < block.timestamp) { + return false; + } + + if (callData.length >= 4) { + bool callPolicyValid = checkCallPolicy(state, callData, target, selector, callPolicy); + + if (!callPolicyValid) { + return false; + } + + if (value > callPolicy.maxValuePerUse) { + // revert SessionMaxValueExceeded(value, callPolicy.maxValuePerUse); + return false; + } + + return callPolicy.valueLimit.checkAndUpdate(state.callValue[target][selector], value); + } else { + if (target != transferPolicy.target) { + return false; + } + + if (value > transferPolicy.maxValuePerUse) { + // revert SessionMaxValueExceeded(value, transferPolicy.maxValuePerUse); + return false; + } + return transferPolicy.valueLimit.checkAndUpdate(state.transferValue[target], value); + } + } + + function remainingLimit(UsageLimit memory limit, UsageTracker storage tracker) private view returns (uint256) { + if (limit.limitType == LimitType.Unlimited) { + // this might be still limited by `maxValuePerUse` or a constraint + return type(uint256).max; + } + if (limit.limitType == LimitType.Lifetime) { + return limit.limit - tracker.lifetimeUsage; + } + if (limit.limitType == LimitType.Allowance) { + // this is not used during validation, so it's fine to use block.timestamp + uint64 period = uint64(block.timestamp / limit.period); + return limit.limit - tracker.allowanceUsage[period]; + } + } + + function getState( + SessionStorage storage session, + CallSpec[] memory callPolicies, + TransferSpec[] memory transferPolicies + ) internal view returns (SessionState memory) { + LimitState[] memory transferValue; + LimitState[] memory callValue; + LimitState[] memory callParams; + + { + uint256 totalConstraints = 0; + for (uint256 i = 0; i < callPolicies.length; i++) { + totalConstraints += callPolicies[i].constraints.length; + } + + transferValue = new LimitState[](transferPolicies.length); + callValue = new LimitState[](callPolicies.length); + callParams = new LimitState[](totalConstraints); // there will be empty ones at the end + } + uint256 paramLimitIndex = 0; + + { + for (uint256 i = 0; i < transferValue.length; i++) { + TransferSpec memory transferSpec = transferPolicies[i]; + transferValue[i] = LimitState({ + remaining: remainingLimit(transferSpec.valueLimit, session.transferValue[transferSpec.target]), + target: transferSpec.target, + selector: bytes4(0), + index: 0 + }); + } + } + + for (uint256 i = 0; i < callValue.length; i++) { + CallSpec memory callSpec = callPolicies[i]; + + { + callValue[i] = LimitState({ + remaining: remainingLimit(callSpec.valueLimit, session.callValue[callSpec.target][callSpec.selector]), + target: callSpec.target, + selector: callSpec.selector, + index: 0 + }); + } + + { + for (uint256 j = 0; j < callSpec.constraints.length; j++) { + if (callSpec.constraints[j].limit.limitType != LimitType.Unlimited) { + callParams[paramLimitIndex++] = LimitState({ + remaining: remainingLimit( + callSpec.constraints[j].limit, session.params[callSpec.target][callSpec.selector][j] + ), + target: callSpec.target, + selector: callSpec.selector, + index: callSpec.constraints[j].index + }); + } + } + } + } + + // shrink array to actual size + assembly { + mstore(callParams, paramLimitIndex) + } + + return SessionState({transferValue: transferValue, callValue: callValue, callParams: callParams}); + } + +} diff --git a/src/4337/SimpleAccount.sol b/src/4337/SimpleAccount.sol new file mode 100644 index 0000000..e8fe79f --- /dev/null +++ b/src/4337/SimpleAccount.sol @@ -0,0 +1,483 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import "account-abstraction-v0.7/core/Helpers.sol"; +import {UserOperationLib} from "account-abstraction-v0.7/core/UserOperationLib.sol"; +import {IEntryPoint} from "account-abstraction-v0.7/interfaces/IEntryPoint.sol"; +import {PackedUserOperation} from "account-abstraction-v0.7/interfaces/PackedUserOperation.sol"; + +import {Receiver} from "solady/accounts/Receiver.sol"; +import {ECDSA} from "solady/utils/ECDSA.sol"; +import {EnumerableSetLib} from "solady/utils/EnumerableSetLib.sol"; +import {SignatureCheckerLib} from "solady/utils/SignatureCheckerLib.sol"; +import {EIP712} from "solady/utils/EIP712.sol"; + +import {IERC4337Account} from "../interface/IERC4337Account.sol"; +import {OwnableRoles} from "../ownable/OwnableRoles.sol"; +import {SessionLib} from "./SessionLib.sol"; + +library SimpleAccountStorage { + + /// @custom:storage-location erc7201:simple.account + bytes32 public constant SIMPLE_ACCOUNT_STORAGE_POSITION = + keccak256(abi.encode(uint256(keccak256("simple.account")) - 1)) & ~bytes32(uint256(0xff)); + + struct Data { + /// @dev The set of all admins of the wallet. + EnumerableSetLib.AddressSet allAdmins; + bool isInitialized; + EnumerableSetLib.AddressSet allSigners; + mapping(address signer => bytes32 sessionUid) sessionIds; + mapping(bytes32 sessionUid => SessionLib.CallSpec[]) callPolicies; + mapping(bytes32 sessionUid => SessionLib.TransferSpec[]) transferPolicies; + mapping(bytes32 sessionUid => SessionLib.SessionStorage sessionState) sessions; + mapping(bytes32 sessionUid => uint256 expiresAt) sessionExpiration; + mapping(bytes32 targetHash => SessionLib.CallSpec) targetCallPolicy; + mapping(bytes32 targetHash => SessionLib.TransferSpec) targetTransferPolicy; + mapping(bytes32 sessionUid => mapping(address target => bool)) anySelectorForTarget; + } + + function data() internal pure returns (Data storage $) { + bytes32 position = SIMPLE_ACCOUNT_STORAGE_POSITION; + assembly { + $.slot := position + } + } + +} + +contract SimpleAccount is IERC4337Account, EIP712, Receiver, OwnableRoles { + + using EnumerableSetLib for EnumerableSetLib.AddressSet; + using ECDSA for bytes32; + + using SessionLib for SessionLib.SessionStorage; + + error InvalidCaller(); + error ExecutionFailed(); + error SessionZeroSigner(); + error SessionExpiresTooSoon(); + + event SessionCreated(bytes32 indexed sessionUid, SessionLib.SessionSpec sessionSpec); + event Received(address sender, uint256 amount); + + uint256 private constant _ADMIN_ROLE = 1 << 0; + + IEntryPoint public immutable entrypoint; + + modifier onlyEntryPoint() { + if (msg.sender != address(entrypoint)) { + revert InvalidCaller(); + } + _; + } + + modifier onlyEntryPointOrSelf() { + if (msg.sender != address(entrypoint) && msg.sender != address(this)) { + revert InvalidCaller(); + } + _; + } + + receive() external payable override { + emit Received(msg.sender, msg.value); + } + + /// + /// Initializer + /// + constructor(address _entrypoint) payable { + entrypoint = IEntryPoint(_entrypoint); + _disableERC4337ImplementationInitializer(); + } + + /// @dev Automatically initializes the owner for the implementation. This blocks someone + /// from initializing the implementation and doing a delegatecall to SELFDESTRUCT. + /// Proxies to the implementation will still be able to initialize as per normal. + function _disableERC4337ImplementationInitializer() internal virtual { + // Note that `Ownable._guardInitializeOwner` has been and must be overridden + // to return true, to block double-initialization. We'll initialize to `address(1)`, + // so that it's easier to verify that the implementation has been initialized. + _initializeOwner(address(1)); + } + + /// @dev To prevent double-initialization (reuses the owner storage slot for efficiency). + function _guardInitializeOwner() internal pure virtual override returns (bool) { + return true; + } + + function initialize(address owner) external payable { + _initializeOwner(owner); + } + + /// + /// ERC-4337 + /// + function execute(address _target, uint256 _value, bytes calldata _calldata) external virtual onlyEntryPointOrSelf { + _call(_target, _value, _calldata); + } + + function executeBatch(address[] calldata _target, uint256[] calldata _value, bytes[] calldata _calldata) + external + virtual + onlyEntryPointOrSelf + { + require(_target.length == _calldata.length && _target.length == _value.length, "Account: wrong array lengths."); + for (uint256 i = 0; i < _target.length; i++) { + _call(_target[i], _value[i], _calldata[i]); + } + } + + function _call(address _target, uint256 value, bytes memory _calldata) + internal + virtual + returns (bytes memory result) + { + bool success; + (success, result) = _target.call{value: value}(_calldata); + if (!success) { + assembly { + revert(add(result, 32), mload(result)) + } + } + } + + // /** + // * @dev ERC-1271 isValidSignature + // * This function is intended to be used to validate a smart account signature + // * and may forward the call to a validator module + // * + // * @param hash The hash of the data that is signed + // * @param data The data that is signed + // */ + // function isValidSignature(bytes32 hash, bytes calldata data) external view returns (bytes4) { + // (address validator, bytes memory signature) = abi.decode(data, (address, bytes)); + + // if (!_isValidatorInstalled(validator)) { + // revert InvalidModule(validator); + // } + + // return IValidator(validator).isValidSignatureWithSender(msg.sender, hash, signature); + // } + + // /** + // * @dev Returns the account id of the smart account + // * @return accountImplementationId the account id of the smart account + // * the accountId should be structured like so: + // * "vendorname.accountname.semver" + // */ + // function accountId() external pure returns (string memory) { + // return "thirdweb.modular.v0.0.1"; + // } + + function _domainNameAndVersion() internal pure override returns (string memory name, string memory version) { + name = "SimpleAccount"; + version = "1"; + } + + /// + /// Session Key + /// + function createSessionKey(SessionLib.SessionSpec calldata sessionSpec) external { + if (sessionSpec.signer == address(0)) { + revert SessionZeroSigner(); + } + // Sessions should expire in no less than 60 seconds. + if (sessionSpec.expiresAt <= block.timestamp + 60) { + revert SessionExpiresTooSoon(); + } + + bytes32 sessionUid = keccak256(abi.encodePacked(sessionSpec.signer, block.timestamp)); + + _accountStorage().sessionIds[sessionSpec.signer] = sessionUid; + _accountStorage().allSigners.add(sessionSpec.signer); + _accountStorage().sessionExpiration[sessionUid] = sessionSpec.expiresAt; + + for (uint256 i = 0; i < sessionSpec.callPolicies.length; i++) { + _accountStorage().callPolicies[sessionUid].push(sessionSpec.callPolicies[i]); + bytes32 targetHash = keccak256( + abi.encodePacked(sessionSpec.callPolicies[i].target, sessionSpec.callPolicies[i].selector, sessionUid) + ); + _accountStorage().targetCallPolicy[targetHash] = sessionSpec.callPolicies[i]; + + if (sessionSpec.callPolicies[i].selector == bytes4(0x00000000)) { + _accountStorage().anySelectorForTarget[sessionUid][sessionSpec.callPolicies[i].target] = true; + } + } + + for (uint256 i = 0; i < sessionSpec.transferPolicies.length; i++) { + _accountStorage().transferPolicies[sessionUid].push(sessionSpec.transferPolicies[i]); + bytes32 targetHash = keccak256(abi.encodePacked(sessionSpec.transferPolicies[i].target, sessionUid)); + _accountStorage().targetTransferPolicy[targetHash] = sessionSpec.transferPolicies[i]; + } + + emit SessionCreated(sessionUid, sessionSpec); + } + + function getCallPoliciesForSigner(address signer) + external + view + returns (SessionLib.CallSpec[] memory) + { + bytes32 sessionUid = _accountStorage().sessionIds[signer]; + return _accountStorage().callPolicies[sessionUid]; + } + + function getTransferPoliciesForSigner(address signer) + external + view + returns (SessionLib.TransferSpec[] memory) + { + bytes32 sessionUid = _accountStorage().sessionIds[signer]; + return _accountStorage().transferPolicies[sessionUid]; + } + + function getSessionExpirationForSigner(address signer) external view returns (uint256) { + bytes32 sessionUid = _accountStorage().sessionIds[signer]; + return _accountStorage().sessionExpiration[sessionUid]; + } + + function getSessionStateForSigner(address signer) + external + view + returns (SessionLib.SessionState memory) + { + bytes32 sessionUid = _accountStorage().sessionIds[signer]; + + return _accountStorage().sessions[sessionUid].getState( + _accountStorage().callPolicies[sessionUid], + _accountStorage().transferPolicies[sessionUid] + ); + } + + /// + /// IERC4337Account + /// + + function validateUserOp(PackedUserOperation calldata userOp, bytes32 userOpHash, uint256 missingAccountFunds) + external + payable + onlyEntryPoint + returns (uint256 validationData) + { + // bubble up the return value of the validator module + validationData = _validateSignature(userOp, userOpHash); + + assembly { + if missingAccountFunds { + pop(call(gas(), caller(), missingAccountFunds, callvalue(), callvalue(), callvalue(), callvalue())) + //ignore failure (its EntryPoint's job to verify, not account.) + } + } + } + + function _validateSignature(PackedUserOperation calldata userOp, bytes32 userOpHash) + internal + virtual + returns (uint256 validationData) + { + (address signer, bytes memory signature) = abi.decode(userOp.signature, (address, bytes)); + bytes32 hash = SignatureCheckerLib.toEthSignedMessageHash(userOpHash); + + bool isValidSig = SignatureCheckerLib.isValidERC6492SignatureNow(signer, hash, signature); + + if (!isValidSig) { + return SIG_VALIDATION_FAILED; + } + + (bool isValid, uint48 validAfter, uint48 validUntil) = _isValidSigner(signer, userOp); + + if (!isValid) { + return SIG_VALIDATION_FAILED; + } + + return _packValidationData(ValidationData(address(0), validAfter, validUntil)); + } + + function _isValidSigner(address _signer, PackedUserOperation calldata _userOp) + internal + virtual + returns (bool isValid, uint48 validAfter, uint48 validUntil) + { + if (_signer == owner() || hasAnyRole(_signer, _ADMIN_ROLE)) { + return (true, 0, 0); + } + + bytes32 sessionUid = _accountStorage().sessionIds[_signer]; + + bytes4 selector = _getFunctionSelector(_userOp.callData); + + if (selector == this.execute.selector) { + (address target, uint256 value, bytes memory callData) = decodeExecuteCalldata(_userOp.callData); + isValid = _validatePolicies(target, value, callData, sessionUid); + + return (isValid, validAfter, validUntil); + } else if (selector == this.executeBatch.selector) { + (address[] memory targets, uint256[] memory values, bytes[] memory callData) = decodeExecuteBatchCalldata(_userOp.callData); + uint256 length = targets.length; + + for (uint256 i = 0; i < length; i++) { + isValid = _validatePolicies(targets[i], values[i], callData[i], sessionUid); + + if (!isValid) { + return (isValid, validAfter, validUntil); + } + } + + isValid = true; + return (isValid, validAfter, validUntil); + } + } + + function _validatePolicies(address target, uint256 value, bytes memory callData, bytes32 sessionUid) + internal + returns (bool) + { + uint256 sessionExpiresAt = _accountStorage().sessionExpiration[sessionUid]; + SessionLib.CallSpec memory callPolicy; + SessionLib.TransferSpec memory transferPolicy; + + bytes4 targetSelector; + if (callData.length >= 4) { + targetSelector = _accountStorage().anySelectorForTarget[sessionUid][target] + ? bytes4(0x00000000) + : bytes4(callData[0]) | (bytes4(callData[1]) >> 8) | (bytes4(callData[2]) >> 16) + | (bytes4(callData[3]) >> 24); + bytes32 targetHash = keccak256(abi.encodePacked(target, targetSelector, sessionUid)); + callPolicy = _accountStorage().targetCallPolicy[targetHash]; + } else { + bytes32 targetHash = keccak256(abi.encodePacked(target, sessionUid)); + transferPolicy = _accountStorage().targetTransferPolicy[targetHash]; + } + return _accountStorage().sessions[sessionUid].validate( + target, targetSelector, value, callData, sessionExpiresAt, callPolicy, transferPolicy + ); + } + + function decodeExecuteCalldata(bytes calldata data) internal pure returns (address _target, uint256 _value, bytes memory _callData ) { + (_target, _value, _callData) = abi.decode(data[4:], (address, uint256, bytes)); + } + + function decodeExecuteBatchCalldata(bytes calldata data) + internal + pure + returns (address[] memory _targets, uint256[] memory _values, bytes[] memory _callData) + { + require(data.length >= 4 + 32 + 32 + 32, "!Data"); + + (_targets, _values, _callData) = abi.decode(data[4:], (address[], uint256[], bytes[])); + } + + function executeUserOp(PackedUserOperation calldata userOp, bytes32 /*userOpHash*/ ) + external + payable + onlyEntryPoint + { + bytes calldata callData = userOp.callData[4:]; + (bool success,) = address(this).delegatecall(callData); + if (!success) { + revert ExecutionFailed(); + } + } + + function _getFunctionSelector(bytes calldata data) internal pure returns (bytes4 functionSelector) { + require(data.length >= 4, "!Data"); + return bytes4(data[:4]); + } + + /// + /// Balance (solady) + /// + + /// @dev Returns the account's balance on the EntryPoint. + function getDeposit() public view virtual returns (uint256 result) { + address ep = address(entrypoint); + /// @solidity memory-safe-assembly + assembly { + mstore(0x20, address()) // Store the `account` argument. + mstore(0x00, 0x70a08231) // `balanceOf(address)`. + result := + mul( + // Returns 0 if the EntryPoint does not exist. + mload(0x20), + and( + // The arguments of `and` are evaluated from right to left. + gt(returndatasize(), 0x1f), // At least 32 bytes returned. + staticcall(gas(), ep, 0x1c, 0x24, 0x20, 0x20) + ) + ) + } + } + + /// @dev Deposit more funds for this account in the EntryPoint. + function addDeposit() public payable virtual { + address ep = address(entrypoint); + /// @solidity memory-safe-assembly + assembly { + // The EntryPoint has balance accounting logic in the `receive()` function. + // forgefmt: disable-next-item + if iszero( + mul( + extcodesize(ep), + call( + gas(), + ep, + callvalue(), + codesize(), + 0x00, + codesize(), + 0x00 + ) + ) + ) { + revert(codesize(), 0x00) // For gas estimation. + } + } + } + + /// @dev Withdraw ETH from the account's deposit on the EntryPoint. + function withdrawDepositTo(address to, uint256 amount) public payable virtual onlyOwner { + address ep = address(entrypoint); + /// @solidity memory-safe-assembly + assembly { + mstore(0x14, to) // Store the `to` argument. + mstore(0x34, amount) // Store the `amount` argument. + mstore(0x00, 0x205c2878000000000000000000000000) // `withdrawTo(address,uint256)`. + if iszero(mul(extcodesize(ep), call(gas(), ep, 0, 0x10, 0x44, codesize(), 0x00))) { + returndatacopy(mload(0x40), 0x00, returndatasize()) + revert(mload(0x40), returndatasize()) + } + mstore(0x34, 0) // Restore the part of the free memory pointer that was overwritten. + } + } + + /// @notice Returns all admins of the account. + function getAllAdmins() external view returns (address[] memory) { + return _accountStorage().allAdmins.values(); + } + + /// @dev Grants the roles directly without authorization guard. + /// Each bit of `roles` represents the role to turn on. + function _grantRoles(address user, uint256 roles) internal override { + super._updateRoles(user, roles, true); + + if (roles & _ADMIN_ROLE != 0) { + _accountStorage().allAdmins.add(user); + } + } + + /// @dev Removes the roles directly without authorization guard. + /// Each bit of `roles` represents the role to turn off. + function _removeRoles(address user, uint256 roles) internal override { + super._updateRoles(user, roles, false); + + if (roles & _ADMIN_ROLE != 0) { + _accountStorage().allAdmins.remove(user); + } + } + + function _accountStorage() internal pure returns (SimpleAccountStorage.Data storage data) { + data = SimpleAccountStorage.data(); + } + +} diff --git a/src/4337/SimpleAccountFactory.sol b/src/4337/SimpleAccountFactory.sol new file mode 100644 index 0000000..180186c --- /dev/null +++ b/src/4337/SimpleAccountFactory.sol @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import {SimpleAccount} from "./SimpleAccount.sol"; +import {LibClone} from "solady/utils/LibClone.sol"; + +contract SimpleAccountFactory { + + error InitializationFailed(); + + address public immutable entrypoint; + address public immutable accountImplementation; + + constructor(address _entrypoint) { + entrypoint = _entrypoint; + accountImplementation = address(new SimpleAccount(_entrypoint)); + } + + function createAccount(address owner, bytes memory salt) public payable returns (address) { + (bool alreadyDeployed, address account) = _createAccount(owner, salt); + + if (!alreadyDeployed) { + bytes memory fullInitData = abi.encodeWithSelector(SimpleAccount.initialize.selector, owner); + (bool success,) = account.call{value: msg.value}(fullInitData); + if (!success) { + revert InitializationFailed(); + } + } + + return account; + } + + function getAddress(address owner, bytes memory salt) public view virtual returns (address account) { + account = _getAccount(owner, salt); + } + + function _getAccountSalt(address owner, bytes memory salt) internal pure virtual returns (bytes32) { + return keccak256(abi.encodePacked(owner, salt)); + } + + function _getAccount(address owner, bytes memory salt) internal view virtual returns (address account) { + address impl = accountImplementation; + bytes32 accountSalt = _getAccountSalt(owner, salt); + + account = LibClone.predictDeterministicAddress(impl, accountSalt, address(this)); + } + + function _createAccount(address owner, bytes memory salt) + internal + returns (bool alreadyDeployed, address account) + { + account = _getAccount(owner, salt); + + if (account.code.length > 0) { + return (true, account); + } + + bytes32 accountSalt = _getAccountSalt(owner, salt); + account = LibClone.cloneDeterministic(accountImplementation, accountSalt); + } + +} diff --git a/test/SimpleAccount.t.sol b/test/SimpleAccount.t.sol new file mode 100644 index 0000000..afc2a55 --- /dev/null +++ b/test/SimpleAccount.t.sol @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.26; + +import {UserOperationLib} from "account-abstraction-v0.7/core/UserOperationLib.sol"; +import {IEntryPoint} from "account-abstraction-v0.7/interfaces/IEntryPoint.sol"; +import {PackedUserOperation} from "account-abstraction-v0.7/interfaces/PackedUserOperation.sol"; +import {Test} from "forge-std/Test.sol"; +import {ECDSA} from "solady/utils/ECDSA.sol"; +import {SignatureCheckerLib} from "solady/utils/SignatureCheckerLib.sol"; + +import {MockERC20} from "test/mock/MockERC20.sol"; +import {MockTarget} from "test/mock/MockTarget.sol"; +import {MockValidator} from "test/mock/MockValidator.sol"; +import {EntryPointLib} from "test/utils/ERC4337Test.sol"; + +import "lib/forge-std/src/console.sol"; + +import {SimpleAccount} from "src/4337/SimpleAccount.sol"; +import {SimpleAccountFactory} from "src/4337/SimpleAccountFactory.sol"; + +import {SessionLib} from "src/4337/SessionLib.sol"; +import {DefaultValidator, SessionKey, SessionKeyParams, SessionKeyType} from "src/DefaultValidator.sol"; + +contract SimpleAccountTest is Test { + + IEntryPoint public entrypoint; + SimpleAccountFactory public factory; + SimpleAccount public account; + + bytes public accountSalt = abi.encodePacked(uint256(0xdeadbeef)); + + uint256 accountOwnerPKey = 0x2; + address public accountOwner = vm.addr(accountOwnerPKey); + uint256 signerOnePKey = 0x3; + address signerOne = vm.addr(signerOnePKey); + + MockTarget public target; + MockERC20 public erc20; + + function setUp() public { + entrypoint = IEntryPoint(EntryPointLib.deploy()); + target = new MockTarget(); + erc20 = new MockERC20(); + factory = new SimpleAccountFactory(address(entrypoint)); + account = SimpleAccount(payable(factory.createAccount(accountOwner, accountSalt))); + + vm.deal(address(account), 1 ether); + erc20.mint(address(account), 1 ether); + } + + function test_execute_owner() public { + bytes memory userOpCalldata = + abi.encodeCall(SimpleAccount.execute, (address(target), uint256(0), abi.encodeCall(MockTarget.set, 42))); + + uint256 nonce = getNonce(address(account)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory userOps = _getSignedUserOp(accountOwner, accountOwnerPKey, userOp); + entrypoint.handleOps(userOps, payable(address(account))); + + assertTrue(target.number() == 42); + } + + function test_executeBatch_owner() public { + address[] memory targets = new address[](2); + uint256[] memory values = new uint256[](2); + bytes[] memory callData = new bytes[](2); + + targets[0] = address(target); + targets[1] = address(target); + callData[0] = abi.encodeCall(MockTarget.set, 11); + callData[1] = abi.encodeCall(MockTarget.setIf, (42, 11)); + + bytes memory userOpCalldata = abi.encodeCall(SimpleAccount.executeBatch, (targets, values, callData)); + + uint256 nonce = getNonce(address(account)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory userOps = _getSignedUserOp(accountOwner, accountOwnerPKey, userOp); + + entrypoint.handleOps(userOps, payable(address(account))); + + assertTrue(target.number() == 42); + } + + // function test_signature() public { + // // sign message + // bytes32 message = keccak256("Hello World"); + // bytes32 messageHash = DefaultValidator(payable(validator)).getMessageHash(message); + // (uint8 v, bytes32 r, bytes32 s) = vm.sign(accountOwnerPKey, messageHash); + // bytes memory signature = abi.encodePacked(r, s, v); + + // // encode validator address + signature + // bytes memory sigWithValidator = abi.encode(validator, signature); + + // // verify signature + // bytes4 returnValue = account.isValidSignature(message, sigWithValidator); + // assertEq(returnValue, bytes4(0x1626ba7e)); // Magic value from ERC-1271 + // } + + // // === signer is admin (non owner) + + // === session key tests + + function test_createSessionKey() public { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + _createSessionKey(spec); + + SessionLib.CallSpec[] memory fetchedCallPolicies = account.getCallPoliciesForSigner(signerOne); + + SessionLib.TransferSpec[] memory fetchedTransferPolicies = account.getTransferPoliciesForSigner(signerOne); + + uint256 expiresAt = account.getSessionExpirationForSigner(signerOne); + + assertEq(expiresAt, spec.expiresAt); + } + + function test_transferPolicy() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); + transferPolicies[0].target = address(target); + transferPolicies[0].maxValuePerUse = 500; + transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.transferPolicies = transferPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall(SimpleAccount.execute, (address(target), 100, "")); + + uint256 nonce = getNonce(address(account)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + + assertTrue(address(target).balance == 100); + } + + function test_revert_transferPolicy_crossLimit() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); + transferPolicies[0].target = address(target); + transferPolicies[0].maxValuePerUse = 500; + transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.transferPolicies = transferPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall(SimpleAccount.execute, (address(target), 600, "")); + + uint256 nonce = getNonce(address(account)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + + function test_revert_transferPolicy_wrongTarget() public { + // 1. create session key + + { + SessionLib.SessionSpec memory spec = getDefaultSessionSpec(signerOne); + + SessionLib.TransferSpec[] memory transferPolicies = new SessionLib.TransferSpec[](1); + transferPolicies[0].target = address(target); + transferPolicies[0].maxValuePerUse = 500; + transferPolicies[0].valueLimit = SessionLib.UsageLimit(SessionLib.LimitType.Allowance, 500, 10); + + spec.transferPolicies = transferPolicies; + + _createSessionKey(spec); + } + + // 2. prepare and send User Op + + bytes memory userOpCalldata = abi.encodeCall(SimpleAccount.execute, (address(0x1234), 500, "")); + + uint256 nonce = getNonce(address(account)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(signerOne, signerOnePKey, userOp); + + vm.expectRevert(); + entrypoint.handleOps(ops, payable(address(account))); + } + + // test utils + + function getNonce(address _account) internal returns (uint256 nonce) { + nonce = entrypoint.getNonce(_account, 0); + } + + function getDefaultUserOp() internal returns (PackedUserOperation memory userOp) { + userOp = PackedUserOperation({ + sender: address(0), + nonce: 0, + initCode: "", + callData: "", + accountGasLimits: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + preVerificationGas: 2e6, + gasFees: bytes32(abi.encodePacked(uint128(2e6), uint128(2e6))), + paymasterAndData: bytes(""), + signature: abi.encodePacked(hex"41414141") + }); + } + + function getDefaultSessionSpec(address signer) internal returns (SessionLib.SessionSpec memory spec) { + spec.signer = signer; + spec.expiresAt = 100; + spec.callPolicies = new SessionLib.CallSpec[](0); + spec.transferPolicies = new SessionLib.TransferSpec[](0); + } + + function _createSessionKey(SessionLib.SessionSpec memory spec) internal { + bytes memory userOpCalldata = abi.encodeCall( + SimpleAccount.execute, + (address(account), uint256(0), abi.encodeCall(SimpleAccount.createSessionKey, (spec))) + ); + + uint256 nonce = getNonce(address(account)); + + PackedUserOperation memory userOp = getDefaultUserOp(); + userOp.sender = address(account); + userOp.nonce = nonce; + userOp.callData = userOpCalldata; + + PackedUserOperation[] memory ops = _getSignedUserOp(accountOwner, accountOwnerPKey, userOp); + + entrypoint.handleOps(ops, payable(address(account))); + } + + function _getSignedUserOp(address signer, uint256 pkey, PackedUserOperation memory userOp) + internal + returns (PackedUserOperation[] memory) + { + // Sign UserOp + { + bytes32 opHash = entrypoint.getUserOpHash(userOp); + bytes32 msgHash = SignatureCheckerLib.toEthSignedMessageHash(opHash); + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(pkey, msgHash); + bytes memory userOpSignature = abi.encodePacked(r, s, v); + + userOp.signature = abi.encode(signer, userOpSignature); + } + + // Create userOps array + PackedUserOperation[] memory userOps = new PackedUserOperation[](1); + userOps[0] = userOp; + + return userOps; + } + +}