Skip to content

Commit b1d4098

Browse files
markusschmausbrandonwillard
authored andcommitted
Convert __init__s to only accept keyword arguments
1 parent 0929c9d commit b1d4098

File tree

13 files changed

+192
-102
lines changed

13 files changed

+192
-102
lines changed

aesara/graph/null_type.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class NullType(Type):
1717
1818
"""
1919

20+
__props__ = ("why_null",)
21+
2022
def __init__(self, why_null="(no explanation given)"):
2123
self.why_null = why_null
2224

aesara/graph/type.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from abc import abstractmethod
2-
from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union
1+
import inspect
2+
from abc import ABCMeta, abstractmethod
3+
from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union, final
34

45
from typing_extensions import Protocol, TypeAlias, runtime_checkable
56

@@ -11,14 +12,27 @@
1112
D = TypeVar("D")
1213

1314

14-
class NewTypeMeta(type):
15-
# pass
15+
class NewTypeMeta(ABCMeta):
16+
__props__: tuple[str, ...]
17+
1618
def __call__(cls, *args, **kwargs):
1719
raise RuntimeError("Use subtype")
1820
# return super().__call__(*args, **kwargs)
1921

2022
def subtype(cls, *args, **kwargs):
21-
return super().__call__(*args, **kwargs)
23+
kwargs = cls.type_parameters(*args, **kwargs)
24+
return super().__call__(**kwargs)
25+
26+
def type_parameters(cls, *args, **kwargs):
27+
if args:
28+
init_args = tuple(inspect.signature(cls.__init__).parameters.keys())[1:]
29+
if cls.__props__[: len(args)] != init_args[: len(args)]:
30+
raise RuntimeError(
31+
f"{cls.__props__=} doesn't match {init_args=} for {args=}"
32+
)
33+
34+
kwargs |= zip(cls.__props__, args)
35+
return kwargs
2236

2337

2438
class Type(Generic[D], metaclass=NewTypeMeta):
@@ -293,6 +307,11 @@ def _props_dict(self):
293307
"""
294308
return {a: getattr(self, a) for a in self.__props__}
295309

310+
@final
311+
def __init__(self, **kwargs):
312+
for k, v in kwargs.items():
313+
setattr(self, k, v)
314+
296315
def __hash__(self):
297316
return hash((type(self), tuple(getattr(self, a) for a in self.__props__)))
298317

