Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions docs/statements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,28 @@ raise
-----

The ``raise`` statement triggers an exception and reverts the current call.

.. code-block:: vyper

raise "something went wrong"

The error string is not required. If it is provided, it is limited to 1024 bytes.

Custom errors can also be raised. They share the same syntax as events at module scope and are encoded with a 4-byte selector followed by ABI-encoded arguments:

.. code-block:: vyper

error Unauthorized:
caller: address
expected: address

assert msg.sender == owner, Unauthorized(caller=msg.sender, expected=owner)

Custom errors are included in the generated ABI with ``type: "error"``.

assert
------

.. code-block:: vyper

raise "something went wrong"

The error string is not required. If it is provided, it is limited to 1024 bytes.

assert
------

The ``assert`` statement makes an assertion about a given condition. If the condition evaluates falsely, the transaction is reverted.

Expand Down
53 changes: 53 additions & 0 deletions tests/functional/syntax/test_custom_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

from tests.evm_backends.abi import abi_decode
from tests.evm_backends.base_env import ExecutionReverted
from vyper.utils import method_id


def test_custom_error_revert(env, get_contract):
code = """
error Unauthorized:
caller: address

@external
def fail():
raise Unauthorized(caller=msg.sender)
"""

contract = get_contract(code)

with pytest.raises(ExecutionReverted) as excinfo:
contract.fail(sender=env.deployer)

revert_hex = excinfo.value.args[0]
assert revert_hex.startswith("0x")

data = bytes.fromhex(revert_hex[2:])
assert data[:4] == method_id("Unauthorized(address)")

(caller,) = abi_decode("(address)", data[4:])
assert caller == env.deployer


def test_custom_error_dynamic_arg(env, get_contract):
code = """
error Fancy:
note: String[16]
count: uint256

@external
def boom():
raise Fancy(note="hi", count=3)
"""

contract = get_contract(code)

with pytest.raises(ExecutionReverted) as excinfo:
contract.boom(sender=env.deployer)

data = bytes.fromhex(excinfo.value.args[0][2:])
assert data[:4] == method_id("Fancy(string,uint256)")

decoded = abi_decode("(string,uint256)", data[4:])
assert decoded == ("hi", 3)
25 changes: 25 additions & 0 deletions tests/unit/compiler/test_abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,31 @@ def foo(s: decimal) -> decimal:
assert out["abi"] == expected_abi


def test_custom_error_abi():
code = """
error Unauthorized:
caller: address

error Simple:
pass

@external
def fail():
raise Unauthorized(caller=msg.sender)
"""

abi = compile_code(code, output_formats=["abi"])["abi"]

unauthorized = next(
item for item in abi if item.get("type") == "error" and item.get("name") == "Unauthorized"
)
assert unauthorized["inputs"][0]["name"] == "caller"
assert unauthorized["inputs"][0]["type"] == "address"

simple = next(item for item in abi if item.get("type") == "error" and item.get("name") == "Simple")
assert simple["inputs"] == []


