Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 90 additions & 7 deletions ibis/common/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def __str__(self):
errors += f"\n `{name}`: {value!r} of type {type(value)} is not {pattern.describe()}"

sig = f"{self.func.__name__}{self.sig}"
cause = str(self.__cause__) if self.__cause__ else ""
# remove the leading "_custom_bind_fn()" that comes from the custom bind function generated in SignatureBinder
cause = str(self.__cause__).removeprefix('_custom_bind_fn() ') if self.__cause__ else ""

return self.msg.format(sig=sig, call=call, cause=cause, errors=errors)

Expand Down Expand Up @@ -295,13 +296,83 @@ def from_argument(cls, name: str, annotation: Argument) -> Self:
)


class ReprableVariableName:
"""Holds a string that will be used as a variable name in code to generate a default value for a parameter
in a binding function for a Signature created by SignatureBinder.

Needed because Signature.__repr__, which is used to generate binding function argument list, will call repr() on default values.
"""
def __init__(self, name: str):
self.name = name

def __repr__(self):
"""Return the variable name without quotes."""
return self.name


class SignatureBinder:
"""Given a Signature, builds a callable object that binds arguments to parameters
according to that Signature, returning a dict of parameter names to bound values.

Behaviour of the resulting callable object is equivalent to inspect.Signature.bind,
but is more performant as it uses cpython's argument binding logic directly,
instead of a slower pure-python implementation.

Example::

from ibis.common.annotations import Signature
def fn(a, b: int, c: Foo = Foo()): ...
sig = Signature.from_callable(fn)
binder = SignatureBinder(sig)
binder(1, 2) # returns {'a': 1, 'b': 2, 'c': Foo()}
"""

def __init__(self, signature: Signature):
namespace = {} # a namespace of default variable name -> default value used with exec below
processed_params = []
for name, param in signature.parameters.items():
if param.default is not inspect.Parameter.empty:
# Create a unique variable name for the default value of this parameter,
# and store the actual default value in the namespace under that name.
varname = f'__default_{name}__'
default_val = ReprableVariableName(varname)
namespace[varname] = param.default
else:
default_val = inspect.Parameter.empty

processed_params.append(param.replace(
default=default_val,
annotation=inspect.Parameter.empty
))

# build a new signature with default values replaced with generated variable names
processed_signature = inspect.Signature(parameters=processed_params)
self.bind_fn_str = f'def _custom_bind_fn{processed_signature}:\n return locals()'

exec(compile(self.bind_fn_str, '<string>', 'exec'), namespace)
self._bind_fn = namespace['_custom_bind_fn']

def __call__(self, *args, **kwargs):
return self._bind_fn(*args, **kwargs)

def __repr__(self) -> str:
"""To help with debugging, returns the generated source code of the binding function."""
return self.bind_fn_str


class Signature(inspect.Signature):
"""Validatable signature.

Primarily used in the implementation of `ibis.common.grounds.Annotable`.
"""

__slots__ = ()
__slots__ = ('_patterns', '_binder_fn')

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# prebuild dict of patterns to avoid slow retrieval via property&MappingProxyType
self._patterns = {k: param.annotation.pattern for k, param in self.parameters.items() if hasattr(param.annotation, 'pattern')}
self._binder_fn = SignatureBinder(self)._bind_fn

@classmethod
def merge(cls, *signatures, **annotations):
Expand Down Expand Up @@ -509,15 +580,27 @@ def validate(self, func, args, kwargs):

return this

def validate_fast(self, func, args, kwargs):
"""Faster validation using custom bind function for this signature (instead of Signature.bind)."""
try:
bound_kwargs = self._binder_fn(*args, **kwargs)
except TypeError as err:
raise SignatureValidationError(
"{call} {cause}\n\nExpected signature: {sig}",
sig=self,
func=func,
args=args,
kwargs=kwargs,
) from err

return self.validate_nobind(func, bound_kwargs)