aesara/link/c/params_type.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,9 @@ class ParamsType(CType):
343343
344344
"""
345345

346-
def __init__(self, **kwargs):
346+
@classmethod
347+
def type_parameters(cls, **kwargs):
348+
params = dict()
347349
if len(kwargs) == 0:
348350
raise ValueError("Cannot create ParamsType from empty data.")
349351

@@ -366,14 +368,14 @@ def __init__(self, **kwargs):
366368
% (attribute_name, type_name)
367369
)
368370

369-
self.length = len(kwargs)
370-
self.fields = tuple(sorted(kwargs.keys()))
371-
self.types = tuple(kwargs[field] for field in self.fields)
372-
self.name = self.generate_struct_name()
371+
params["length"] = len(kwargs)
372+
params["fields"] = tuple(sorted(kwargs.keys()))
373+
params["types"] = tuple(kwargs[field] for field in params["fields"])
374+
params["name"] = cls.generate_struct_name(params)
373375

374-
self.__const_to_enum = {}
375-
self.__alias_to_enum = {}
376-
enum_types = [t for t in self.types if isinstance(t, EnumType)]
376+
params["_const_to_enum"] = {}
377+
params["_alias_to_enum"] = {}
378+
enum_types = [t for t in params["types"] if isinstance(t, EnumType)]
377379
if enum_types:
378380
# We don't want same enum names in different enum types.
379381
if sum(len(t) for t in enum_types) != len(
@@ -398,35 +400,40 @@ def __init__(self, **kwargs):
398400
)
399401
# We map each enum name to the enum type in which it is defined.
400402
# We will then use this dict to find enum value when looking for enum name in ParamsType object directly.
401-
self.__const_to_enum = {
403+
params["_const_to_enum"] = {
402404
enum_name: enum_type
403405
for enum_type in enum_types
404406
for enum_name in enum_type
405407
}
406-
self.__alias_to_enum = {
408+
params["_alias_to_enum"] = {
407409
alias: enum_type
408410
for enum_type in enum_types
409411
for alias in enum_type.aliases
410412
}
411413

414+
return params
415+
412416
def __setstate__(self, state):
413417
# NB:
414418
# I have overridden __getattr__ to make enum constants available through
415419
# the ParamsType when it contains enum types. To do that, I use some internal
416-
# attributes: self.__const_to_enum and self.__alias_to_enum. These attributes
420+
# attributes: self._const_to_enum and self._alias_to_enum. These attributes
417421
# are normally found by Python without need to call getattr(), but when the
418422
# ParamsType is unpickled, it seems gettatr() may be called at a point before
419-
# __const_to_enum or __alias_to_enum are unpickled, so that gettatr() can't find
423+
# _const_to_enum or _alias_to_enum are unpickled, so that gettatr() can't find
420424
# those attributes, and then loop infinitely.
421425
# For this reason, I must add this trivial implementation of __setstate__()
422426
# to avoid errors when unpickling.
423427
self.__dict__.update(state)
424428

425429
def __getattr__(self, key):
426430
# Now we can access value of each enum defined inside enum types wrapped into the current ParamsType.
427-
if key in self.__const_to_enum:
428-
return self.__const_to_enum[key][key]
429-
return super().__getattr__(self, key)
431+
# const_to_enum = super().__getattribute__("_const_to_enum")
432+
if not key.startswith("__"):
433+
const_to_enum = self._const_to_enum
434+
if key in const_to_enum:
435+
return const_to_enum[key][key]
436+
raise AttributeError(f"'{self}' object has no attribute '{key}'")
430437

431438
def __repr__(self):
432439
return "ParamsType<%s>" % ", ".join(
@@ -446,13 +453,14 @@ def __eq__(self, other):
446453
def __hash__(self):
447454
return hash((type(self),) + self.fields + self.types)
448455

449-
def generate_struct_name(self):
450-
# This method tries to generate an unique name for the current instance.
456+
@staticmethod
457+
def generate_struct_name(params):
458+
# This method tries to generate a unique name for the current instance.
451459
# This name is intended to be used as struct name in C code and as constant
452460
# definition to check if a similar ParamsType has already been created
453461
# (see c_support_code() below).
454-
fields_string = ",".join(self.fields).encode("utf-8")
455-
types_string = ",".join(str(t) for t in self.types).encode("utf-8")
462+
fields_string = ",".join(params["fields"]).encode("utf-8")
463+
types_string = ",".join(str(t) for t in params["types"]).encode("utf-8")
456464
fields_hex = hashlib.sha256(fields_string).hexdigest()
457465
types_hex = hashlib.sha256(types_string).hexdigest()
458466
return f"_Params_{fields_hex}_{types_hex}"
@@ -510,7 +518,7 @@ def get_enum(self, key):
510518
print(wrapper.TWO)
511519
512520
"""
513-
return self.__const_to_enum[key][key]
521+
return self._const_to_enum[key][key]
514522

515523
def enum_from_alias(self, alias):
516524
"""
@@ -547,10 +555,11 @@ def enum_from_alias(self, alias):
547555
method to do that.
548556
549557
"""
558+
alias_to_enum = self._alias_to_enum
550559
return (
551-
self.__alias_to_enum[alias].fromalias(alias)
552-
if alias in self.__alias_to_enum
553-
else self.__const_to_enum[alias][alias]
560+
alias_to_enum[alias].fromalias(alias)
561+
if alias in alias_to_enum
562+
else self._const_to_enum[alias][alias]
554563
)
555564

556565
def get_params(self, *objects, **kwargs) -> Params:

0 commit comments

Comments
 (0)