def test_struct_abi():
code = """
struct MyStruct:
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/semantics/types/test_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest

from vyper.semantics.types.user import ErrorT
from vyper.utils import method_id_int


ERROR_ID_TESTS = [
("error Unauthorized: pass", "Unauthorized()", method_id_int("Unauthorized()")),
(
"""error InsufficientBalance:
available: uint256
required: uint256
""",
"InsufficientBalance(uint256,uint256)",
method_id_int("InsufficientBalance(uint256,uint256)"),
),
]


@pytest.mark.parametrize("source,signature,selector", ERROR_ID_TESTS)
def test_error_selector(build_node, source, signature, selector):
node = build_node(source)
err = ErrorT.from_ErrorDef(node)

assert err.signature == signature
assert err.selector == selector
8 changes: 8 additions & 0 deletions vyper/ast/grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module: ( DOCSTRING
| variable_def
| enum_def // TODO deprecate at some point in favor of flag
| flag_def
| error_def
| event_def
| function_def
| exports_decl
Expand Down Expand Up @@ -71,6 +72,13 @@ indexed_event_arg: NAME ":" "indexed" "(" type ")"
event_body: _NEWLINE _INDENT (((event_member | indexed_event_arg ) _NEWLINE)+ | _PASS _NEWLINE) _DEDENT
event_def: _EVENT_DECL NAME ":" ( event_body | _PASS )

// Custom errors mirror event syntax (without indexed fields)
_ERROR_DECL: "error"
error_member: NAME ":" type
// Errors which use no args use a pass statement instead
error_body: _NEWLINE _INDENT ((error_member _NEWLINE)+ | _PASS _NEWLINE) _DEDENT
error_def: _ERROR_DECL NAME ":" ( error_body | _PASS )

// TODO deprecate in favor of flag
// Enums
_ENUM_DECL: "enum"
Expand Down
4 changes: 4 additions & 0 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,10 @@ class EventDef(TopLevel):
__slots__ = ("name", "body")


class ErrorDef(TopLevel):
__slots__ = ("name", "body")


class InterfaceDef(TopLevel):
__slots__ = ("name", "body")

Expand Down
4 changes: 4 additions & 0 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ class EventDef(VyperNode):
body: list = ...
name: str = ...

class ErrorDef(VyperNode):
body: list = ...
name: str = ...

class InterfaceDef(VyperNode):
body: list = ...
name: str = ...
Expand Down
1 change: 1 addition & 0 deletions vyper/ast/pre_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def consume(self, token, result):
"flag": "FlagDef",
"enum": "EnumDef",
"event": "EventDef",
"error": "ErrorDef",
"interface": "InterfaceDef",
"struct": "StructDef",
}
Expand Down
61 changes: 60 additions & 1 deletion vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vyper.codegen.expr import Expr
from vyper.codegen.return_ import make_return_stmt
from vyper.exceptions import CodegenPanic, StructureException, TypeCheckFailure, tag_exceptions
from vyper.semantics.types import DArrayT
from vyper.semantics.types import DArrayT, ErrorT, TupleT
from vyper.semantics.types.shortcuts import UINT256_T


Expand Down Expand Up @@ -123,6 +123,10 @@ def _assert_reason(self, test_expr, msg):
["assert_unreachable", test_expr], error_msg="assert unreachable"
)

msg_type = msg._metadata.get("type") if hasattr(msg, "_metadata") else None
if isinstance(msg_type, ErrorT):
return self._assert_custom_error(test_expr, msg, msg_type)

# set constant so that revert reason str is well behaved
try:
tmp = self.context.constancy
Expand Down Expand Up @@ -157,6 +161,61 @@ def _assert_reason(self, test_expr, msg):
ir_node = ["if", ["iszero", test_expr], revert_seq]
return IRnode.from_list(ir_node, error_msg="user revert with reason")

def _custom_error_args(
self, call: vy_ast.Call, error_t: ErrorT
) -> list[vy_ast.VyperNode]:
if len(call.keywords) > 0:
kwarg_lookup = {kw.arg: kw.value for kw in call.keywords}
return [kwarg_lookup[name] for name in error_t.arguments.keys()]

return call.args

def _assert_custom_error(self, test_expr, msg: vy_ast.Call, error_t: ErrorT):
is_raise = test_expr is None

arg_nodes = self._custom_error_args(msg, error_t)
arg_irs = [Expr(arg, self.context).ir_node for arg in arg_nodes]

args_tuple_t = TupleT(tuple(error_t.arguments.values()))
args_as_tuple = IRnode.from_list(["multi", *arg_irs], typ=args_tuple_t)

abi_t = args_tuple_t.abi_type
buflen = abi_t.size_bound() + 32
buf = self.context.new_internal_variable(get_type_for_exact_size(buflen))

if len(arg_irs) == 0:
revert_seq = [
"seq",
["mstore", buf, error_t.selector],
["revert", add_ofst(buf, 28), 4],
]
else:
payload_buf = add_ofst(buf, 32)
encode_buflen = buflen - 32
encoded_length = abi_encode(
payload_buf,
args_as_tuple,
self.context,
bufsz=encode_buflen,
returns_len=True,
)
with encoded_length.cache_when_complex("encoded_len") as (
b1,
encoded_length,
):
revert_seq = [
"seq",
["mstore", buf, error_t.selector],
["revert", add_ofst(buf, 28), ["add", 4, encoded_length]],
]
revert_seq = b1.resolve(revert_seq)

if is_raise:
ir_node = revert_seq
else:
ir_node = ["if", ["iszero", test_expr], revert_seq]
return IRnode.from_list(ir_node, error_msg="user revert with reason")

def parse_Assert(self):
test_expr = Expr.parse_value_expr(self.stmt.test, self.context)

Expand Down
23 changes: 17 additions & 6 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
AddressT,
BoolT,
DArrayT,
ErrorT,
EventT,
FlagT,
HashMapT,
Expand Down Expand Up @@ -369,16 +370,26 @@ def visit_AnnAssign(self, node):
self.expr_visitor.visit(node.target, typ)

def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None:
if isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE":
return

if isinstance(msg_node, vy_ast.Call):
call_type = get_exact_type_from_node(msg_node.func)
if is_type_t(call_type, ErrorT):
self.expr_visitor.visit(msg_node, call_type.typedef)
return

if isinstance(msg_node, vy_ast.Str):
if not msg_node.value.strip():
raise StructureException("Reason string cannot be empty", msg_node)
self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node))
elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"):
try:
validate_expected_type(msg_node, StringT(1024))
except TypeMismatch as e:
raise InvalidType("revert reason must fit within String[1024]") from e
self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node))
return

try:
validate_expected_type(msg_node, StringT(1024))
except TypeMismatch as e:
raise InvalidType("revert reason must fit within String[1024]") from e
self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node))
# CMC 2023-10-19 nice to have: tag UNREACHABLE nodes with a special type

def visit_Assert(self, node):
Expand Down
11 changes: 9 additions & 2 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace
from vyper.semantics.types import TYPE_T, EventT, FlagT, InterfaceT, StructT, VyperType, is_type_t
from vyper.semantics.types import TYPE_T, ErrorT, EventT, FlagT, InterfaceT, StructT, VyperType, is_type_t
from vyper.semantics.types.function import ContractFunctionT
from vyper.semantics.types.module import ModuleT
from vyper.semantics.types.utils import type_from_annotation
Expand Down Expand Up @@ -172,6 +172,7 @@ def __init__(
self._all_implements: dict[VyperType, vy_ast.VyperNode] = {}

self._events: list[EventT] = []
self._errors: list[ErrorT] = []

self.module_t: Optional[ModuleT] = None

Expand All @@ -194,7 +195,7 @@ def analyze_module_body(self):

# handle some node types using a dependency resolution routine
# which loops, swallowing exceptions until all nodes are processed
type_decls = (vy_ast.FlagDef, vy_ast.StructDef, vy_ast.InterfaceDef, vy_ast.EventDef)
type_decls = (vy_ast.FlagDef, vy_ast.StructDef, vy_ast.InterfaceDef, vy_ast.ErrorDef, vy_ast.EventDef)
self._visit_nodes_looping(type_decls)

# handle functions
Expand Down Expand Up @@ -716,6 +717,12 @@ def visit_FlagDef(self, node):
node._metadata["flag_type"] = obj
self.namespace[node.name] = obj

def visit_ErrorDef(self, node):
obj = ErrorT.from_ErrorDef(node)
node._metadata["error_type"] = obj
self.namespace[node.name] = obj
self._errors.append(obj)

def visit_EventDef(self, node):
obj = EventT.from_EventDef(node)
node._metadata["event_type"] = obj
Expand Down
2 changes: 2 additions & 0 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ def types_from_Name(self, node):
try:
t = self.namespace[node.id]
# when this is a type, we want to lower it
if isinstance(t, TYPE_T):
return [t]
if isinstance(t, VyperType):
# TYPE_T is used to handle cases where a type can occur in call or
# attribute conditions, like Flag.foo or MyStruct({...})
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .module import InterfaceT, ModuleT
from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT
from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT
from .user import EventT, FlagT, StructT
from .user import ErrorT, EventT, FlagT, StructT


def _get_primitive_types():
Expand Down
Loading
Loading