def validate_nobind(self, func, kwargs):
"""Validate the arguments against the signature without binding."""
this, errors = {}, []
for name, param in self.parameters.items():
value = kwargs.get(name, param.default)
if value is EMPTY:
raise TypeError(f"missing required argument `{name!r}`")
for name, pattern in self._patterns.items():
value = kwargs[name]

pattern = param.annotation.pattern
result = pattern.match(value, this)
if result is NoMatch:
errors.append((name, value, pattern))
Expand Down
4 changes: 2 additions & 2 deletions ibis/common/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ class Annotable(Abstract, metaclass=AnnotableMeta):
@classmethod
def __create__(cls, *args: Any, **kwargs: Any) -> Self:
# construct the instance by passing only validated keyword arguments
kwargs = cls.__signature__.validate(cls, args, kwargs)
return super().__create__(**kwargs)
validated_kwargs = cls.__signature__.validate_fast(cls, args, kwargs)
return super().__create__(**validated_kwargs)

@classmethod
def __recreate__(cls, kwargs: Any) -> Self:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Example('1', '2', '3', '4', '5', []) has failed due to the following errors:
Example(a='1', b='2', c='3', d='4', e='5', f=[]) has failed due to the following errors:
`a`: '1' of type <class 'str'> is not an int
`b`: '2' of type <class 'str'> is not an int
`d`: '4' of type <class 'str'> is not either None or a float
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Example('1', '2', '3', '4', '5', []) has failed due to the following errors:
Example(a='1', b='2', c='3', d='4', e='5', f=[]) has failed due to the following errors:
`a`: '1' of type <class 'str'> is not an int
`b`: '2' of type <class 'str'> is not an int
`d`: '4' of type <class 'str'> is not either None or a float
Expand Down
6 changes: 3 additions & 3 deletions ibis/common/tests/test_grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ class Test2(Test):
c = is_int
args = varargs(is_int)

with pytest.raises(ValidationError, match="missing a required argument: 'c'"):
with pytest.raises(ValidationError, match="missing 1 required positional argument: 'c'"):
Test2(1, 2)

a = Test2(1, 2, 3)
Expand Down Expand Up @@ -578,7 +578,7 @@ class Test2(Test):
c = is_int
options = varkwargs(is_int)

with pytest.raises(ValidationError, match="missing a required argument: 'c'"):
with pytest.raises(ValidationError, match="missing 1 required positional argument: 'c'"):
Test2(1, 2)

a = Test2(1, 2, c=3)
Expand Down Expand Up @@ -858,7 +858,7 @@ class Flexible(Annotable):


def test_annotable_attribute():
with pytest.raises(ValidationError, match="too many positional arguments"):
with pytest.raises(ValidationError, match="takes 1 positional argument but 2 were given"):
BaseValue(1, 2)

v = BaseValue(1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Literal(1) missing a required argument: 'dtype'
Literal(1) missing 1 required positional argument: 'dtype'

Expected signature: Literal(value: Annotated[Any, Not(pattern=InstanceOf(type=<class 'Deferred'>))], dtype: DataType)
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Literal(1, Int8(nullable=True), 'foo') too many positional arguments
Literal(1, Int8(nullable=True), 'foo') takes 2 positional arguments but 3 were given

Expected signature: Literal(value: Annotated[Any, Not(pattern=InstanceOf(type=<class 'Deferred'>))], dtype: DataType)
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Literal(1, Int8(nullable=True), dtype=Int16(nullable=True)) multiple values for argument 'dtype'
Literal(1, Int8(nullable=True), dtype=Int16(nullable=True)) got multiple values for argument 'dtype'

Expected signature: Literal(value: Annotated[Any, Not(pattern=InstanceOf(type=<class 'Deferred'>))], dtype: DataType)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Literal(1, 4) has failed due to the following errors:
Literal(value=1, dtype=4) has failed due to the following errors:
`dtype`: 4 of type <class 'int'> is not coercible to a DataType

Expected signature: Literal(value: Annotated[Any, Not(pattern=InstanceOf(type=<class 'Deferred'>))], dtype: DataType)
Loading