From 02f14355fe8e25ada1aa73129a238885466eab54 Mon Sep 17 00:00:00 2001 From: Markus Schmaus Date: Fri, 23 Sep 2022 14:35:30 +0200 Subject: [PATCH 01/21] Turn `HasDataType` and `HasShape` into `Protocol`s --- aesara/graph/type.py | 22 +++++++++++++++------- aesara/link/c/cmodule.py | 2 +- aesara/scalar/basic.py | 5 +++-- aesara/sparse/type.py | 3 +-- aesara/tensor/type.py | 8 ++++++-- 5 files changed, 26 insertions(+), 14 deletions(-) diff --git a/aesara/graph/type.py b/aesara/graph/type.py index e08f40e09a..37860914e3 100644 --- a/aesara/graph/type.py +++ b/aesara/graph/type.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union -from typing_extensions import TypeAlias +from typing_extensions import Protocol, TypeAlias, runtime_checkable from aesara.graph import utils from aesara.graph.basic import Constant, Variable @@ -262,14 +262,22 @@ def values_eq_approx(cls, a: D, b: D) -> bool: return cls.values_eq(a, b) -class HasDataType: - """A mixin for a type that has a :attr:`dtype` attribute.""" +DataType = str - dtype: str +@runtime_checkable +class HasDataType(Protocol): + """A protocol matching any class with :attr:`dtype` attribute.""" -class HasShape: - """A mixin for a type that has :attr:`shape` and :attr:`ndim` attributes.""" + dtype: DataType + + +ShapeType = Tuple[Optional[int], ...] + + +@runtime_checkable +class HasShape(Protocol): + """A protocol matching any class that has :attr:`shape` and :attr:`ndim` attributes.""" ndim: int - shape: Tuple[Optional[int], ...] + shape: ShapeType diff --git a/aesara/link/c/cmodule.py b/aesara/link/c/cmodule.py index 58102b7303..ccf5852ea0 100644 --- a/aesara/link/c/cmodule.py +++ b/aesara/link/c/cmodule.py @@ -2441,7 +2441,7 @@ def linking_patch(lib_dirs: List[str], libs: List[str]) -> List[str]: if sys.platform != "win32": return [f"-l{l}" for l in libs] - def sort_key(lib): # type: ignore + def sort_key(lib): name, *numbers, extension = lib.split(".") return (extension == "dll", tuple(map(int, numbers))) diff --git a/aesara/scalar/basic.py b/aesara/scalar/basic.py index 6764f4a16e..ec0f21e36d 100644 --- a/aesara/scalar/basic.py +++ b/aesara/scalar/basic.py @@ -28,7 +28,7 @@ from aesara.graph.basic import Apply, Constant, Variable, clone, list_of_nodes from aesara.graph.fg import FunctionGraph from aesara.graph.rewriting.basic import MergeOptimizer -from aesara.graph.type import HasDataType, HasShape +from aesara.graph.type import DataType from aesara.graph.utils import MetaObject, MethodNotDefined from aesara.link.c.op import COp from aesara.link.c.type import CType @@ -268,7 +268,7 @@ def convert(x, dtype=None): return x_ -class ScalarType(CType, HasDataType, HasShape): +class ScalarType(CType): """ Internal class, should not be used by clients. @@ -284,6 +284,7 @@ class ScalarType(CType, HasDataType, HasShape): __props__ = ("dtype",) ndim = 0 shape = () + dtype: DataType def __init__(self, dtype): if isinstance(dtype, str) and dtype == "floatX": diff --git a/aesara/sparse/type.py b/aesara/sparse/type.py index e2ce91d64c..048e69e111 100644 --- a/aesara/sparse/type.py +++ b/aesara/sparse/type.py @@ -7,7 +7,6 @@ import aesara from aesara import scalar as aes from aesara.graph.basic import Variable -from aesara.graph.type import HasDataType from aesara.tensor.type import DenseTensorType, TensorType @@ -33,7 +32,7 @@ def _is_sparse(x): return isinstance(x, scipy.sparse.spmatrix) -class SparseTensorType(TensorType, HasDataType): +class SparseTensorType(TensorType): """A `Type` for sparse tensors. Notes diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 06cf964142..2affadb154 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -8,7 +8,7 @@ from aesara import scalar as aes from aesara.configdefaults import config from aesara.graph.basic import Variable -from aesara.graph.type import HasDataType, HasShape +from aesara.graph.type import DataType, ShapeType from aesara.graph.utils import MetaType from aesara.link.c.type import CType from aesara.misc.safe_asarray import _asarray @@ -48,11 +48,15 @@ } -class TensorType(CType[np.ndarray], HasDataType, HasShape): +class TensorType(CType[np.ndarray]): r"""Symbolic `Type` representing `numpy.ndarray`\s.""" __props__: Tuple[str, ...] = ("dtype", "shape") + ndim: int + shape: ShapeType + dtype: DataType + dtype_specs_map = dtype_specs_map context_name = "cpu" filter_checks_isfinite = False From 0929c9dc1d152b881538df8cb1abcb98e132b8c0 Mon Sep 17 00:00:00 2001 From: Markus Schmaus Date: Sat, 24 Sep 2022 09:49:53 +0200 Subject: [PATCH 02/21] Prepare switch from class to metaclass for `Type` --- aesara/breakpoint.py | 2 +- aesara/gradient.py | 8 +- aesara/graph/null_type.py | 2 +- aesara/graph/type.py | 57 +++++- aesara/link/c/params_type.py | 2 +- aesara/link/c/type.py | 10 +- aesara/raise_op.py | 6 +- aesara/sandbox/multinomial.py | 4 +- aesara/sandbox/rng_mrg.py | 4 +- aesara/scalar/basic.py | 8 +- aesara/scalar/sharedvar.py | 2 +- aesara/scan/op.py | 22 +-- aesara/sparse/basic.py | 90 +++++---- aesara/sparse/rewriting.py | 10 +- aesara/sparse/sandbox/sp2.py | 2 +- aesara/sparse/sharedvar.py | 2 +- aesara/sparse/type.py | 2 +- aesara/tensor/basic.py | 34 ++-- aesara/tensor/blas.py | 2 +- aesara/tensor/blas_c.py | 4 +- aesara/tensor/elemwise.py | 14 +- aesara/tensor/extra_ops.py | 23 ++- aesara/tensor/fft.py | 8 +- aesara/tensor/fourier.py | 2 +- aesara/tensor/io.py | 8 +- aesara/tensor/math.py | 8 +- aesara/tensor/nnet/abstract_conv.py | 8 +- aesara/tensor/nnet/basic.py | 8 +- aesara/tensor/nnet/batchnorm.py | 29 +-- aesara/tensor/nnet/conv3d2d.py | 8 +- aesara/tensor/nnet/corr.py | 18 +- aesara/tensor/nnet/corr3d.py | 14 +- aesara/tensor/nnet/neighbours.py | 2 +- aesara/tensor/nnet/sigm.py | 2 +- aesara/tensor/random/op.py | 6 +- aesara/tensor/random/type.py | 4 +- aesara/tensor/rewriting/basic.py | 2 +- aesara/tensor/shape.py | 12 +- aesara/tensor/sharedvar.py | 4 +- aesara/tensor/signal/pool.py | 20 +- aesara/tensor/sort.py | 6 +- aesara/tensor/subtensor.py | 22 ++- aesara/tensor/type.py | 197 +++++++++---------- aesara/tensor/type_other.py | 6 +- aesara/tensor/var.py | 6 +- aesara/typed_list/basic.py | 4 +- aesara/typed_list/type.py | 2 +- doc/aesara_installer_for_anaconda.bat | 86 ++++---- doc/extending/extending_aesara_solution_1.py | 6 +- doc/extending/other_ops.rst | 2 +- tests/compile/function/test_pfunc.py | 4 +- tests/compile/test_builders.py | 2 +- tests/compile/test_debugmode.py | 4 +- tests/compile/test_shared.py | 16 +- tests/graph/rewriting/test_basic.py | 22 +-- tests/graph/rewriting/test_kanren.py | 2 +- tests/graph/rewriting/test_unify.py | 16 +- tests/graph/test_basic.py | 14 +- tests/graph/test_compute_test_value.py | 2 +- tests/graph/test_destroyhandler.py | 4 +- tests/graph/test_features.py | 4 +- tests/graph/test_fg.py | 4 +- tests/graph/test_op.py | 14 +- tests/graph/test_types.py | 22 +-- tests/graph/utils.py | 10 +- tests/link/c/test_basic.py | 2 +- tests/link/c/test_op.py | 6 +- tests/link/c/test_params_type.py | 111 ++++++----- tests/link/c/test_type.py | 30 +-- tests/link/numba/test_basic.py | 2 +- tests/link/test_link.py | 2 +- tests/scalar/test_basic.py | 4 +- tests/scalar/test_type.py | 8 +- tests/sparse/test_basic.py | 86 ++++---- tests/sparse/test_type.py | 8 +- tests/tensor/nnet/speed_test_conv.py | 2 +- tests/tensor/nnet/test_abstract_conv.py | 4 +- tests/tensor/nnet/test_batchnorm.py | 2 +- tests/tensor/rewriting/test_basic.py | 16 +- tests/tensor/rewriting/test_elemwise.py | 14 +- tests/tensor/rewriting/test_math.py | 16 +- tests/tensor/rewriting/test_shape.py | 2 +- tests/tensor/rewriting/test_subtensor.py | 2 +- tests/tensor/signal/test_conv.py | 4 +- tests/tensor/signal/test_pool.py | 2 +- tests/tensor/test_basic.py | 52 ++--- tests/tensor/test_casting.py | 6 +- tests/tensor/test_elemwise.py | 74 +++---- tests/tensor/test_extra_ops.py | 32 +-- tests/tensor/test_io.py | 8 +- tests/tensor/test_math.py | 18 +- tests/tensor/test_merge.py | 2 +- tests/tensor/test_shape.py | 10 +- tests/tensor/test_subtensor.py | 8 +- tests/tensor/test_type.py | 78 ++++---- tests/tensor/test_type_other.py | 6 +- tests/tensor/test_var.py | 20 +- tests/tensor/utils.py | 6 +- tests/test_gradient.py | 6 +- tests/test_printing.py | 2 +- tests/test_rop.py | 20 +- tests/typed_list/test_basic.py | 142 ++++++------- tests/typed_list/test_rewriting.py | 24 +-- tests/typed_list/test_type.py | 78 +++++--- 104 files changed, 1008 insertions(+), 857 deletions(-) diff --git a/aesara/breakpoint.py b/aesara/breakpoint.py index c0973538cf..0f68485863 100644 --- a/aesara/breakpoint.py +++ b/aesara/breakpoint.py @@ -143,7 +143,7 @@ def perform(self, node, inputs, output_storage): output_storage[i][0] = inputs[i + 1] def grad(self, inputs, output_gradients): - return [DisconnectedType()()] + output_gradients + return [DisconnectedType.subtype()()] + output_gradients def infer_shape(self, fgraph, inputs, input_shapes): # Return the shape of every input but the condition (first input) diff --git a/aesara/gradient.py b/aesara/gradient.py index 51b2bb77ad..bc8fd67117 100644 --- a/aesara/gradient.py +++ b/aesara/gradient.py @@ -90,7 +90,7 @@ def grad_not_implemented(op, x_pos, x, comment=""): """ return ( - NullType( + NullType.subtype( ( "This variable is Null because the grad method for " f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}" @@ -113,7 +113,7 @@ def grad_undefined(op, x_pos, x, comment=""): """ return ( - NullType( + NullType.subtype( ( "This variable is Null because the grad method for " f"input {x_pos} ({x}) of the {op} op is not implemented. {comment}" @@ -158,7 +158,7 @@ def __str__(self): return "DisconnectedType" -disconnected_type = DisconnectedType() +disconnected_type = DisconnectedType.subtype() def Rop( @@ -1803,7 +1803,7 @@ def verify_grad( ) tensor_pt = [ - aesara.tensor.type.TensorType( + aesara.tensor.type.TensorType.subtype( aesara.tensor.as_tensor_variable(p).dtype, aesara.tensor.as_tensor_variable(p).broadcastable, )(name=f"input {i}") diff --git a/aesara/graph/null_type.py b/aesara/graph/null_type.py index 7487253156..eae0c04c14 100644 --- a/aesara/graph/null_type.py +++ b/aesara/graph/null_type.py @@ -42,4 +42,4 @@ def __str__(self): return "NullType" -null_type = NullType() +null_type = NullType.subtype() diff --git a/aesara/graph/type.py b/aesara/graph/type.py index 37860914e3..91ce307ae6 100644 --- a/aesara/graph/type.py +++ b/aesara/graph/type.py @@ -5,13 +5,23 @@ from aesara.graph import utils from aesara.graph.basic import Constant, Variable -from aesara.graph.utils import MetaObject +from aesara.graph.utils import MetaType D = TypeVar("D") -class Type(MetaObject, Generic[D]): +class NewTypeMeta(type): + # pass + def __call__(cls, *args, **kwargs): + raise RuntimeError("Use subtype") + # return super().__call__(*args, **kwargs) + + def subtype(cls, *args, **kwargs): + return super().__call__(*args, **kwargs) + + +class Type(Generic[D], metaclass=NewTypeMeta): """ Interface specification for variable type instances. @@ -35,6 +45,12 @@ class Type(MetaObject, Generic[D]): The `Type` that will be created by a call to `Type.make_constant`. """ + __props__: tuple[str, ...] = () + + @classmethod + def create(cls, **kwargs): + MetaType(f"{cls.__name__}[{kwargs}]", (cls,), kwargs) + def in_same_class(self, otype: "Type") -> Optional[bool]: """Determine if another `Type` represents a subset from the same "class" of types represented by `self`. @@ -214,7 +230,7 @@ def make_constant(self, value: D, name: Optional[Text] = None) -> constant_type: def clone(self, *args, **kwargs) -> "Type": """Clone a copy of this type with the given arguments/keyword values, if any.""" - return type(self)(*args, **kwargs) + return type(self).subtype(*args, **kwargs) def __call__(self, name: Optional[Text] = None) -> variable_type: """Return a new `Variable` instance of Type `self`. @@ -261,6 +277,41 @@ def values_eq_approx(cls, a: D, b: D) -> bool: """ return cls.values_eq(a, b) + def _props(self): + """ + Tuple of properties of all attributes + """ + return tuple(getattr(self, a) for a in self.__props__) + + def _props_dict(self): + """This return a dict of all ``__props__`` key-> value. + + This is useful in optimization to swap op that should have the + same props. This help detect error that the new op have at + least all the original props. + + """ + return {a: getattr(self, a) for a in self.__props__} + + def __hash__(self): + return hash((type(self), tuple(getattr(self, a) for a in self.__props__))) + + def __eq__(self, other): + return type(self) == type(other) and tuple( + getattr(self, a) for a in self.__props__ + ) == tuple(getattr(other, a) for a in self.__props__) + + def __str__(self): + if self.__props__ is None or len(self.__props__) == 0: + return f"{self.__class__.__name__}()" + else: + return "{}{{{}}}".format( + self.__class__.__name__, + ", ".join( + "{}={!r}".format(p, getattr(self, p)) for p in self.__props__ + ), + ) + DataType = str diff --git a/aesara/link/c/params_type.py b/aesara/link/c/params_type.py index c48db53fc5..08d5937254 100644 --- a/aesara/link/c/params_type.py +++ b/aesara/link/c/params_type.py @@ -626,7 +626,7 @@ def extended(self, **kwargs): """ self_to_dict = {self.fields[i]: self.types[i] for i in range(self.length)} self_to_dict.update(kwargs) - return ParamsType(**self_to_dict) + return ParamsType.subtype(**self_to_dict) # Returns a Params object with expected attributes or (in strict mode) checks that data has expected attributes. def filter(self, data, strict=False, allow_downcast=None): diff --git a/aesara/link/c/type.py b/aesara/link/c/type.py index 33632fa1a6..9b29d5355e 100644 --- a/aesara/link/c/type.py +++ b/aesara/link/c/type.py @@ -115,7 +115,7 @@ def __str__(self): return self.__class__.__name__ -generic = Generic() +generic = Generic.subtype() _cdata_type = None @@ -497,7 +497,10 @@ def __repr__(self): def __getattr__(self, key): if key in self: return self[key] - return CType.__getattr__(self, key) + else: + raise AttributeError( + f"{self.__class__.__name__} object has no attribute or enum value {key}" + ) def __setattr__(self, key, value): if key in self: @@ -530,6 +533,9 @@ def __eq__(self, other): and all(self.aliases[a] == other.aliases[a] for a in self.aliases) ) + def __ne__(self, other): + return not self == other + # EnumType should be used to create constants available in both Python and C code. # However, for convenience, we make sure EnumType can have a value, like other common types, # such that it could be used as-is as an op param. diff --git a/aesara/raise_op.py b/aesara/raise_op.py index 766f0df534..2d851aaa65 100644 --- a/aesara/raise_op.py +++ b/aesara/raise_op.py @@ -22,7 +22,7 @@ def __hash__(self): return hash(type(self)) -exception_type = ExceptionType() +exception_type = ExceptionType.subtype() class CheckAndRaise(COp): @@ -38,7 +38,7 @@ class CheckAndRaise(COp): view_map = {0: [0]} check_input = False - params_type = ParamsType(exc_type=exception_type) + params_type = ParamsType.subtype(exc_type=exception_type) def __init__(self, exc_type, msg=""): @@ -100,7 +100,7 @@ def perform(self, node, inputs, outputs, params): raise self.exc_type(self.msg) def grad(self, input, output_gradients): - return output_gradients + [DisconnectedType()()] * (len(input) - 1) + return output_gradients + [DisconnectedType.subtype()()] * (len(input) - 1) def connection_pattern(self, node): return [[1]] + [[0]] * (len(node.inputs) - 1) diff --git a/aesara/sandbox/multinomial.py b/aesara/sandbox/multinomial.py index fc72ca8c6d..d2bba95ae4 100644 --- a/aesara/sandbox/multinomial.py +++ b/aesara/sandbox/multinomial.py @@ -71,7 +71,7 @@ def c_code(self, node, name, ins, outs, sub): if self.odtype == "auto": t = f"PyArray_TYPE({pvals})" else: - t = ScalarType(self.odtype).dtype_specs()[1] + t = ScalarType.subtype(self.odtype).dtype_specs()[1] if t.startswith("aesara_complex"): t = t.replace("aesara_complex", "NPY_COMPLEX") else: @@ -263,7 +263,7 @@ def c_code(self, node, name, ins, outs, sub): if self.odtype == "auto": t = "NPY_INT64" else: - t = ScalarType(self.odtype).dtype_specs()[1] + t = ScalarType.subtype(self.odtype).dtype_specs()[1] if t.startswith("aesara_complex"): t = t.replace("aesara_complex", "NPY_COMPLEX") else: diff --git a/aesara/sandbox/rng_mrg.py b/aesara/sandbox/rng_mrg.py index a1afeb3d5d..6bafe4c0e8 100644 --- a/aesara/sandbox/rng_mrg.py +++ b/aesara/sandbox/rng_mrg.py @@ -325,7 +325,7 @@ def mrg_next_value(rstate, new_rstate, NORM, mask, offset): class mrg_uniform_base(Op): # TODO : need description for class, parameter __props__ = ("output_type", "inplace") - params_type = ParamsType( + params_type = ParamsType.subtype( inplace=bool_t, # following params will come from self.output_type. # NB: As output object may not be allocated in C code, @@ -392,7 +392,7 @@ def new(cls, rstate, ndim, dtype, size): v_size = as_tensor_variable(size) if ndim is None: ndim = get_vector_length(v_size) - op = cls(TensorType(dtype, (False,) * ndim)) + op = cls(TensorType.subtype(dtype, (False,) * ndim)) return op(rstate, v_size) def perform(self, node, inp, out, params): diff --git a/aesara/scalar/basic.py b/aesara/scalar/basic.py index ec0f21e36d..e21679500c 100644 --- a/aesara/scalar/basic.py +++ b/aesara/scalar/basic.py @@ -298,7 +298,7 @@ def __init__(self, dtype): def clone(self, dtype=None, **kwargs): if dtype is None: dtype = self.dtype - return type(self)(dtype) + return type(self).subtype(dtype) @staticmethod def may_share_memory(a, b): @@ -679,7 +679,7 @@ def get_scalar_type(dtype, cache: Dict[str, ScalarType] = {}) -> ScalarType: """ if dtype not in cache: - cache[dtype] = ScalarType(dtype=dtype) + cache[dtype] = ScalarType.subtype(dtype=dtype) return cache[dtype] @@ -2405,13 +2405,13 @@ def grad(self, inputs, gout): (gz,) = gout if y.type in continuous_types: # x is disconnected because the elements of x are not used - return DisconnectedType()(), gz + return DisconnectedType.subtype()(), gz else: # when y is discrete, we assume the function can be extended # to deal with real-valued inputs by rounding them to the # nearest integer. f(x+eps) thus equals f(x) so the gradient # is zero, not disconnected or undefined - return DisconnectedType()(), y.zeros_like() + return DisconnectedType.subtype()(), y.zeros_like() second = Second(transfer_type(1), name="second") diff --git a/aesara/scalar/sharedvar.py b/aesara/scalar/sharedvar.py index 1b6d97ff40..d407699794 100644 --- a/aesara/scalar/sharedvar.py +++ b/aesara/scalar/sharedvar.py @@ -50,7 +50,7 @@ def shared(value, name=None, strict=False, allow_downcast=None): dtype = str(dtype) value = getattr(np, dtype)(value) - scalar_type = ScalarType(dtype=dtype) + scalar_type = ScalarType.subtype(dtype=dtype) rval = ScalarSharedVariable( type=scalar_type, value=value, diff --git a/aesara/scan/op.py b/aesara/scan/op.py index f8a8ad01ef..d13435ca44 100644 --- a/aesara/scan/op.py +++ b/aesara/scan/op.py @@ -793,7 +793,7 @@ def __init__( self.output_types = [] def tensorConstructor(shape, dtype): - return TensorType(dtype=dtype, shape=shape) + return TensorType.subtype(dtype=dtype, shape=shape) if typeConstructor is None: typeConstructor = tensorConstructor @@ -3033,7 +3033,7 @@ def compute_all_gradients(known_grads): if not isinstance(outputs, (list, tuple)): outputs = [outputs] # Re-order the gradients correctly - gradients = [DisconnectedType()()] + gradients = [DisconnectedType.subtype()()] offset = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot + n_sitsot_outs for p, (x, t) in enumerate( @@ -3057,7 +3057,7 @@ def compute_all_gradients(known_grads): else: gradients.append(x[::-1]) elif t == "disconnected": - gradients.append(DisconnectedType()()) + gradients.append(DisconnectedType.subtype()()) elif t == "through_shared": gradients.append( grad_undefined( @@ -3066,7 +3066,7 @@ def compute_all_gradients(known_grads): ) else: # t contains the "why_null" string of a NullType - gradients.append(NullType(t)()) + gradients.append(NullType.subtype(t)()) end = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot for p, (x, t) in enumerate(zip(outputs[:end], type_outs[:end])): @@ -3085,7 +3085,7 @@ def compute_all_gradients(known_grads): else: gradients.append(x[::-1]) elif t == "disconnected": - gradients.append(DisconnectedType()()) + gradients.append(DisconnectedType.subtype()()) elif t == "through_shared": gradients.append( grad_undefined( @@ -3097,7 +3097,7 @@ def compute_all_gradients(known_grads): ) else: # t contains the "why_null" string of a NullType - gradients.append(NullType(t)()) + gradients.append(NullType.subtype(t)()) start = len(gradients) node = outs[0].owner @@ -3108,7 +3108,7 @@ def compute_all_gradients(known_grads): if not isinstance(dC_dout.type, DisconnectedType) and connected: disconnected = False if disconnected: - gradients.append(DisconnectedType()()) + gradients.append(DisconnectedType.subtype()()) else: gradients.append( grad_undefined( @@ -3117,7 +3117,7 @@ def compute_all_gradients(known_grads): ) start = len(gradients) - gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)] + gradients += [DisconnectedType.subtype()() for _ in range(info.n_nit_sot)] begin = end end = begin + n_sitsot_outs @@ -3125,7 +3125,7 @@ def compute_all_gradients(known_grads): if t == "connected": gradients.append(x[-1]) elif t == "disconnected": - gradients.append(DisconnectedType()()) + gradients.append(DisconnectedType.subtype()()) elif t == "through_shared": gradients.append( grad_undefined( @@ -3137,7 +3137,7 @@ def compute_all_gradients(known_grads): ) else: # t contains the "why_null" string of a NullType - gradients.append(NullType(t)()) + gradients.append(NullType.subtype(t)()) # Mask disconnected gradients # Ideally we would want to assert that the gradients we are @@ -3153,7 +3153,7 @@ def compute_all_gradients(known_grads): ): disconnected = False if disconnected: - gradients[idx] = DisconnectedType()() + gradients[idx] = DisconnectedType.subtype()() return gradients def R_op(self, inputs, eval_points): diff --git a/aesara/sparse/basic.py b/aesara/sparse/basic.py index 6f5bb22b0a..8a872efa1e 100644 --- a/aesara/sparse/basic.py +++ b/aesara/sparse/basic.py @@ -201,7 +201,9 @@ def constant(x, name=None): raise TypeError("sparse.constant must be called on a " "scipy.sparse.spmatrix") try: return SparseConstant( - SparseTensorType(format=x.format, dtype=x.dtype), x.copy(), name=name + SparseTensorType.subtype(format=x.format, dtype=x.dtype), + x.copy(), + name=name, ) except TypeError: raise TypeError(f"Could not convert {x} to SparseTensorType", type(x)) @@ -501,11 +503,11 @@ def __repr__(self): SparseTensorType.constant_type = SparseConstant -# for more dtypes, call SparseTensorType(format, dtype) +# for more dtypes, call SparseTensorType.subtype(format, dtype) def matrix(format, name=None, dtype=None): if dtype is None: dtype = config.floatX - type = SparseTensorType(format=format, dtype=dtype) + type = SparseTensorType.subtype(format=format, dtype=dtype) return type(name) @@ -521,12 +523,12 @@ def bsr_matrix(name=None, dtype=None): return matrix("bsr", name, dtype) -csc_dmatrix = SparseTensorType(format="csc", dtype="float64") -csr_dmatrix = SparseTensorType(format="csr", dtype="float64") -bsr_dmatrix = SparseTensorType(format="bsr", dtype="float64") -csc_fmatrix = SparseTensorType(format="csc", dtype="float32") -csr_fmatrix = SparseTensorType(format="csr", dtype="float32") -bsr_fmatrix = SparseTensorType(format="bsr", dtype="float32") +csc_dmatrix = SparseTensorType.subtype(format="csc", dtype="float64") +csr_dmatrix = SparseTensorType.subtype(format="csr", dtype="float64") +bsr_dmatrix = SparseTensorType.subtype(format="bsr", dtype="float64") +csc_fmatrix = SparseTensorType.subtype(format="csc", dtype="float32") +csr_fmatrix = SparseTensorType.subtype(format="csr", dtype="float32") +bsr_fmatrix = SparseTensorType.subtype(format="bsr", dtype="float32") all_dtypes = list(SparseTensorType.dtype_specs_map.keys()) complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"] @@ -592,7 +594,7 @@ def make_node(self, csm): csm = as_sparse_variable(csm) assert csm.format in ("csr", "csc") - data = TensorType(dtype=csm.type.dtype, shape=(False,))() + data = TensorType.subtype(dtype=csm.type.dtype, shape=(False,))() return Apply(self, [csm], [data, ivector(), ivector(), ivector()]) def perform(self, node, inputs, out): @@ -733,7 +735,7 @@ def make_node(self, data, indices, indptr, shape): return Apply( self, [data, indices, indptr, shape], - [SparseTensorType(dtype=data.type.dtype, format=self.format)()], + [SparseTensorType.subtype(dtype=data.type.dtype, format=self.format)()], ) def perform(self, node, inputs, outputs): @@ -776,9 +778,9 @@ def grad(self, inputs, gout): ) return [ g_data, - DisconnectedType()(), - DisconnectedType()(), - DisconnectedType()(), + DisconnectedType.subtype()(), + DisconnectedType.subtype()(), + DisconnectedType.subtype()(), ] def infer_shape(self, fgraph, node, shapes): @@ -888,7 +890,9 @@ def make_node(self, x): x = as_sparse_variable(x) assert x.format in ("csr", "csc") return Apply( - self, [x], [SparseTensorType(dtype=self.out_type, format=x.format)()] + self, + [x], + [SparseTensorType.subtype(dtype=self.out_type, format=x.format)()], ) def perform(self, node, inputs, outputs): @@ -994,7 +998,7 @@ def make_node(self, x): return Apply( self, [x], - [TensorType(dtype=x.type.dtype, shape=(False, False))()], + [TensorType.subtype(dtype=x.type.dtype, shape=(False, False))()], ) def perform(self, node, inputs, outputs): @@ -1076,7 +1080,9 @@ def make_node(self, x): assert x.ndim == 2 return Apply( - self, [x], [SparseTensorType(dtype=x.type.dtype, format=self.format)()] + self, + [x], + [SparseTensorType.subtype(dtype=x.type.dtype, format=self.format)()], ) def perform(self, node, inputs, outputs): @@ -1510,7 +1516,7 @@ def make_node(self, x): self, [x], [ - SparseTensorType( + SparseTensorType.subtype( dtype=x.type.dtype, format=self.format_map[x.type.format] )() ], @@ -1757,7 +1763,7 @@ def make_node(self, x): if self.axis is not None: b = (False,) - z = TensorType(shape=b, dtype=x.dtype)() + z = TensorType.subtype(shape=b, dtype=x.dtype)() return Apply(self, [x], [z]) def perform(self, node, inputs, outputs): @@ -1918,7 +1924,9 @@ def make_node(self, diag): if diag.type.ndim != 1: raise TypeError("data argument must be a vector", diag.type) - return Apply(self, [diag], [SparseTensorType(dtype=diag.dtype, format="csc")()]) + return Apply( + self, [diag], [SparseTensorType.subtype(dtype=diag.dtype, format="csc")()] + ) def perform(self, node, inputs, outputs): (z,) = outputs @@ -2039,7 +2047,9 @@ def make_node(self, x, y): assert y.format in ("csr", "csc") out_dtype = aes.upcast(x.type.dtype, y.type.dtype) return Apply( - self, [x, y], [SparseTensorType(dtype=out_dtype, format=x.type.format)()] + self, + [x, y], + [SparseTensorType.subtype(dtype=out_dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -2097,7 +2107,9 @@ def make_node(self, x, y): if x.type.format != y.type.format: raise NotImplementedError() return Apply( - self, [x, y], [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()] + self, + [x, y], + [SparseTensorType.subtype(dtype=x.type.dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -2138,7 +2150,7 @@ def make_node(self, x, y): return Apply( self, [x, y], - [TensorType(dtype=out_dtype, shape=y.type.broadcastable)()], + [TensorType.subtype(dtype=out_dtype, shape=y.type.broadcastable)()], ) def perform(self, node, inputs, outputs): @@ -2198,7 +2210,9 @@ def make_node(self, x, y): if x.type.dtype != y.type.dtype: raise NotImplementedError() return Apply( - self, [x, y], [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()] + self, + [x, y], + [SparseTensorType.subtype(dtype=x.type.dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -2314,7 +2328,9 @@ def make_node(self, x, y): assert y.format in ("csr", "csc") out_dtype = aes.upcast(x.type.dtype, y.type.dtype) return Apply( - self, [x, y], [SparseTensorType(dtype=out_dtype, format=x.type.format)()] + self, + [x, y], + [SparseTensorType.subtype(dtype=out_dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -2357,7 +2373,7 @@ def make_node(self, x, y): # Broadcasting of the sparse matrix is not supported. # We support nd == 0 used by grad of SpSum() assert y.type.ndim in (0, 2) - out = SparseTensorType(dtype=dtype, format=x.type.format)() + out = SparseTensorType.subtype(dtype=dtype, format=x.type.format)() return Apply(self, [x, y], [out]) def perform(self, node, inputs, outputs): @@ -2463,7 +2479,9 @@ def make_node(self, x, y): f"Got {x.type.dtype} and {y.type.dtype}." ) return Apply( - self, [x, y], [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()] + self, + [x, y], + [SparseTensorType.subtype(dtype=x.type.dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -2579,7 +2597,9 @@ def make_node(self, x, y): if x.type.format != y.type.format: raise NotImplementedError() return Apply( - self, [x, y], [SparseTensorType(dtype="uint8", format=x.type.format)()] + self, + [x, y], + [SparseTensorType.subtype(dtype="uint8", format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -2621,7 +2641,7 @@ def make_node(self, x, y): x, y = as_sparse_variable(x), at.as_tensor_variable(y) assert y.type.ndim == 2 - out = TensorType(dtype="uint8", shape=(False, False))() + out = TensorType.subtype(dtype="uint8", shape=(False, False))() return Apply(self, [x, y], [out]) def perform(self, node, inputs, outputs): @@ -2829,7 +2849,9 @@ def make_node(self, *mat): assert x.format in ("csr", "csc") return Apply( - self, var, [SparseTensorType(dtype=self.dtype, format=self.format)()] + self, + var, + [SparseTensorType.subtype(dtype=self.dtype, format=self.format)()], ) def perform(self, node, block, outputs): @@ -3335,7 +3357,7 @@ def make_node(self, x, y): raise NotImplementedError() inputs = [x, y] # Need to convert? e.g. assparse - outputs = [SparseTensorType(dtype=x.type.dtype, format=myformat)()] + outputs = [SparseTensorType.subtype(dtype=x.type.dtype, format=myformat)()] return Apply(self, inputs, outputs) def perform(self, node, inp, out_): @@ -3457,7 +3479,9 @@ def make_node(self, a, b): raise NotImplementedError("non-matrix b") if _is_sparse_variable(b): - return Apply(self, [a, b], [SparseTensorType(a.type.format, dtype_out)()]) + return Apply( + self, [a, b], [SparseTensorType.subtype(a.type.format, dtype_out)()] + ) else: return Apply( self, @@ -4265,7 +4289,7 @@ def grad(self, inputs, grads): gx = g_output gy = aesara.tensor.subtensor.advanced_subtensor1(g_output, *idx_list) - return [gx, gy] + [DisconnectedType()()] * len(idx_list) + return [gx, gy] + [DisconnectedType.subtype()()] * len(idx_list) construct_sparse_from_list = ConstructSparseFromList() diff --git a/aesara/sparse/rewriting.py b/aesara/sparse/rewriting.py index fde57a30ac..a9f9c2cc1b 100644 --- a/aesara/sparse/rewriting.py +++ b/aesara/sparse/rewriting.py @@ -126,7 +126,7 @@ def make_node(self, x, y): # The magic number two here arises because L{scipy.sparse} # objects must be matrices (have dimension 2) assert y.type.ndim == 2 - out = TensorType(dtype=out_dtype, shape=y.type.broadcastable)() + out = TensorType.subtype(dtype=out_dtype, shape=y.type.broadcastable)() return Apply(self, [data, indices, indptr, y], [out]) def c_code(self, node, name, inputs, outputs, sub): @@ -504,7 +504,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ (a_val, a_ind, a_ptr, b) = inputs (z,) = outputs - typenum_z = TensorType(self.dtype_out, []).dtype_specs()[2] + typenum_z = TensorType.subtype(self.dtype_out, []).dtype_specs()[2] if node.inputs[0].type.dtype in ("complex64", "complex128"): raise NotImplementedError("Complex types are not supported for a_val") if node.inputs[3].type.dtype in ("complex64", "complex128"): @@ -1898,9 +1898,9 @@ def c_code(self, node, name, inputs, outputs, sub): typenum_x = node.inputs[0].type.dtype_specs()[2] typenum_y = node.inputs[1].type.dtype_specs()[2] typenum_p = node.inputs[2].type.dtype_specs()[2] - typenum_zd = TensorType(node.outputs[0].dtype, []).dtype_specs()[2] - typenum_zi = TensorType(node.outputs[1].dtype, []).dtype_specs()[2] - typenum_zp = TensorType(node.outputs[2].dtype, []).dtype_specs()[2] + typenum_zd = TensorType.subtype(node.outputs[0].dtype, []).dtype_specs()[2] + typenum_zi = TensorType.subtype(node.outputs[1].dtype, []).dtype_specs()[2] + typenum_zp = TensorType.subtype(node.outputs[2].dtype, []).dtype_specs()[2] rval = """ if (PyArray_NDIM(%(x)s) != 2) { diff --git a/aesara/sparse/sandbox/sp2.py b/aesara/sparse/sandbox/sp2.py index e86db84ae4..9842f5c80b 100644 --- a/aesara/sparse/sandbox/sp2.py +++ b/aesara/sparse/sandbox/sp2.py @@ -110,7 +110,7 @@ def make_node(self, n, p, shape): return Apply( self, [n, p, shape], - [SparseTensorType(dtype=self.dtype, format=self.format)()], + [SparseTensorType.subtype(dtype=self.dtype, format=self.format)()], ) def perform(self, node, inputs, outputs): diff --git a/aesara/sparse/sharedvar.py b/aesara/sparse/sharedvar.py index 47fc365b86..d35f88e228 100644 --- a/aesara/sparse/sharedvar.py +++ b/aesara/sparse/sharedvar.py @@ -23,7 +23,7 @@ def sparse_constructor( if format is None: format = value.format - type = SparseTensorType(format=format, dtype=value.dtype) + type = SparseTensorType.subtype(format=format, dtype=value.dtype) if not borrow: value = copy.deepcopy(value) return SparseTensorSharedVariable( diff --git a/aesara/sparse/type.py b/aesara/sparse/type.py index 048e69e111..d8b39d0a80 100644 --- a/aesara/sparse/type.py +++ b/aesara/sparse/type.py @@ -95,7 +95,7 @@ def clone( dtype = self.dtype if shape is None: shape = self.shape - return type(self)(format, dtype, shape=shape, **kwargs) + return type(self).subtype(format, dtype, shape=shape, **kwargs) def filter(self, value, strict=False, allow_downcast=None): if isinstance(value, Variable): diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index 83a127b3c5..50c3c5b104 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -220,7 +220,7 @@ def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant: assert x_.ndim == ndim - ttype = TensorType(dtype=x_.dtype, shape=x_.shape) + ttype = TensorType.subtype(dtype=x_.dtype, shape=x_.shape) return TensorConstant(ttype, x_, name=name) @@ -858,7 +858,9 @@ def make_node(self, a): a = as_tensor_variable(a) if a.ndim == 0: raise ValueError("Nonzero only supports non-scalar arrays.") - output = [TensorType(dtype="int64", shape=(False,))() for i in range(a.ndim)] + output = [ + TensorType.subtype(dtype="int64", shape=(False,))() for i in range(a.ndim) + ] return Apply(self, [a], output) def perform(self, node, inp, out_): @@ -989,7 +991,7 @@ def make_node(self, N, M, k): return Apply( self, [N, M, k], - [TensorType(dtype=self.dtype, shape=(False, False))()], + [TensorType.subtype(dtype=self.dtype, shape=(False, False))()], ) def perform(self, node, inp, out_): @@ -1268,7 +1270,7 @@ def make_node(self, n, m, k): return Apply( self, [n, m, k], - [TensorType(dtype=self.dtype, shape=(False, False))()], + [TensorType.subtype(dtype=self.dtype, shape=(False, False))()], ) def perform(self, node, inp, out_): @@ -1402,7 +1404,7 @@ def make_node(self, value, *shape): v.ndim, len(sh), ) - otype = TensorType(dtype=v.dtype, shape=bcast) + otype = TensorType.subtype(dtype=v.dtype, shape=bcast) return Apply(self, [v] + sh, [otype()]) def perform(self, node, inputs, out_): @@ -1510,7 +1512,7 @@ def grad(self, inputs, grads): # the inputs that specify the shape. If you grow the # shape by epsilon, the existing elements do not # change. - return [gx] + [DisconnectedType()() for i in inputs[1:]] + return [gx] + [DisconnectedType.subtype()() for i in inputs[1:]] def R_op(self, inputs, eval_points): if eval_points[0] is None: @@ -1644,7 +1646,7 @@ def make_node(self, *inputs): else: dtype = self.dtype - otype = TensorType(dtype, (len(inputs),)) + otype = TensorType.subtype(dtype, (len(inputs),)) return Apply(self, inputs, [otype()]) def perform(self, node, inputs, out_): @@ -1899,7 +1901,7 @@ def make_node(self, x, axis, splits): raise TypeError("`axis` parameter must be an integer scalar") inputs = [x, axis, splits] - out_type = TensorType(dtype=x.dtype, shape=[None] * x.type.ndim) + out_type = TensorType.subtype(dtype=x.dtype, shape=[None] * x.type.ndim) outputs = [out_type() for i in range(self.len_splits)] return Apply(self, inputs, outputs) @@ -1951,7 +1953,7 @@ def grad(self, inputs, g_outputs): # If all the output gradients are disconnected, then so are the inputs if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs): return [ - DisconnectedType()(), + DisconnectedType.subtype()(), grad_undefined(self, 1, axis), grad_undefined(self, 2, n), ] @@ -2940,14 +2942,14 @@ def L_op(self, inputs, outputs, grads): if self.dtype in discrete_dtypes: return [ start.zeros_like(dtype=config.floatX), - DisconnectedType()(), + DisconnectedType.subtype()(), step.zeros_like(dtype=config.floatX), ] else: num_steps_taken = outputs[0].shape[0] return [ gz.sum(), - DisconnectedType()(), + DisconnectedType.subtype()(), (gz * arange(num_steps_taken, dtype=self.dtype)).sum(), ] @@ -3375,7 +3377,7 @@ def make_node(self, x): return Apply( self, [x], - [x.type.__class__(dtype=x.dtype, shape=[False] * (x.ndim - 1))()], + [x.type.__class__.subtype(dtype=x.dtype, shape=[False] * (x.ndim - 1))()], ) def perform(self, node, inputs, outputs): @@ -3796,7 +3798,7 @@ def make_node(self, a, choices): else: bcast.append(False) - o = TensorType(choice.dtype, bcast) + o = TensorType.subtype(choice.dtype, bcast) return Apply(self, [a, choice], [o()]) def perform(self, node, inputs, outputs): @@ -3811,7 +3813,7 @@ class AllocEmpty(COp): """Implement Alloc on the cpu, but without initializing memory.""" __props__ = ("dtype",) - params_type = ParamsType(typecode=int32) + params_type = ParamsType.subtype(typecode=int32) # specify the type of the data def __init__(self, dtype): @@ -3824,7 +3826,7 @@ def typecode(self): def make_node(self, *_shape): _shape, bcast = infer_broadcastable(_shape) - otype = TensorType(dtype=self.dtype, shape=bcast) + otype = TensorType.subtype(dtype=self.dtype, shape=bcast) output = otype() output.tag.values_eq_approx = values_eq_approx_always_true @@ -3903,7 +3905,7 @@ def connection_pattern(self, node): return [[False] for i in node.inputs] def grad(self, inputs, grads): - return [DisconnectedType()() for i in inputs] + return [DisconnectedType.subtype()() for i in inputs] def R_op(self, inputs, eval_points): return [zeros(inputs, self.dtype)] diff --git a/aesara/tensor/blas.py b/aesara/tensor/blas.py index ee478b6a8a..ffbf236496 100644 --- a/aesara/tensor/blas.py +++ b/aesara/tensor/blas.py @@ -898,7 +898,7 @@ class Gemm(GemmRelated): E_float = "gemm requires floating-point dtypes" __props__ = ("inplace",) - params_type = ParamsType( + params_type = ParamsType.subtype( inplace=bool_t, ) check_input = False diff --git a/aesara/tensor/blas_c.py b/aesara/tensor/blas_c.py index c808528b97..acdd5a74bc 100644 --- a/aesara/tensor/blas_c.py +++ b/aesara/tensor/blas_c.py @@ -326,7 +326,7 @@ def ger_c_code(A, a, x, y, Z, fail, params): class CGer(BaseBLAS, Ger): - params_type = ParamsType( + params_type = ParamsType.subtype( destructive=bool_t, ) @@ -609,7 +609,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non class CGemv(BaseBLAS, Gemv): - params_type = ParamsType( + params_type = ParamsType.subtype( inplace=bool_t, ) diff --git a/aesara/tensor/elemwise.py b/aesara/tensor/elemwise.py index 34f9ea5459..ea89561b53 100644 --- a/aesara/tensor/elemwise.py +++ b/aesara/tensor/elemwise.py @@ -119,7 +119,7 @@ class DimShuffle(ExternalCOp): @property def params_type(self): - return ParamsType( + return ParamsType.subtype( shuffle=lvector, augment=lvector, transposition=lvector, @@ -209,7 +209,7 @@ def make_node(self, _input): else: out_static_shape.append(input.type.shape[dim_idx]) - output = TensorType(dtype=input.type.dtype, shape=out_static_shape)() + output = TensorType.subtype(dtype=input.type.dtype, shape=out_static_shape)() return Apply(self, [input], [output]) @@ -484,7 +484,7 @@ def make_node(self, *inputs): inputs = [as_tensor_variable(i) for i in inputs] out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs) outputs = [ - TensorType(dtype=dtype, shape=shape)() + TensorType.subtype(dtype=dtype, shape=shape)() for dtype, shape in zip(out_dtypes, out_shapes) ] return Apply(self, inputs, outputs) @@ -1331,7 +1331,9 @@ def make_node(self, input): broadcastable = [x for i, x in enumerate(inp_bdcast) if i not in axis] - output = TensorType(dtype=self._output_dtype(inp_dtype), shape=broadcastable)() + output = TensorType.subtype( + dtype=self._output_dtype(inp_dtype), shape=broadcastable + )() return Apply(op, [input], [output]) @@ -1411,7 +1413,9 @@ def _c_all(self, node, name, inames, onames, sub): if acc_dtype is not None: if acc_dtype == "float16": raise MethodNotDefined("no c_code for float16") - acc_type = TensorType(shape=node.outputs[0].broadcastable, dtype=acc_dtype) + acc_type = TensorType.subtype( + shape=node.outputs[0].broadcastable, dtype=acc_dtype + ) adtype = acc_type.dtype_specs()[1] else: adtype = odtype diff --git a/aesara/tensor/extra_ops.py b/aesara/tensor/extra_ops.py index 54c2339888..129e7a0e62 100644 --- a/aesara/tensor/extra_ops.py +++ b/aesara/tensor/extra_ops.py @@ -109,7 +109,7 @@ class SearchsortedOp(COp): """ - params_type = Generic() + params_type = Generic.subtype() __props__ = ("side",) check_input = False @@ -284,8 +284,8 @@ class CumOp(COp): __props__ = ("axis", "mode") check_input = False - params_type = ParamsType( - c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul")) + params_type = ParamsType.subtype( + c_axis=int_t, mode=EnumList.subtype(("MODE_ADD", "add"), ("MODE_MUL", "mul")) ) def __init__(self, axis=None, mode="add"): @@ -679,7 +679,7 @@ def make_node(self, x, repeats): broadcastable = list(x.broadcastable) broadcastable[self.axis] = False - out_type = TensorType(x.dtype, broadcastable) + out_type = TensorType.subtype(x.dtype, broadcastable) return Apply(self, [x, repeats], [out_type()]) @@ -708,7 +708,10 @@ def grad(self, inputs, gout): shape = [x.shape[k] for k in range(x.ndim)] shape.insert(axis, repeats) - return [gz.reshape(shape, x.ndim + 1).sum(axis=axis), DisconnectedType()()] + return [ + gz.reshape(shape, x.ndim + 1).sum(axis=axis), + DisconnectedType.subtype()(), + ] elif repeats.ndim == 1: # For this implementation, we would need to specify the length # of repeats in order to split gz in the right way to sum @@ -1202,8 +1205,8 @@ def make_node(self, x): b if axis != self_axis else False for axis, b in enumerate(x.broadcastable) ] - outputs = [TensorType(shape=broadcastable, dtype=x.dtype)()] - typ = TensorType(shape=[False], dtype="int64") + outputs = [TensorType.subtype(shape=broadcastable, dtype=x.dtype)()] + typ = TensorType.subtype(shape=[False], dtype="int64") if self.return_index: outputs.append(typ()) if self.return_inverse: @@ -1309,7 +1312,7 @@ def make_node(self, indices, dims): self, [indices, dims], [ - TensorType(dtype="int64", shape=(False,) * indices.ndim)() + TensorType.subtype(dtype="int64", shape=(False,) * indices.ndim)() for i in range(at.get_vector_length(dims)) ], ) @@ -1388,7 +1391,7 @@ def make_node(self, *inp): return Apply( self, multi_index + [dims], - [TensorType(dtype="int64", shape=(False,) * multi_index[0].ndim)()], + [TensorType.subtype(dtype="int64", shape=(False,) * multi_index[0].ndim)()], ) def infer_shape(self, fgraph, node, input_shapes): @@ -1615,7 +1618,7 @@ def make_node(self, a, *shape): shape, bcast = at.infer_broadcastable(shape) - out = TensorType(dtype=a.type.dtype, shape=bcast)() + out = TensorType.subtype(dtype=a.type.dtype, shape=bcast)() # Attempt to prevent in-place operations on this view-based output out.tag.indestructible = True diff --git a/aesara/tensor/fft.py b/aesara/tensor/fft.py index 0fcdfbdeec..c7e0b779d9 100644 --- a/aesara/tensor/fft.py +++ b/aesara/tensor/fft.py @@ -15,7 +15,7 @@ class RFFTOp(Op): def output_type(self, inp): # add extra dim for real/imag - return TensorType(inp.dtype, shape=[False] * (inp.type.ndim + 1)) + return TensorType.subtype(inp.dtype, shape=[False] * (inp.type.ndim + 1)) def make_node(self, a, s=None): a = as_tensor_variable(a) @@ -60,7 +60,7 @@ def grad(self, inputs, output_grads): + [slice(None)] ) gout = set_subtensor(gout[idx], gout[idx] * 0.5) - return [irfft_op(gout, s), DisconnectedType()()] + return [irfft_op(gout, s), DisconnectedType.subtype()()] def connection_pattern(self, node): # Specify that shape input parameter has no connection to graph and gradients. @@ -76,7 +76,7 @@ class IRFFTOp(Op): def output_type(self, inp): # remove extra dim for real/imag - return TensorType(inp.dtype, shape=[False] * (inp.type.ndim - 1)) + return TensorType.subtype(inp.dtype, shape=[False] * (inp.type.ndim - 1)) def make_node(self, a, s=None): a = as_tensor_variable(a) @@ -123,7 +123,7 @@ def grad(self, inputs, output_grads): + [slice(None)] ) gf = set_subtensor(gf[idx], gf[idx] * 2) - return [gf, DisconnectedType()()] + return [gf, DisconnectedType.subtype()()] def connection_pattern(self, node): # Specify that shape input parameter has no connection to graph and gradients. diff --git a/aesara/tensor/fourier.py b/aesara/tensor/fourier.py index bc069b31e2..fd50a0636e 100644 --- a/aesara/tensor/fourier.py +++ b/aesara/tensor/fourier.py @@ -99,7 +99,7 @@ def make_node(self, a, n, axis): return Apply( self, [a, n, axis], - [TensorType("complex128", a.type.broadcastable)()], + [TensorType.subtype("complex128", a.type.broadcastable)()], ) def infer_shape(self, fgraph, node, in_shapes): diff --git a/aesara/tensor/io.py b/aesara/tensor/io.py index ab670b4bb2..239db36fcf 100644 --- a/aesara/tensor/io.py +++ b/aesara/tensor/io.py @@ -35,7 +35,7 @@ def __init__(self, dtype, broadcastable, mmap_mode=None): def make_node(self, path): if isinstance(path, str): - path = Constant(Generic(), path) + path = Constant(Generic.subtype(), path) return Apply(self, [path], [tensor(self.dtype, shape=self.broadcastable)]) def perform(self, node, inp, out): @@ -136,7 +136,7 @@ def make_node(self): self, [], [ - Variable(Generic(), None), + Variable(Generic.subtype(), None), tensor(self.dtype, shape=self.broadcastable), ], ) @@ -222,7 +222,7 @@ def __init__(self, dest, tag): self.tag = tag def make_node(self, data): - return Apply(self, [data], [Variable(Generic(), None), data.type()]) + return Apply(self, [data], [Variable(Generic.subtype(), None), data.type()]) view_map = {1: [0]} @@ -259,7 +259,7 @@ def __init__(self, tag): self.tag = tag def make_node(self, request, data): - return Apply(self, [request, data], [Variable(Generic(), None)]) + return Apply(self, [request, data], [Variable(Generic.subtype(), None)]) def perform(self, node, inp, out): request = inp[0] diff --git a/aesara/tensor/math.py b/aesara/tensor/math.py index 2b6724aa4b..f82250965e 100644 --- a/aesara/tensor/math.py +++ b/aesara/tensor/math.py @@ -134,7 +134,7 @@ class MaxAndArgmax(COp): nin = 2 # tensor, axis nout = 2 # max val, max idx E_axis = "invalid axis" - params_type = Generic() + params_type = Generic.subtype() __props__ = ("axis",) _f16_ok = True @@ -307,7 +307,7 @@ def grad(self, inp, grads): # if the op is totally disconnected, so are its inputs if g_max_disconnected and g_max_idx_disconnected: - return [DisconnectedType()(), DisconnectedType()()] + return [DisconnectedType.subtype()(), DisconnectedType.subtype()()] # if the max is disconnected but the argmax is not, # the gradient on its inputs is zero @@ -350,7 +350,7 @@ class Argmax(COp): __props__ = ("axis",) _f16_ok = True - params_type = ParamsType(c_axis=aes.int64) + params_type = ParamsType.subtype(c_axis=aes.int64) def __init__(self, axis): if axis is not None: @@ -2931,7 +2931,7 @@ def make_node(self, a, b): out_shape = self._get_output_shape( a, b, (a.type.shape, b.type.shape), validate=True ) - out = TensorType(dtype=self.dtype, shape=out_shape)() + out = TensorType.subtype(dtype=self.dtype, shape=out_shape)() return Apply(self, [a, b], [out]) def perform(self, node, inputs, outputs): diff --git a/aesara/tensor/nnet/abstract_conv.py b/aesara/tensor/nnet/abstract_conv.py index dbfc0b7b69..b3c3b555ad 100644 --- a/aesara/tensor/nnet/abstract_conv.py +++ b/aesara/tensor/nnet/abstract_conv.py @@ -3045,7 +3045,7 @@ def grad(self, inp, grads): d_bottom = bottom.type.filter_variable(d_bottom) d_top = top.type.filter_variable(d_top) - d_height_width = (aesara.gradient.DisconnectedType()(),) + d_height_width = (aesara.gradient.DisconnectedType.subtype()(),) return (d_bottom, d_top) + d_height_width @@ -3104,7 +3104,7 @@ def grad(self, inp, grads): d_bottom = bottom.type.filter_variable(d_bottom) d_top = top.type.filter_variable(d_top) - d_depth_height_width = (aesara.gradient.DisconnectedType()(),) + d_depth_height_width = (aesara.gradient.DisconnectedType.subtype()(),) return (d_bottom, d_top) + d_depth_height_width @@ -3415,7 +3415,7 @@ def grad(self, inp, grads): d_weights = weights.type.filter_variable(d_weights) d_top = top.type.filter_variable(d_top) - d_height_width = (aesara.gradient.DisconnectedType()(),) + d_height_width = (aesara.gradient.DisconnectedType.subtype()(),) return (d_weights, d_top) + d_height_width @@ -3474,7 +3474,7 @@ def grad(self, inp, grads): d_weights = weights.type.filter_variable(d_weights) d_top = top.type.filter_variable(d_top) - d_depth_height_width = (aesara.gradient.DisconnectedType()(),) + d_depth_height_width = (aesara.gradient.DisconnectedType.subtype()(),) return (d_weights, d_top) + d_depth_height_width diff --git a/aesara/tensor/nnet/basic.py b/aesara/tensor/nnet/basic.py index af1e3d0bd1..096f57eae5 100644 --- a/aesara/tensor/nnet/basic.py +++ b/aesara/tensor/nnet/basic.py @@ -127,7 +127,7 @@ def L_op(self, inp, outputs, grads): (g_sm,) = grads if isinstance(g_sm.type, DisconnectedType): - return [DisconnectedType()(), DisconnectedType()()] + return [DisconnectedType.subtype()(), DisconnectedType.subtype()()] dx = softmax_grad_legacy(g_sm, outputs[0]) db = at_sum(dx, axis=0) @@ -1340,7 +1340,7 @@ def make_node(self, x, b, y_idx): raise ValueError("y_idx must be 1-d tensor of [u]ints", y_idx.type) # TODO: Is this correct? It used to be y, not y_idx - nll = TensorType(x.type.dtype, y_idx.type.broadcastable).make_variable() + nll = TensorType.subtype(x.type.dtype, y_idx.type.broadcastable).make_variable() # nll = TensorType(x.dtype, y.broadcastable) sm = x.type() am = y_idx.type() @@ -1440,7 +1440,7 @@ def grad(self, inp, grads): def fancy_sum(terms): if len(terms) == 0: - return DisconnectedType()() + return DisconnectedType.subtype()() rval = terms[0] for term in terms[1:]: rval = rval + term @@ -1819,7 +1819,7 @@ def make_node(self, coding_dist, true_one_of_n): return Apply( self, [_coding_dist, _true_one_of_n], - [TensorType(dtype=_coding_dist.dtype, shape=[False])()], + [TensorType.subtype(dtype=_coding_dist.dtype, shape=[False])()], ) def perform(self, node, inp, out): diff --git a/aesara/tensor/nnet/batchnorm.py b/aesara/tensor/nnet/batchnorm.py index 2edf3675e4..0693fc7dc6 100644 --- a/aesara/tensor/nnet/batchnorm.py +++ b/aesara/tensor/nnet/batchnorm.py @@ -22,11 +22,11 @@ class BNComposite(Composite): @config.change_flags(compute_test_value="off") def __init__(self, dtype): self.dtype = dtype - x = aesara.scalar.ScalarType(dtype=dtype).make_variable() - mean = aesara.scalar.ScalarType(dtype=dtype).make_variable() - std = aesara.scalar.ScalarType(dtype=dtype).make_variable() - gamma = aesara.scalar.ScalarType(dtype=dtype).make_variable() - beta = aesara.scalar.ScalarType(dtype=dtype).make_variable() + x = aesara.scalar.ScalarType.subtype(dtype=dtype).make_variable() + mean = aesara.scalar.ScalarType.subtype(dtype=dtype).make_variable() + std = aesara.scalar.ScalarType.subtype(dtype=dtype).make_variable() + gamma = aesara.scalar.ScalarType.subtype(dtype=dtype).make_variable() + beta = aesara.scalar.ScalarType.subtype(dtype=dtype).make_variable() o = add(mul(true_div(sub(x, mean), std), gamma), beta) inputs = [x, mean, std, gamma, beta] outputs = [o] @@ -485,12 +485,12 @@ def L_op(self, inputs, outputs, grads): dy = grads[0] _, x_mean, x_invstd = outputs[:3] disconnected_outputs = [ - aesara.gradient.DisconnectedType()(), # epsilon - aesara.gradient.DisconnectedType()(), + aesara.gradient.DisconnectedType.subtype()(), # epsilon + aesara.gradient.DisconnectedType.subtype()(), ] # running_average_factor # Optional running_mean and running_var. for i in range(5, len(inputs)): - disconnected_outputs.append(aesara.gradient.DisconnectedType()()) + disconnected_outputs.append(aesara.gradient.DisconnectedType.subtype()()) return ( AbstractBatchNormTrainGrad(self.axes)( x, dy, scale, x_mean, x_invstd, epsilon @@ -628,7 +628,14 @@ def grad(self, inputs, grads): dvar = -(dy * (x - est_mean)).sum(axes, keepdims=True) * ( scale / (two * est_var_eps * est_std) ) - return [dx, dscale, dbias, dmean, dvar, aesara.gradient.DisconnectedType()()] + return [ + dx, + dscale, + dbias, + dmean, + dvar, + aesara.gradient.DisconnectedType.subtype()(), + ] def connection_pattern(self, node): # Specify that epsilon is not connected to outputs. @@ -735,10 +742,10 @@ def grad(self, inp, grads): g_wrt_scale, g_wrt_x_mean, g_wrt_x_invstd, - aesara.gradient.DisconnectedType()(), + aesara.gradient.DisconnectedType.subtype()(), ] return [ - aesara.gradient.DisconnectedType()() + aesara.gradient.DisconnectedType.subtype()() if (isinstance(r, int) and r == 0) else r for r in results diff --git a/aesara/tensor/nnet/conv3d2d.py b/aesara/tensor/nnet/conv3d2d.py index 044211d6a0..5c740218d3 100644 --- a/aesara/tensor/nnet/conv3d2d.py +++ b/aesara/tensor/nnet/conv3d2d.py @@ -108,7 +108,7 @@ def make_node(self, x, i0, i1): _i1 = at.as_tensor_variable(i1) # TODO: We could produce a more precise static shape output type type_shape = (1 if shape == 1 else None for shape in x.type.shape) - out_type = at.TensorType(x.type.dtype, shape=type_shape) + out_type = at.TensorType.subtype(x.type.dtype, shape=type_shape) return Apply(self, [x, _i0, _i1], [out_type()]) def perform(self, node, inputs, output_storage): @@ -121,7 +121,7 @@ def perform(self, node, inputs, output_storage): def grad(self, inputs, g_outputs): z = at.zeros_like(inputs[0]) gx = inc_diagonal_subtensor(z, inputs[1], inputs[2], g_outputs[0]) - return [gx, DisconnectedType()(), DisconnectedType()()] + return [gx, DisconnectedType.subtype()(), DisconnectedType.subtype()()] def connection_pattern(self, node): rval = [[True], [False], [False]] @@ -167,8 +167,8 @@ def grad(self, inputs, g_outputs): gy = g_outputs[0] return [ gy, - DisconnectedType()(), - DisconnectedType()(), + DisconnectedType.subtype()(), + DisconnectedType.subtype()(), diagonal_subtensor(gy, i0, i1), ] diff --git a/aesara/tensor/nnet/corr.py b/aesara/tensor/nnet/corr.py index c6758a11e1..4a38e53719 100644 --- a/aesara/tensor/nnet/corr.py +++ b/aesara/tensor/nnet/corr.py @@ -60,8 +60,8 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp): _direction: Optional[str] = None - params_type = ParamsType( - direction=EnumList( + params_type = ParamsType.subtype( + direction=EnumList.subtype( ("DIRECTION_FORWARD", "forward"), # 0 ("DIRECTION_BACKPROP_WEIGHTS", "backprop weights"), # 1 ("DIRECTION_BACKPROP_INPUTS", "backprop inputs"), @@ -699,7 +699,7 @@ def make_node(self, img, kern): False, ] dtype = img.type.dtype - return Apply(self, [img, kern], [TensorType(dtype, broadcastable)()]) + return Apply(self, [img, kern], [TensorType.subtype(dtype, broadcastable)()]) def infer_shape(self, fgraph, node, input_shape): imshp = input_shape[0] @@ -787,7 +787,9 @@ def make_node(self, img, topgrad, shape=None): ] dtype = img.type.dtype return Apply( - self, [img, topgrad] + height_width, [TensorType(dtype, broadcastable)()] + self, + [img, topgrad] + height_width, + [TensorType.subtype(dtype, broadcastable)()], ) def infer_shape(self, fgraph, node, input_shape): @@ -857,7 +859,7 @@ def grad(self, inp, grads): self.unshared, )(bottom, weights) d_height_width = ( - (aesara.gradient.DisconnectedType()(),) * 2 if len(inp) == 4 else () + (aesara.gradient.DisconnectedType.subtype()(),) * 2 if len(inp) == 4 else () ) return (d_bottom, d_top) + d_height_width @@ -915,7 +917,9 @@ def make_node(self, kern, topgrad, shape=None): ] dtype = kern.type.dtype return Apply( - self, [kern, topgrad] + height_width, [TensorType(dtype, broadcastable)()] + self, + [kern, topgrad] + height_width, + [TensorType.subtype(dtype, broadcastable)()], ) def infer_shape(self, fgraph, node, input_shape): @@ -989,7 +993,7 @@ def grad(self, inp, grads): self.unshared, )(bottom, weights) d_height_width = ( - (aesara.gradient.DisconnectedType()(),) * 2 if len(inp) == 4 else () + (aesara.gradient.DisconnectedType.subtype()(),) * 2 if len(inp) == 4 else () ) return (d_weights, d_top) + d_height_width diff --git a/aesara/tensor/nnet/corr3d.py b/aesara/tensor/nnet/corr3d.py index dc2585b132..2395832892 100644 --- a/aesara/tensor/nnet/corr3d.py +++ b/aesara/tensor/nnet/corr3d.py @@ -52,8 +52,8 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp): _direction: Optional[str] = None - params_type = ParamsType( - direction=EnumList( + params_type = ParamsType.subtype( + direction=EnumList.subtype( ("DIRECTION_FORWARD", "forward"), # 0 ("DIRECTION_BACKPROP_WEIGHTS", "backprop weights"), # 1 ("DIRECTION_BACKPROP_INPUTS", "backprop inputs"), @@ -639,7 +639,7 @@ def make_node(self, img, kern): False, ] dtype = img.type.dtype - return Apply(self, [img, kern], [TensorType(dtype, broadcastable)()]) + return Apply(self, [img, kern], [TensorType.subtype(dtype, broadcastable)()]) def infer_shape(self, fgraph, node, input_shape): imshp = input_shape[0] @@ -719,7 +719,7 @@ def make_node(self, img, topgrad, shape=None): return Apply( self, [img, topgrad] + height_width_depth, - [TensorType(dtype, broadcastable)()], + [TensorType.subtype(dtype, broadcastable)()], ) def infer_shape(self, fgraph, node, input_shape): @@ -784,7 +784,7 @@ def grad(self, inp, grads): num_groups=self.num_groups, )(bottom, weights) d_height_width_depth = ( - (aesara.gradient.DisconnectedType()(),) * 3 if len(inp) == 5 else () + (aesara.gradient.DisconnectedType.subtype()(),) * 3 if len(inp) == 5 else () ) return (d_bottom, d_top) + d_height_width_depth @@ -842,7 +842,7 @@ def make_node(self, kern, topgrad, shape=None): return Apply( self, [kern, topgrad] + height_width_depth, - [TensorType(dtype, broadcastable)()], + [TensorType.subtype(dtype, broadcastable)()], ) def infer_shape(self, fgraph, node, input_shape): @@ -918,7 +918,7 @@ def grad(self, inp, grads): num_groups=self.num_groups, )(bottom, weights) d_height_width_depth = ( - (aesara.gradient.DisconnectedType()(),) * 3 if len(inp) == 5 else () + (aesara.gradient.DisconnectedType.subtype()(),) * 3 if len(inp) == 5 else () ) return (d_weights, d_top) + d_height_width_depth diff --git a/aesara/tensor/nnet/neighbours.py b/aesara/tensor/nnet/neighbours.py index 0ed3018b65..01d186cf78 100644 --- a/aesara/tensor/nnet/neighbours.py +++ b/aesara/tensor/nnet/neighbours.py @@ -41,7 +41,7 @@ class Images2Neibs(COp): """ __props__ = ("mode",) - BORDER_MODE = EnumList( + BORDER_MODE = EnumList.subtype( ("MODE_VALID", "valid"), ("MODE_HALF", "half"), ("MODE_FULL", "full"), diff --git a/aesara/tensor/nnet/sigm.py b/aesara/tensor/nnet/sigm.py index a351192741..1437f6d88d 100644 --- a/aesara/tensor/nnet/sigm.py +++ b/aesara/tensor/nnet/sigm.py @@ -147,7 +147,7 @@ def hard_sigmoid(x): """ # Use the same dtype as determined by "upgrade_to_float", # and perform computation in that dtype. - out_dtype = aes.upgrade_to_float(aes.ScalarType(dtype=x.dtype))[0].dtype + out_dtype = aes.upgrade_to_float(aes.ScalarType.subtype(dtype=x.dtype))[0].dtype slope = constant(0.2, dtype=out_dtype) shift = constant(0.5, dtype=out_dtype) x = (x * slope) + shift diff --git a/aesara/tensor/random/op.py b/aesara/tensor/random/op.py index 86cb676feb..8a6e36afde 100644 --- a/aesara/tensor/random/op.py +++ b/aesara/tensor/random/op.py @@ -336,7 +336,7 @@ def make_node(self, rng, size, dtype, *dist_params): dtype_idx = constant(dtype, dtype="int64") dtype = all_dtypes[dtype_idx.data] - outtype = TensorType(dtype=dtype, shape=bcast) + outtype = TensorType.subtype(dtype=dtype, shape=bcast) out_var = outtype() inputs = (rng, size, dtype_idx) + dist_params outputs = (rng.type(), out_var) @@ -405,7 +405,7 @@ def perform(self, node, inputs, output_storage): class RandomStateConstructor(AbstractRNGConstructor): - random_type = RandomStateType() + random_type = RandomStateType.subtype() random_constructor = "RandomState" @@ -413,7 +413,7 @@ class RandomStateConstructor(AbstractRNGConstructor): class DefaultGeneratorMakerOp(AbstractRNGConstructor): - random_type = RandomGeneratorType() + random_type = RandomGeneratorType.subtype() random_constructor = "default_rng" diff --git a/aesara/tensor/random/type.py b/aesara/tensor/random/type.py index 5c897473dc..456606fc91 100644 --- a/aesara/tensor/random/type.py +++ b/aesara/tensor/random/type.py @@ -119,7 +119,7 @@ def __hash__(self): 1, ) -random_state_type = RandomStateType() +random_state_type = RandomStateType.subtype() class RandomGeneratorType(RandomType[np.random.Generator]): @@ -215,4 +215,4 @@ def __hash__(self): 1, ) -random_generator_type = RandomGeneratorType() +random_generator_type = RandomGeneratorType.subtype() diff --git a/aesara/tensor/rewriting/basic.py b/aesara/tensor/rewriting/basic.py index 7cb095a346..fa6066b72f 100644 --- a/aesara/tensor/rewriting/basic.py +++ b/aesara/tensor/rewriting/basic.py @@ -1153,7 +1153,7 @@ def constant_folding(fgraph, node): # instances appropriate for a given constant. # TODO: Add handling for sparse types. if isinstance(output.type, DenseTensorType): - output_type = TensorType( + output_type = TensorType.subtype( output.type.dtype, tuple(s == 1 for s in data.shape), name=output.type.name, diff --git a/aesara/tensor/shape.py b/aesara/tensor/shape.py index f6ed3590c5..cf56730c81 100644 --- a/aesara/tensor/shape.py +++ b/aesara/tensor/shape.py @@ -65,7 +65,7 @@ def make_node(self, x): x = at.as_tensor_variable(x) if isinstance(x.type, TensorType): - out_var = TensorType("int64", (x.type.ndim,))() + out_var = TensorType.subtype("int64", (x.type.ndim,))() else: out_var = aesara.tensor.type.lvector() @@ -93,7 +93,7 @@ def grad(self, inp, grads): # the elements of the tensor variable do not participate # in the computation of the shape, so they are not really # part of the graph - return [aesara.gradient.DisconnectedType()()] + return [aesara.gradient.DisconnectedType.subtype()()] def R_op(self, inputs, eval_points): return [None] @@ -212,7 +212,7 @@ def __init__(self, i): # using params. @property def params_type(self): - return ParamsType(i=aesara.scalar.basic.int64) + return ParamsType.subtype(i=aesara.scalar.basic.int64) def __str__(self): return "%s{%i}" % (self.__class__.__name__, self.i) @@ -463,7 +463,7 @@ def grad(self, inp, grads): x, *shape = inp (gz,) = grads return [specify_shape(gz, shape)] + [ - aesara.gradient.DisconnectedType()() for _ in range(len(shape)) + aesara.gradient.DisconnectedType.subtype()() for _ in range(len(shape)) ] def R_op(self, inputs, eval_points): @@ -584,7 +584,7 @@ class Reshape(COp): check_input = False __props__ = ("ndim",) - params_type = ParamsType(ndim=int32) + params_type = ParamsType.subtype(ndim=int32) # name does not participate because it doesn't affect computations def __init__(self, ndim, name=None): @@ -649,7 +649,7 @@ def connection_pattern(self, node): def grad(self, inp, grads): x, shp = inp (g_out,) = grads - return [reshape(g_out, shape(x), ndim=x.ndim), DisconnectedType()()] + return [reshape(g_out, shape(x), ndim=x.ndim), DisconnectedType.subtype()()] def R_op(self, inputs, eval_points): if eval_points[0] is None: diff --git a/aesara/tensor/sharedvar.py b/aesara/tensor/sharedvar.py index 76d9f3148b..954fa61402 100644 --- a/aesara/tensor/sharedvar.py +++ b/aesara/tensor/sharedvar.py @@ -69,7 +69,7 @@ def tensor_constructor( # if shape is None: shape = (False,) * len(value.shape) - type = TensorType(value.dtype, shape=shape) + type = TensorType.subtype(value.dtype, shape=shape) return TensorSharedVariable( type=type, value=np.array(value, copy=(not borrow)), @@ -118,7 +118,7 @@ def scalar_constructor( dtype = str(dtype) value = _asarray(value, dtype=dtype) - tensor_type = TensorType(dtype=str(value.dtype), shape=[]) + tensor_type = TensorType.subtype(dtype=str(value.dtype), shape=[]) try: # Do not pass the dtype to asarray because we want this to fail if diff --git a/aesara/tensor/signal/pool.py b/aesara/tensor/signal/pool.py index 60a6a0ddad..a6ccc752c8 100755 --- a/aesara/tensor/signal/pool.py +++ b/aesara/tensor/signal/pool.py @@ -310,7 +310,7 @@ class Pool(OpenMPOp): """ __props__ = ("ignore_border", "mode", "ndim") - params_type = ParamsType( + params_type = ParamsType.subtype( ignore_border=bool_t, ) @@ -535,7 +535,7 @@ def make_node(self, x, ws, stride=None, pad=None): raise TypeError("Padding parameters must be ints.") # If the input shape are broadcastable we can have 0 in the output shape broad = x.broadcastable[:-nd] + (False,) * nd - out = TensorType(x.dtype, broad) + out = TensorType.subtype(x.dtype, broad) return Apply(self, [x, ws, stride, pad], [out()]) def perform(self, node, inp, out, params): @@ -602,7 +602,7 @@ def infer_shape(self, fgraph, node, in_shapes): def L_op(self, inputs, outputs, grads): x, ws, stride, pad = inputs (gz,) = grads - disc = [DisconnectedType()() for i in inputs[1:]] + disc = [DisconnectedType.subtype()() for i in inputs[1:]] if self.mode == "max": return [ MaxPoolGrad(ndim=self.ndim, ignore_border=self.ignore_border)( @@ -1248,7 +1248,7 @@ def grad(self, inp, grads): DownsampleFactorMaxGradGrad( ndim=self.ndim, ignore_border=self.ignore_border )(x, maxout, ggx, ws, stride, pad), - ] + [DisconnectedType()() for i in inp[3:]] + ] + [DisconnectedType.subtype()() for i in inp[3:]] def connection_pattern(self, node): return [[1], [1], [1], [0], [0], [0]] @@ -1585,7 +1585,7 @@ def grad(self, inp, grads): Pool(ignore_border=self.ignore_border, ndim=self.ndim, mode=self.mode)( ggx, ws, stride, pad ), - ] + [DisconnectedType()() for i in inp[2:]] + ] + [DisconnectedType.subtype()() for i in inp[2:]] def connection_pattern(self, node): return [[1], [1], [0], [0], [0]] @@ -1933,9 +1933,9 @@ def grad(self, inp, grads): MaxPoolGrad(ignore_border=self.ignore_border, ndim=self.ndim)( x, maxout, gz, ws, stride, pad ), - DisconnectedType()(), - DisconnectedType()(), - DisconnectedType()(), + DisconnectedType.subtype()(), + DisconnectedType.subtype()(), + DisconnectedType.subtype()(), ] def connection_pattern(self, node): @@ -2159,7 +2159,7 @@ class MaxPoolRop(OpenMPOp): """ __props__ = ("ignore_border", "mode", "ndim") - params_type = ParamsType( + params_type = ParamsType.subtype( ignore_border=bool_t, ) @@ -2201,7 +2201,7 @@ def make_node(self, x, eval_point, ws, stride=None, pad=None): raise TypeError("Padding parameters must be ints.") # If the input shape are broadcastable we can have 0 in the output shape broad = x.broadcastable[:-nd] + (False,) * nd - out = TensorType(eval_point.dtype, broad) + out = TensorType.subtype(eval_point.dtype, broad) return Apply(self, [x, eval_point, ws, stride, pad], [out()]) def perform(self, node, inp, out, params): diff --git a/aesara/tensor/sort.py b/aesara/tensor/sort.py index 20bd23fcfc..5928b249b9 100644 --- a/aesara/tensor/sort.py +++ b/aesara/tensor/sort.py @@ -178,7 +178,7 @@ def make_node(self, input, axis=-1): return Apply( self, [input, axis], - [TensorType(dtype="int64", shape=input.type.shape)()], + [TensorType.subtype(dtype="int64", shape=input.type.shape)()], ) def perform(self, node, inputs, output_storage): @@ -416,7 +416,9 @@ def make_node(self, inp, kth): if self.return_values: outs.append(inp.type()) if self.return_indices: - outs.append(TensorType(dtype=self.idx_dtype, shape=inp.type.shape)()) + outs.append( + TensorType.subtype(dtype=self.idx_dtype, shape=inp.type.shape)() + ) return Apply(self, [inp, kth], outs) def perform(self, node, inputs, output_storage): diff --git a/aesara/tensor/subtensor.py b/aesara/tensor/subtensor.py index f9abfaf784..8717f75619 100644 --- a/aesara/tensor/subtensor.py +++ b/aesara/tensor/subtensor.py @@ -809,7 +809,7 @@ def grad(self, inputs, grads): # set subtensor here at: # aesara/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor() first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *rest) - return [first] + [DisconnectedType()()] * len(rest) + return [first] + [DisconnectedType.subtype()()] * len(rest) def connection_pattern(self, node): @@ -1868,7 +1868,7 @@ def grad(self, inputs, grads): gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list) gy = _sum_grad_over_bcasted_dims(y, gy) - return [gx, gy] + [DisconnectedType()()] * len(idx_list) + return [gx, gy] + [DisconnectedType.subtype()()] * len(idx_list) class IncSubtensorPrinter(SubtensorPrinter): @@ -1949,7 +1949,9 @@ def make_node(self, x, ilist): if x_.type.ndim == 0: raise TypeError("cannot index into a scalar") bcast = (ilist_.broadcastable[0],) + x_.broadcastable[1:] - return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype, shape=bcast)()]) + return Apply( + self, [x_, ilist_], [TensorType.subtype(dtype=x.dtype, shape=bcast)()] + ) def perform(self, node, inp, out_): x, i = inp @@ -2009,7 +2011,7 @@ def grad(self, inputs, grads): else: gx = x.zeros_like() rval1 = [advanced_inc_subtensor1(gx, gz, ilist)] - return rval1 + [DisconnectedType()()] * (len(inputs) - 1) + return rval1 + [DisconnectedType.subtype()()] * (len(inputs) - 1) def R_op(self, inputs, eval_points): if eval_points[0] is None: @@ -2134,7 +2136,7 @@ class AdvancedIncSubtensor1(COp): __props__ = ("inplace", "set_instead_of_inc") check_input = False - params_type = ParamsType(inplace=aes.bool, set_instead_of_inc=aes.bool) + params_type = ParamsType.subtype(inplace=aes.bool, set_instead_of_inc=aes.bool) def __init__(self, inplace=False, set_instead_of_inc=False): self.inplace = bool(inplace) @@ -2493,7 +2495,7 @@ def grad(self, inputs, grads): gy = advanced_subtensor1(g_output, idx_list) gy = _sum_grad_over_bcasted_dims(y, gy) - return [gx, gy] + [DisconnectedType()()] + return [gx, gy] + [DisconnectedType.subtype()()] advanced_inc_subtensor1 = AdvancedIncSubtensor1() @@ -2641,9 +2643,9 @@ def grad(self, inputs, grads): else: gx = x.zeros_like() rest = inputs[1:] - return [advanced_inc_subtensor(gx, gz, *rest)] + [DisconnectedType()()] * len( - rest - ) + return [advanced_inc_subtensor(gx, gz, *rest)] + [ + DisconnectedType.subtype()() + ] * len(rest) advanced_subtensor = AdvancedSubtensor() @@ -2743,7 +2745,7 @@ def grad(self, inpt, output_gradients): # Make sure to sum gy over the dimensions of y that have been # added or broadcasted gy = _sum_grad_over_bcasted_dims(y, gy) - return [gx, gy] + [DisconnectedType()() for _ in idxs] + return [gx, gy] + [DisconnectedType.subtype()() for _ in idxs] advanced_inc_subtensor = AdvancedIncSubtensor() diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 2affadb154..77b17abb1d 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -8,8 +8,7 @@ from aesara import scalar as aes from aesara.configdefaults import config from aesara.graph.basic import Variable -from aesara.graph.type import DataType, ShapeType -from aesara.graph.utils import MetaType +from aesara.graph.type import DataType, NewTypeMeta, ShapeType from aesara.link.c.type import CType from aesara.misc.safe_asarray import _asarray from aesara.utils import apply_across_args @@ -128,7 +127,7 @@ def clone( dtype = self.dtype if shape is None: shape = self.shape - return type(self)(dtype, shape, name=self.name) + return type(self).subtype(dtype, shape, name=self.name) def filter(self, data, strict=False, allow_downcast=None): """Convert `data` to something which can be associated to a `TensorVariable`. @@ -619,7 +618,7 @@ def c_code_cache_version(self): return () -class DenseTypeMeta(MetaType): +class DenseTypeMeta(NewTypeMeta): def __instancecheck__(self, o): if type(o) == TensorType or isinstance(o, DenseTypeMeta): return True @@ -771,21 +770,21 @@ def values_eq_approx_always_true(a, b): def tensor(*args, **kwargs): name = kwargs.pop("name", None) - return TensorType(*args, **kwargs)(name=name) + return TensorType.subtype(*args, **kwargs)(name=name) -cscalar = TensorType("complex64", ()) -zscalar = TensorType("complex128", ()) -fscalar = TensorType("float32", ()) -dscalar = TensorType("float64", ()) -bscalar = TensorType("int8", ()) -wscalar = TensorType("int16", ()) -iscalar = TensorType("int32", ()) -lscalar = TensorType("int64", ()) -ubscalar = TensorType("uint8", ()) -uwscalar = TensorType("uint16", ()) -uiscalar = TensorType("uint32", ()) -ulscalar = TensorType("uint64", ()) +cscalar = TensorType.subtype("complex64", ()) +zscalar = TensorType.subtype("complex128", ()) +fscalar = TensorType.subtype("float32", ()) +dscalar = TensorType.subtype("float64", ()) +bscalar = TensorType.subtype("int8", ()) +wscalar = TensorType.subtype("int16", ()) +iscalar = TensorType.subtype("int32", ()) +lscalar = TensorType.subtype("int64", ()) +ubscalar = TensorType.subtype("uint8", ()) +uwscalar = TensorType.subtype("uint16", ()) +uiscalar = TensorType.subtype("uint32", ()) +ulscalar = TensorType.subtype("uint64", ()) def scalar(name=None, dtype=None): @@ -801,7 +800,7 @@ def scalar(name=None, dtype=None): """ if dtype is None: dtype = config.floatX - type = TensorType(dtype, ()) + type = TensorType.subtype(dtype, ()) return type(name) @@ -816,14 +815,14 @@ def scalar(name=None, dtype=None): float_scalar_types = float_types complex_scalar_types = complex_types -cvector = TensorType("complex64", (False,)) -zvector = TensorType("complex128", (False,)) -fvector = TensorType("float32", (False,)) -dvector = TensorType("float64", (False,)) -bvector = TensorType("int8", (False,)) -wvector = TensorType("int16", (False,)) -ivector = TensorType("int32", (False,)) -lvector = TensorType("int64", (False,)) +cvector = TensorType.subtype("complex64", (False,)) +zvector = TensorType.subtype("complex128", (False,)) +fvector = TensorType.subtype("float32", (False,)) +dvector = TensorType.subtype("float64", (False,)) +bvector = TensorType.subtype("int8", (False,)) +wvector = TensorType.subtype("int16", (False,)) +ivector = TensorType.subtype("int32", (False,)) +lvector = TensorType.subtype("int64", (False,)) def vector(name=None, dtype=None): @@ -839,7 +838,7 @@ def vector(name=None, dtype=None): """ if dtype is None: dtype = config.floatX - type = TensorType(dtype, (False,)) + type = TensorType.subtype(dtype, (False,)) return type(name) @@ -851,14 +850,14 @@ def vector(name=None, dtype=None): float_vector_types = fvector, dvector complex_vector_types = cvector, zvector -cmatrix = TensorType("complex64", (False, False)) -zmatrix = TensorType("complex128", (False, False)) -fmatrix = TensorType("float32", (False, False)) -dmatrix = TensorType("float64", (False, False)) -bmatrix = TensorType("int8", (False, False)) -wmatrix = TensorType("int16", (False, False)) -imatrix = TensorType("int32", (False, False)) -lmatrix = TensorType("int64", (False, False)) +cmatrix = TensorType.subtype("complex64", (False, False)) +zmatrix = TensorType.subtype("complex128", (False, False)) +fmatrix = TensorType.subtype("float32", (False, False)) +dmatrix = TensorType.subtype("float64", (False, False)) +bmatrix = TensorType.subtype("int8", (False, False)) +wmatrix = TensorType.subtype("int16", (False, False)) +imatrix = TensorType.subtype("int32", (False, False)) +lmatrix = TensorType.subtype("int64", (False, False)) def matrix(name=None, dtype=None): @@ -874,7 +873,7 @@ def matrix(name=None, dtype=None): """ if dtype is None: dtype = config.floatX - type = TensorType(dtype, (False, False)) + type = TensorType.subtype(dtype, (False, False)) return type(name) @@ -886,14 +885,14 @@ def matrix(name=None, dtype=None): float_matrix_types = fmatrix, dmatrix complex_matrix_types = cmatrix, zmatrix -crow = TensorType("complex64", (True, False)) -zrow = TensorType("complex128", (True, False)) -frow = TensorType("float32", (True, False)) -drow = TensorType("float64", (True, False)) -brow = TensorType("int8", (True, False)) -wrow = TensorType("int16", (True, False)) -irow = TensorType("int32", (True, False)) -lrow = TensorType("int64", (True, False)) +crow = TensorType.subtype("complex64", (True, False)) +zrow = TensorType.subtype("complex128", (True, False)) +frow = TensorType.subtype("float32", (True, False)) +drow = TensorType.subtype("float64", (True, False)) +brow = TensorType.subtype("int8", (True, False)) +wrow = TensorType.subtype("int16", (True, False)) +irow = TensorType.subtype("int32", (True, False)) +lrow = TensorType.subtype("int64", (True, False)) def row(name=None, dtype=None): @@ -909,20 +908,20 @@ def row(name=None, dtype=None): """ if dtype is None: dtype = config.floatX - type = TensorType(dtype, (True, False)) + type = TensorType.subtype(dtype, (True, False)) return type(name) rows, frows, drows, irows, lrows = apply_across_args(row, frow, drow, irow, lrow) -ccol = TensorType("complex64", (False, True)) -zcol = TensorType("complex128", (False, True)) -fcol = TensorType("float32", (False, True)) -dcol = TensorType("float64", (False, True)) -bcol = TensorType("int8", (False, True)) -wcol = TensorType("int16", (False, True)) -icol = TensorType("int32", (False, True)) -lcol = TensorType("int64", (False, True)) +ccol = TensorType.subtype("complex64", (False, True)) +zcol = TensorType.subtype("complex128", (False, True)) +fcol = TensorType.subtype("float32", (False, True)) +dcol = TensorType.subtype("float64", (False, True)) +bcol = TensorType.subtype("int8", (False, True)) +wcol = TensorType.subtype("int16", (False, True)) +icol = TensorType.subtype("int32", (False, True)) +lcol = TensorType.subtype("int64", (False, True)) def col(name=None, dtype=None): @@ -938,20 +937,20 @@ def col(name=None, dtype=None): """ if dtype is None: dtype = config.floatX - type = TensorType(dtype, (False, True)) + type = TensorType.subtype(dtype, (False, True)) return type(name) cols, fcols, dcols, icols, lcols = apply_across_args(col, fcol, dcol, icol, lcol) -ctensor3 = TensorType("complex64", ((False,) * 3)) -ztensor3 = TensorType("complex128", ((False,) * 3)) -ftensor3 = TensorType("float32", ((False,) * 3)) -dtensor3 = TensorType("float64", ((False,) * 3)) -btensor3 = TensorType("int8", ((False,) * 3)) -wtensor3 = TensorType("int16", ((False,) * 3)) -itensor3 = TensorType("int32", ((False,) * 3)) -ltensor3 = TensorType("int64", ((False,) * 3)) +ctensor3 = TensorType.subtype("complex64", ((False,) * 3)) +ztensor3 = TensorType.subtype("complex128", ((False,) * 3)) +ftensor3 = TensorType.subtype("float32", ((False,) * 3)) +dtensor3 = TensorType.subtype("float64", ((False,) * 3)) +btensor3 = TensorType.subtype("int8", ((False,) * 3)) +wtensor3 = TensorType.subtype("int16", ((False,) * 3)) +itensor3 = TensorType.subtype("int32", ((False,) * 3)) +ltensor3 = TensorType.subtype("int64", ((False,) * 3)) def tensor3(name=None, dtype=None): @@ -967,7 +966,7 @@ def tensor3(name=None, dtype=None): """ if dtype is None: dtype = config.floatX - type = TensorType(dtype, (False, False, False)) + type = TensorType.subtype(dtype, (False, False, False)) return type(name) @@ -975,14 +974,14 @@ def tensor3(name=None, dtype=None): tensor3, ftensor3, dtensor3, itensor3, ltensor3 ) -ctensor4 = TensorType("complex64", ((False,) * 4)) -ztensor4 = TensorType("complex128", ((False,) * 4)) -ftensor4 = TensorType("float32", ((False,) * 4)) -dtensor4 = TensorType("float64", ((False,) * 4)) -btensor4 = TensorType("int8", ((False,) * 4)) -wtensor4 = TensorType("int16", ((False,) * 4)) -itensor4 = TensorType("int32", ((False,) * 4)) -ltensor4 = TensorType("int64", ((False,) * 4)) +ctensor4 = TensorType.subtype("complex64", ((False,) * 4)) +ztensor4 = TensorType.subtype("complex128", ((False,) * 4)) +ftensor4 = TensorType.subtype("float32", ((False,) * 4)) +dtensor4 = TensorType.subtype("float64", ((False,) * 4)) +btensor4 = TensorType.subtype("int8", ((False,) * 4)) +wtensor4 = TensorType.subtype("int16", ((False,) * 4)) +itensor4 = TensorType.subtype("int32", ((False,) * 4)) +ltensor4 = TensorType.subtype("int64", ((False,) * 4)) def tensor4(name=None, dtype=None): @@ -998,7 +997,7 @@ def tensor4(name=None, dtype=None): """ if dtype is None: dtype = config.floatX - type = TensorType(dtype, (False, False, False, False)) + type = TensorType.subtype(dtype, (False, False, False, False)) return type(name) @@ -1006,14 +1005,14 @@ def tensor4(name=None, dtype=None): tensor4, ftensor4, dtensor4, itensor4, ltensor4 ) -ctensor5 = TensorType("complex64", ((False,) * 5)) -ztensor5 = TensorType("complex128", ((False,) * 5)) -ftensor5 = TensorType("float32", ((False,) * 5)) -dtensor5 = TensorType("float64", ((False,) * 5)) -btensor5 = TensorType("int8", ((False,) * 5)) -wtensor5 = TensorType("int16", ((False,) * 5)) -itensor5 = TensorType("int32", ((False,) * 5)) -ltensor5 = TensorType("int64", ((False,) * 5)) +ctensor5 = TensorType.subtype("complex64", ((False,) * 5)) +ztensor5 = TensorType.subtype("complex128", ((False,) * 5)) +ftensor5 = TensorType.subtype("float32", ((False,) * 5)) +dtensor5 = TensorType.subtype("float64", ((False,) * 5)) +btensor5 = TensorType.subtype("int8", ((False,) * 5)) +wtensor5 = TensorType.subtype("int16", ((False,) * 5)) +itensor5 = TensorType.subtype("int32", ((False,) * 5)) +ltensor5 = TensorType.subtype("int64", ((False,) * 5)) def tensor5(name=None, dtype=None): @@ -1029,7 +1028,7 @@ def tensor5(name=None, dtype=None): """ if dtype is None: dtype = config.floatX - type = TensorType(dtype, (False, False, False, False, False)) + type = TensorType.subtype(dtype, (False, False, False, False, False)) return type(name) @@ -1037,14 +1036,14 @@ def tensor5(name=None, dtype=None): tensor5, ftensor5, dtensor5, itensor5, ltensor5 ) -ctensor6 = TensorType("complex64", ((False,) * 6)) -ztensor6 = TensorType("complex128", ((False,) * 6)) -ftensor6 = TensorType("float32", ((False,) * 6)) -dtensor6 = TensorType("float64", ((False,) * 6)) -btensor6 = TensorType("int8", ((False,) * 6)) -wtensor6 = TensorType("int16", ((False,) * 6)) -itensor6 = TensorType("int32", ((False,) * 6)) -ltensor6 = TensorType("int64", ((False,) * 6)) +ctensor6 = TensorType.subtype("complex64", ((False,) * 6)) +ztensor6 = TensorType.subtype("complex128", ((False,) * 6)) +ftensor6 = TensorType.subtype("float32", ((False,) * 6)) +dtensor6 = TensorType.subtype("float64", ((False,) * 6)) +btensor6 = TensorType.subtype("int8", ((False,) * 6)) +wtensor6 = TensorType.subtype("int16", ((False,) * 6)) +itensor6 = TensorType.subtype("int32", ((False,) * 6)) +ltensor6 = TensorType.subtype("int64", ((False,) * 6)) def tensor6(name=None, dtype=None): @@ -1060,7 +1059,7 @@ def tensor6(name=None, dtype=None): """ if dtype is None: dtype = config.floatX - type = TensorType(dtype, (False,) * 6) + type = TensorType.subtype(dtype, (False,) * 6) return type(name) @@ -1068,14 +1067,14 @@ def tensor6(name=None, dtype=None): tensor6, ftensor6, dtensor6, itensor6, ltensor6 ) -ctensor7 = TensorType("complex64", ((False,) * 7)) -ztensor7 = TensorType("complex128", ((False,) * 7)) -ftensor7 = TensorType("float32", ((False,) * 7)) -dtensor7 = TensorType("float64", ((False,) * 7)) -btensor7 = TensorType("int8", ((False,) * 7)) -wtensor7 = TensorType("int16", ((False,) * 7)) -itensor7 = TensorType("int32", ((False,) * 7)) -ltensor7 = TensorType("int64", ((False,) * 7)) +ctensor7 = TensorType.subtype("complex64", ((False,) * 7)) +ztensor7 = TensorType.subtype("complex128", ((False,) * 7)) +ftensor7 = TensorType.subtype("float32", ((False,) * 7)) +dtensor7 = TensorType.subtype("float64", ((False,) * 7)) +btensor7 = TensorType.subtype("int8", ((False,) * 7)) +wtensor7 = TensorType.subtype("int16", ((False,) * 7)) +itensor7 = TensorType.subtype("int32", ((False,) * 7)) +ltensor7 = TensorType.subtype("int64", ((False,) * 7)) def tensor7(name=None, dtype=None): @@ -1091,7 +1090,7 @@ def tensor7(name=None, dtype=None): """ if dtype is None: dtype = config.floatX - type = TensorType(dtype, (False,) * 7) + type = TensorType.subtype(dtype, (False,) * 7) return type(name) diff --git a/aesara/tensor/type_other.py b/aesara/tensor/type_other.py index e0c438c5e5..b0b7a91dc2 100644 --- a/aesara/tensor/type_other.py +++ b/aesara/tensor/type_other.py @@ -45,7 +45,7 @@ def perform(self, node, inp, out_): out[0] = slice(*inp) def grad(self, inputs, grads): - return [DisconnectedType()() for i in inputs] + return [DisconnectedType.subtype()() for i in inputs] make_slice = MakeSlice() @@ -76,7 +76,7 @@ def may_share_memory(a, b): return isinstance(a, slice) and a is b -slicetype = SliceType() +slicetype = SliceType.subtype() class SliceConstant(Constant): @@ -140,7 +140,7 @@ def may_share_memory(a, b): return False -none_type_t = NoneTypeT() +none_type_t = NoneTypeT.subtype() NoneConst = Constant(none_type_t, None, name="NoneConst") diff --git a/aesara/tensor/var.py b/aesara/tensor/var.py index 8b281e6bd0..0fcb8a52ed 100644 --- a/aesara/tensor/var.py +++ b/aesara/tensor/var.py @@ -10,7 +10,7 @@ from aesara import tensor as at from aesara.configdefaults import config from aesara.graph.basic import Constant, OptionalApplyType, Variable -from aesara.graph.utils import MetaType +from aesara.graph.type import NewTypeMeta from aesara.scalar import ComplexError, IntegerDivisionError from aesara.tensor import _get_vector_length, as_tensor_variable from aesara.tensor.exceptions import AdvancedIndexingError @@ -1068,7 +1068,7 @@ def __deepcopy__(self, memo): TensorType.constant_type = TensorConstant -class DenseVariableMeta(MetaType): +class DenseVariableMeta(NewTypeMeta): def __instancecheck__(self, o): if type(o) == TensorVariable or isinstance(o, DenseVariableMeta): return True @@ -1083,7 +1083,7 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta): """ -class DenseConstantMeta(MetaType): +class DenseConstantMeta(NewTypeMeta): def __instancecheck__(self, o): if type(o) == TensorConstant or isinstance(o, DenseConstantMeta): return True diff --git a/aesara/typed_list/basic.py b/aesara/typed_list/basic.py index 9208ee7609..470d3ff242 100644 --- a/aesara/typed_list/basic.py +++ b/aesara/typed_list/basic.py @@ -75,7 +75,7 @@ def make_node(self, x, index): assert isinstance(x.type, TypedListType) if not isinstance(index, Variable): if isinstance(index, slice): - index = Constant(SliceType(), index) + index = Constant(SliceType.subtype(), index) return Apply(self, [x, index], [x.type()]) else: index = at.constant(index, ndim=0, dtype="int64") @@ -654,7 +654,7 @@ def make_node(self, a): a2.append(elem) if not all(a2[0].type.is_super(elem.type) for elem in a2): raise TypeError("MakeList need all input variable to be of the same type.") - tl = TypedListType(a2[0].type)() + tl = TypedListType.subtype(a2[0].type)() return Apply(self, a2, [tl]) diff --git a/aesara/typed_list/type.py b/aesara/typed_list/type.py index 83e8a40f24..4936e6958f 100644 --- a/aesara/typed_list/type.py +++ b/aesara/typed_list/type.py @@ -24,7 +24,7 @@ def __init__(self, ttype, depth=0): if depth == 0: self.ttype = ttype else: - self.ttype = TypedListType(ttype, depth - 1) + self.ttype = TypedListType.subtype(ttype, depth - 1) def filter(self, x, strict=False, allow_downcast=None): """ diff --git a/doc/aesara_installer_for_anaconda.bat b/doc/aesara_installer_for_anaconda.bat index d8e2a22cc9..d788a0802d 100644 --- a/doc/aesara_installer_for_anaconda.bat +++ b/doc/aesara_installer_for_anaconda.bat @@ -1,43 +1,43 @@ -@echo off - -rem if ANACONDA_DIR is not defined -if [%ANACONDA_DIR%] == [^%ANACONDA_DIR^%] ( - if exist "c:\Anaconda" set ANACONDA_DIR=C:\Anaconda - ) - -if [%ANACONDA_DIR%] == [^%ANACONDA_DIR^%] ( - echo "Anaconda not found. Please install AnacondaCE or set the ANACONDA_DIR environment variable to the location of your Anaconda installation." - goto end - ) - -if not exist %ANACONDA_DIR% ( - echo Anaconda install directory %ANACONDA_DIR% does not exist - goto end) - -echo Anaconda found in %ANACONDA_DIR% -echo copying dlls from %ANACONDA_DIR%\MinGW\x86_64-w64-mingw32\lib to %ANACONDA_DIR%\ -copy %ANACONDA_DIR%\MinGW\x86_64-w64-mingw32\lib\*.dll %ANACONDA_DIR% -echo done - -echo Trying to install aesara -pip install Aesara -echo installed - -rem Put a default .aesararc.txt -set AESARARC=%USERPROFILE%\.aesararc.txt -set AESARARC_=%USERPROFILE%\.aesararc_install.txt -echo [global]> %AESARARC_% -echo openmp=False>> %AESARARC_% -echo.>> %AESARARC_% -echo [blas]>> %AESARARC_% -echo ldflags=>> %AESARARC_% - -if exist %AESARARC% ( - echo A .aesararc.txt config file already exists, so we will not change it. - echo The default version is in %AESARARC_%, we suggest you check it out. -) else ( - rename %AESARARC_% .aesararc.txt -) - -:end -echo end +@echo off + +rem if ANACONDA_DIR is not defined +if [%ANACONDA_DIR%] == [^%ANACONDA_DIR^%] ( + if exist "c:\Anaconda" set ANACONDA_DIR=C:\Anaconda + ) + +if [%ANACONDA_DIR%] == [^%ANACONDA_DIR^%] ( + echo "Anaconda not found. Please install AnacondaCE or set the ANACONDA_DIR environment variable to the location of your Anaconda installation." + goto end + ) + +if not exist %ANACONDA_DIR% ( + echo Anaconda install directory %ANACONDA_DIR% does not exist + goto end) + +echo Anaconda found in %ANACONDA_DIR% +echo copying dlls from %ANACONDA_DIR%\MinGW\x86_64-w64-mingw32\lib to %ANACONDA_DIR%\ +copy %ANACONDA_DIR%\MinGW\x86_64-w64-mingw32\lib\*.dll %ANACONDA_DIR% +echo done + +echo Trying to install aesara +pip install Aesara +echo installed + +rem Put a default .aesararc.txt +set AESARARC=%USERPROFILE%\.aesararc.txt +set AESARARC_=%USERPROFILE%\.aesararc_install.txt +echo [global]> %AESARARC_% +echo openmp=False>> %AESARARC_% +echo.>> %AESARARC_% +echo [blas]>> %AESARARC_% +echo ldflags=>> %AESARARC_% + +if exist %AESARARC% ( + echo A .aesararc.txt config file already exists, so we will not change it. + echo The default version is in %AESARARC_%, we suggest you check it out. +) else ( + rename %AESARARC_% .aesararc.txt +) + +:end +echo end diff --git a/doc/extending/extending_aesara_solution_1.py b/doc/extending/extending_aesara_solution_1.py index d232756ebd..435a21a5ba 100755 --- a/doc/extending/extending_aesara_solution_1.py +++ b/doc/extending/extending_aesara_solution_1.py @@ -17,7 +17,7 @@ def make_node(self, x, y): x = at.as_tensor_variable(x) y = at.as_tensor_variable(y) outdim = x.ndim - output = TensorType( + output = TensorType.subtype( dtype=aesara.scalar.upcast(x.dtype, y.dtype), shape=[False] * outdim )() return Apply(self, inputs=[x, y], outputs=[output]) @@ -42,10 +42,10 @@ def make_node(self, x, y): x = at.as_tensor_variable(x) y = at.as_tensor_variable(y) outdim = x.ndim - output1 = TensorType( + output1 = TensorType.subtype( dtype=aesara.scalar.upcast(x.dtype, y.dtype), shape=[False] * outdim )() - output2 = TensorType( + output2 = TensorType.subtype( dtype=aesara.scalar.upcast(x.dtype, y.dtype), shape=[False] * outdim )() return Apply(self, inputs=[x, y], outputs=[output1, output2]) diff --git a/doc/extending/other_ops.rst b/doc/extending/other_ops.rst index f1f3765a08..84dc506c7c 100644 --- a/doc/extending/other_ops.rst +++ b/doc/extending/other_ops.rst @@ -54,7 +54,7 @@ you can create output variables like this: .. code-block:: python out_format = inputs[0].format # or 'csr' or 'csc' if the output format is fixed - SparseTensorType(dtype=inputs[0].dtype, format=out_format).make_variable() + SparseTensorType.subtype(dtype=inputs[0].dtype, format=out_format).make_variable() See the sparse :class:`Aesara.sparse.basic.Cast` `Op` code for a good example of a sparse `Op` with Python code. diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py index 5ded264858..1d4debc3fa 100644 --- a/tests/compile/function/test_pfunc.py +++ b/tests/compile/function/test_pfunc.py @@ -751,8 +751,8 @@ def test_sparse_input_aliasing_affecting_inplace_operations(self): # operations are used) and to break the elemwise composition # with some non-elemwise op (here dot) - x = sparse.SparseTensorType("csc", dtype="float64")() - y = sparse.SparseTensorType("csc", dtype="float64")() + x = sparse.SparseTensorType.subtype("csc", dtype="float64")() + y = sparse.SparseTensorType.subtype("csc", dtype="float64")() f = function([In(x, mutable=True), In(y, mutable=True)], (x + y) + (x + y)) # Test 1. If the same variable is given twice diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index b770121134..223180d006 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -234,7 +234,7 @@ def go2(inps, gs): op_linear2 = cls_ofg( [x, w, b], [x * w + b], - grad_overrides=[go1, NullType()(), DisconnectedType()()], + grad_overrides=[go1, NullType.subtype()(), DisconnectedType.subtype()()], ) zz2 = at_sum(op_linear2(xx, ww, bb)) dx2, dw2, db2 = grad( diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py index e724cfd50a..3b9a43c955 100644 --- a/tests/compile/test_debugmode.py +++ b/tests/compile/test_debugmode.py @@ -716,8 +716,8 @@ def make_node(self, v): v = at.as_tensor_variable(v) assert v.type.ndim == 1 type_class = type(v.type) - out_r_type = type_class(dtype=v.dtype, shape=(True, False)) - out_c_type = type_class(dtype=v.dtype, shape=(False, True)) + out_r_type = type_class.subtype(dtype=v.dtype, shape=(True, False)) + out_c_type = type_class.subtype(dtype=v.dtype, shape=(False, True)) return Apply(self, [v], [out_r_type(), out_c_type()]) def perform(self, node, inp, out): diff --git a/tests/compile/test_shared.py b/tests/compile/test_shared.py index 49058a7fee..0c891d8cba 100644 --- a/tests/compile/test_shared.py +++ b/tests/compile/test_shared.py @@ -36,11 +36,11 @@ def test_ctors(self): # test tensor constructor b = shared(np.zeros((5, 5), dtype="int32")) - assert b.type == TensorType("int32", shape=[False, False]) + assert b.type == TensorType.subtype("int32", shape=[False, False]) b = shared(np.random.random((4, 5))) - assert b.type == TensorType("float64", shape=[False, False]) + assert b.type == TensorType.subtype("float64", shape=[False, False]) b = shared(np.random.random((5, 1, 2))) - assert b.type == TensorType("float64", shape=[False, False, False]) + assert b.type == TensorType.subtype("float64", shape=[False, False, False]) assert shared([]).type == generic @@ -67,7 +67,7 @@ def test_create_numpy_strict_false(self): # so creation should work SharedVariable( name="u", - type=TensorType(shape=[False], dtype="float64"), + type=TensorType.subtype(shape=[False], dtype="float64"), value=np.asarray([1.0, 2.0]), strict=False, ) @@ -76,7 +76,7 @@ def test_create_numpy_strict_false(self): # so creation should work SharedVariable( name="u", - type=TensorType(shape=[False], dtype="float64"), + type=TensorType.subtype(shape=[False], dtype="float64"), value=[1.0, 2.0], strict=False, ) @@ -85,7 +85,7 @@ def test_create_numpy_strict_false(self): # so creation should work SharedVariable( name="u", - type=TensorType(shape=[False], dtype="float64"), + type=TensorType.subtype(shape=[False], dtype="float64"), value=[1, 2], # different dtype and not a numpy array strict=False, ) @@ -95,7 +95,7 @@ def test_create_numpy_strict_false(self): try: SharedVariable( name="u", - type=TensorType(shape=[False], dtype="float64"), + type=TensorType.subtype(shape=[False], dtype="float64"), value=dict(), # not an array by any stretch strict=False, ) @@ -109,7 +109,7 @@ def test_use_numpy_strict_false(self): # so creation should work u = SharedVariable( name="u", - type=TensorType(shape=[False], dtype="float64"), + type=TensorType.subtype(shape=[False], dtype="float64"), value=np.asarray([1.0, 2.0]), strict=False, ) diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index e68d42b9cb..15fc682179 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -162,9 +162,9 @@ def test_ambiguous(self): assert str(g) == "FunctionGraph(Op1(x))" def test_constant(self): - x = Constant(MyType(), 2, name="x") + x = Constant(MyType.subtype(), 2, name="x") y = MyVariable("y") - z = Constant(MyType(), 2, name="z") + z = Constant(MyType.subtype(), 2, name="z") e = op1(op1(x, y), y) g = FunctionGraph([y], [e]) OpKeyPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g) @@ -256,7 +256,7 @@ def __init__(self, param): self.param = param def make_node(self): - return Apply(self, [], [MyType()()]) + return Apply(self, [], [MyType.subtype()()]) def perform(self, node, inputs, output_storage): output_storage[0][0] = self.param @@ -275,8 +275,8 @@ def test_straightforward(self): def test_constant_merging(self): x = MyVariable("x") - y = Constant(MyType(), 2, name="y") - z = Constant(MyType(), 2, name="z") + y = Constant(MyType.subtype(), 2, name="y") + z = Constant(MyType.subtype(), 2, name="z") e = op1(op2(x, y), op2(x, y), op2(x, z)) g = FunctionGraph([x, y, z], [e], clone=False) MergeOptimizer().rewrite(g) @@ -311,8 +311,8 @@ def test_merge_outputs(self): def test_identical_constant_args(self): x = MyVariable("x") - y = Constant(MyType(), 2, name="y") - z = Constant(MyType(), 2, name="z") + y = Constant(MyType.subtype(), 2, name="y") + z = Constant(MyType.subtype(), 2, name="z") e1 = op1(y, z) g = FunctionGraph([x, y, z], [e1], clone=False) MergeOptimizer().rewrite(g) @@ -515,8 +515,8 @@ def test_pre_constant_merge(): x = MyVariable("x") y = MyVariable("y") - c1 = Constant(MyType(), 1, "c1") - c2 = Constant(MyType(), 1, "c1") + c1 = Constant(MyType.subtype(), 1, "c1") + c2 = Constant(MyType.subtype(), 1, "c1") o1 = op2(c1, x) o2 = op1(o1, y, c2) @@ -559,8 +559,8 @@ def test_pre_greedy_node_rewriter(): x = MyVariable("x") y = MyVariable("y") - c1 = Constant(MyType(), 1, "c1") - c2 = Constant(MyType(), 2, "c2") + c1 = Constant(MyType.subtype(), 1, "c1") + c2 = Constant(MyType.subtype(), 2, "c2") o1 = op2(c1, c2) o3 = op1(c1, y) o2 = op1(o1, c2, x, o3, o1) diff --git a/tests/graph/rewriting/test_kanren.py b/tests/graph/rewriting/test_kanren.py index 75d8ec037f..bccc369c3e 100644 --- a/tests/graph/rewriting/test_kanren.py +++ b/tests/graph/rewriting/test_kanren.py @@ -98,7 +98,7 @@ def results_filter(results): def test_KanrenRelationSub_multiout(): class MyMultiOutOp(Op): def make_node(self, *inputs): - outputs = [MyType()(), MyType()()] + outputs = [MyType.subtype()(), MyType.subtype()()] return Apply(self, list(inputs), outputs) def perform(self, node, inputs, outputs): diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index 6ce1284794..c98f1317b2 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -72,7 +72,7 @@ def test_cons(): assert car(op1) == CustomOp assert cdr(op1) == (1,) - tt1 = TensorType("float32", [True, False]) + tt1 = TensorType.subtype("float32", [True, False]) assert car(tt1) == TensorType assert cdr(tt1) == ("float32", (1, None)) @@ -130,7 +130,7 @@ def test_etuples(): class MyMultiOutOp(Op): def make_node(self, *inputs): - outputs = [MyType()(), MyType()()] + outputs = [MyType.subtype()(), MyType.subtype()()] return Apply(self, list(inputs), outputs) def perform(self, node, inputs, outputs): @@ -156,8 +156,8 @@ def test_unify_Variable(): assert s == {} # These `Variable`s have no owners - v1 = MyType()() - v2 = MyType()() + v1 = MyType.subtype()() + v2 = MyType.subtype()() assert v1 != v2 @@ -247,8 +247,8 @@ def test_unify_Constant(): def test_unify_Type(): - t1 = TensorType(np.float64, (True, False)) - t2 = TensorType(np.float64, (True, False)) + t1 = TensorType.subtype(np.float64, (True, False)) + t2 = TensorType.subtype(np.float64, (True, False)) # `Type`, `Type` s = unify(t1, t2) @@ -260,8 +260,8 @@ def test_unify_Type(): from aesara.scalar.basic import ScalarType - st1 = ScalarType(np.float64) - st2 = ScalarType(np.float64) + st1 = ScalarType.subtype(np.float64) + st2 = ScalarType.subtype(np.float64) s = unify(st1, st2) assert s == {} diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 9bbd282dd4..01bd2e4402 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -66,7 +66,7 @@ def __repr__(self): def MyVariable(thingy): - return Variable(MyType(thingy), None, None) + return Variable(MyType.subtype(thingy), None, None) class MyOp(Op): @@ -338,8 +338,8 @@ def test_tensorvariable(self): # Get counter value autoname_id = next(Variable.__count__) Variable.__count__ = count(autoname_id) - r1 = TensorType(dtype="int32", shape=())("myvar") - r2 = TensorVariable(TensorType(dtype="int32", shape=()), None) + r1 = TensorType.subtype(dtype="int32", shape=())("myvar") + r2 = TensorVariable(TensorType.subtype(dtype="int32", shape=()), None) r3 = shared(np.random.standard_normal((3, 4))) assert r1.auto_name == "auto_" + str(autoname_id) assert r2.auto_name == "auto_" + str(autoname_id + 1) @@ -731,7 +731,7 @@ def test_clone_get_equiv(): def test_NominalVariable(): - type1 = MyType(1) + type1 = MyType.subtype(1) nv1 = NominalVariable(1, type1) nv2 = NominalVariable(1, type1) @@ -740,13 +740,13 @@ def test_NominalVariable(): assert nv1.equals(nv2) assert hash(nv1) == hash(nv2) - type2 = MyType(2) + type2 = MyType.subtype(2) nv3 = NominalVariable(1, type2) assert not nv1.equals(nv3) assert hash(nv1) != hash(nv3) - type3 = MyType(1) + type3 = MyType.subtype(1) assert type3 == type1 @@ -779,7 +779,7 @@ def test_NominalVariable(): def test_NominalVariable_create_variable_type(): - ttype = TensorType("float64", (None, None)) + ttype = TensorType.subtype("float64", (None, None)) ntv = NominalVariable(0, ttype) assert isinstance(ntv, TensorVariable) diff --git a/tests/graph/test_compute_test_value.py b/tests/graph/test_compute_test_value.py index 821e029f95..3fc3bed7b6 100644 --- a/tests/graph/test_compute_test_value.py +++ b/tests/graph/test_compute_test_value.py @@ -65,7 +65,7 @@ def make_node(self, input): def perform(self, node, inputs, outputs): outputs[0][0] = inputs[0] - test_input = SomeType()() + test_input = SomeType.subtype()() orig_object = object() test_input.tag.test_value = orig_object diff --git a/tests/graph/test_destroyhandler.py b/tests/graph/test_destroyhandler.py index 3470284e66..5dddfe75a1 100644 --- a/tests/graph/test_destroyhandler.py +++ b/tests/graph/test_destroyhandler.py @@ -46,11 +46,11 @@ def __eq__(self, other): def MyVariable(name): - return Variable(MyType(), None, None, name=name) + return Variable(MyType.subtype(), None, None, name=name) def MyConstant(data): - return Constant(MyType(), data=data) + return Constant(MyType.subtype(), data=data) class MyOp(Op): diff --git a/tests/graph/test_features.py b/tests/graph/test_features.py index bd5044fb29..4906b5794c 100644 --- a/tests/graph/test_features.py +++ b/tests/graph/test_features.py @@ -44,7 +44,7 @@ def as_variable(x): for input in inputs: if not isinstance(input.type, MyType): raise Exception("Error 1") - outputs = [MyType(self.name + "_R")()] + outputs = [MyType.subtype(self.name + "_R")()] return Apply(self, inputs, outputs) def __str__(self): @@ -58,7 +58,7 @@ def perform(self, *args, **kwargs): dot = MyOp(2, "Dot") def MyVariable(name): - return Variable(MyType(name), None, None) + return Variable(MyType.subtype(name), None, None) def inputs(): x = MyVariable("x") diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index 8e495ff44c..3df7186fed 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -191,7 +191,7 @@ def test_import_var(self): with pytest.raises(TypeError, match="Computation graph contains.*"): from aesara.graph.null_type import NullType - fg.import_var(NullType()(), "testing") + fg.import_var(NullType.subtype()(), "testing") def test_change_input(self): @@ -695,7 +695,7 @@ def test_empty(self): assert fg.clients == {var1: [], var2: []} def test_nominals(self): - t1 = MyType() + t1 = MyType.subtype() nm = NominalVariable(1, t1) nm2 = NominalVariable(2, t1) diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index 969e674e8b..d768d438a2 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -60,7 +60,7 @@ def make_node(self, *inputs): for input in inputs: if not isinstance(input.type, MyType): raise Exception("Error 1") - outputs = [MyType(sum(input.type.thingy for input in inputs))()] + outputs = [MyType.subtype(sum(input.type.thingy for input in inputs))()] return Apply(self, inputs, outputs) def perform(self, *args, **kwargs): @@ -76,7 +76,7 @@ class NoInputOp(Op): __props__ = () def make_node(self): - return Apply(self, [], [MyType("test")()]) + return Apply(self, [], [MyType.subtype("test")()]) def perform(self, node, inputs, output_storage): output_storage[0][0] = "test Op no input" @@ -86,18 +86,20 @@ class TestOp: # Sanity tests def test_sanity_0(self): - r1, r2 = MyType(1)(), MyType(2)() + r1, r2 = MyType.subtype(1)(), MyType.subtype(2)() node = MyOp.make_node(r1, r2) # Are the inputs what I provided? assert [x for x in node.inputs] == [r1, r2] # Are the outputs what I expect? - assert [x.type for x in node.outputs] == [MyType(3)] + assert [x.type for x in node.outputs] == [MyType.subtype(3)] assert node.outputs[0].owner is node and node.outputs[0].index == 0 # validate def test_validate(self): try: - MyOp(Generic()(), MyType(1)()) # MyOp requires MyType instances + MyOp( + Generic.subtype()(), MyType.subtype(1)() + ) # MyOp requires MyType instances raise Exception("Expected an exception") except Exception as e: if str(e) != "Error 1": @@ -234,5 +236,5 @@ class SomeOp(aesara.tensor.Op): def perform(self, *_): raise NotImplementedError() - x = at.TensorType(dtype="float64", shape=(1,))("x") + x = at.TensorType.subtype(dtype="float64", shape=(1,))("x") assert SomeOp()(x).type == at.dvector diff --git a/tests/graph/test_types.py b/tests/graph/test_types.py index 2b8a30367d..fa188eb605 100644 --- a/tests/graph/test_types.py +++ b/tests/graph/test_types.py @@ -28,30 +28,30 @@ def is_super(self, other): def test_is_super(): - t1 = MyType(1) - t2 = MyType(2) + t1 = MyType.subtype(1) + t2 = MyType.subtype(2) assert t1.is_super(t2) is None - t1_2 = MyType(1) + t1_2 = MyType.subtype(1) assert t1.is_super(t1_2) def test_in_same_class(): - t1 = MyType(1) - t2 = MyType(2) + t1 = MyType.subtype(1) + t2 = MyType.subtype(2) assert t1.in_same_class(t2) is False - t1_2 = MyType(1) + t1_2 = MyType.subtype(1) assert t1.in_same_class(t1_2) def test_convert_variable(): - t1 = MyType(1) - v1 = Variable(MyType(1), None, None) - v2 = Variable(MyType(2), None, None) - v3 = Variable(MyType2(0), None, None) + t1 = MyType.subtype(1) + v1 = Variable(MyType.subtype(1), None, None) + v2 = Variable(MyType.subtype(2), None, None) + v3 = Variable(MyType2.subtype(0), None, None) assert t1.convert_variable(v1) is v1 assert t1.convert_variable(v2) is None @@ -61,5 +61,5 @@ def test_convert_variable(): def test_default_clone(): - mt = MyType(1) + mt = MyType.subtype(1) assert isinstance(mt.clone(1), MyType) diff --git a/tests/graph/utils.py b/tests/graph/utils.py index a3f1c47be1..6122d5ceae 100644 --- a/tests/graph/utils.py +++ b/tests/graph/utils.py @@ -41,15 +41,15 @@ def __repr__(self): def MyVariable(name): - return Variable(MyType(), None, None, name=name) + return Variable(MyType.subtype(), None, None, name=name) def MyConstant(name, data=None): - return Constant(MyType(), data, name=name) + return Constant(MyType.subtype(), data, name=name) def MyVariable2(name): - return Variable(MyType2(), None, None, name=name) + return Variable(MyType2.subtype(), None, None, name=name) class MyOp(Op): @@ -66,7 +66,7 @@ def make_node(self, *inputs): for input in inputs: if not isinstance(input.type, MyType): raise Exception("Error 1") - outputs = [MyType()() for i in range(self.n_outs)] + outputs = [MyType.subtype()() for i in range(self.n_outs)] return Apply(self, inputs, outputs) def perform(self, node, inputs, outputs): @@ -101,7 +101,7 @@ def make_node(self, *inputs): if not isinstance(input.type, MyType): raise Exception("Error 1") - outputs = [MyType2()()] + outputs = [MyType2.subtype()()] return Apply(self, inputs, outputs) diff --git a/tests/link/c/test_basic.py b/tests/link/c/test_basic.py index 645972fc19..a3e61de4af 100644 --- a/tests/link/c/test_basic.py +++ b/tests/link/c/test_basic.py @@ -79,7 +79,7 @@ def __hash__(self): return hash(type(self)) -tdouble = TDouble() +tdouble = TDouble.subtype() def double(name): diff --git a/tests/link/c/test_op.py b/tests/link/c/test_op.py index ee448a8314..b73c9087b3 100644 --- a/tests/link/c/test_op.py +++ b/tests/link/c/test_op.py @@ -26,7 +26,7 @@ from aesara.link.c.type import Generic from aesara.tensor.type import TensorType -tensor_type_0d = TensorType("float64", tuple()) +tensor_type_0d = TensorType.subtype("float64", tuple()) scalar_type = ScalarType("float64") generic_type = Generic() @@ -227,9 +227,9 @@ def test_ExternalCOp_c_code_cache_version(): out_2, err = get_hash(modname, seed=3849) assert err is None - hash_1, msg, _ = out_1.decode().split("\n") + hash_1, msg, *_ = out_1.decode().split("\n") assert msg == "__success__" - hash_2, msg, _ = out_2.decode().split("\n") + hash_2, msg, *_ = out_2.decode().split("\n") assert msg == "__success__" assert hash_1 == hash_2 diff --git a/tests/link/c/test_params_type.py b/tests/link/c/test_params_type.py index 7053c054c7..c5d371410d 100644 --- a/tests/link/c/test_params_type.py +++ b/tests/link/c/test_params_type.py @@ -12,15 +12,15 @@ from tests import unittest_tools as utt -tensor_type_0d = TensorType("float64", tuple()) -scalar_type = ScalarType("float64") -generic_type = Generic() +tensor_type_0d = TensorType.subtype("float64", tuple()) +scalar_type = ScalarType.subtype("float64") +generic_type = Generic.subtype() # A test op to compute `y = a*x^2 + bx + c` for any tensor x, with a, b, c as op params. class QuadraticOpFunc(COp): __props__ = ("a", "b", "c") - params_type = ParamsType(a=tensor_type_0d, b=scalar_type, c=generic_type) + params_type = ParamsType.subtype(a=tensor_type_0d, b=scalar_type, c=generic_type) def __init__(self, a, b, c): self.a = a @@ -103,7 +103,7 @@ def c_code(self, node, name, inputs, outputs, sub): # external file). class QuadraticCOpFunc(ExternalCOp): __props__ = ("a", "b", "c") - params_type = ParamsType(a=tensor_type_0d, b=scalar_type, c=generic_type) + params_type = ParamsType.subtype(a=tensor_type_0d, b=scalar_type, c=generic_type) def __init__(self, a, b, c): super().__init__( @@ -125,17 +125,17 @@ def perform(self, node, inputs, output_storage, coefficients): class TestParamsType: def test_hash_and_eq_params(self): - wp1 = ParamsType( - a=Generic(), - array=TensorType("int64", (False,)), - floatting=ScalarType("float64"), - npy_scalar=TensorType("float64", tuple()), + wp1 = ParamsType.subtype( + a=Generic.subtype(), + array=TensorType.subtype("int64", (False,)), + floatting=ScalarType.subtype("float64"), + npy_scalar=TensorType.subtype("float64", tuple()), ) - wp2 = ParamsType( - a=Generic(), - array=TensorType("int64", (False,)), - floatting=ScalarType("float64"), - npy_scalar=TensorType("float64", tuple()), + wp2 = ParamsType.subtype( + a=Generic.subtype(), + array=TensorType.subtype("int64", (False,)), + floatting=ScalarType.subtype("float64"), + npy_scalar=TensorType.subtype("float64", tuple()), ) w1 = Params( wp1, @@ -155,11 +155,11 @@ def test_hash_and_eq_params(self): assert not (w1 != w2) assert hash(w1) == hash(w2) # Changing attributes names only (a -> other_name). - wp2_other = ParamsType( - other_name=Generic(), - array=TensorType("int64", (False,)), - floatting=ScalarType("float64"), - npy_scalar=TensorType("float64", tuple()), + wp2_other = ParamsType.subtype( + other_name=Generic.subtype(), + array=TensorType.subtype("int64", (False,)), + floatting=ScalarType.subtype("float64"), + npy_scalar=TensorType.subtype("float64", tuple()), ) w2 = Params( wp2_other, @@ -189,41 +189,41 @@ def test_hash_and_eq_params(self): assert w1 != w2 def test_hash_and_eq_params_type(self): - w1 = ParamsType( - a1=TensorType("int64", (False, False)), - a2=TensorType("int64", (False, True, False, False, True)), - a3=Generic(), + w1 = ParamsType.subtype( + a1=TensorType.subtype("int64", (False, False)), + a2=TensorType.subtype("int64", (False, True, False, False, True)), + a3=Generic.subtype(), ) - w2 = ParamsType( - a1=TensorType("int64", (False, False)), - a2=TensorType("int64", (False, True, False, False, True)), - a3=Generic(), + w2 = ParamsType.subtype( + a1=TensorType.subtype("int64", (False, False)), + a2=TensorType.subtype("int64", (False, True, False, False, True)), + a3=Generic.subtype(), ) assert w1 == w2 assert not (w1 != w2) assert hash(w1) == hash(w2) assert w1.name == w2.name # Changing attributes names only. - w2 = ParamsType( - a1=TensorType("int64", (False, False)), - other_name=TensorType( + w2 = ParamsType.subtype( + a1=TensorType.subtype("int64", (False, False)), + other_name=TensorType.subtype( "int64", (False, True, False, False, True) ), # a2 -> other_name - a3=Generic(), + a3=Generic.subtype(), ) assert w1 != w2 # Changing attributes types only. - w2 = ParamsType( - a1=TensorType("int64", (False, False)), - a2=Generic(), # changing class - a3=Generic(), + w2 = ParamsType.subtype( + a1=TensorType.subtype("int64", (False, False)), + a2=Generic.subtype(), # changing class + a3=Generic.subtype(), ) assert w1 != w2 # Changing attributes types characteristics only. - w2 = ParamsType( - a1=TensorType("int64", (False, True)), # changing broadcasting - a2=TensorType("int64", (False, True, False, False, True)), - a3=Generic(), + w2 = ParamsType.subtype( + a1=TensorType.subtype("int64", (False, True)), # changing broadcasting + a2=TensorType.subtype("int64", (False, True, False, False, True)), + a3=Generic.subtype(), ) assert w1 != w2 @@ -238,10 +238,10 @@ def test_params_type_filtering(self): ) random_tensor = np.random.normal(size=size_tensor5).reshape(shape_tensor5) - w = ParamsType( - a1=TensorType("int32", (False, False)), - a2=TensorType("float64", (False, False, False, False, False)), - a3=Generic(), + w = ParamsType.subtype( + a1=TensorType.subtype("int32", (False, False)), + a2=TensorType.subtype("float64", (False, False, False, False, False)), + a3=Generic.subtype(), ) # With a value that does not match the params type. @@ -298,7 +298,10 @@ def test_params_type_filtering(self): def test_params_type_with_enums(self): # Test that we fail if we create a params type with common enum names inside different enum types. try: - ParamsType(enum1=EnumList("A", "B", "C"), enum2=EnumList("A", "B", "F")) + ParamsType.subtype( + enum1=EnumList.subtype("A", "B", "C"), + enum2=EnumList.subtype("A", "B", "F"), + ) except AttributeError: pass else: @@ -308,14 +311,14 @@ def test_params_type_with_enums(self): # Test that we fail if we create a params type with common names in both aliases and constants. try: - ParamsType( - enum1=EnumList(("A", "a"), ("B", "b")), - enum2=EnumList(("ONE", "a"), ("TWO", "two")), + ParamsType.subtype( + enum1=EnumList.subtype(("A", "a"), ("B", "b")), + enum2=EnumList.subtype(("ONE", "a"), ("TWO", "two")), ) except AttributeError: - ParamsType( - enum1=EnumList(("A", "a"), ("B", "b")), - enum2=EnumList(("ONE", "one"), ("TWO", "two")), + ParamsType.subtype( + enum1=EnumList.subtype(("A", "a"), ("B", "b")), + enum2=EnumList.subtype(("ONE", "one"), ("TWO", "two")), ) else: raise Exception( @@ -323,9 +326,9 @@ def test_params_type_with_enums(self): ) # Test that we can access enum values through wrapper directly. - w = ParamsType( - enum1=EnumList("A", ("B", "beta"), "C"), - enum2=EnumList(("D", "delta"), "E", "F"), + w = ParamsType.subtype( + enum1=EnumList.subtype("A", ("B", "beta"), "C"), + enum2=EnumList.subtype(("D", "delta"), "E", "F"), ) assert w.A == 0 and w.B == 1 and w.C == 2 assert w.D == 0 and w.E == 1 and w.F == 2 diff --git a/tests/link/c/test_type.py b/tests/link/c/test_type.py index aedf00c9a7..edc397cd39 100644 --- a/tests/link/c/test_type.py +++ b/tests/link/c/test_type.py @@ -15,7 +15,7 @@ class ProdOp(COp): __props__ = () def make_node(self, i): - return Apply(self, [i], [CDataType("void *", "py_decref")()]) + return Apply(self, [i], [CDataType.subtype("void *", "py_decref")()]) def c_support_code(self, **kwargs): return """ @@ -44,7 +44,7 @@ class GetOp(COp): __props__ = () def make_node(self, c): - return Apply(self, [c], [TensorType("float32", (False,))()]) + return Apply(self, [c], [TensorType.subtype("float32", (False,))()]) def c_support_code(self, **kwargs): return """ @@ -73,7 +73,7 @@ def perform(self, *args, **kwargs): not aesara.config.cxx, reason="G++ not available, so we need to skip this test." ) def test_cdata(): - i = TensorType("float32", (False,))() + i = TensorType.subtype("float32", (False,))() c = ProdOp()(i) i2 = GetOp()(c) mode = None @@ -91,7 +91,7 @@ def test_cdata(): class MyOpEnumList(COp): __props__ = ("op_chosen",) - params_type = EnumList( + params_type = EnumList.subtype( ("ADD", "+"), ("SUB", "-"), ("MULTIPLY", "*"), @@ -164,7 +164,7 @@ def c_code(self, node, name, inputs, outputs, sub): class MyOpCEnumType(COp): __props__ = ("python_value",) - params_type = CEnumType( + params_type = CEnumType.subtype( ("MILLION", "million"), ("BILLION", "billion"), ("TWO_BILLIONS", "two_billions"), @@ -212,14 +212,14 @@ def test_enum_class(self): # Check that invalid enum name raises exception. for invalid_name in ("a", "_A", "0"): try: - EnumList(invalid_name) + EnumList.subtype(invalid_name) except AttributeError: pass else: raise Exception("EnumList with invalid name should fail.") try: - EnumType(**{invalid_name: 0}) + EnumType.subtype(**{invalid_name: 0}) except AttributeError: pass else: @@ -227,15 +227,15 @@ def test_enum_class(self): # Check that invalid enum value raises exception. try: - EnumType(INVALID_VALUE="string is not allowed.") + EnumType.subtype(INVALID_VALUE="string is not allowed.") except TypeError: pass else: raise Exception("EnumType with invalid value should fail.") # Check EnumType. - e1 = EnumType(C1=True, C2=12, C3=True, C4=-1, C5=False, C6=0.0) - e2 = EnumType(C1=1, C2=12, C3=1, C4=-1.0, C5=0.0, C6=0) + e1 = EnumType.subtype(C1=True, C2=12, C3=True, C4=-1, C5=False, C6=0.0) + e2 = EnumType.subtype(C1=1, C2=12, C3=1, C4=-1.0, C5=0.0, C6=0) assert e1 == e2 assert not (e1 != e2) assert hash(e1) == hash(e2) @@ -243,9 +243,9 @@ def test_enum_class(self): assert len((e1.ctype, e1.C1, e1.C2, e1.C3, e1.C4, e1.C5, e1.C6)) == 7 # Check enum with aliases. - e1 = EnumType(A=("alpha", 0), B=("beta", 1), C=2) - e2 = EnumType(A=("alpha", 0), B=("beta", 1), C=2) - e3 = EnumType(A=("a", 0), B=("beta", 1), C=2) + e1 = EnumType.subtype(A=("alpha", 0), B=("beta", 1), C=2) + e2 = EnumType.subtype(A=("alpha", 0), B=("beta", 1), C=2) + e3 = EnumType.subtype(A=("a", 0), B=("beta", 1), C=2) assert e1 == e2 assert e1 != e3 assert e1.filter("beta") == e1.fromalias("beta") == e1.B == 1 @@ -253,9 +253,9 @@ def test_enum_class(self): # Check that invalid alias (same as a constant) raises exception. try: - EnumList(("A", "a"), ("B", "B")) + EnumList.subtype(("A", "a"), ("B", "B")) except TypeError: - EnumList(("A", "a"), ("B", "b")) + EnumList.subtype(("A", "a"), ("B", "b")) else: raise Exception( "Enum with an alias name equal to a constant name should fail." diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index d3d6a1d870..d48dd56194 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -225,7 +225,7 @@ def assert_fn(x, y): @pytest.mark.parametrize( "v, expected, force_scalar, not_implemented", [ - (MyType(), None, False, True), + (MyType.subtype(), None, False, True), (aes.float32, numba.types.float32, False, False), (at.fscalar, numba.types.Array(numba.types.float32, 0, "A"), False, False), (at.fscalar, numba.types.float32, True, False), diff --git a/tests/link/test_link.py b/tests/link/test_link.py index f45c68ab4f..51ce905ca0 100644 --- a/tests/link/test_link.py +++ b/tests/link/test_link.py @@ -67,7 +67,7 @@ def filter(self, data): return float(data) -tdouble = TDouble() +tdouble = TDouble.subtype() def double(name): diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 8609d9e843..04df4a5705 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -347,8 +347,8 @@ def test_true_div(self): xi = int8("xi") yi = int8("yi") - xf = ScalarType(aesara.config.floatX)("xf") - yf = ScalarType(aesara.config.floatX)("yf") + xf = ScalarType.subtype(aesara.config.floatX)("xf") + yf = ScalarType.subtype(aesara.config.floatX)("yf") ei = true_div(xi, yi) fi = aesara.function([xi, yi], ei) diff --git a/tests/scalar/test_type.py b/tests/scalar/test_type.py index 1bce85c9ff..81d2a973c3 100644 --- a/tests/scalar/test_type.py +++ b/tests/scalar/test_type.py @@ -14,7 +14,7 @@ def test_numpy_dtype(): - test_type = ScalarType(np.int32) + test_type = ScalarType.subtype(np.int32) assert test_type.dtype == "int32" @@ -39,7 +39,7 @@ def test_div_types(): def test_filter_float_subclass(): """Make sure `ScalarType.filter` can handle `float` subclasses.""" with config.change_flags(floatX="float64"): - test_type = ScalarType("float64") + test_type = ScalarType.subtype("float64") nan = np.array([np.nan], dtype="float64")[0] assert isinstance(nan, float) @@ -49,7 +49,7 @@ def test_filter_float_subclass(): with config.change_flags(floatX="float32"): # Try again, except this time `nan` isn't a `float` - test_type = ScalarType("float32") + test_type = ScalarType.subtype("float32") nan = np.array([np.nan], dtype="float32")[0] assert isinstance(nan, np.floating) @@ -63,6 +63,6 @@ def test_filter_float_subclass(): def test_clone(): - st = ScalarType("int64") + st = ScalarType.subtype("int64") assert st == st.clone() assert st.clone("float64").dtype == "float64" diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index a3fd87ec47..9dd09e8360 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -418,7 +418,7 @@ def test_getitem_2d(self): pass def test_getitem_scalar(self): - x = SparseTensorType("csr", dtype=config.floatX)() + x = SparseTensorType.subtype("csr", dtype=config.floatX)() self._compile_and_check( [x], [x[2, 2]], @@ -456,7 +456,7 @@ def test_csm_grad(self): ) def test_transpose(self): - x = SparseTensorType("csr", dtype=config.floatX)() + x = SparseTensorType.subtype("csr", dtype=config.floatX)() self._compile_and_check( [x], [x.T], @@ -465,7 +465,7 @@ def test_transpose(self): ) def test_neg(self): - x = SparseTensorType("csr", dtype=config.floatX)() + x = SparseTensorType.subtype("csr", dtype=config.floatX)() self._compile_and_check( [x], [-x], @@ -474,8 +474,8 @@ def test_neg(self): ) def test_add_ss(self): - x = SparseTensorType("csr", dtype=config.floatX)() - y = SparseTensorType("csr", dtype=config.floatX)() + x = SparseTensorType.subtype("csr", dtype=config.floatX)() + y = SparseTensorType.subtype("csr", dtype=config.floatX)() self._compile_and_check( [x, y], [x + y], @@ -487,7 +487,7 @@ def test_add_ss(self): ) def test_add_sd(self): - x = SparseTensorType("csr", dtype=config.floatX)() + x = SparseTensorType.subtype("csr", dtype=config.floatX)() y = matrix() self._compile_and_check( [x, y], @@ -500,8 +500,8 @@ def test_add_sd(self): ) def test_mul_ss(self): - x = SparseTensorType("csr", dtype=config.floatX)() - y = SparseTensorType("csr", dtype=config.floatX)() + x = SparseTensorType.subtype("csr", dtype=config.floatX)() + y = SparseTensorType.subtype("csr", dtype=config.floatX)() self._compile_and_check( [x, y], [x * y], @@ -513,7 +513,7 @@ def test_mul_ss(self): ) def test_mul_sd(self): - x = SparseTensorType("csr", dtype=config.floatX)() + x = SparseTensorType.subtype("csr", dtype=config.floatX)() y = matrix() self._compile_and_check( [x, y], @@ -527,7 +527,7 @@ def test_mul_sd(self): ) def test_remove0(self): - x = SparseTensorType("csr", dtype=config.floatX)() + x = SparseTensorType.subtype("csr", dtype=config.floatX)() self._compile_and_check( [x], [Remove0()(x)], @@ -536,8 +536,8 @@ def test_remove0(self): ) def test_dot(self): - x = SparseTensorType("csc", dtype=config.floatX)() - y = SparseTensorType("csc", dtype=config.floatX)() + x = SparseTensorType.subtype("csc", dtype=config.floatX)() + y = SparseTensorType.subtype("csc", dtype=config.floatX)() self._compile_and_check( [x, y], [Dot()(x, y)], @@ -550,12 +550,12 @@ def test_dot(self): def test_dot_broadcast(self): for x, y in [ - (SparseTensorType("csr", "float32")(), vector()[:, None]), - (SparseTensorType("csr", "float32")(), vector()[None, :]), - (SparseTensorType("csr", "float32")(), matrix()), - (vector()[:, None], SparseTensorType("csr", "float32")()), - (vector()[None, :], SparseTensorType("csr", "float32")()), - (matrix(), SparseTensorType("csr", "float32")()), + (SparseTensorType.subtype("csr", "float32")(), vector()[:, None]), + (SparseTensorType.subtype("csr", "float32")(), vector()[None, :]), + (SparseTensorType.subtype("csr", "float32")(), matrix()), + (vector()[:, None], SparseTensorType.subtype("csr", "float32")()), + (vector()[None, :], SparseTensorType.subtype("csr", "float32")()), + (matrix(), SparseTensorType.subtype("csr", "float32")()), ]: sparse_out = at.dot(x, y) @@ -567,8 +567,8 @@ def test_dot_broadcast(self): assert dense_out.broadcastable == sparse_out.broadcastable def test_structured_dot(self): - x = SparseTensorType("csc", dtype=config.floatX)() - y = SparseTensorType("csc", dtype=config.floatX)() + x = SparseTensorType.subtype("csc", dtype=config.floatX)() + y = SparseTensorType.subtype("csc", dtype=config.floatX)() self._compile_and_check( [x, y], [structured_dot(x, y)], @@ -588,8 +588,8 @@ def test_structured_dot_grad(self): ("csc", StructuredDotGradCSC), ("csr", StructuredDotGradCSR), ]: - x = SparseTensorType(format, dtype=config.floatX)() - y = SparseTensorType(format, dtype=config.floatX)() + x = SparseTensorType.subtype(format, dtype=config.floatX)() + y = SparseTensorType.subtype(format, dtype=config.floatX)() grads = aesara.grad(dense_from_sparse(structured_dot(x, y)).sum(), [x, y]) self._compile_and_check( [x, y], @@ -611,7 +611,7 @@ def test_structured_dot_grad(self): ) def test_dense_from_sparse(self): - x = SparseTensorType("csr", dtype=config.floatX)() + x = SparseTensorType.subtype("csr", dtype=config.floatX)() self._compile_and_check( [x], [dense_from_sparse(x)], @@ -693,7 +693,7 @@ def fn(m): def test_err(self): for ndim in [1, 3]: - t = TensorType(dtype=config.floatX, shape=(False,) * ndim)() + t = TensorType.subtype(dtype=config.floatX, shape=(False,) * ndim)() v = ivector() sub = t[v] @@ -1135,7 +1135,7 @@ def test_csm_properties(self): for format in ("csc", "csr"): for dtype in ("float32", "float64"): - x = SparseTensorType(format, dtype=dtype)() + x = SparseTensorType.subtype(format, dtype=dtype)() f = aesara.function([x], csm_properties(x)) spmat = sp_types[format](random_lil((4, 3), dtype, 3)) @@ -1293,7 +1293,7 @@ def test_upcast(self): for dense_dtype in typenames: for sparse_dtype in typenames: correct_dtype = aesara.scalar.upcast(sparse_dtype, dense_dtype) - a = SparseTensorType("csc", dtype=sparse_dtype)() + a = SparseTensorType.subtype("csc", dtype=sparse_dtype)() b = matrix(dtype=dense_dtype) d = structured_dot(a, b) assert d.type.dtype == correct_dtype @@ -1334,7 +1334,7 @@ def test_opt_unpack(self): return # - kerns = TensorType(dtype="int64", shape=[False])("kerns") + kerns = TensorType.subtype(dtype="int64", shape=[False])("kerns") spmat = sp.sparse.lil_matrix((4, 6), dtype="int64") for i in range(5): # set non-zeros in random locations (row x, col y) @@ -1343,7 +1343,7 @@ def test_opt_unpack(self): spmat[x, y] = np.random.random() * 10 spmat = sp.sparse.csc_matrix(spmat) - images = TensorType(dtype="float32", shape=[False, False])("images") + images = TensorType.subtype(dtype="float32", shape=[False, False])("images") cscmat = CSC(kerns, spmat.indices[: spmat.size], spmat.indptr, spmat.shape) f = aesara.function([kerns, images], structured_dot(cscmat, images.T)) @@ -1380,8 +1380,8 @@ def test_dot_sparse_sparse(self): for sparse_format_a in ["csc", "csr", "bsr"]: for sparse_format_b in ["csc", "csr", "bsr"]: - a = SparseTensorType(sparse_format_a, dtype=sparse_dtype)() - b = SparseTensorType(sparse_format_b, dtype=sparse_dtype)() + a = SparseTensorType.subtype(sparse_format_a, dtype=sparse_dtype)() + b = SparseTensorType.subtype(sparse_format_b, dtype=sparse_dtype)() d = at.dot(a, b) f = aesara.function([a, b], Out(d, borrow=True)) for M, N, K, nnz in [ @@ -1402,7 +1402,7 @@ def test_csc_correct_output_faster_than_scipy(self): sparse_dtype = "float64" dense_dtype = "float64" - a = SparseTensorType("csc", dtype=sparse_dtype)() + a = SparseTensorType.subtype("csc", dtype=sparse_dtype)() b = matrix(dtype=dense_dtype) d = at.dot(a, b) f = aesara.function([a, b], Out(d, borrow=True)) @@ -1450,7 +1450,7 @@ def test_csr_correct_output_faster_than_scipy(self): sparse_dtype = "float32" dense_dtype = "float32" - a = SparseTensorType("csr", dtype=sparse_dtype)() + a = SparseTensorType.subtype("csr", dtype=sparse_dtype)() b = matrix(dtype=dense_dtype) d = at.dot(a, b) f = aesara.function([a, b], d) @@ -1572,8 +1572,8 @@ def test_sparse_sparse(self): ("csr", "csc"), ("csr", "csr"), ]: - x = sparse.SparseTensorType(format=x_f, dtype=d1)("x") - y = sparse.SparseTensorType(format=x_f, dtype=d2)("x") + x = sparse.SparseTensorType.subtype(format=x_f, dtype=d1)("x") + y = sparse.SparseTensorType.subtype(format=x_f, dtype=d2)("x") def f_a(x, y): return x * y @@ -1889,7 +1889,7 @@ def test(self): def test_shape_i(): sparse_dtype = "float32" - a = SparseTensorType("csr", dtype=sparse_dtype)() + a = SparseTensorType.subtype("csr", dtype=sparse_dtype)() f = aesara.function([a], a.shape[1]) assert f(sp.sparse.csr_matrix(random_lil((100, 10), sparse_dtype, 3))) == 10 @@ -1899,7 +1899,7 @@ def test_shape(): # does not actually create a dense tensor in the process. sparse_dtype = "float32" - a = SparseTensorType("csr", dtype=sparse_dtype)() + a = SparseTensorType.subtype("csr", dtype=sparse_dtype)() f = aesara.function([a], a.shape) assert np.all( f(sp.sparse.csr_matrix(random_lil((100, 10), sparse_dtype, 3))) == (100, 10) @@ -1958,8 +1958,8 @@ def test_sparse_shared_memory(): a = random_lil((3, 4), "float32", 3).tocsr() m1 = random_lil((4, 4), "float32", 3).tocsr() m2 = random_lil((4, 4), "float32", 3).tocsr() - x = SparseTensorType("csr", dtype="float32")() - y = SparseTensorType("csr", dtype="float32")() + x = SparseTensorType.subtype("csr", dtype="float32")() + y = SparseTensorType.subtype("csr", dtype="float32")() sdot = sparse.structured_dot z = sdot(x * 3, m1) + sdot(y * 2, m2) @@ -3176,11 +3176,11 @@ def structured_function(*args): def test_useless_conj(): - x = sparse.SparseTensorType("csr", dtype="complex128")() + x = sparse.SparseTensorType.subtype("csr", dtype="complex128")() assert x.conj() is not x # No conjugate when the data type isn't complex - x = sparse.SparseTensorType("csr", dtype="float64")() + x = sparse.SparseTensorType.subtype("csr", dtype="float64")() assert x.conj() is x @@ -3200,7 +3200,7 @@ def test_mul_s_v(self): for format in ("csr", "csc"): for dtype in ("float32", "float64"): - x = sparse.SparseTensorType(format, dtype=dtype)() + x = sparse.SparseTensorType.subtype(format, dtype=dtype)() y = vector(dtype=dtype) f = aesara.function([x, y], mul_s_v(x, y)) @@ -3228,7 +3228,7 @@ def test_structured_add_s_v(self): for format in ("csr", "csc"): for dtype in ("float32", "float64"): - x = sparse.SparseTensorType(format, dtype=dtype)() + x = sparse.SparseTensorType.subtype(format, dtype=dtype)() y = vector(dtype=dtype) f = aesara.function([x, y], structured_add_s_v(x, y)) @@ -3275,7 +3275,7 @@ def test_op_sd(self): variable, data = sparse_random_inputs( format, shape=(10, 10), out_dtype=dtype, n=2, p=0.1 ) - variable[1] = TensorType(dtype=dtype, shape=(False, False))() + variable[1] = TensorType.subtype(dtype=dtype, shape=(False, False))() data[1] = data[1].toarray() f = aesara.function(variable, self.op(*variable)) diff --git a/tests/sparse/test_type.py b/tests/sparse/test_type.py index 5843e9c938..1ea7601fae 100644 --- a/tests/sparse/test_type.py +++ b/tests/sparse/test_type.py @@ -7,20 +7,20 @@ def test_SparseTensorType_constructor(): - st = SparseTensorType("csc", "float64") + st = SparseTensorType.subtype("csc", "float64") assert st.format == "csc" assert st.shape == (None, None) - st = SparseTensorType("bsr", "float64", shape=(None, 1)) + st = SparseTensorType.subtype("bsr", "float64", shape=(None, 1)) assert st.format == "bsr" assert st.shape == (None, 1) with pytest.raises(ValueError): - SparseTensorType("blah", "float64") + SparseTensorType.subtype("blah", "float64") def test_SparseTensorType_clone(): - st = SparseTensorType("csr", "float64", shape=(3, None)) + st = SparseTensorType.subtype("csr", "float64", shape=(3, None)) assert st == st.clone() st_clone = st.clone(format="csc") diff --git a/tests/tensor/nnet/speed_test_conv.py b/tests/tensor/nnet/speed_test_conv.py index 0a413c9848..7dd4d6719a 100644 --- a/tests/tensor/nnet/speed_test_conv.py +++ b/tests/tensor/nnet/speed_test_conv.py @@ -39,7 +39,7 @@ def flip(kern, kshp): global_rng = np.random.default_rng(3423489) -dmatrix4 = TensorType("float64", (False, False, False, False)) +dmatrix4 = TensorType.subtype("float64", (False, False, False, False)) def exec_multilayer_conv_nnet_old( diff --git a/tests/tensor/nnet/test_abstract_conv.py b/tests/tensor/nnet/test_abstract_conv.py index 31a3df7aa3..c01dd02ef7 100644 --- a/tests/tensor/nnet/test_abstract_conv.py +++ b/tests/tensor/nnet/test_abstract_conv.py @@ -2529,7 +2529,7 @@ def setup_method(self): self.ref_mode = "FAST_RUN" def test_fwd(self): - tensor6 = TensorType(config.floatX, (False,) * 6) + tensor6 = TensorType.subtype(config.floatX, (False,) * 6) img_sym = tensor4("img") kern_sym = tensor6("kern") ref_kern_sym = tensor4("ref_kern") @@ -2652,7 +2652,7 @@ def conv_gradweight(inputs_val, output_val): utt.verify_grad(conv_gradweight, [img, top], mode=self.mode, eps=1) def test_gradinput(self): - tensor6 = TensorType(config.floatX, (False,) * 6) + tensor6 = TensorType.subtype(config.floatX, (False,) * 6) kern_sym = tensor6("kern") top_sym = tensor4("top") ref_kern_sym = tensor4("ref_kern") diff --git a/tests/tensor/nnet/test_batchnorm.py b/tests/tensor/nnet/test_batchnorm.py index b5c57d6117..56d71ae4de 100644 --- a/tests/tensor/nnet/test_batchnorm.py +++ b/tests/tensor/nnet/test_batchnorm.py @@ -495,7 +495,7 @@ def test_batch_normalization_train_broadcast(): params_dimshuffle[axis] = i # construct non-broadcasted parameter variables - param_type = TensorType(x.dtype, (False,) * len(non_bc_axes)) + param_type = TensorType.subtype(x.dtype, (False,) * len(non_bc_axes)) scale, bias, running_mean, running_var = ( param_type(n) for n in ("scale", "bias", "running_mean", "running_var") ) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 0301909c82..8646ce46d0 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -689,7 +689,7 @@ def setup_method(self): def test_consecutive(self): x = fmatrix() - o = Elemwise(aes.Cast(aes.ScalarType("float64")))(x.astype("float64")) + o = Elemwise(aes.Cast(aes.ScalarType.subtype("float64")))(x.astype("float64")) f = function([x], o, mode=self.mode) dx = np.random.random((5, 4)).astype("float32") f(dx) @@ -698,7 +698,7 @@ def test_consecutive(self): assert isinstance(topo[0].op.scalar_op, aes.basic.Cast) x = dmatrix() - o = Elemwise(aes.Cast(aes.ScalarType("float32")))(x.astype("float32")) + o = Elemwise(aes.Cast(aes.ScalarType.subtype("float32")))(x.astype("float32")) f = function([x], o, mode=self.mode) dx = np.random.random((5, 4)) f(dx) @@ -709,7 +709,9 @@ def test_consecutive(self): def test_upcast(self): # Upcast followed by any other cast x = fmatrix() - o = Elemwise(aes.Cast(aes.ScalarType("complex128")))(x.astype("complex64")) + o = Elemwise(aes.Cast(aes.ScalarType.subtype("complex128")))( + x.astype("complex64") + ) f = function([x], o, mode=self.mode) dx = np.random.random((5, 4)).astype("float32") f(dx) @@ -719,7 +721,7 @@ def test_upcast(self): # Upcast followed by a downcast back to the base type x = fmatrix() - o = Elemwise(aes.Cast(aes.ScalarType("float32")))(x.astype("float64")) + o = Elemwise(aes.Cast(aes.ScalarType.subtype("float32")))(x.astype("float64")) f = function([x], o, mode=self.mode) dx = np.random.random((5, 4)).astype("float32") f(dx) @@ -730,7 +732,7 @@ def test_upcast(self): # Downcast followed by an upcast back to the base type # The rewrite shouldn't be applied x = dmatrix() - o = Elemwise(aes.Cast(aes.ScalarType("float64")))(x.astype("float32")) + o = Elemwise(aes.Cast(aes.ScalarType.subtype("float64")))(x.astype("float32")) f = function([x], o, mode=self.mode) dx = np.random.random((5, 4)) f(dx) @@ -1316,7 +1318,7 @@ def test_local_join_make_vector(): ], ) def test_local_tensor_scalar_tensor(dtype): - t_type = TensorType(dtype=dtype, shape=()) + t_type = TensorType.subtype(dtype=dtype, shape=()) t = t_type() s = at.scalar_from_tensor(t) t2 = at.tensor_from_scalar(s) @@ -1346,7 +1348,7 @@ def test_local_tensor_scalar_tensor(dtype): ], ) def test_local_scalar_tensor_scalar(dtype): - s_type = aes.ScalarType(dtype=dtype) + s_type = aes.ScalarType.subtype(dtype=dtype) s = s_type() t = at.tensor_from_scalar(s) s2 = at.scalar_from_tensor(t) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index cfb9b6a61d..d0d2ceaed7 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -71,9 +71,9 @@ def ds(x, y): def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)): - x = TensorType(shape=xbc, dtype="float64")("x") - y = TensorType(shape=ybc, dtype="float64")("y") - z = TensorType(shape=zbc, dtype="float64")("z") + x = TensorType.subtype(shape=xbc, dtype="float64")("x") + y = TensorType.subtype(shape=ybc, dtype="float64")("y") + z = TensorType.subtype(shape=zbc, dtype="float64")("z") return x, y, z @@ -205,10 +205,10 @@ def test_dimshuffle_on_broadcastable(self): def test_local_useless_dimshuffle_in_reshape(): - vec = TensorType(shape=(False,), dtype="float64")("vector") - mat = TensorType(shape=(False, False), dtype="float64")("mat") - row = TensorType(shape=(True, False), dtype="float64")("row") - col = TensorType(shape=(False, True), dtype="float64")("col") + vec = TensorType.subtype(shape=(False,), dtype="float64")("vector") + mat = TensorType.subtype(shape=(False, False), dtype="float64")("mat") + row = TensorType.subtype(shape=(True, False), dtype="float64")("row") + col = TensorType.subtype(shape=(False, True), dtype="float64")("col") reshape_dimshuffle_vector = reshape(vec.dimshuffle("x", 0), vec.shape) reshape_dimshuffle_mat = reshape(mat.dimshuffle("x", 0, "x", 1), mat.shape) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index a73632db9d..914baf3886 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -161,9 +161,9 @@ def rewrite(g, level="fast_run"): def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)): - x = TensorType(shape=xbc, dtype="float64")("x") - y = TensorType(shape=ybc, dtype="float64")("y") - z = TensorType(shape=zbc, dtype="float64")("z") + x = TensorType.subtype(shape=xbc, dtype="float64")("x") + y = TensorType.subtype(shape=ybc, dtype="float64")("y") + z = TensorType.subtype(shape=zbc, dtype="float64")("z") return x, y, z @@ -3558,7 +3558,7 @@ def test_local_reduce_broadcast_all_0(self): at_max, at_min, ]: - x = TensorType("int64", (True, True, True))() + x = TensorType.subtype("int64", (True, True, True))() f = function([x], [fct(x)], mode=self.mode) assert not any( isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort() @@ -3573,7 +3573,7 @@ def test_local_reduce_broadcast_all_1(self): at_max, at_min, ]: - x = TensorType("int64", (True, True))() + x = TensorType.subtype("int64", (True, True))() f = function([x], [fct(x, axis=[0, 1])], mode=self.mode) assert not any( isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort() @@ -3588,7 +3588,7 @@ def test_local_reduce_broadcast_some_0(self): at_max, at_min, ]: - x = TensorType("int64", (True, False, True))() + x = TensorType.subtype("int64", (True, False, True))() f = function([x], [fct(x, axis=[0, 1])], mode=self.mode) order = f.maker.fgraph.toposort() @@ -3613,7 +3613,7 @@ def test_local_reduce_broadcast_some_1(self): at_max, at_min, ]: - x = TensorType("int64", (True, True, True))() + x = TensorType.subtype("int64", (True, True, True))() f = function([x], [fct(x, axis=[0, 2])], mode=self.mode) assert not any( isinstance(node.op, CAReduce) for node in f.maker.fgraph.toposort() @@ -4092,7 +4092,7 @@ def test_local_log_sum_exp_maximum(): check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op) # If the sum is performed with keepdims=True - x = TensorType(dtype="floatX", shape=(False, True, False))("x") + x = TensorType.subtype(dtype="floatX", shape=(False, True, False))("x") sum_keepdims_op = x.sum(axis=(0, 1), keepdims=True).owner.op check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=sum_keepdims_op) diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index 09dc0585d0..e4fc4da146 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -516,7 +516,7 @@ def __eq__(self, other): class MyVariable(Variable): pass - x = MyVariable(MyType(), None, None) + x = MyVariable(MyType.subtype(), None, None) s = Shape_i(0)(x) fgraph = FunctionGraph(outputs=[s], clone=False) _ = rewrite_graph(fgraph, clone=False) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 754dfc6995..bdfa28c6b3 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -1939,7 +1939,7 @@ def filter(self, *args, **kwargs): def __eq__(self, other): return isinstance(other, MyType) and other.thingy == self.thingy - x = shape(Variable(MyType(), None, None))[0] + x = shape(Variable(MyType.subtype(), None, None))[0] assert not local_subtensor_shape_constant.transform(None, x.owner) diff --git a/tests/tensor/signal/test_conv.py b/tests/tensor/signal/test_conv.py index 62cfbb0cc1..6aa39ec36e 100644 --- a/tests/tensor/signal/test_conv.py +++ b/tests/tensor/signal/test_conv.py @@ -16,8 +16,8 @@ def validate(self, image_shape, filter_shape, out_dim, verify_grad=True): image_dim = len(image_shape) filter_dim = len(filter_shape) - input = TensorType("float64", [False] * image_dim)() - filters = TensorType("float64", [False] * filter_dim)() + input = TensorType.subtype("float64", [False] * image_dim)() + filters = TensorType.subtype("float64", [False] * filter_dim)() bsize = image_shape[0] if image_dim != 3: diff --git a/tests/tensor/signal/test_pool.py b/tests/tensor/signal/test_pool.py index 4539090236..4e5dcb1476 100644 --- a/tests/tensor/signal/test_pool.py +++ b/tests/tensor/signal/test_pool.py @@ -1122,7 +1122,7 @@ def test_max_pool_2d_6D(self): rng = np.random.default_rng(utt.fetch_seed()) maxpoolshps = [(3, 2)] imval = rng.random((2, 1, 1, 1, 3, 4)) - images = TensorType("float64", [False] * 6)() + images = TensorType.subtype("float64", [False] * 6)() for maxpoolshp, ignore_border, mode in product( maxpoolshps, diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 509b651085..551b8aed5a 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -511,7 +511,7 @@ def perform(self, *args, **kwargs): def test_constant(): - int8_vector_type = TensorType(dtype="int8", shape=(False,)) + int8_vector_type = TensorType.subtype(dtype="int8", shape=(False,)) # Make sure we return a `TensorConstant` unchanged x = TensorConstant(int8_vector_type, [1, 2]) @@ -519,7 +519,7 @@ def test_constant(): assert y is x # Make sure we can add and remove broadcastable dimensions - int8_scalar_type = TensorType(dtype="int8", shape=()) + int8_scalar_type = TensorType.subtype(dtype="int8", shape=()) x_data = np.array(2, dtype="int8") x = TensorConstant(int8_scalar_type, x_data) @@ -575,17 +575,17 @@ def test_list(self): as_tensor_variable(bad_apply_var) def test_ndim_strip_leading_broadcastable(self): - x = TensorType(config.floatX, (True, False))("x") + x = TensorType.subtype(config.floatX, (True, False))("x") x = as_tensor_variable(x, ndim=1) assert x.ndim == 1 def test_ndim_all_broadcastable(self): - x = TensorType(config.floatX, (True, True))("x") + x = TensorType.subtype(config.floatX, (True, True))("x") res = as_tensor_variable(x, ndim=0) assert res.ndim == 0 def test_ndim_incompatible(self): - x = TensorType(config.floatX, (True, False))("x") + x = TensorType.subtype(config.floatX, (True, False))("x") with pytest.raises(ValueError, match="^Tensor of type.*"): as_tensor_variable(x, ndim=0) @@ -656,12 +656,12 @@ def test_constant_consistency(self, x, y): def test_constant_identity(self): # Values that are already `TensorType`s shouldn't be recreated by # `as_tensor_variable` - x_scalar = TensorConstant(TensorType(dtype="int8", shape=()), 2) + x_scalar = TensorConstant(TensorType.subtype(dtype="int8", shape=()), 2) a_scalar = as_tensor_variable(x_scalar) assert x_scalar is a_scalar x_vector = TensorConstant( - TensorType(dtype="int8", shape=(False,)), + TensorType.subtype(dtype="int8", shape=(False,)), np.array([1, 2], dtype="int8"), ) a_vector = as_tensor_variable(x_vector) @@ -1715,8 +1715,8 @@ def test_broadcastable_flag_assignment_mixed_thisaxes(self): # We can't set the value| with pytest.raises(TypeError): b.set_value(rng.random((3, 4, 1)).astype(self.floatX)) - a = TensorType(dtype=self.floatX, shape=[False, False, True])() - b = TensorType(dtype=self.floatX, shape=[True, False, True])() + a = TensorType.subtype(dtype=self.floatX, shape=[False, False, True])() + b = TensorType.subtype(dtype=self.floatX, shape=[True, False, True])() c = self.join_op(0, a, b) f = function([a, b], c, mode=self.mode) bad_b_val = rng.random((3, 4, 1)).astype(self.floatX) @@ -1774,19 +1774,19 @@ def test_broadcastable_single_input_broadcastable_dimension(self): def test_broadcastable_flags_many_dims_and_inputs(self): # Test that the right broadcastable flags get set for a join # with many inputs and many input dimensions. - a = TensorType( + a = TensorType.subtype( dtype=self.floatX, shape=[True, False, True, False, False, False] )() - b = TensorType( + b = TensorType.subtype( dtype=self.floatX, shape=[True, True, True, False, False, False] )() - c = TensorType( + c = TensorType.subtype( dtype=self.floatX, shape=[True, False, False, False, False, False] )() - d = TensorType( + d = TensorType.subtype( dtype=self.floatX, shape=[True, False, True, True, False, True] )() - e = TensorType( + e = TensorType.subtype( dtype=self.floatX, shape=[True, False, True, False, False, True] )() f = self.join_op(0, a, b, c, d, e) @@ -1881,8 +1881,8 @@ def get_mat(s1, s2): def test_rebroadcast(self): # Regression test for a crash that used to happen when rebroadcasting. - x = TensorType(self.floatX, [False, False, True])() - u = TensorType(self.floatX, [False, False, True])() + x = TensorType.subtype(self.floatX, [False, False, True])() + u = TensorType.subtype(self.floatX, [False, False, True])() # This line used to crash. at.concatenate([x, -u], axis=2) @@ -1939,7 +1939,7 @@ def test_split_neg(self): f() def test_split_static_shape(self): - x = TensorType("floatX", shape=(5,))("x") + x = TensorType.subtype("floatX", shape=(5,))("x") s = iscalar("s") y = Split(2)(x, 0, [s, 5 - s])[0] assert y.type.shape == (None,) @@ -2118,7 +2118,7 @@ def test_flatten_ndim2(): def test_flatten_ndim2_of_3(): - a = TensorType("float64", (False, False, False))() + a = TensorType.subtype("float64", (False, False, False))() c = flatten(a, 2) f = inplace_func([a], c) a_val = _asarray([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype="float64") @@ -2135,23 +2135,23 @@ def test_flatten_broadcastable(): # Ensure that the broadcastable pattern of the output is coherent with # that of the input - inp = TensorType("float64", (False, False, False, False))() + inp = TensorType.subtype("float64", (False, False, False, False))() out = flatten(inp, ndim=2) assert out.broadcastable == (False, False) - inp = TensorType("float64", (False, False, False, True))() + inp = TensorType.subtype("float64", (False, False, False, True))() out = flatten(inp, ndim=2) assert out.broadcastable == (False, False) - inp = TensorType("float64", (False, True, False, True))() + inp = TensorType.subtype("float64", (False, True, False, True))() out = flatten(inp, ndim=2) assert out.broadcastable == (False, False) - inp = TensorType("float64", (False, True, True, True))() + inp = TensorType.subtype("float64", (False, True, True, True))() out = flatten(inp, ndim=2) assert out.broadcastable == (False, True) - inp = TensorType("float64", (True, False, True, True))() + inp = TensorType.subtype("float64", (True, False, True, True))() out = flatten(inp, ndim=3) assert out.broadcastable == (True, False, True) @@ -2949,7 +2949,7 @@ def test_3b_2(self): # input.type.broadcastable = (False, True, False), # p.type.broadcastable = (False, False). - input = TensorType("floatX", (False, True, False))() + input = TensorType.subtype("floatX", (False, True, False))() p = imatrix() out = permute_row_elements(input, p) permute = function([input, p], out) @@ -3317,7 +3317,7 @@ def test_make_vector(self): assert get_scalar_constant_value(mv[np.int32(0)]) == 1 assert get_scalar_constant_value(mv[np.int64(1)]) == 2 assert get_scalar_constant_value(mv[np.uint(2)]) == 3 - t = aes.ScalarType("int64") + t = aes.ScalarType.subtype("int64") with pytest.raises(NotScalarConstantError): get_scalar_constant_value(mv[t()]) @@ -3532,7 +3532,7 @@ def _generator(self): for d in range(1, dims + 1): # Create a TensorType of the same dimensions as # as the data we want to test. - x = TensorType(dtype=config.floatX, shape=(False,) * d)("x") + x = TensorType.subtype(dtype=config.floatX, shape=(False,) * d)("x") # Make a slice of the test data that has the # dimensions we need by doing xv[0,...,0] diff --git a/tests/tensor/test_casting.py b/tests/tensor/test_casting.py index e7f4e63fc5..25fb735aba 100644 --- a/tests/tensor/test_casting.py +++ b/tests/tensor/test_casting.py @@ -75,7 +75,7 @@ def test_illegal(self): ), ) def test_basic(self, type1, type2, converter): - x = TensorType(dtype=type1, shape=(False,))() + x = TensorType.subtype(dtype=type1, shape=(False,))() y = converter(x) f = function([In(x, strict=True)], y) a = np.arange(10, dtype=type1) @@ -86,8 +86,8 @@ def test_convert_to_complex(self): val64 = np.ones(3, dtype="complex64") + 0.5j val128 = np.ones(3, dtype="complex128") + 0.5j - vec64 = TensorType("complex64", (False,))() - vec128 = TensorType("complex128", (False,))() + vec64 = TensorType.subtype("complex64", (False,))() + vec128 = TensorType.subtype("complex128", (False,))() f = function([vec64], _convert_to_complex128(vec64)) # we need to compare with the same type. diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 6bd514f277..858831b932 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -57,12 +57,12 @@ def with_linker(self, linker): ((1,), ("x", "x"), (1, 1)), ]: ib = [(entry == 1) for entry in xsh] - x = self.type(self.dtype, ib)("x") + x = self.type.subtype(self.dtype, ib)("x") e = self.op(ib, shuffle)(x) f = aesara.function([x], e, mode=Mode(linker=linker)) assert f(np.ones(xsh, dtype=self.dtype)).shape == zsh # test that DimShuffle.infer_shape work correctly - x = self.type(self.dtype, ib)("x") + x = self.type.subtype(self.dtype, ib)("x") e = self.op(ib, shuffle)(x) f = aesara.function( [x], e.shape, mode=Mode(linker=linker), on_unused_input="ignore" @@ -71,13 +71,13 @@ def with_linker(self, linker): # Test when we drop a axis that is not broadcastable ib = [False, True, False] - x = self.type(self.dtype, ib)("x") + x = self.type.subtype(self.dtype, ib)("x") with pytest.raises(ValueError): self.op(ib, shuffle) # Test when we drop a axis that don't have shape 1 ib = [True, True, False] - x = self.type(self.dtype, ib)("x") + x = self.type.subtype(self.dtype, ib)("x") e = self.op(ib, (1, 2))(x) f = aesara.function([x], e.shape, mode=Mode(linker=linker)) with pytest.raises(TypeError): @@ -86,7 +86,7 @@ def with_linker(self, linker): # Test that we can't take a dimensions multiple time xsh, shuffle, zsh = ((1, 1, 4), (0, 1, 2, 0), (1, 4)) ib = [False, True, False] - x = self.type(self.dtype, ib)("x") + x = self.type.subtype(self.dtype, ib)("x") with pytest.raises(ValueError): DimShuffle(ib, shuffle) @@ -112,7 +112,7 @@ def test_infer_shape(self): ((1,), ("x", "x")), ]: ib = [(entry == 1) for entry in xsh] - adtens = self.type(self.dtype, ib)("x") + adtens = self.type.subtype(self.dtype, ib)("x") adtens_val = np.ones(xsh, dtype=self.dtype) self._compile_and_check( [adtens], @@ -123,7 +123,7 @@ def test_infer_shape(self): ) def test_too_big_rank(self): - x = self.type(self.dtype, shape=())() + x = self.type.subtype(self.dtype, shape=())() y = x.dimshuffle(("x",) * (np.MAXDIMS + 1)) with pytest.raises(ValueError): y.eval({x: 0}) @@ -227,22 +227,26 @@ def with_linker(self, linker, op, type, rand_val): ((), ()), ]: if shape_info == "complete": - x_type = type(aesara.config.floatX, shape=xsh) - y_type = type(aesara.config.floatX, shape=ysh) + x_type = type.subtype(aesara.config.floatX, shape=xsh) + y_type = type.subtype(aesara.config.floatX, shape=ysh) elif shape_info == "only_broadcastable": # This condition is here for backwards compatibility, when the only # type shape provided by Aesara was broadcastable/non-broadcastable - x_type = type( + x_type = type.subtype( aesara.config.floatX, broadcastable=[(entry == 1) for entry in xsh], ) - y_type = type( + y_type = type.subtype( aesara.config.floatX, broadcastable=[(entry == 1) for entry in ysh], ) else: - x_type = type(aesara.config.floatX, shape=[None for _ in xsh]) - y_type = type(aesara.config.floatX, shape=[None for _ in ysh]) + x_type = type.subtype( + aesara.config.floatX, shape=[None for _ in xsh] + ) + y_type = type.subtype( + aesara.config.floatX, shape=[None for _ in ysh] + ) x = x_type("x") y = y_type("y") @@ -278,22 +282,26 @@ def with_linker_inplace(self, linker, op, type, rand_val): ((), ()), ]: if shape_info == "complete": - x_type = type(aesara.config.floatX, shape=xsh) - y_type = type(aesara.config.floatX, shape=ysh) + x_type = type.subtype(aesara.config.floatX, shape=xsh) + y_type = type.subtype(aesara.config.floatX, shape=ysh) elif shape_info == "only_broadcastable": # This condition is here for backwards compatibility, when the only # type shape provided by Aesara was broadcastable/non-broadcastable - x_type = type( + x_type = type.subtype( aesara.config.floatX, broadcastable=[(entry == 1) for entry in xsh], ) - y_type = type( + y_type = type.subtype( aesara.config.floatX, broadcastable=[(entry == 1) for entry in ysh], ) else: - x_type = type(aesara.config.floatX, shape=[None for _ in xsh]) - y_type = type(aesara.config.floatX, shape=[None for _ in ysh]) + x_type = type.subtype( + aesara.config.floatX, shape=[None for _ in xsh] + ) + y_type = type.subtype( + aesara.config.floatX, shape=[None for _ in ysh] + ) x = x_type("x") y = y_type("y") @@ -349,8 +357,8 @@ def test_fill(self): [self.type, self.ctype], [self.rand_val, self.rand_cval], ): - x = t(aesara.config.floatX, (False, False))("x") - y = t(aesara.config.floatX, (True, True))("y") + x = t.subtype(aesara.config.floatX, (False, False))("x") + y = t.subtype(aesara.config.floatX, (True, True))("y") e = op(aes.Second(aes.transfer_type(0)), {0: 0})(x, y) f = make_function(linker().accept(FunctionGraph([x, y], [e]))) xv = rval((5, 5)) @@ -365,8 +373,8 @@ def test_fill_var(self): def test_fill_grad(self): # Fix bug reported at # https://groups.google.com/d/topic/theano-users/nQshB8gUA6k/discussion - x = TensorType(config.floatX, (False, True, False))("x") - y = TensorType(config.floatX, (False, True, False))("y") + x = TensorType.subtype(config.floatX, (False, True, False))("x") + y = TensorType.subtype(config.floatX, (False, True, False))("y") e = second(x, y) aesara.grad(e.sum(), y) @@ -380,8 +388,8 @@ def test_weird_strides(self): [self.type, self.ctype], [self.rand_val, self.rand_cval], ): - x = t(aesara.config.floatX, (False,) * 5)("x") - y = t(aesara.config.floatX, (False,) * 5)("y") + x = t.subtype(aesara.config.floatX, (False,) * 5)("x") + y = t.subtype(aesara.config.floatX, (False,) * 5)("y") e = op(aes.add)(x, y) f = make_function(linker().accept(FunctionGraph([x, y], [e]))) xv = rval((2, 2, 2, 2, 2)) @@ -399,7 +407,7 @@ def test_same_inputs(self): [self.type, self.ctype], [self.rand_val, self.rand_cval], ): - x = t(aesara.config.floatX, (False,) * 2)("x") + x = t.subtype(aesara.config.floatX, (False,) * 2)("x") e = op(aes.add)(x, x) f = make_function(linker().accept(FunctionGraph([x], [e]))) xv = rval((2, 2)) @@ -440,7 +448,7 @@ def with_mode( for xsh, tosum in self.cases: if dtype == "floatX": dtype = aesara.config.floatX - x = self.type(dtype, [(entry == 1) for entry in xsh])("x") + x = self.type.subtype(dtype, [(entry == 1) for entry in xsh])("x") d = {} if pre_scalar_op is not None: d = {"pre_scalar_op": pre_scalar_op} @@ -548,7 +556,7 @@ def with_mode( # GpuCAReduce don't implement all cases when size is 0 assert xv.size == 0 - x = self.type(dtype, [(entry == 1) for entry in xsh])("x") + x = self.type.subtype(dtype, [(entry == 1) for entry in xsh])("x") if tensor_op is None: e = self.op(scalar_op, axis=tosum)(x) else: @@ -653,7 +661,7 @@ def test_infer_shape(self, dtype=None, pre_scalar_op=None): if dtype is None: dtype = aesara.config.floatX for xsh, tosum in self.cases: - x = self.type(dtype, [(entry == 1) for entry in xsh])("x") + x = self.type.subtype(dtype, [(entry == 1) for entry in xsh])("x") if pre_scalar_op is not None: x = pre_scalar_op(x) if tosum is None: @@ -749,8 +757,8 @@ def test_infer_shape(self): ((2, 3, 4, 1), (2, 3, 4, 5)), ]: dtype = aesara.config.floatX - t_left = TensorType(dtype, [(entry == 1) for entry in s_left])() - t_right = TensorType(dtype, [(entry == 1) for entry in s_right])() + t_left = TensorType.subtype(dtype, [(entry == 1) for entry in s_left])() + t_right = TensorType.subtype(dtype, [(entry == 1) for entry in s_right])() t_left_val = np.zeros(s_left, dtype=dtype) t_right_val = np.zeros(s_right, dtype=dtype) self._compile_and_check( @@ -806,7 +814,7 @@ def test_str(self): def test_partial_static_shape_info(self): """Make sure that `Elemwise.infer_shape` can handle changes in the static shape information during rewriting.""" - x = TensorType("floatX", shape=(None, None))() + x = TensorType.subtype("floatX", shape=(None, None))() z = Elemwise(aes.add)(x, x) x_inferred_shape = (aes.constant(1), aes.constant(1)) @@ -829,7 +837,7 @@ def make_node(self, *args): res.inputs, # Return two outputs [ - TensorType(dtype="float64", shape=(None, None))() + TensorType.subtype(dtype="float64", shape=(None, None))() for i in range(2) ], ) diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 22759226bf..6cffa97361 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -337,10 +337,10 @@ def test_perform(self, axis, n): @pytest.mark.parametrize( "x_type", ( - at.TensorType("float64", (None, None)), - at.TensorType("float64", (None, 30)), - at.TensorType("float64", (10, None)), - at.TensorType("float64", (10, 30)), + at.TensorType.subtype("float64", (None, None)), + at.TensorType.subtype("float64", (None, 30)), + at.TensorType.subtype("float64", (10, None)), + at.TensorType.subtype("float64", (10, 30)), ), ) @pytest.mark.parametrize("axis", (-2, -1, 0, 1)) @@ -375,7 +375,7 @@ def setup_method(self): ) def test_op(self, shape, broadcast): data = np.random.random(size=shape).astype(config.floatX) - variable = TensorType(config.floatX, broadcast)() + variable = TensorType.subtype(config.floatX, broadcast)() f = aesara.function([variable], self.op(variable)) @@ -398,7 +398,7 @@ def test_op(self, shape, broadcast): ) def test_infer_shape(self, shape, broadcast): data = np.random.random(size=shape).astype(config.floatX) - variable = TensorType(config.floatX, broadcast)() + variable = TensorType.subtype(config.floatX, broadcast)() self._compile_and_check( [variable], [self.op(variable)], [data], DimShuffle, warn=False @@ -433,7 +433,7 @@ def test_grad(self, shape, broadcast): def test_var_interface(self, shape, broadcast): # same as test_op, but use a_aesara_var.squeeze. data = np.random.random(size=shape).astype(config.floatX) - variable = TensorType(config.floatX, broadcast)() + variable = TensorType.subtype(config.floatX, broadcast)() f = aesara.function([variable], variable.squeeze()) @@ -444,29 +444,29 @@ def test_var_interface(self, shape, broadcast): assert np.allclose(tested, expected) def test_axis(self): - variable = TensorType(config.floatX, [False, True, False])() + variable = TensorType.subtype(config.floatX, [False, True, False])() res = squeeze(variable, axis=1) assert res.broadcastable == (False, False) - variable = TensorType(config.floatX, [False, True, False])() + variable = TensorType.subtype(config.floatX, [False, True, False])() res = squeeze(variable, axis=(1,)) assert res.broadcastable == (False, False) - variable = TensorType(config.floatX, [False, True, False, True])() + variable = TensorType.subtype(config.floatX, [False, True, False, True])() res = squeeze(variable, axis=(1, 3)) assert res.broadcastable == (False, False) - variable = TensorType(config.floatX, [True, False, True, False, True])() + variable = TensorType.subtype(config.floatX, [True, False, True, False, True])() res = squeeze(variable, axis=(0, -1)) assert res.broadcastable == (False, True, False) def test_invalid_axis(self): # Test that trying to squeeze a non broadcastable dimension raises error - variable = TensorType(config.floatX, [True, False])() + variable = TensorType.subtype(config.floatX, [True, False])() with pytest.raises( ValueError, match="Cannot drop a non-broadcastable dimension" ): @@ -540,7 +540,7 @@ def setup_method(self): def test_basic(self, ndim, dtype): rng = np.random.default_rng(4282) - x = TensorType(config.floatX, [False] * ndim)() + x = TensorType.subtype(config.floatX, [False] * ndim)() a = rng.random((10,) * ndim).astype(config.floatX) for axis in self._possible_axis(ndim): @@ -579,7 +579,7 @@ def test_basic(self, ndim, dtype): ) # check when r is aesara tensortype that broadcastable is (True,) - r_var = TensorType(shape=(True,), dtype=dtype)() + r_var = TensorType.subtype(shape=(True,), dtype=dtype)() r = rng.integers(1, 6, size=(1,)).astype(dtype) f = aesara.function([x, r_var], repeat(x, r_var, axis=axis)) assert np.allclose(np.repeat(a, r[0], axis=axis), f(a, r)) @@ -593,7 +593,7 @@ def test_basic(self, ndim, dtype): def test_infer_shape(self, ndim, dtype): rng = np.random.default_rng(4282) - x = TensorType(config.floatX, [False] * ndim)() + x = TensorType.subtype(config.floatX, [False] * ndim)() shp = (np.arange(ndim) + 1) * 3 a = rng.random(shp).astype(config.floatX) @@ -635,7 +635,7 @@ def test_grad(self, ndim): utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a]) def test_broadcastable(self): - x = TensorType(config.floatX, [False, True, False])() + x = TensorType.subtype(config.floatX, [False, True, False])() r = Repeat(axis=1)(x, 2) assert r.broadcastable == (False, False, False) r = Repeat(axis=1)(x, 1) diff --git a/tests/tensor/test_io.py b/tests/tensor/test_io.py index addc0a54bf..38a75ba0a2 100644 --- a/tests/tensor/test_io.py +++ b/tests/tensor/test_io.py @@ -17,7 +17,7 @@ def setup_method(self): np.save(self.filename, self.data) def test_basic(self): - path = Variable(Generic(), None) + path = Variable(Generic.subtype(), None) # Not specifying mmap_mode defaults to None, and the data is # copied into main memory x = load(path, "int32", (False,)) @@ -29,13 +29,13 @@ def test_invalid_modes(self): # Modes 'r+', 'r', and 'w+' cannot work with Aesara, becausei # the output array may be modified inplace, and that should not # modify the original file. - path = Variable(Generic(), None) + path = Variable(Generic.subtype(), None) for mmap_mode in ("r+", "r", "w+", "toto"): with pytest.raises(ValueError): load(path, "int32", (False,), mmap_mode) def test1(self): - path = Variable(Generic(), None) + path = Variable(Generic.subtype(), None) # 'c' means "copy-on-write", which allow the array to be overwritten # by an inplace Op in the graph, without modifying the underlying # file. @@ -48,7 +48,7 @@ def test1(self): assert (fn(self.filename) == (self.data**2).sum()).all() def test_memmap(self): - path = Variable(Generic(), None) + path = Variable(Generic.subtype(), None) x = load(path, "int32", (False,), mmap_mode="c") fn = function([path], x) assert type(fn(self.filename)) == np.core.memmap diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 71114a03dd..8e772387b3 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -562,7 +562,7 @@ def test_maximum_minimum_grad(): def test_py_c_match(): - a = TensorType(dtype="int8", shape=(False,))() + a = TensorType.subtype(dtype="int8", shape=(False,))() f = function([a], arccos(a), mode="DebugMode") # This can fail in DebugMode f(np.asarray([1, 0, -1], dtype="int8")) @@ -1934,7 +1934,7 @@ def is_super_shape(var1, var2): (False, True), (False, False), ): - x = TensorType(dtype=dtype0, shape=bc0)() + x = TensorType.subtype(dtype=dtype0, shape=bc0)() for bc1 in ( (True,), (False,), @@ -1944,7 +1944,7 @@ def is_super_shape(var1, var2): (False, False), ): - y = TensorType(dtype=dtype1, shape=bc1)() + y = TensorType.subtype(dtype=dtype1, shape=bc1)() z = dense_dot(x, y) if dtype0.startswith("float") and dtype1.startswith("float"): @@ -2117,7 +2117,7 @@ def test_scalar0(self): def test_broadcastable1(self): rng = np.random.default_rng(seed=utt.fetch_seed()) - x = TensorType(dtype=config.floatX, shape=(True, False, False))("x") + x = TensorType.subtype(dtype=config.floatX, shape=(True, False, False))("x") y = tensor3("y") z = tensordot(x, y) assert z.broadcastable == (True, False) @@ -2129,7 +2129,7 @@ def test_broadcastable1(self): def test_broadcastable2(self): rng = np.random.default_rng(seed=utt.fetch_seed()) - x = TensorType(dtype=config.floatX, shape=(True, False, False))("x") + x = TensorType.subtype(dtype=config.floatX, shape=(True, False, False))("x") y = tensor3("y") axes = [[2, 1], [0, 1]] z = tensordot(x, y, axes=axes) @@ -2156,7 +2156,7 @@ def test_smallest(): def test_var(): - a = TensorType(dtype="float64", shape=[False, False, False])() + a = TensorType.subtype(dtype="float64", shape=[False, False, False])() f = function([a], var(a)) a_val = np.arange(6).reshape(1, 2, 3) @@ -2206,7 +2206,7 @@ def test_var(): class TestSum: def test_sum_overflow(self): # Ensure that overflow errors are a little bit harder to get - a = TensorType(dtype="int8", shape=[False])() + a = TensorType.subtype(dtype="int8", shape=[False])() f = function([a], at_sum(a)) assert f([1] * 300) == 300 @@ -2269,7 +2269,7 @@ def numpy_array(dtype): return np.array([1], dtype=dtype) def aesara_i_scalar(dtype): - return aes.ScalarType(str(dtype))() + return aes.ScalarType.subtype(str(dtype))() def numpy_i_scalar(dtype): return numpy_scalar(dtype) @@ -3262,7 +3262,7 @@ def test_grad_useless_sum(): mode = get_default_mode().including("canonicalize") mode.check_isfinite = False - x = TensorType(config.floatX, (True,))("x") + x = TensorType.subtype(config.floatX, (True,))("x") l = log(1.0 - sigmoid(x))[0] g = grad(l, x) diff --git a/tests/tensor/test_merge.py b/tests/tensor/test_merge.py index 0879d30e06..dd45bf3e0b 100644 --- a/tests/tensor/test_merge.py +++ b/tests/tensor/test_merge.py @@ -35,7 +35,7 @@ def make_node(self, *inputs): for input in inputs: if not isinstance(input.type, MyType): raise Exception("Error 1") - outputs = [MyType()()] + outputs = [MyType.subtype()()] return Apply(self, inputs, outputs) def perform(self, *args, **kwargs): diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 1db5d510ec..15bcb95c17 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -67,7 +67,7 @@ def filter(self, *args, **kwargs): def __eq__(self, other): return isinstance(other, MyType) and other.thingy == self.thingy - s = shape(Variable(MyType(), None)) + s = shape(Variable(MyType.subtype(), None)) assert s.type.broadcastable == (False,) s = shape(np.array(1)) @@ -398,11 +398,11 @@ def test_fixed_shapes(self): assert y.shape.equals(shape) def test_fixed_partial_shapes(self): - x = TensorType("floatX", (None, None))("x") + x = TensorType.subtype("floatX", (None, None))("x") y = specify_shape(x, (None, 5)) assert y.type.shape == (None, 5) - x = TensorType("floatX", (3, None))("x") + x = TensorType.subtype("floatX", (3, None))("x") y = specify_shape(x, (None, 5)) assert y.type.shape == (3, 5) @@ -479,7 +479,7 @@ def test_bad_shape(self): def test_infer_shape(self): rng = np.random.default_rng(3453) adtens4 = dtensor4() - aivec = TensorVariable(TensorType("int64", (4,)), None) + aivec = TensorVariable(TensorType.subtype("int64", (4,)), None) aivec_val = [3, 4, 2, 5] adtens4_val = rng.random(aivec_val) self._compile_and_check( @@ -505,7 +505,7 @@ def test_infer_shape_partial(self): def test_direct_return(self): """Test that when specified shape does not provide new information, input is returned directly.""" - x = TensorType("float64", shape=(1, 2, None))("x") + x = TensorType.subtype("float64", shape=(1, 2, None))("x") assert specify_shape(x, (1, 2, None)) is x assert specify_shape(x, (None, None, None)) is x diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 1af4e50ccd..98b3984310 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -960,7 +960,7 @@ def test_adv_sub1_idx_broadcast(self): # The idx can be a broadcastable vector. ones = np.ones((4, 3), dtype=self.dtype) n = self.shared(ones * 5) - idx = TensorType(dtype="int64", shape=(True,))() + idx = TensorType.subtype(dtype="int64", shape=(True,))() assert idx.type.broadcastable == (True,) t = n[idx] @@ -1166,7 +1166,7 @@ def test_advanced1_inc_and_set(self): # Symbolic variable to be incremented. # We create a new one every time in order not to # have duplicated variables in the function's inputs - data_var = TensorType( + data_var = TensorType.subtype( shape=[False] * data_n_dims, dtype=self.dtype )() # Symbolic variable with rows to be incremented. @@ -1189,7 +1189,7 @@ def test_advanced1_inc_and_set(self): ) idx_num = idx_num.astype("int64") # Symbolic variable with increment value. - inc_var = TensorType( + inc_var = TensorType.subtype( shape=[False] * inc_n_dims, dtype=self.dtype )() # Trick for the case where `inc_shape` is the same as @@ -1879,7 +1879,7 @@ def test_inc_adv_subtensor_w_2vec(self, ignore_duplicates): subt = self.m[self.ix1, self.ix12] a = inc_subtensor(subt, subt, ignore_duplicates=ignore_duplicates) - typ = TensorType(self.m.type.dtype, self.ix2.type.broadcastable) + typ = TensorType.subtype(self.m.type.dtype, self.ix2.type.broadcastable) assert a.type == typ f = aesara.function( diff --git a/tests/tensor/test_type.py b/tests/tensor/test_type.py index eed93e2527..0c82337370 100644 --- a/tests/tensor/test_type.py +++ b/tests/tensor/test_type.py @@ -20,35 +20,35 @@ ], ) def test_numpy_dtype(dtype, exp_dtype): - test_type = TensorType(dtype, []) + test_type = TensorType.subtype(dtype, []) assert test_type.dtype == exp_dtype def test_in_same_class(): - test_type = TensorType(config.floatX, [False, False]) - test_type2 = TensorType(config.floatX, [False, True]) + test_type = TensorType.subtype(config.floatX, [False, False]) + test_type2 = TensorType.subtype(config.floatX, [False, True]) assert test_type.in_same_class(test_type) assert not test_type.in_same_class(test_type2) def test_is_super(): - test_type = TensorType(config.floatX, [False, False]) - test_type2 = TensorType(config.floatX, [False, True]) + test_type = TensorType.subtype(config.floatX, [False, False]) + test_type2 = TensorType.subtype(config.floatX, [False, True]) assert test_type.is_super(test_type) assert test_type.is_super(test_type2) assert not test_type2.is_super(test_type) - test_type3 = TensorType(config.floatX, [False, False, False]) + test_type3 = TensorType.subtype(config.floatX, [False, False, False]) assert not test_type3.is_super(test_type) def test_convert_variable(): - test_type = TensorType(config.floatX, [False, False]) + test_type = TensorType.subtype(config.floatX, [False, False]) test_var = test_type() - test_type2 = TensorType(config.floatX, [True, False]) + test_type2 = TensorType.subtype(config.floatX, [True, False]) test_var2 = test_type2() res = test_type.convert_variable(test_var) @@ -60,7 +60,7 @@ def test_convert_variable(): res = test_type2.convert_variable(test_var) assert res.type == test_type2 - test_type3 = TensorType(config.floatX, [True, False, True]) + test_type3 = TensorType.subtype(config.floatX, [True, False, True]) test_var3 = test_type3() res = test_type2.convert_variable(test_var3) @@ -72,9 +72,9 @@ def test_convert_variable(): def test_convert_variable_mixed_specificity(): - type1 = TensorType(config.floatX, shape=(1, None, 3)) - type2 = TensorType(config.floatX, shape=(None, 5, 3)) - type3 = TensorType(config.floatX, shape=(1, 5, 3)) + type1 = TensorType.subtype(config.floatX, shape=(1, None, 3)) + type2 = TensorType.subtype(config.floatX, shape=(None, 5, 3)) + type3 = TensorType.subtype(config.floatX, shape=(1, 5, 3)) test_var1 = type1() test_var2 = type2() @@ -84,12 +84,12 @@ def test_convert_variable_mixed_specificity(): def test_filter_variable(): - test_type = TensorType(config.floatX, []) + test_type = TensorType.subtype(config.floatX, []) with pytest.raises(TypeError): test_type.filter(test_type()) - test_type = TensorType(config.floatX, [True, False]) + test_type = TensorType.subtype(config.floatX, [True, False]) with pytest.raises(TypeError): test_type.filter(np.empty((0, 1), dtype=config.floatX)) @@ -103,7 +103,7 @@ def test_filter_variable(): test_type.filter_checks_isfinite = True test_type.filter(np.full((1, 2), np.inf, dtype=config.floatX)) - test_type2 = TensorType(config.floatX, [False, False]) + test_type2 = TensorType.subtype(config.floatX, [False, False]) test_var = test_type() test_var2 = test_type2() @@ -114,13 +114,13 @@ def test_filter_variable(): res = test_type.filter_variable(test_var2, allow_convert=True) assert res.type == test_type - test_type3 = TensorType(config.floatX, shape=(1, 20)) + test_type3 = TensorType.subtype(config.floatX, shape=(1, 20)) res = test_type3.filter_variable(test_var, allow_convert=True) assert res.type == test_type3 def test_filter_strict(): - test_type = TensorType(config.floatX, []) + test_type = TensorType.subtype(config.floatX, []) with pytest.raises(TypeError): test_type.filter(1, strict=True) @@ -131,7 +131,7 @@ def test_filter_strict(): def test_filter_ndarray_subclass(): """Make sure `TensorType.filter` can handle NumPy `ndarray` subclasses.""" - test_type = TensorType(config.floatX, [False]) + test_type = TensorType.subtype(config.floatX, [False]) class MyNdarray(np.ndarray): pass @@ -147,7 +147,7 @@ class MyNdarray(np.ndarray): def test_filter_float_subclass(): """Make sure `TensorType.filter` can handle `float` subclasses.""" with config.change_flags(floatX="float64"): - test_type = TensorType("float64", shape=[]) + test_type = TensorType.subtype("float64", shape=[]) nan = np.array([np.nan], dtype="float64")[0] assert isinstance(nan, float) and not isinstance(nan, np.ndarray) @@ -157,7 +157,7 @@ def test_filter_float_subclass(): with config.change_flags(floatX="float32"): # Try again, except this time `nan` isn't a `float` - test_type = TensorType("float32", shape=[]) + test_type = TensorType.subtype("float32", shape=[]) nan = np.array([np.nan], dtype="float32")[0] assert isinstance(nan, np.floating) and not isinstance(nan, np.ndarray) @@ -173,7 +173,7 @@ def test_filter_memmap(): filename = path.join(mkdtemp(), "newfile.dat") fp = np.memmap(filename, dtype=config.floatX, mode="w+", shape=(3, 4)) - test_type = TensorType(config.floatX, [False, False]) + test_type = TensorType.subtype(config.floatX, [False, False]) res = test_type.filter(fp) assert res is fp @@ -219,26 +219,26 @@ def test_tensor_values_eq_approx(): def test_fixed_shape_basic(): - t1 = TensorType("float64", (1, 1)) + t1 = TensorType.subtype("float64", (1, 1)) assert t1.shape == (1, 1) assert t1.broadcastable == (True, True) - t1 = TensorType("float64", (0,)) + t1 = TensorType.subtype("float64", (0,)) assert t1.shape == (0,) assert t1.broadcastable == (False,) - t1 = TensorType("float64", (False, False)) + t1 = TensorType.subtype("float64", (False, False)) assert t1.shape == (None, None) assert t1.broadcastable == (False, False) - t1 = TensorType("float64", (2, 3)) + t1 = TensorType.subtype("float64", (2, 3)) assert t1.shape == (2, 3) assert t1.broadcastable == (False, False) assert t1.value_zeros(t1.shape).shape == t1.shape assert str(t1) == "TensorType(float64, (2, 3))" - t1 = TensorType("float64", (1,)) + t1 = TensorType.subtype("float64", (1,)) assert t1.shape == (1,) assert t1.broadcastable == (True,) @@ -252,7 +252,7 @@ def test_fixed_shape_basic(): def test_fixed_shape_clone(): - t1 = TensorType("float64", (1,)) + t1 = TensorType.subtype("float64", (1,)) t2 = t1.clone(dtype="float32", shape=(2, 4)) assert t2.shape == (2, 4) @@ -262,8 +262,8 @@ def test_fixed_shape_clone(): def test_fixed_shape_comparisons(): - t1 = TensorType("float64", (True, True)) - t2 = TensorType("float64", (1, 1)) + t1 = TensorType.subtype("float64", (True, True)) + t2 = TensorType.subtype("float64", (1, 1)) assert t1 == t2 assert t1.is_super(t2) @@ -271,19 +271,19 @@ def test_fixed_shape_comparisons(): assert hash(t1) == hash(t2) - t3 = TensorType("float64", (True, False)) - t4 = TensorType("float64", (1, 2)) + t3 = TensorType.subtype("float64", (True, False)) + t4 = TensorType.subtype("float64", (1, 2)) assert t3 != t4 - t1 = TensorType("float64", (True, True)) - t2 = TensorType("float64", ()) + t1 = TensorType.subtype("float64", (True, True)) + t2 = TensorType.subtype("float64", ()) assert t1 != t2 def test_fixed_shape_convert_variable(): # These are equivalent types - t1 = TensorType("float64", (True, True)) - t2 = TensorType("float64", (1, 1)) + t1 = TensorType.subtype("float64", (True, True)) + t2 = TensorType.subtype("float64", (1, 1)) assert t1 == t2 assert t1.shape == t2.shape @@ -299,13 +299,13 @@ def test_fixed_shape_convert_variable(): res = t2.convert_variable(t1_var) assert res is t1_var - t3 = TensorType("float64", (False, True)) + t3 = TensorType.subtype("float64", (False, True)) t3_var = t3() res = t2.convert_variable(t3_var) assert isinstance(res.owner.op, SpecifyShape) - t3 = TensorType("float64", (False, False)) - t4 = TensorType("float64", (3, 2)) + t3 = TensorType.subtype("float64", (False, False)) + t4 = TensorType.subtype("float64", (3, 2)) t4_var = t4() assert t3.shape == (None, None) res = t3.convert_variable(t4_var) @@ -315,7 +315,7 @@ def test_fixed_shape_convert_variable(): def test_deprecated_kwargs(): with pytest.warns(DeprecationWarning, match=".*broadcastable.*"): - res = TensorType("float64", broadcastable=(True, False)) + res = TensorType.subtype("float64", broadcastable=(True, False)) assert res.shape == (1, None) diff --git a/tests/tensor/test_type_other.py b/tests/tensor/test_type_other.py index d4d84b4e79..1267c1c406 100644 --- a/tests/tensor/test_type_other.py +++ b/tests/tensor/test_type_other.py @@ -16,7 +16,7 @@ def test_SliceType(): - st = SliceType() + st = SliceType.subtype() assert st == st.clone() @@ -36,8 +36,8 @@ def test_none_Constant(): # Tests equals # We had an error in the past with unpickling - o1 = Constant(NoneTypeT(), None, name="NoneConst") - o2 = Constant(NoneTypeT(), None, name="NoneConst") + o1 = Constant(NoneTypeT.subtype(), None, name="NoneConst") + o2 = Constant(NoneTypeT.subtype(), None, name="NoneConst") assert o1.equals(o2) assert NoneConst.equals(o1) assert o1.equals(NoneConst) diff --git a/tests/tensor/test_var.py b/tests/tensor/test_var.py index 05127cb3d2..a268e6189c 100644 --- a/tests/tensor/test_var.py +++ b/tests/tensor/test_var.py @@ -155,18 +155,18 @@ def test__getitem__Subtensor(): def test__getitem__AdvancedSubtensor_bool(): x = matrix("x") - i = TensorType("bool", (False, False))("i") + i = TensorType.subtype("bool", (False, False))("i") z = x[i] op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])] assert op_types[-1] == AdvancedSubtensor - i = TensorType("bool", (False,))("i") + i = TensorType.subtype("bool", (False,))("i") z = x[:, i] op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])] assert op_types[-1] == AdvancedSubtensor - i = TensorType("bool", (False,))("i") + i = TensorType.subtype("bool", (False,))("i") z = x[..., i] op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])] assert op_types[-1] == AdvancedSubtensor @@ -244,23 +244,27 @@ def test__getitem__newaxis(x, indices, new_order): def test_fixed_shape_variable_basic(): - x = TensorVariable(TensorType("int64", (4,)), None) + x = TensorVariable(TensorType.subtype("int64", (4,)), None) assert isinstance(x.shape, Constant) assert np.array_equal(x.shape.data, (4,)) - x = TensorConstant(TensorType("int64", (False, False)), np.array([[1, 2], [2, 3]])) + x = TensorConstant( + TensorType.subtype("int64", (False, False)), np.array([[1, 2], [2, 3]]) + ) assert x.type.shape == (2, 2) with pytest.raises(ValueError): - TensorConstant(TensorType("int64", (True, False)), np.array([[1, 2], [2, 3]])) + TensorConstant( + TensorType.subtype("int64", (True, False)), np.array([[1, 2], [2, 3]]) + ) def test_get_vector_length(): - x = TensorVariable(TensorType("int64", (4,)), None) + x = TensorVariable(TensorType.subtype("int64", (4,)), None) res = get_vector_length(x) assert res == 4 - x = TensorVariable(TensorType("int64", (None,)), None) + x = TensorVariable(TensorType.subtype("int64", (None,)), None) with pytest.raises(ValueError): get_vector_length(x) diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index c59c99d99d..251504eab6 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -447,7 +447,7 @@ def test_good(self): for testname, inputs in good.items(): inputs = [copy(input) for input in inputs] inputrs = [ - TensorType( + TensorType.subtype( dtype=input.dtype, shape=[shape_elem == 1 for shape_elem in input.shape], )() @@ -609,7 +609,7 @@ def test_grad_none(self): for testname, inputs in self.good.items(): inputs = [copy(input) for input in inputs] inputrs = [ - TensorType( + TensorType.subtype( dtype=input.dtype, shape=[shape_elem == 1 for shape_elem in input.shape], )() @@ -633,7 +633,7 @@ def test_grad_none(self): else: dtype = str(out.dtype) bcast = [shape_elem == 1 for shape_elem in out.shape] - var = TensorType(dtype=dtype, shape=bcast)() + var = TensorType.subtype(dtype=dtype, shape=bcast)() out_grad_vars.append(var) try: diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 50dcf8170a..ef6079f069 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -541,7 +541,7 @@ def make_node(self, f, g): return Apply(self, inputs=[f, g], outputs=[scalar()]) def grad(self, inputs, output_grads): - return [inputs[0].zeros_like(), NullType()()] + return [inputs[0].zeros_like(), NullType.subtype()()] def perform(self, *args, **kwargs): raise NotImplementedError() @@ -696,7 +696,7 @@ def test_undefined_cost_grad(): cost = x + y assert cost.dtype in discrete_dtypes with pytest.raises(NullTypeGradError): - grad(cost, [x, y], known_grads={cost: NullType()()}) + grad(cost, [x, y], known_grads={cost: NullType.subtype()()}) def test_disconnected_cost_grad(): @@ -714,7 +714,7 @@ def test_disconnected_cost_grad(): grad( cost, [x, y], - known_grads={cost: DisconnectedType()()}, + known_grads={cost: DisconnectedType.subtype()()}, disconnected_inputs="raise", ) except DisconnectedInputError: diff --git a/tests/test_printing.py b/tests/test_printing.py index ac64024152..bc3a9b6aa2 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -331,7 +331,7 @@ def test_debugprint_inner_graph(): """ for exp_line, res_line in zip(exp_res.split("\n"), lines): - assert exp_line.strip() == res_line.strip() + assert res_line.strip() == exp_line.strip() # Test nested inner-graph `Op`s igo_2 = MyInnerGraphOp([r3, r4], [out]) diff --git a/tests/test_rop.py b/tests/test_rop.py index d1f85307f6..f63583931d 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -309,18 +309,18 @@ def test_conv(self): filter_shape = (2, 2, 2, 3) image_dim = len(image_shape) filter_dim = len(filter_shape) - input = TensorType(aesara.config.floatX, [False] * image_dim)( + input = TensorType.subtype(aesara.config.floatX, [False] * image_dim)( name="input" ) - filters = TensorType(aesara.config.floatX, [False] * filter_dim)( - name="filter" - ) - ev_input = TensorType(aesara.config.floatX, [False] * image_dim)( - name="ev_input" - ) - ev_filters = TensorType(aesara.config.floatX, [False] * filter_dim)( - name="ev_filters" - ) + filters = TensorType.subtype( + aesara.config.floatX, [False] * filter_dim + )(name="filter") + ev_input = TensorType.subtype( + aesara.config.floatX, [False] * image_dim + )(name="ev_input") + ev_filters = TensorType.subtype( + aesara.config.floatX, [False] * filter_dim + )(name="ev_filters") def sym_conv2d(input, filters): return conv_op(input, filters, border_mode=border_mode) diff --git a/tests/typed_list/test_basic.py b/tests/typed_list/test_basic.py index 554f120843..2234335098 100644 --- a/tests/typed_list/test_basic.py +++ b/tests/typed_list/test_basic.py @@ -56,11 +56,11 @@ def random_lil(shape, dtype, nnz): class TestGetItem: def test_sanity_check_slice(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() - mySymbolicSlice = SliceType()() + mySymbolicSlice = SliceType.subtype()() z = GetItem()(mySymbolicMatricesList, mySymbolicSlice) @@ -74,8 +74,8 @@ def test_sanity_check_slice(self): def test_sanity_check_single(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() mySymbolicScalar = scalar(dtype="int64") @@ -89,8 +89,8 @@ def test_sanity_check_single(self): assert np.array_equal(f([x], np.asarray(0, dtype="int64")), x) def test_interface(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() mySymbolicScalar = scalar(dtype="int64") @@ -109,8 +109,8 @@ def test_interface(self): assert np.array_equal(f([x]), x) def test_wrong_input(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() mySymbolicMatrix = matrix() @@ -118,8 +118,8 @@ def test_wrong_input(self): GetItem()(mySymbolicMatricesList, mySymbolicMatrix) def test_constant_input(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = GetItem()(mySymbolicMatricesList, 0) @@ -139,8 +139,8 @@ def test_constant_input(self): class TestAppend: def test_inplace(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() @@ -155,8 +155,8 @@ def test_inplace(self): assert np.array_equal(f([x], y), [x, y]) def test_sanity_check(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() @@ -171,8 +171,8 @@ def test_sanity_check(self): assert np.array_equal(f([x], y), [x, y]) def test_interfaces(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() @@ -189,11 +189,11 @@ def test_interfaces(self): class TestExtend: def test_inplace(self): - mySymbolicMatricesList1 = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList1 = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() - mySymbolicMatricesList2 = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList2 = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = Extend(True)(mySymbolicMatricesList1, mySymbolicMatricesList2) @@ -209,11 +209,11 @@ def test_inplace(self): assert np.array_equal(f([x], [y]), [x, y]) def test_sanity_check(self): - mySymbolicMatricesList1 = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList1 = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() - mySymbolicMatricesList2 = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList2 = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = Extend()(mySymbolicMatricesList1, mySymbolicMatricesList2) @@ -227,11 +227,11 @@ def test_sanity_check(self): assert np.array_equal(f([x], [y]), [x, y]) def test_interface(self): - mySymbolicMatricesList1 = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList1 = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() - mySymbolicMatricesList2 = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList2 = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = mySymbolicMatricesList1.extend(mySymbolicMatricesList2) @@ -247,8 +247,8 @@ def test_interface(self): class TestInsert: def test_inplace(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() myScalar = scalar(dtype="int64") @@ -266,8 +266,8 @@ def test_inplace(self): assert np.array_equal(f([x], np.asarray(1, dtype="int64"), y), [x, y]) def test_sanity_check(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() myScalar = scalar(dtype="int64") @@ -283,8 +283,8 @@ def test_sanity_check(self): assert np.array_equal(f([x], np.asarray(1, dtype="int64"), y), [x, y]) def test_interface(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() myScalar = scalar(dtype="int64") @@ -302,8 +302,8 @@ def test_interface(self): class TestRemove: def test_inplace(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() @@ -318,8 +318,8 @@ def test_inplace(self): assert np.array_equal(f([x, y], y), [x]) def test_sanity_check(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() @@ -334,8 +334,8 @@ def test_sanity_check(self): assert np.array_equal(f([x, y], y), [x]) def test_interface(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() @@ -352,8 +352,8 @@ def test_interface(self): class TestReverse: def test_inplace(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = Reverse(True)(mySymbolicMatricesList) @@ -367,8 +367,8 @@ def test_inplace(self): assert np.array_equal(f([x, y]), [y, x]) def test_sanity_check(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = Reverse()(mySymbolicMatricesList) @@ -382,8 +382,8 @@ def test_sanity_check(self): assert np.array_equal(f([x, y]), [y, x]) def test_interface(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = mySymbolicMatricesList.reverse() @@ -399,8 +399,8 @@ def test_interface(self): class TestIndex: def test_sanity_check(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() @@ -415,8 +415,8 @@ def test_sanity_check(self): assert f([x, y], y) == 1 def test_interface(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() @@ -431,11 +431,11 @@ def test_interface(self): assert f([x, y], y) == 1 def test_non_tensor_type(self): - mySymbolicNestedMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)), 1 + mySymbolicNestedMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)), 1 )() - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = Index()(mySymbolicNestedMatricesList, mySymbolicMatricesList) @@ -450,8 +450,8 @@ def test_non_tensor_type(self): def test_sparse(self): sp = pytest.importorskip("scipy") - mySymbolicSparseList = TypedListType( - sparse.SparseTensorType("csr", aesara.config.floatX) + mySymbolicSparseList = TypedListType.subtype( + sparse.SparseTensorType.subtype("csr", aesara.config.floatX) )() mySymbolicSparse = sparse.csr_matrix() @@ -467,8 +467,8 @@ def test_sparse(self): class TestCount: def test_sanity_check(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() @@ -483,8 +483,8 @@ def test_sanity_check(self): assert f([y, y, x, y], y) == 3 def test_interface(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() myMatrix = matrix() @@ -499,11 +499,11 @@ def test_interface(self): assert f([x, y], y) == 1 def test_non_tensor_type(self): - mySymbolicNestedMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)), 1 + mySymbolicNestedMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)), 1 )() - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = Count()(mySymbolicNestedMatricesList, mySymbolicMatricesList) @@ -518,8 +518,8 @@ def test_non_tensor_type(self): def test_sparse(self): sp = pytest.importorskip("scipy") - mySymbolicSparseList = TypedListType( - sparse.SparseTensorType("csr", aesara.config.floatX) + mySymbolicSparseList = TypedListType.subtype( + sparse.SparseTensorType.subtype("csr", aesara.config.floatX) )() mySymbolicSparse = sparse.csr_matrix() @@ -535,8 +535,8 @@ def test_sparse(self): class TestLength: def test_sanity_check(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = Length()(mySymbolicMatricesList) @@ -548,8 +548,8 @@ def test_sanity_check(self): assert f([x, x, x, x]) == 4 def test_interface(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = mySymbolicMatricesList.__len__() diff --git a/tests/typed_list/test_rewriting.py b/tests/typed_list/test_rewriting.py index 167424cfb8..e0457091d2 100644 --- a/tests/typed_list/test_rewriting.py +++ b/tests/typed_list/test_rewriting.py @@ -12,8 +12,8 @@ class TestInplace: def test_reverse_inplace(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = Reverse()(mySymbolicMatricesList) @@ -35,8 +35,8 @@ def test_reverse_inplace(self): assert np.array_equal(f([x, y]), [y, x]) def test_append_inplace(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() mySymbolicMatrix = matrix() z = Append()(mySymbolicMatricesList, mySymbolicMatrix) @@ -61,12 +61,12 @@ def test_append_inplace(self): assert np.array_equal(f([x], y), [x, y]) def test_extend_inplace(self): - mySymbolicMatricesList1 = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList1 = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() - mySymbolicMatricesList2 = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList2 = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() z = Extend()(mySymbolicMatricesList1, mySymbolicMatricesList2) @@ -90,8 +90,8 @@ def test_extend_inplace(self): assert np.array_equal(f([x], [y]), [x, y]) def test_insert_inplace(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() mySymbolicIndex = scalar(dtype="int64") mySymbolicMatrix = matrix() @@ -120,8 +120,8 @@ def test_insert_inplace(self): assert np.array_equal(f([x], np.asarray(1, dtype="int64"), y), [x, y]) def test_remove_inplace(self): - mySymbolicMatricesList = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicMatricesList = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() mySymbolicMatrix = matrix() z = Remove()(mySymbolicMatricesList, mySymbolicMatrix) diff --git a/tests/typed_list/test_type.py b/tests/typed_list/test_type.py index 4ee1b76e02..47eddaa0be 100644 --- a/tests/typed_list/test_type.py +++ b/tests/typed_list/test_type.py @@ -15,7 +15,7 @@ def test_wrong_input_on_creation(self): # type is not a valid aesara type with pytest.raises(TypeError): - TypedListType(None) + TypedListType.subtype(None) def test_wrong_input_on_filter(self): # Typed list type should raises an @@ -24,7 +24,9 @@ def test_wrong_input_on_filter(self): # specified on creation # list of matrices - myType = TypedListType(TensorType(aesara.config.floatX, (False, False))) + myType = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) + ) with pytest.raises(TypeError): myType.filter([4]) @@ -34,7 +36,9 @@ def test_not_a_list_on_filter(self): # if no iterable variable is given on input # list of matrices - myType = TypedListType(TensorType(aesara.config.floatX, (False, False))) + myType = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) + ) with pytest.raises(TypeError): myType.filter(4) @@ -45,11 +49,15 @@ def test_type_equality(self): # variables # list of matrices - myType1 = TypedListType(TensorType(aesara.config.floatX, (False, False))) + myType1 = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) + ) # list of matrices - myType2 = TypedListType(TensorType(aesara.config.floatX, (False, False))) + myType2 = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) + ) # list of scalars - myType3 = TypedListType(TensorType(aesara.config.floatX, ())) + myType3 = TypedListType.subtype(TensorType.subtype(aesara.config.floatX, ())) assert myType2 == myType1 assert myType3 != myType1 @@ -57,7 +65,9 @@ def test_type_equality(self): def test_filter_sanity_check(self): # Simple test on typed list type filter - myType = TypedListType(TensorType(aesara.config.floatX, (False, False))) + myType = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) + ) x = random_ranged(-1000, 1000, [100, 100]) @@ -68,14 +78,16 @@ def test_intern_filter(self): # filtered. If they weren't this code would raise # an exception. - myType = TypedListType(TensorType("float64", (False, False))) + myType = TypedListType.subtype(TensorType.subtype("float64", (False, False))) x = np.asarray([[4, 5], [4, 5]], dtype="float32") assert np.array_equal(myType.filter([x]), [x]) def test_load_alot(self): - myType = TypedListType(TensorType(aesara.config.floatX, (False, False))) + myType = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) + ) x = random_ranged(-1000, 1000, [10, 10]) testList = [] @@ -87,9 +99,11 @@ def test_load_alot(self): def test_basic_nested_list(self): # Testing nested list with one level of depth - myNestedType = TypedListType(TensorType(aesara.config.floatX, (False, False))) + myNestedType = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) + ) - myType = TypedListType(myNestedType) + myType = TypedListType.subtype(myNestedType) x = random_ranged(-1000, 1000, [100, 100]) @@ -98,51 +112,65 @@ def test_basic_nested_list(self): def test_comparison_different_depth(self): # Nested list with different depth aren't the same - myNestedType = TypedListType(TensorType(aesara.config.floatX, (False, False))) + myNestedType = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) + ) - myNestedType2 = TypedListType(myNestedType) + myNestedType2 = TypedListType.subtype(myNestedType) - myNestedType3 = TypedListType(myNestedType2) + myNestedType3 = TypedListType.subtype(myNestedType2) assert myNestedType2 != myNestedType3 def test_nested_list_arg(self): # test for the 'depth' optional argument - myNestedType = TypedListType( - TensorType(aesara.config.floatX, (False, False)), 3 + myNestedType = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)), 3 ) - myType = TypedListType(TensorType(aesara.config.floatX, (False, False))) + myType = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) + ) - myManualNestedType = TypedListType(TypedListType(TypedListType(myType))) + myManualNestedType = TypedListType.subtype( + TypedListType.subtype(TypedListType.subtype(myType)) + ) assert myNestedType == myManualNestedType def test_get_depth(self): # test case for get_depth utilitary function - myType = TypedListType(TensorType(aesara.config.floatX, (False, False))) + myType = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) + ) - myManualNestedType = TypedListType(TypedListType(TypedListType(myType))) + myManualNestedType = TypedListType.subtype( + TypedListType.subtype(TypedListType.subtype(myType)) + ) assert myManualNestedType.get_depth() == 3 def test_comparison_uneven_nested(self): # test for comparison between uneven nested list - myType = TypedListType(TensorType(aesara.config.floatX, (False, False))) + myType = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) + ) - myManualNestedType1 = TypedListType(TypedListType(TypedListType(myType))) + myManualNestedType1 = TypedListType.subtype( + TypedListType.subtype(TypedListType.subtype(myType)) + ) - myManualNestedType2 = TypedListType(TypedListType(myType)) + myManualNestedType2 = TypedListType.subtype(TypedListType.subtype(myType)) assert myManualNestedType1 != myManualNestedType2 assert myManualNestedType2 != myManualNestedType1 def test_variable_is_Typed_List_variable(self): - mySymbolicVariable = TypedListType( - TensorType(aesara.config.floatX, (False, False)) + mySymbolicVariable = TypedListType.subtype( + TensorType.subtype(aesara.config.floatX, (False, False)) )() assert isinstance(mySymbolicVariable, TypedListVariable) From b1d40988ee5dbdf9a13b10f9ddc0ba9e9a507d7c Mon Sep 17 00:00:00 2001 From: Markus Schmaus Date: Sun, 25 Sep 2022 14:33:28 +0200 Subject: [PATCH 03/21] Convert `__init__`s to only accept keyword arguments --- aesara/graph/null_type.py | 2 + aesara/graph/type.py | 29 +++++++-- aesara/link/c/params_type.py | 55 +++++++++------- aesara/link/c/type.py | 117 ++++++++++++++++++++++++----------- aesara/scalar/basic.py | 6 +- aesara/sparse/type.py | 16 +++-- aesara/tensor/type.py | 34 ++++++---- aesara/tensor/type_other.py | 2 +- aesara/typed_list/type.py | 25 +++----- tests/graph/test_basic.py | 2 + tests/graph/test_features.py | 2 + tests/graph/test_op.py | 2 + tests/graph/test_types.py | 2 + 13 files changed, 192 insertions(+), 102 deletions(-) diff --git a/aesara/graph/null_type.py b/aesara/graph/null_type.py index eae0c04c14..de572aec4a 100644 --- a/aesara/graph/null_type.py +++ b/aesara/graph/null_type.py @@ -17,6 +17,8 @@ class NullType(Type): """ + __props__ = ("why_null",) + def __init__(self, why_null="(no explanation given)"): self.why_null = why_null diff --git a/aesara/graph/type.py b/aesara/graph/type.py index 91ce307ae6..ca19f09d0a 100644 --- a/aesara/graph/type.py +++ b/aesara/graph/type.py @@ -1,5 +1,6 @@ -from abc import abstractmethod -from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union +import inspect +from abc import ABCMeta, abstractmethod +from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union, final from typing_extensions import Protocol, TypeAlias, runtime_checkable @@ -11,14 +12,27 @@ D = TypeVar("D") -class NewTypeMeta(type): - # pass +class NewTypeMeta(ABCMeta): + __props__: tuple[str, ...] + def __call__(cls, *args, **kwargs): raise RuntimeError("Use subtype") # return super().__call__(*args, **kwargs) def subtype(cls, *args, **kwargs): - return super().__call__(*args, **kwargs) + kwargs = cls.type_parameters(*args, **kwargs) + return super().__call__(**kwargs) + + def type_parameters(cls, *args, **kwargs): + if args: + init_args = tuple(inspect.signature(cls.__init__).parameters.keys())[1:] + if cls.__props__[: len(args)] != init_args[: len(args)]: + raise RuntimeError( + f"{cls.__props__=} doesn't match {init_args=} for {args=}" + ) + + kwargs |= zip(cls.__props__, args) + return kwargs class Type(Generic[D], metaclass=NewTypeMeta): @@ -293,6 +307,11 @@ def _props_dict(self): """ return {a: getattr(self, a) for a in self.__props__} + @final + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + def __hash__(self): return hash((type(self), tuple(getattr(self, a) for a in self.__props__))) diff --git a/aesara/link/c/params_type.py b/aesara/link/c/params_type.py index 08d5937254..2b85873192 100644 --- a/aesara/link/c/params_type.py +++ b/aesara/link/c/params_type.py @@ -343,7 +343,9 @@ class ParamsType(CType): """ - def __init__(self, **kwargs): + @classmethod + def type_parameters(cls, **kwargs): + params = dict() if len(kwargs) == 0: raise ValueError("Cannot create ParamsType from empty data.") @@ -366,14 +368,14 @@ def __init__(self, **kwargs): % (attribute_name, type_name) ) - self.length = len(kwargs) - self.fields = tuple(sorted(kwargs.keys())) - self.types = tuple(kwargs[field] for field in self.fields) - self.name = self.generate_struct_name() + params["length"] = len(kwargs) + params["fields"] = tuple(sorted(kwargs.keys())) + params["types"] = tuple(kwargs[field] for field in params["fields"]) + params["name"] = cls.generate_struct_name(params) - self.__const_to_enum = {} - self.__alias_to_enum = {} - enum_types = [t for t in self.types if isinstance(t, EnumType)] + params["_const_to_enum"] = {} + params["_alias_to_enum"] = {} + enum_types = [t for t in params["types"] if isinstance(t, EnumType)] if enum_types: # We don't want same enum names in different enum types. if sum(len(t) for t in enum_types) != len( @@ -398,25 +400,27 @@ def __init__(self, **kwargs): ) # We map each enum name to the enum type in which it is defined. # We will then use this dict to find enum value when looking for enum name in ParamsType object directly. - self.__const_to_enum = { + params["_const_to_enum"] = { enum_name: enum_type for enum_type in enum_types for enum_name in enum_type } - self.__alias_to_enum = { + params["_alias_to_enum"] = { alias: enum_type for enum_type in enum_types for alias in enum_type.aliases } + return params + def __setstate__(self, state): # NB: # I have overridden __getattr__ to make enum constants available through # the ParamsType when it contains enum types. To do that, I use some internal - # attributes: self.__const_to_enum and self.__alias_to_enum. These attributes + # attributes: self._const_to_enum and self._alias_to_enum. These attributes # are normally found by Python without need to call getattr(), but when the # ParamsType is unpickled, it seems gettatr() may be called at a point before - # __const_to_enum or __alias_to_enum are unpickled, so that gettatr() can't find + # _const_to_enum or _alias_to_enum are unpickled, so that gettatr() can't find # those attributes, and then loop infinitely. # For this reason, I must add this trivial implementation of __setstate__() # to avoid errors when unpickling. @@ -424,9 +428,12 @@ def __setstate__(self, state): def __getattr__(self, key): # Now we can access value of each enum defined inside enum types wrapped into the current ParamsType. - if key in self.__const_to_enum: - return self.__const_to_enum[key][key] - return super().__getattr__(self, key) + # const_to_enum = super().__getattribute__("_const_to_enum") + if not key.startswith("__"): + const_to_enum = self._const_to_enum + if key in const_to_enum: + return const_to_enum[key][key] + raise AttributeError(f"'{self}' object has no attribute '{key}'") def __repr__(self): return "ParamsType<%s>" % ", ".join( @@ -446,13 +453,14 @@ def __eq__(self, other): def __hash__(self): return hash((type(self),) + self.fields + self.types) - def generate_struct_name(self): - # This method tries to generate an unique name for the current instance. + @staticmethod + def generate_struct_name(params): + # This method tries to generate a unique name for the current instance. # This name is intended to be used as struct name in C code and as constant # definition to check if a similar ParamsType has already been created # (see c_support_code() below). - fields_string = ",".join(self.fields).encode("utf-8") - types_string = ",".join(str(t) for t in self.types).encode("utf-8") + fields_string = ",".join(params["fields"]).encode("utf-8") + types_string = ",".join(str(t) for t in params["types"]).encode("utf-8") fields_hex = hashlib.sha256(fields_string).hexdigest() types_hex = hashlib.sha256(types_string).hexdigest() return f"_Params_{fields_hex}_{types_hex}" @@ -510,7 +518,7 @@ def get_enum(self, key): print(wrapper.TWO) """ - return self.__const_to_enum[key][key] + return self._const_to_enum[key][key] def enum_from_alias(self, alias): """ @@ -547,10 +555,11 @@ def enum_from_alias(self, alias): method to do that. """ + alias_to_enum = self._alias_to_enum return ( - self.__alias_to_enum[alias].fromalias(alias) - if alias in self.__alias_to_enum - else self.__const_to_enum[alias][alias] + alias_to_enum[alias].fromalias(alias) + if alias in alias_to_enum + else self._const_to_enum[alias][alias] ) def get_params(self, *objects, **kwargs) -> Params: diff --git a/aesara/link/c/type.py b/aesara/link/c/type.py index 9b29d5355e..faa8e01895 100644 --- a/aesara/link/c/type.py +++ b/aesara/link/c/type.py @@ -1,6 +1,7 @@ import ctypes import platform import re +from collections.abc import Mapping from typing import TypeVar from aesara.graph.basic import Constant @@ -306,7 +307,29 @@ def signature(self): CDataType.constant_type = CDataTypeConstant -class EnumType(CType, dict): +class FrozenMap(dict): + def __setitem__(self, key, value): + raise TypeError("constant values are immutable.") + + def __delitem__(self, key): + raise TypeError("constant values are immutable.") + + def __hash__(self): + return hash(frozenset(self.items())) + + def __eq__(self, other): + return ( + type(self) == type(other) + and len(self) == len(other) + and all(k in other for k in self) + and all(self[k] == other[k] for k in self) + ) + + def __ne__(self, other): + return not self == other + + +class EnumType(Mapping, CType): """ Main subclasses: - :class:`EnumList` @@ -403,63 +426,75 @@ class EnumType(CType, dict): """ - def __init_ctype(self, ctype): + __props__ = ("constants", "aliases", "ctype", "cname") + + @classmethod + def __init_ctype(cls, ctype): # C type may be a list of keywords, e.g. "unsigned long long". # We should check each part. ctype_parts = ctype.split() if not all(re.match("^[A-Za-z_][A-Za-z0-9_]*$", el) for el in ctype_parts): - raise TypeError(f"{type(self).__name__}: invalid C type.") - self.ctype = " ".join(ctype_parts) + raise TypeError(f"{cls.__name__}: invalid C type.") + return " ".join(ctype_parts) - def __init_cname(self, cname): + @classmethod + def __init_cname(cls, cname): if not re.match("^[A-Za-z_][A-Za-z0-9_]*$", cname): - raise TypeError(f"{type(self).__name__}: invalid C name.") - self.cname = cname + raise TypeError(f"{cls.__name__}: invalid C name.") + return cname + + @classmethod + def type_parameters(cls, **kwargs): - def __init__(self, **kwargs): - self.__init_ctype(kwargs.pop("ctype", "double")) - self.__init_cname(kwargs.pop("cname", self.ctype.replace(" ", "_"))) - self.aliases = dict() + ctype = cls.__init_ctype(kwargs.pop("ctype", "double")) + cname = cls.__init_cname(kwargs.pop("cname", ctype.replace(" ", "_"))) + aliases = dict() for k in kwargs: if re.match("^[A-Z][A-Z0-9_]*$", k) is None: raise AttributeError( - f'{type(self).__name__}: invalid enum name: "{k}". ' + f'{cls.__name__}: invalid enum name: "{k}". ' "Only capital letters, underscores and digits " "are allowed." ) if isinstance(kwargs[k], (list, tuple)): if len(kwargs[k]) != 2: raise TypeError( - f"{type(self).__name__}: when using a tuple to define a constant, your tuple should contain 2 values: " + f"{cls.__name__}: when using a tuple to define a constant, your tuple should contain 2 values: " "constant alias followed by constant value." ) alias, value = kwargs[k] if not isinstance(alias, str): raise TypeError( - f'{type(self).__name__}: constant alias should be a string, got "{alias}".' + f'{cls.__name__}: constant alias should be a string, got "{alias}".' ) if alias == k: raise TypeError( - f"{type(self).__name__}: it's useless to create an alias " + f"{cls.__name__}: it's useless to create an alias " "with the same name as its associated constant." ) - if alias in self.aliases: + if alias in aliases: raise TypeError( - f'{type(self).__name__}: consant alias "{alias}" already used.' + f'{cls.__name__}: consant alias "{alias}" already used.' ) - self.aliases[alias] = k + aliases[alias] = k kwargs[k] = value if isinstance(kwargs[k], bool): kwargs[k] = int(kwargs[k]) elif not isinstance(kwargs[k], (int, float)): raise TypeError( - f'{type(self).__name__}: constant "{k}": expected integer or floating value, got "{type(kwargs[k]).__name__}".' + f'{cls.__name__}: constant "{k}": expected integer or floating value, got "{type(kwargs[k]).__name__}".' ) - if [a for a in self.aliases if a in self]: + if [a for a in aliases if a in kwargs]: raise TypeError( - f"{type(self).__name__}: some aliases have same names as constants." + f"{cls.__name__}: some aliases have same names as constants." ) - super().__init__(**kwargs) + + return { + "constants": kwargs, + "aliases": aliases, + "ctype": ctype, + "cname": cname, + } def fromalias(self, alias): """ @@ -495,7 +530,7 @@ def __repr__(self): ) def __getattr__(self, key): - if key in self: + if key in self.constants: return self[key] else: raise AttributeError( @@ -503,15 +538,25 @@ def __getattr__(self, key): ) def __setattr__(self, key, value): - if key in self: - raise NotImplementedError("constant values are immutable.") - CType.__setattr__(self, key, value) + if key in self.__props__: + CType.__setattr__(self, key, value) + else: + raise TypeError("constant values are immutable.") + + def __iter__(self): + return self.constants.__iter__() + + def __len__(self): + return len(self.constants) + + def __getitem__(self, item): + return self.constants[item] def __setitem__(self, key, value): - raise NotImplementedError("constant values are immutable.") + raise TypeError("constant values are immutable.") def __delitem__(self, key): - raise NotImplementedError("constant values are immutable.") + raise TypeError("constant values are immutable.") def __hash__(self): # All values are Python basic types, then easy to hash. @@ -691,10 +736,10 @@ class EnumList(EnumType): """ - def __init__(self, *args, **kwargs): + @classmethod + def type_parameters(cls, *args, **kwargs): assert len(kwargs) in (0, 1, 2), ( - type(self).__name__ - + ': expected 0 to 2 extra parameters ("ctype", "cname").' + cls.__name__ + ': expected 0 to 2 extra parameters ("ctype", "cname").' ) ctype = kwargs.pop("ctype", "int") cname = kwargs.pop("cname", None) @@ -703,13 +748,13 @@ def __init__(self, *args, **kwargs): if isinstance(arg, (list, tuple)): if len(arg) != 2: raise TypeError( - f"{type(self).__name__}: when using a tuple to define a constant, your tuple should contain 2 values: " + f"{cls.__name__}: when using a tuple to define a constant, your tuple should contain 2 values: " "constant name followed by constant alias." ) constant_name, constant_alias = arg if not isinstance(constant_alias, str): raise TypeError( - f'{type(self).__name__}: constant alias should be a string, got "{constant_alias}".' + f'{cls.__name__}: constant alias should be a string, got "{constant_alias}".' ) constant_value = (constant_alias, arg_rank) else: @@ -717,18 +762,18 @@ def __init__(self, *args, **kwargs): constant_value = arg_rank if not isinstance(constant_name, str): raise TypeError( - f'{type(self).__name__}: constant name should be a string, got "{constant_name}".' + f'{cls.__name__}: constant name should be a string, got "{constant_name}".' ) if constant_name in kwargs: raise TypeError( - f'{type(self).__name__}: constant name already used ("{constant_name}").' + f'{cls.__name__}: constant name already used ("{constant_name}").' ) kwargs[constant_name] = constant_value kwargs.update(ctype=ctype) if cname is not None: kwargs.update(cname=cname) - super().__init__(**kwargs) + return super().type_parameters(**kwargs) class CEnumType(EnumList): diff --git a/aesara/scalar/basic.py b/aesara/scalar/basic.py index e21679500c..5cb9ac20ef 100644 --- a/aesara/scalar/basic.py +++ b/aesara/scalar/basic.py @@ -286,14 +286,14 @@ class ScalarType(CType): shape = () dtype: DataType - def __init__(self, dtype): + @classmethod + def type_parameters(cls, dtype): if isinstance(dtype, str) and dtype == "floatX": dtype = config.floatX else: dtype = np.dtype(dtype).name - self.dtype = dtype - self.dtype_specs() # error checking + return {"dtype": dtype} def clone(self, dtype=None, **kwargs): if dtype is None: diff --git a/aesara/sparse/type.py b/aesara/sparse/type.py index d8b39d0a80..584493da1b 100644 --- a/aesara/sparse/type.py +++ b/aesara/sparse/type.py @@ -41,7 +41,7 @@ class SparseTensorType(TensorType): """ - __props__ = ("dtype", "format", "shape") + __props__ = ("format", "dtype", "shape") format_cls = { "csr": scipy.sparse.csr_matrix, "csc": scipy.sparse.csc_matrix, @@ -63,8 +63,9 @@ class SparseTensorType(TensorType): } ndim = 2 - def __init__( - self, + @classmethod + def type_parameters( + cls, format: SparsityTypes, dtype: Union[str, np.dtype], shape: Optional[Iterable[Optional[Union[bool, int]]]] = None, @@ -74,14 +75,17 @@ def __init__( if shape is None and broadcastable is None: shape = (None, None) - if format not in self.format_cls: + if format not in cls.format_cls: raise ValueError( f'unsupported format "{format}" not in list', ) - self.format = format + params = super().type_parameters( + dtype, shape=shape, name=name, broadcastable=broadcastable + ) - super().__init__(dtype, shape=shape, name=name, broadcastable=broadcastable) + params["format"] = format + return params def clone( self, diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 77b17abb1d..89e4522cb4 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -64,8 +64,9 @@ class TensorType(CType[np.ndarray]): ``numpy.nan`` or ``numpy.inf`` entries. (Used in `DebugMode`) """ - def __init__( - self, + @classmethod + def type_parameters( + cls, dtype: Union[str, np.dtype], shape: Optional[Iterable[Optional[Union[bool, int]]]] = None, name: Optional[str] = None, @@ -88,6 +89,7 @@ def __init__( """ + params = dict() if broadcastable is not None: warnings.warn( "The `broadcastable` keyword is deprecated; use `shape`.", @@ -96,12 +98,12 @@ def __init__( shape = broadcastable if str(dtype) == "floatX": - self.dtype = config.floatX + params["dtype"] = config.floatX else: if np.obj2sctype(dtype) is None: raise TypeError(f"Invalid dtype: {dtype}") - self.dtype = np.dtype(dtype).name + params["dtype"] = np.dtype(dtype).name def parse_bcast_and_shape(s): if isinstance(s, (bool, np.bool_)): @@ -109,10 +111,12 @@ def parse_bcast_and_shape(s): else: return s - self.shape = tuple(parse_bcast_and_shape(s) for s in shape) - self.dtype_specs() # error checking is done there - self.name = name - self.numpy_dtype = np.dtype(self.dtype) + params["shape"] = tuple(parse_bcast_and_shape(s) for s in shape) + cls.dtype_specs_params(params) # error checking is done there + params["name"] = name + params["numpy_dtype"] = np.dtype(params["dtype"]) + + return params def clone( self, dtype=None, shape=None, broadcastable=None, **kwargs @@ -280,12 +284,18 @@ def dtype_specs(self): This function is used internally as part of C code generation. """ + return self.dtype_specs_dtype(self.dtype) + + @classmethod + def dtype_specs_params(cls, params): + return cls.dtype_specs_dtype(params["dtype"]) + + @classmethod + def dtype_specs_dtype(cls, dtype): try: - return self.dtype_specs_map[self.dtype] + return cls.dtype_specs_map[dtype] except KeyError: - raise TypeError( - f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}" - ) + raise TypeError(f"Unsupported dtype for {cls.__name__}: {dtype}") def to_scalar_type(self): return aes.get_scalar_type(dtype=self.dtype) diff --git a/aesara/tensor/type_other.py b/aesara/tensor/type_other.py index b0b7a91dc2..cd9eae18a8 100644 --- a/aesara/tensor/type_other.py +++ b/aesara/tensor/type_other.py @@ -53,7 +53,7 @@ def grad(self, inputs, grads): class SliceType(Type[slice]): def clone(self, **kwargs): - return type(self)() + return type(self).subtype() def filter(self, x, strict=False, allow_downcast=None): if isinstance(x, slice): diff --git a/aesara/typed_list/type.py b/aesara/typed_list/type.py index 4936e6958f..059c64105c 100644 --- a/aesara/typed_list/type.py +++ b/aesara/typed_list/type.py @@ -7,24 +7,27 @@ class TypedListType(CType): Parameters ---------- ttype - Type of aesara variable this list will contains, can be another list. + Type of aesara variable this list will contain, can be another list. depth Optional parameters, any value above 0 will create a nested list of this depth. (0-based) """ - def __init__(self, ttype, depth=0): + __props__ = ("ttype",) + + @classmethod + def type_parameters(cls, ttype, depth=0): if depth < 0: raise ValueError("Please specify a depth superior or" "equal to 0") if not isinstance(ttype, Type): raise TypeError("Expected an Aesara Type") - if depth == 0: - self.ttype = ttype - else: - self.ttype = TypedListType.subtype(ttype, depth - 1) + if depth > 0: + ttype = TypedListType.subtype(ttype, depth - 1) + + return {"ttype": ttype} def filter(self, x, strict=False, allow_downcast=None): """ @@ -51,16 +54,6 @@ def filter(self, x, strict=False, allow_downcast=None): else: raise TypeError(f"Expected all elements to be {self.ttype}") - def __eq__(self, other): - """ - Two lists are equal if they contain the same type. - - """ - return type(self) == type(other) and self.ttype == other.ttype - - def __hash__(self): - return hash((type(self), self.ttype)) - def __str__(self): return "TypedList <" + str(self.ttype) + ">" diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 01bd2e4402..fb0118cdce 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -46,6 +46,8 @@ class MyType(Type): + __props__ = ("thingy",) + def __init__(self, thingy): self.thingy = thingy diff --git a/tests/graph/test_features.py b/tests/graph/test_features.py index 4906b5794c..06a48d1e10 100644 --- a/tests/graph/test_features.py +++ b/tests/graph/test_features.py @@ -11,6 +11,8 @@ class TestNodeFinder: def test_straightforward(self): class MyType(Type): + __props__ = ("name",) + def __init__(self, name): self.name = name diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index d768d438a2..3b755b142d 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -21,6 +21,8 @@ def as_variable(x): class MyType(Type): + __props__ = ("thingy",) + def __init__(self, thingy): self.thingy = thingy diff --git a/tests/graph/test_types.py b/tests/graph/test_types.py index fa188eb605..6c37e3ecc2 100644 --- a/tests/graph/test_types.py +++ b/tests/graph/test_types.py @@ -5,6 +5,8 @@ class MyType(Type): + __props__ = ("thingy",) + def __init__(self, thingy): self.thingy = thingy From c99a12025db7dd06edf2ab95a9c45dff87ce2017 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 28 Sep 2022 15:45:53 -0500 Subject: [PATCH 04/21] Stop renaming variables in Scan's merge rewrite --- aesara/scan/rewriting.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/aesara/scan/rewriting.py b/aesara/scan/rewriting.py index f63db7b74c..f67d41c9c3 100644 --- a/aesara/scan/rewriting.py +++ b/aesara/scan/rewriting.py @@ -1694,52 +1694,46 @@ def merge(self, nodes): inner_outs = [[] for nd in nodes] outer_outs = [] - def rename(ls, suffix): - for k in ls: - if k.name: - k.name += str(suffix) - return ls - for idx, nd in enumerate(nodes): - inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inner_inputs), idx)) - outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx) + inner_ins[idx].append(nd.op.inner_seqs(nd.op.inner_inputs)) + outer_ins += nd.op.outer_seqs(nd.inputs) mit_mot_out_slices = () mit_mot_in_slices = () for idx, nd in enumerate(nodes): - inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inner_inputs), idx)) + inner_ins[idx].append(nd.op.inner_mitmot(nd.op.inner_inputs)) inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.inner_outputs)) mit_mot_in_slices += nd.op.info.mit_mot_in_slices mit_mot_out_slices += nd.op.info.mit_mot_out_slices[: nd.op.info.n_mit_mot] - outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx) + outer_ins += nd.op.outer_mitmot(nd.inputs) outer_outs += nd.op.outer_mitmot_outs(nd.outputs) mit_sot_in_slices = () for idx, nd in enumerate(nodes): - inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inner_inputs), idx)) + inner_ins[idx].append(nd.op.inner_mitsot(nd.op.inner_inputs)) inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.inner_outputs)) mit_sot_in_slices += nd.op.info.mit_sot_in_slices - outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx) + outer_ins += nd.op.outer_mitsot(nd.inputs) outer_outs += nd.op.outer_mitsot_outs(nd.outputs) sit_sot_in_slices = () for idx, nd in enumerate(nodes): - inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inner_inputs), idx)) + inner_ins[idx].append(nd.op.inner_sitsot(nd.op.inner_inputs)) sit_sot_in_slices += tuple((-1,) for x in range(nd.op.info.n_sit_sot)) inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.inner_outputs)) - outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx) + outer_ins += nd.op.outer_sitsot(nd.inputs) outer_outs += nd.op.outer_sitsot_outs(nd.outputs) for idx, nd in enumerate(nodes): # Shared - inner_ins[idx].append(rename(nd.op.inner_shared(nd.op.inner_inputs), idx)) - outer_ins += rename(nd.op.outer_shared(nd.inputs), idx) + inner_ins[idx].append(nd.op.inner_shared(nd.op.inner_inputs)) + outer_ins += nd.op.outer_shared(nd.inputs) for idx, nd in enumerate(nodes): # NitSot inner_outs[idx].append(nd.op.inner_nitsot_outs(nd.op.inner_outputs)) - outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx) + outer_ins += nd.op.outer_nitsot(nd.inputs) outer_outs += nd.op.outer_nitsot_outs(nd.outputs) for idx, nd in enumerate(nodes): @@ -1752,8 +1746,8 @@ def rename(ls, suffix): # Non Seqs node_inner_non_seqs = nd.op.inner_non_seqs(nd.op.inner_inputs) n_non_seqs += len(node_inner_non_seqs) - inner_ins[idx].append(rename(node_inner_non_seqs, idx)) - outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx) + inner_ins[idx].append(node_inner_non_seqs) + outer_ins += nd.op.outer_non_seqs(nd.inputs) # Add back the number of steps outer_ins = [nodes[0].inputs[0]] + outer_ins From 600a1186fab2fb0b87f05ac7ddb56a5f1f5473f2 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 26 Sep 2022 13:38:37 +0200 Subject: [PATCH 05/21] Allow shared RandomState/Generator updates in JAX compiled functions --- aesara/link/jax/linker.py | 51 ++++++++++++++++++++++++++++++----- tests/link/jax/test_random.py | 33 +++++++++++++++++++++-- 2 files changed, 75 insertions(+), 9 deletions(-) diff --git a/aesara/link/jax/linker.py b/aesara/link/jax/linker.py index 7f1a12164e..49ef83b293 100644 --- a/aesara/link/jax/linker.py +++ b/aesara/link/jax/linker.py @@ -1,5 +1,8 @@ +import warnings + from numpy.random import Generator, RandomState +from aesara.compile.sharedvalue import SharedVariable, shared from aesara.graph.basic import Constant from aesara.link.basic import JITLinker @@ -7,10 +10,48 @@ class JAXLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using JAX.""" - def fgraph_convert(self, fgraph, **kwargs): + def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from aesara.link.jax.dispatch import jax_funcify + from aesara.tensor.random.type import RandomType + + shared_rng_inputs = [ + inp + for inp in fgraph.inputs + if (isinstance(inp, SharedVariable) and isinstance(inp.type, RandomType)) + ] + + # Replace any shared RNG inputs so that their values can be updated in place + # without affecting the original RNG container. This is necessary because + # JAX does not accept RandomState/Generators as inputs, and they will have to + # be typyfied + if shared_rng_inputs: + warnings.warn( + f"The RandomType SharedVariables {shared_rng_inputs} will not be used " + f"in the compiled JAX graph. Instead a copy will be used.", + UserWarning, + ) + new_shared_rng_inputs = [ + shared(inp.get_value(borrow=False)) for inp in shared_rng_inputs + ] + + fgraph.replace_all( + zip(shared_rng_inputs, new_shared_rng_inputs), + import_missing=True, + reason="JAXLinker.fgraph_convert", + ) + + for old_inp, new_inp in zip(shared_rng_inputs, new_shared_rng_inputs): + new_inp_storage = [new_inp.get_value(borrow=True)] + storage_map[new_inp] = new_inp_storage + old_inp_storage = storage_map.pop(old_inp) + input_storage[input_storage.index(old_inp_storage)] = new_inp_storage + fgraph.remove_input( + fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert" + ) - return jax_funcify(fgraph, **kwargs) + return jax_funcify( + fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs + ) def jit_compile(self, fn): import jax @@ -32,11 +73,7 @@ def create_thunk_inputs(self, storage_map): new_value = jax_typify( sinput[0], dtype=getattr(sinput[0], "dtype", None) ) - # We need to remove the reference-based connection to the - # original `RandomState`/shared variable's storage, because - # subsequent attempts to use the same shared variable within - # other non-JAXified graphs will have problems. - sinput = [new_value] + sinput[0] = new_value thunk_inputs.append(sinput) return thunk_inputs diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index d23c6a096a..a8bc06b5db 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -1,7 +1,10 @@ +import re + import numpy as np import pytest from packaging.version import parse as version_parse +import aesara import aesara.tensor as at from aesara.compile.function import function from aesara.compile.sharedvalue import shared @@ -79,8 +82,34 @@ def test_RandomStream(): srng = RandomStream(seed=123) out = srng.normal() - srng.normal() - fn = function([], out, mode=jax_mode) + with pytest.warns( + UserWarning, + match=r"The RandomType SharedVariables \[.+\] will not be used", + ): + fn = function([], out, mode=jax_mode) jax_res_1 = fn() jax_res_2 = fn() - assert np.array_equal(jax_res_1, jax_res_2) + assert not np.array_equal(jax_res_1, jax_res_2) + + +@pytest.mark.parametrize("rng_ctor", (np.random.RandomState, np.random.default_rng)) +def test_random_updates(rng_ctor): + original_value = rng_ctor(seed=98) + rng = shared(original_value, name="original_rng", borrow=False) + next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs + + with pytest.warns( + UserWarning, + match=re.escape( + "The RandomType SharedVariables [original_rng] will not be used" + ), + ): + f = aesara.function([], [x], updates={rng: next_rng}, mode=jax_mode) + assert f() != f() + + # Check that original rng variable content was not overwritten when calling jax_typify + assert all( + a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b) + for a, b in zip(rng.get_value().__getstate__(), original_value.__getstate__()) + ) From 31250910720e49b37789be60df718d7c5a22af6a Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Fri, 30 Sep 2022 13:44:17 -0500 Subject: [PATCH 06/21] Fix _debugprint's handling of empty profile data --- aesara/printing.py | 8 +- tests/test_printing.py | 195 +++++++++++++++++++++++------------------ 2 files changed, 112 insertions(+), 91 deletions(-) diff --git a/aesara/printing.py b/aesara/printing.py index 4324c63bb0..55db6d3500 100644 --- a/aesara/printing.py +++ b/aesara/printing.py @@ -630,11 +630,7 @@ def get_id_str( if node_info and var in node_info: var_output = f"{var_output} ({node_info[var]})" - if profile is None: - print(var_output, file=file) - elif profile.apply_time and node not in profile.apply_time: - print(var_output, file=file) - elif profile.apply_time and node in profile.apply_time: + if profile and profile.apply_time and node in profile.apply_time: op_time = profile.apply_time[node] op_time_percent = (op_time / profile.fct_call_time) * 100 tot_time_dict = profile.compute_total_times() @@ -652,6 +648,8 @@ def get_id_str( ), file=file, ) + else: + print(var_output, file=file) if not already_done and ( not stop_on_name or not (hasattr(var, "name") and var.name is not None) diff --git a/tests/test_printing.py b/tests/test_printing.py index bc3a9b6aa2..9f8f8e6480 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -3,7 +3,9 @@ """ import logging from io import StringIO +from textwrap import dedent +import numpy as np import pytest import aesara @@ -121,12 +123,12 @@ def test_debugprint(): with pytest.raises(TypeError): debugprint("blah") - A = matrix(name="A") - B = matrix(name="B") + A = dmatrix(name="A") + B = dmatrix(name="B") C = A + B C.name = "C" - D = matrix(name="D") - E = matrix(name="E") + D = dmatrix(name="D") + E = dmatrix(name="E") F = D + E G = C + F @@ -140,21 +142,17 @@ def test_debugprint(): s = StringIO() debugprint(G, file=s, id_type="int") s = s.getvalue() - # The additional white space are needed! - reference = ( - "\n".join( - [ - "Elemwise{add,no_inplace} [id 0]", - " |Elemwise{add,no_inplace} [id 1] 'C'", - " | |A [id 2]", - " | |B [id 3]", - " |Elemwise{add,no_inplace} [id 4]", - " |D [id 5]", - " |E [id 6]", - ] - ) - + "\n" - ) + reference = dedent( + r""" + Elemwise{add,no_inplace} [id 0] + |Elemwise{add,no_inplace} [id 1] 'C' + | |A [id 2] + | |B [id 3] + |Elemwise{add,no_inplace} [id 4] + |D [id 5] + |E [id 6] + """ + ).lstrip() assert s == reference @@ -162,20 +160,17 @@ def test_debugprint(): debugprint(G, file=s, id_type="CHAR") s = s.getvalue() # The additional white space are needed! - reference = ( - "\n".join( - [ - "Elemwise{add,no_inplace} [id A]", - " |Elemwise{add,no_inplace} [id B] 'C'", - " | |A [id C]", - " | |B [id D]", - " |Elemwise{add,no_inplace} [id E]", - " |D [id F]", - " |E [id G]", - ] - ) - + "\n" - ) + reference = dedent( + r""" + Elemwise{add,no_inplace} [id A] + |Elemwise{add,no_inplace} [id B] 'C' + | |A [id C] + | |B [id D] + |Elemwise{add,no_inplace} [id E] + |D [id F] + |E [id G] + """ + ).lstrip() assert s == reference @@ -183,61 +178,86 @@ def test_debugprint(): debugprint(G, file=s, id_type="CHAR", stop_on_name=True) s = s.getvalue() # The additional white space are needed! - reference = ( - "\n".join( - [ - "Elemwise{add,no_inplace} [id A]", - " |Elemwise{add,no_inplace} [id B] 'C'", - " |Elemwise{add,no_inplace} [id C]", - " |D [id D]", - " |E [id E]", - ] - ) - + "\n" - ) + reference = dedent( + r""" + Elemwise{add,no_inplace} [id A] + |Elemwise{add,no_inplace} [id B] 'C' + |Elemwise{add,no_inplace} [id C] + |D [id D] + |E [id E] + """ + ).lstrip() assert s == reference s = StringIO() debugprint(G, file=s, id_type="") s = s.getvalue() - # The additional white space are needed! - reference = ( - "\n".join( - [ - "Elemwise{add,no_inplace}", - " |Elemwise{add,no_inplace} 'C'", - " | |A", - " | |B", - " |Elemwise{add,no_inplace}", - " |D", - " |E", - ] - ) - + "\n" - ) + reference = dedent( + r""" + Elemwise{add,no_inplace} + |Elemwise{add,no_inplace} 'C' + | |A + | |B + |Elemwise{add,no_inplace} + |D + |E + """ + ).lstrip() assert s == reference - # test print_storage=True s = StringIO() debugprint(g, file=s, id_type="", print_storage=True) s = s.getvalue() - reference = ( - "\n".join( - [ - "Elemwise{add,no_inplace} 0 [None]", - " |A [None]", - " |B [None]", - " |D [None]", - " |E [None]", - ] - ) - + "\n" - ) + reference = dedent( + r""" + Elemwise{add,no_inplace} 0 [None] + |A [None] + |B [None] + |D [None] + |E [None] + """ + ).lstrip() assert s == reference + # Test the `profile` handling when profile data is missing + g = aesara.function([A, B, D, E], G, mode=mode, profile=True) + + s = StringIO() + debugprint(g, file=s, id_type="", print_storage=True) + s = s.getvalue() + reference = dedent( + r""" + Elemwise{add,no_inplace} 0 [None] + |A [None] + |B [None] + |D [None] + |E [None] + """ + ).lstrip() + + assert s == reference + + # Add profile data + g(np.c_[[1.0]], np.c_[[1.0]], np.c_[[1.0]], np.c_[[1.0]]) + + s = StringIO() + debugprint(g, file=s, id_type="", print_storage=True) + s = s.getvalue() + reference = dedent( + r""" + Elemwise{add,no_inplace} 0 [None] + |A [None] + |B [None] + |D [None] + |E [None] + """ + ).lstrip() + + assert reference in s + A = dmatrix(name="A") B = dmatrix(name="B") D = dmatrix(name="D") @@ -251,19 +271,22 @@ def test_debugprint(): print_view_map=True, ) s = s.getvalue() - exp_res = r"""Elemwise{Composite{(i0 + (i1 - i2))}} 4 - |A - |InplaceDimShuffle{x,0} v={0: [0]} 3 - | |CGemv{inplace} d={0: [0]} 2 - | |AllocEmpty{dtype='float64'} 1 - | | |Shape_i{0} 0 - | | |B - | |TensorConstant{1.0} - | |B - | | - | |TensorConstant{0.0} - |D - """ + exp_res = dedent( + r""" + Elemwise{Composite{(i0 + (i1 - i2))}} 4 + |A + |InplaceDimShuffle{x,0} v={0: [0]} 3 + | |CGemv{inplace} d={0: [0]} 2 + | |AllocEmpty{dtype='float64'} 1 + | | |Shape_i{0} 0 + | | |B + | |TensorConstant{1.0} + | |B + | | + | |TensorConstant{0.0} + |D + """ + ).lstrip() assert [l.strip() for l in s.split("\n")] == [ l.strip() for l in exp_res.split("\n") From bc2f292a232510eb5dbea20792f0ba558bc52af4 Mon Sep 17 00:00:00 2001 From: Markus Schmaus Date: Mon, 3 Oct 2022 11:19:31 +0200 Subject: [PATCH 07/21] Make `Type` a subclass of `type` --- aesara/__init__.py | 4 +- aesara/compile/builders.py | 21 +-- aesara/compile/compiledir.py | 5 +- aesara/compile/debugmode.py | 13 +- aesara/compile/nanguardmode.py | 3 +- aesara/compile/ops.py | 9 +- aesara/gradient.py | 43 +++--- aesara/graph/fg.py | 3 +- aesara/graph/null_type.py | 19 +-- aesara/graph/op.py | 3 +- aesara/graph/type.py | 173 +++++++++++++++------- aesara/issubtype.py | 13 ++ aesara/link/c/op.py | 3 +- aesara/link/c/params_type.py | 39 +++-- aesara/link/c/type.py | 174 ++++++++++++++--------- aesara/link/jax/linker.py | 3 +- aesara/link/numba/dispatch/basic.py | 7 +- aesara/link/numba/linker.py | 3 +- aesara/raise_op.py | 19 +-- aesara/sandbox/rng_mrg.py | 4 +- aesara/scalar/basic.py | 31 ++-- aesara/scan/basic.py | 5 +- aesara/scan/op.py | 43 +++--- aesara/sparse/basic.py | 17 +-- aesara/sparse/type.py | 28 ++-- aesara/tensor/basic.py | 23 +-- aesara/tensor/blas.py | 16 ++- aesara/tensor/elemwise.py | 13 +- aesara/tensor/math.py | 9 +- aesara/tensor/nlinalg.py | 3 +- aesara/tensor/nnet/basic.py | 9 +- aesara/tensor/nnet/batchnorm.py | 45 +++--- aesara/tensor/nnet/rewriting.py | 25 ++-- aesara/tensor/random/basic.py | 7 +- aesara/tensor/random/op.py | 3 +- aesara/tensor/random/type.py | 26 ++-- aesara/tensor/rewriting/basic.py | 3 +- aesara/tensor/rewriting/shape.py | 3 +- aesara/tensor/rewriting/subtensor.py | 13 +- aesara/tensor/shape.py | 11 +- aesara/tensor/subtensor.py | 47 +++--- aesara/tensor/type.py | 48 +++---- aesara/tensor/type_other.py | 23 +-- aesara/tensor/var.py | 7 +- aesara/typed_list/basic.py | 21 +-- aesara/typed_list/type.py | 18 ++- tests/compile/function/test_types.py | 16 ++- tests/compile/test_builders.py | 7 +- tests/compile/test_debugmode.py | 5 +- tests/graph/rewriting/test_unify.py | 7 +- tests/graph/test_basic.py | 19 +-- tests/graph/test_compute_test_value.py | 7 +- tests/graph/test_destroyhandler.py | 16 ++- tests/graph/test_features.py | 20 +-- tests/graph/test_op.py | 27 ++-- tests/graph/test_types.py | 29 ++-- tests/graph/utils.py | 27 ++-- tests/link/c/test_basic.py | 10 +- tests/link/c/test_cmodule.py | 2 +- tests/link/numba/test_basic.py | 12 +- tests/link/test_link.py | 8 +- tests/scalar/test_basic.py | 5 +- tests/sparse/test_var.py | 15 +- tests/tensor/rewriting/test_shape.py | 8 +- tests/tensor/rewriting/test_subtensor.py | 8 +- tests/tensor/test_elemwise.py | 3 +- tests/tensor/test_merge.py | 13 +- tests/tensor/test_shape.py | 14 +- tests/tensor/test_subtensor.py | 5 +- tests/tensor/test_type.py | 3 +- tests/test_raise_op.py | 5 +- 71 files changed, 788 insertions(+), 561 deletions(-) create mode 100644 aesara/issubtype.py diff --git a/aesara/__init__.py b/aesara/__init__.py index 0d81ab1087..658a273e86 100644 --- a/aesara/__init__.py +++ b/aesara/__init__.py @@ -29,6 +29,8 @@ from functools import singledispatch from typing import Any, NoReturn, Optional +from aesara.issubtype import issubtype + aesara_logger = logging.getLogger("aesara") logging_default_handler = logging.StreamHandler() @@ -151,7 +153,7 @@ def get_scalar_constant_value(v): """ # Is it necessary to test for presence of aesara.sparse at runtime? sparse = globals().get("sparse") - if sparse and isinstance(v.type, sparse.SparseTensorType): + if sparse and issubtype(v.type, sparse.SparseTensorType): if v.owner is not None and isinstance(v.owner.op, sparse.CSM): data = v.owner.inputs[0] return tensor.get_scalar_constant_value(data) diff --git a/aesara/compile/builders.py b/aesara/compile/builders.py index b751172c0a..84914231b7 100644 --- a/aesara/compile/builders.py +++ b/aesara/compile/builders.py @@ -26,6 +26,7 @@ from aesara.graph.op import HasInnerGraph, Op from aesara.graph.rewriting.basic import in2out, node_rewriter from aesara.graph.utils import MissingInputError +from aesara.issubtype import issubtype from aesara.tensor.rewriting.shape import ShapeFeature @@ -210,7 +211,7 @@ def _filter_grad_var(grad, inp): # # For now, this converts NullType or DisconnectedType into zeros_like. # other types are unmodified: overrider_var -> None - if isinstance(grad.type, (NullType, DisconnectedType)): + if issubtype(grad.type, (NullType, DisconnectedType)): if hasattr(inp, "zeros_like"): return inp.zeros_like(), grad else: @@ -221,9 +222,9 @@ def _filter_grad_var(grad, inp): @staticmethod def _filter_rop_var(inpJ, out): # mostly similar to _filter_grad_var - if isinstance(inpJ.type, NullType): + if issubtype(inpJ.type, NullType): return out.zeros_like(), inpJ - if isinstance(inpJ.type, DisconnectedType): + if issubtype(inpJ.type, DisconnectedType): # since R_op does not have DisconnectedType yet, we will just # make them zeros. return out.zeros_like(), None @@ -502,7 +503,7 @@ def lop_op(inps, grads): all_grads_l = list(all_grads_l) all_grads_ov_l = list(all_grads_ov_l) elif isinstance(lop_op, Variable): - if isinstance(lop_op.type, (DisconnectedType, NullType)): + if issubtype(lop_op.type, (DisconnectedType, NullType)): all_grads_l = [inp.zeros_like() for inp in local_inputs] all_grads_ov_l = [lop_op.type() for _ in range(inp_len)] else: @@ -529,7 +530,7 @@ def lop_op(inps, grads): all_grads_l.append(gnext) all_grads_ov_l.append(gnext_ov) elif isinstance(fn_gov, Variable): - if isinstance(fn_gov.type, (DisconnectedType, NullType)): + if issubtype(fn_gov.type, (DisconnectedType, NullType)): all_grads_l.append(inp.zeros_like()) all_grads_ov_l.append(fn_gov.type()) else: @@ -614,10 +615,10 @@ def _recompute_rop_op(self): all_rops_l = list(all_rops_l) all_rops_ov_l = list(all_rops_ov_l) elif isinstance(rop_op, Variable): - if isinstance(rop_op.type, NullType): + if issubtype(rop_op.type, NullType): all_rops_l = [inp.zeros_like() for inp in local_inputs] all_rops_ov_l = [rop_op.type() for _ in range(out_len)] - elif isinstance(rop_op.type, DisconnectedType): + elif issubtype(rop_op.type, DisconnectedType): all_rops_l = [inp.zeros_like() for inp in local_inputs] all_rops_ov_l = [None] * out_len else: @@ -644,10 +645,10 @@ def _recompute_rop_op(self): all_rops_l.append(rnext) all_rops_ov_l.append(rnext_ov) elif isinstance(fn_rov, Variable): - if isinstance(fn_rov.type, NullType): + if issubtype(fn_rov.type, NullType): all_rops_l.append(out.zeros_like()) all_rops_ov_l.append(fn_rov.type()) - if isinstance(fn_rov.type, DisconnectedType): + if issubtype(fn_rov.type, DisconnectedType): all_rops_l.append(out.zeros_like()) all_rops_ov_l.append(None) else: @@ -857,7 +858,7 @@ def connection_pattern(self, node): # cpmat_self &= out_is_disconnected for i, t in enumerate(self._lop_op_stypes_l): if t is not None: - if isinstance(t.type, DisconnectedType): + if issubtype(t.type, DisconnectedType): for o in range(out_len): cpmat_self[i][o] = False for o in range(out_len): diff --git a/aesara/compile/compiledir.py b/aesara/compile/compiledir.py index 6ecb6e0eda..790490464c 100644 --- a/aesara/compile/compiledir.py +++ b/aesara/compile/compiledir.py @@ -1,3 +1,6 @@ +from aesara import issubtype + + """ This module contains housekeeping functions for cleaning/purging the "compiledir". It is used by the "aesara-cache" CLI tool, located in the /bin folder of the repository. @@ -131,7 +134,7 @@ def print_compiledir_content(): zeros_op += 1 else: types = list( - {x for x in flatten(keydata.keys) if isinstance(x, CType)} + {x for x in flatten(keydata.keys) if issubtype(x, CType)} ) compile_start = compile_end = float("nan") for fn in os.listdir(os.path.join(compiledir, dir)): diff --git a/aesara/compile/debugmode.py b/aesara/compile/debugmode.py index 6cd1cf4cfc..54f715a0d4 100644 --- a/aesara/compile/debugmode.py +++ b/aesara/compile/debugmode.py @@ -20,6 +20,7 @@ import numpy as np import aesara +from aesara import issubtype from aesara.compile.function.types import ( Function, FunctionMaker, @@ -792,7 +793,7 @@ def _get_preallocated_maps( for r in considered_outputs: # There is no risk to overwrite inputs, since r does not work # inplace. - if isinstance(r.type, TensorType): + if issubtype(r.type, TensorType): reuse_outputs[r][...] = np.asarray(def_val).astype(r.type.dtype) if reuse_outputs: @@ -805,7 +806,7 @@ def _get_preallocated_maps( if "c_contiguous" in prealloc_modes or "ALL" in prealloc_modes: c_cont_outputs = {} for r in considered_outputs: - if isinstance(r.type, TensorType): + if issubtype(r.type, TensorType): # Build a C-contiguous buffer new_buf = r.type.value_zeros(r_vals[r].shape) assert new_buf.flags["C_CONTIGUOUS"] @@ -822,7 +823,7 @@ def _get_preallocated_maps( if "f_contiguous" in prealloc_modes or "ALL" in prealloc_modes: f_cont_outputs = {} for r in considered_outputs: - if isinstance(r.type, TensorType): + if issubtype(r.type, TensorType): new_buf = np.zeros( shape=r_vals[r].shape, dtype=r_vals[r].dtype, order="F" ) @@ -850,7 +851,7 @@ def _get_preallocated_maps( max_ndim = 0 rev_out_broadcastable = [] for r in considered_outputs: - if isinstance(r.type, TensorType): + if issubtype(r.type, TensorType): if max_ndim < r.ndim: rev_out_broadcastable += [True] * (r.ndim - max_ndim) max_ndim = r.ndim @@ -865,7 +866,7 @@ def _get_preallocated_maps( # Initial allocation init_strided = {} for r in considered_outputs: - if isinstance(r.type, TensorType): + if issubtype(r.type, TensorType): # Create a buffer twice as large in every dimension, # except if broadcastable, or for dimensions above # config.DebugMode__check_preallocated_output_ndim @@ -944,7 +945,7 @@ def _get_preallocated_maps( name = f"wrong_size{tuple(shape_diff)}" for r in considered_outputs: - if isinstance(r.type, TensorType): + if issubtype(r.type, TensorType): r_shape_diff = shape_diff[: r.ndim] out_shape = [ max((s + sd), 0) diff --git a/aesara/compile/nanguardmode.py b/aesara/compile/nanguardmode.py index 3e0cbf719a..de165a926a 100644 --- a/aesara/compile/nanguardmode.py +++ b/aesara/compile/nanguardmode.py @@ -5,6 +5,7 @@ import numpy as np import aesara +from aesara import issubtype from aesara.compile.mode import Mode from aesara.configdefaults import config from aesara.tensor.type import discrete_dtypes @@ -36,7 +37,7 @@ def _is_numeric_value(arr, var): return False elif isinstance(arr, (np.random.mtrand.RandomState, np.random.Generator)): return False - elif var and isinstance(var.type, RandomType): + elif var and issubtype(var.type, RandomType): return False elif isinstance(arr, slice): return False diff --git a/aesara/compile/ops.py b/aesara/compile/ops.py index 6bad920631..1f3897f0d5 100644 --- a/aesara/compile/ops.py +++ b/aesara/compile/ops.py @@ -12,6 +12,7 @@ from aesara.graph.basic import Apply from aesara.graph.op import Op +from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.type import CType @@ -64,7 +65,7 @@ def c_code(self, node, nodename, inp, out, sub): (oname,) = out fail = sub["fail"] - itype = node.inputs[0].type.__class__ + itype = node.inputs[0].type.base_type if itype in self.c_code_and_version: code, version = self.c_code_and_version[itype] return code % locals() @@ -199,7 +200,7 @@ def c_code(self, node, name, inames, onames, sub): (oname,) = onames fail = sub["fail"] - itype = node.inputs[0].type.__class__ + itype = node.inputs[0].type.base_type if itype in self.c_code_and_version: code, version = self.c_code_and_version[itype] return code % locals() @@ -311,11 +312,11 @@ def numpy_dot(a, b): """ if not isinstance(itypes, (list, tuple)): itypes = [itypes] - if not all(isinstance(t, CType) for t in itypes): + if not all(issubtype(t, CType) for t in itypes): raise TypeError("itypes has to be a list of Aesara types") if not isinstance(otypes, (list, tuple)): otypes = [otypes] - if not all(isinstance(t, CType) for t in otypes): + if not all(issubtype(t, CType) for t in otypes): raise TypeError("otypes has to be a list of Aesara types") # make sure they are lists and not tuples diff --git a/aesara/gradient.py b/aesara/gradient.py index bc8fd67117..50b4dc80ac 100644 --- a/aesara/gradient.py +++ b/aesara/gradient.py @@ -27,7 +27,8 @@ from aesara.graph.basic import Apply, NominalVariable, Variable from aesara.graph.null_type import NullType, null_type from aesara.graph.op import get_test_values -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type +from aesara.issubtype import issubtype if TYPE_CHECKING: @@ -122,7 +123,7 @@ def grad_undefined(op, x_pos, x, comment=""): )() -class DisconnectedType(Type): +class DisconnectedTypeMeta(NewTypeMeta): """A type indicating that a variable is the result of taking the gradient of ``c`` with respect to ``x`` when ``c`` is not a function of ``x``. @@ -158,6 +159,10 @@ def __str__(self): return "DisconnectedType" +class DisconnectedType(Type, metaclass=DisconnectedTypeMeta): + pass + + disconnected_type = DisconnectedType.subtype() @@ -498,7 +503,7 @@ def grad( if known_grads is None: raise ValueError("cost and known_grads can't both be None.") - if cost is not None and isinstance(cost.type, NullType): + if cost is not None and issubtype(cost.type, NullType): raise ValueError( "Can't differentiate a NaN cost. " f"Cost is NaN because {cost.type.why_null}" @@ -562,7 +567,7 @@ def grad( " or sparse aesara variable" ) - if not isinstance( + if not issubtype( g_var.type, (NullType, DisconnectedType) ) and "float" not in str(g_var.type.dtype): raise TypeError( @@ -627,14 +632,14 @@ def handle_disconnected(var): rval: MutableSequence[Optional[Variable]] = list(_rval) for i in range(len(_rval)): - if isinstance(_rval[i].type, NullType): + if issubtype(_rval[i].type, NullType): if null_gradients == "raise": raise NullTypeGradError( f"`grad` encountered a NaN. {_rval[i].type.why_null}" ) else: assert null_gradients == "return" - if isinstance(_rval[i].type, DisconnectedType): + if issubtype(_rval[i].type, DisconnectedType): handle_disconnected(_rval[i]) if return_disconnected == "zero": rval[i] = _float_zeros_like(_wrt[i]) @@ -1059,7 +1064,7 @@ def access_term_cache(node): # list of bools indicating if each output is connected to the cost outputs_connected = [ - not isinstance(g.type, DisconnectedType) for g in output_grads + not issubtype(g.type, DisconnectedType) for g in output_grads ] connection_pattern = _node_to_pattern(node) @@ -1086,9 +1091,7 @@ def access_term_cache(node): ] # List of bools indicating if each output is NullType - ograd_is_nan = [ - isinstance(output.type, NullType) for output in output_grads - ] + ograd_is_nan = [issubtype(output.type, NullType) for output in output_grads] # List of bools indicating if each input only has NullType outputs only_connected_to_nan = [ @@ -1197,7 +1200,7 @@ def try_to_copy_if_needed(var): orig_output, new_output_grad = packed if not hasattr(orig_output, "shape"): continue - if isinstance(new_output_grad.type, DisconnectedType): + if issubtype(new_output_grad.type, DisconnectedType): continue for orig_output_v, new_output_grad_v in get_test_values(*packed): o_shape = orig_output_v.shape @@ -1225,7 +1228,7 @@ def try_to_copy_if_needed(var): # return the sparse grad for optimization reason. # for ig, i in zip(input_grads, inputs): - # if (not isinstance(ig.type, (DisconnectedType, NullType)) and + # if (not issubtype(ig.type, (DisconnectedType, NullType)) and # type(ig.type) != type(i.type)): # raise ValueError( # "%s returned the wrong type for gradient terms." @@ -1246,7 +1249,7 @@ def try_to_copy_if_needed(var): if ( ograd_is_nan[out_idx] and connection_pattern[inp_idx][out_idx] - and not isinstance(input_grads[inp_idx].type, DisconnectedType) + and not issubtype(input_grads[inp_idx].type, DisconnectedType) ): input_grads[inp_idx] = output_grads[out_idx] @@ -1300,7 +1303,7 @@ def try_to_copy_if_needed(var): f"of shape {i_shape}" ) - if not isinstance(term.type, (NullType, DisconnectedType)): + if not issubtype(term.type, (NullType, DisconnectedType)): if term.type.dtype not in aesara.tensor.type.float_dtypes: raise TypeError( str(node.op) + ".grad illegally " @@ -1309,7 +1312,7 @@ def try_to_copy_if_needed(var): ) if only_connected_to_nan[i]: - assert isinstance(term.type, NullType) + assert issubtype(term.type, NullType) if only_connected_to_int[i]: # This term has only integer outputs and we know @@ -1345,7 +1348,7 @@ def try_to_copy_if_needed(var): for i, (ipt, ig, connected) in enumerate( zip(inputs, input_grads, inputs_connected) ): - actually_connected = not isinstance(ig.type, DisconnectedType) + actually_connected = not issubtype(ig.type, DisconnectedType) if actually_connected and not connected: msg = ( @@ -1392,12 +1395,12 @@ def access_grad_cache(var): " Variable instance." ) - if isinstance(term.type, NullType): + if issubtype(term.type, NullType): null_terms.append(term) continue # Don't try to sum up DisconnectedType placeholders - if isinstance(term.type, DisconnectedType): + if issubtype(term.type, DisconnectedType): continue if hasattr(var, "ndim") and term.ndim != var.ndim: @@ -2099,9 +2102,9 @@ def _is_zero(x): """ if not hasattr(x, "type"): return np.all(x == 0.0) - if isinstance(x.type, NullType): + if issubtype(x.type, NullType): return "no" - if isinstance(x.type, DisconnectedType): + if issubtype(x.type, DisconnectedType): return "yes" no_constant_value = True diff --git a/aesara/graph/fg.py b/aesara/graph/fg.py index 26fb74bd7f..dc538e3158 100644 --- a/aesara/graph/fg.py +++ b/aesara/graph/fg.py @@ -24,6 +24,7 @@ from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate from aesara.graph.utils import MetaObject, MissingInputError, TestValueError +from aesara.issubtype import issubtype from aesara.misc.ordered_set import OrderedSet @@ -309,7 +310,7 @@ def import_var( ): from aesara.graph.null_type import NullType - if isinstance(var.type, NullType): + if issubtype(var.type, NullType): raise TypeError( f"Computation graph contains a NaN. {var.type.why_null}" ) diff --git a/aesara/graph/null_type.py b/aesara/graph/null_type.py index de572aec4a..754a5f65f6 100644 --- a/aesara/graph/null_type.py +++ b/aesara/graph/null_type.py @@ -1,7 +1,7 @@ -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Props, Type -class NullType(Type): +class NullTypeMeta(NewTypeMeta): """ A type that allows no values. @@ -17,10 +17,7 @@ class NullType(Type): """ - __props__ = ("why_null",) - - def __init__(self, why_null="(no explanation given)"): - self.why_null = why_null + why_null: Props[str] = None def filter(self, data, strict=False, allow_downcast=None): raise ValueError("No values may be assigned to a NullType") @@ -34,14 +31,12 @@ def may_share_memory(a, b): def values_eq(self, a, b, force_same_dtype=True): raise ValueError("NullType has no values to compare") - def __eq__(self, other): - return type(self) == type(other) - - def __hash__(self): - return hash(type(self)) - def __str__(self): return "NullType" +class NullType(Type, metaclass=NullTypeMeta): + pass + + null_type = NullType.subtype() diff --git a/aesara/graph/op.py b/aesara/graph/op.py index 6eb62cd2b4..60dee3cc56 100644 --- a/aesara/graph/op.py +++ b/aesara/graph/op.py @@ -29,6 +29,7 @@ add_tag_trace, get_variable_trace_string, ) +from aesara.issubtype import issubtype from aesara.link.c.params_type import Params, ParamsType @@ -477,7 +478,7 @@ def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool: def get_params(self, node: Apply) -> Params: """Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`.""" - if isinstance(self.params_type, ParamsType): + if issubtype(self.params_type, ParamsType): wrapper = self.params_type if not all(hasattr(self, field) for field in wrapper.fields): # Let's print missing attributes for debugging. diff --git a/aesara/graph/type.py b/aesara/graph/type.py index ca19f09d0a..8b6699e1b7 100644 --- a/aesara/graph/type.py +++ b/aesara/graph/type.py @@ -1,6 +1,17 @@ -import inspect +import copyreg from abc import ABCMeta, abstractmethod -from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union, final +from itertools import chain +from typing import ( + Annotated, + Any, + Optional, + Text, + Tuple, + TypeVar, + Union, + get_args, + get_type_hints, +) from typing_extensions import Protocol, TypeAlias, runtime_checkable @@ -10,32 +21,11 @@ D = TypeVar("D") +PropsV = TypeVar("PropsV") +Props = Annotated[Optional[PropsV], "props"] class NewTypeMeta(ABCMeta): - __props__: tuple[str, ...] - - def __call__(cls, *args, **kwargs): - raise RuntimeError("Use subtype") - # return super().__call__(*args, **kwargs) - - def subtype(cls, *args, **kwargs): - kwargs = cls.type_parameters(*args, **kwargs) - return super().__call__(**kwargs) - - def type_parameters(cls, *args, **kwargs): - if args: - init_args = tuple(inspect.signature(cls.__init__).parameters.keys())[1:] - if cls.__props__[: len(args)] != init_args[: len(args)]: - raise RuntimeError( - f"{cls.__props__=} doesn't match {init_args=} for {args=}" - ) - - kwargs |= zip(cls.__props__, args) - return kwargs - - -class Type(Generic[D], metaclass=NewTypeMeta): """ Interface specification for variable type instances. @@ -59,7 +49,84 @@ class Type(Generic[D], metaclass=NewTypeMeta): The `Type` that will be created by a call to `Type.make_constant`. """ - __props__: tuple[str, ...] = () + _prop_names: tuple[str, ...] = tuple() + _subclass_cache = dict() + + _base_type: Optional["NewTypeMeta"] = None + _type_parameters: dict[str, Any] = dict() + + @staticmethod + def make_key(params): + res = [] + for k, v in sorted(params.items()): + if isinstance(v, dict): + v = NewTypeMeta.make_key(v) + res.append((k, v)) + + return tuple(res) + + @classmethod + def __new__(cls, *args, **kwargs): + res = super().__new__(*args, **kwargs) + props = tuple( + k + for k, v in chain( + get_type_hints(type(res), include_extras=True).items(), + get_type_hints(res, include_extras=True).items(), + ) + if "props" in get_args(v) + ) + res._prop_names = props + copyreg.pickle(type(res), _pickle_NewTypeMeta) + return res + + def subtype(cls, *args, **kwargs): + # For dynamically created types the attribute base_type exists and points to the base type it was derived from + base_type = cls.base_type + kwargs = base_type.type_parameters(*args, **kwargs) + + return base_type.subtype_params(kwargs) + + @property + def base_type(cls): + if cls._base_type is None: + return cls + else: + return cls._base_type + + def subtype_params(cls, params): + if not params: + return cls + + key = (cls, *NewTypeMeta.make_key(params)) + try: + return NewTypeMeta._subclass_cache[key] + except KeyError: + pass + cls_name = f"{cls.__name__}{params}" + + res = type(cls)(cls_name, (cls,), params) + res._base_type = cls + res._type_parameters = params + + NewTypeMeta._subclass_cache[key] = res + return res + + def __call__(self, name: Optional[Text] = None) -> Any: + """Return a new `Variable` instance of Type `self`. + + Parameters + ---------- + name : None or str + A pretty string for printing and debugging. + + """ + return utils.add_tag_trace(self.make_variable(name)) + + def type_parameters(cls, *args, **kwargs): + if args: + kwargs |= zip(cls._prop_names, args) + return kwargs @classmethod def create(cls, **kwargs): @@ -161,7 +228,7 @@ def filter_inplace( def filter_variable( self, other: Union[Variable, D], allow_convert: bool = True - ) -> variable_type: + ) -> Any: r"""Convert a `other` into a `Variable` with a `Type` that's compatible with `self`. If the involved `Type`\s are not compatible, a `TypeError` will be raised. @@ -244,18 +311,7 @@ def make_constant(self, value: D, name: Optional[Text] = None) -> constant_type: def clone(self, *args, **kwargs) -> "Type": """Clone a copy of this type with the given arguments/keyword values, if any.""" - return type(self).subtype(*args, **kwargs) - - def __call__(self, name: Optional[Text] = None) -> variable_type: - """Return a new `Variable` instance of Type `self`. - - Parameters - ---------- - name : None or str - A pretty string for printing and debugging. - - """ - return utils.add_tag_trace(self.make_variable(name)) + return self.subtype(*args, **kwargs) @classmethod def values_eq(cls, a: D, b: D) -> bool: @@ -295,7 +351,7 @@ def _props(self): """ Tuple of properties of all attributes """ - return tuple(getattr(self, a) for a in self.__props__) + return tuple(getattr(self, a) for a in self._prop_names) def _props_dict(self): """This return a dict of all ``__props__`` key-> value. @@ -305,33 +361,42 @@ def _props_dict(self): least all the original props. """ - return {a: getattr(self, a) for a in self.__props__} + return {a: getattr(self, a) for a in self._prop_names} - @final - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) + # def __hash__(self): + # return hash((type(self), tuple(getattr(self, a, None) for a in self._prop_names))) - def __hash__(self): - return hash((type(self), tuple(getattr(self, a) for a in self.__props__))) - - def __eq__(self, other): - return type(self) == type(other) and tuple( - getattr(self, a) for a in self.__props__ - ) == tuple(getattr(other, a) for a in self.__props__) + # def __eq__(self, other): + # return type(self) == type(other) and tuple( + # getattr(self, a) for a in self._prop_names + # ) == tuple(getattr(other, a) for a in self._prop_names) def __str__(self): - if self.__props__ is None or len(self.__props__) == 0: + if self._prop_names is None or len(self._prop_names) == 0: return f"{self.__class__.__name__}()" else: return "{}{{{}}}".format( self.__class__.__name__, ", ".join( - "{}={!r}".format(p, getattr(self, p)) for p in self.__props__ + "{}={!r}".format(p, getattr(self, p)) for p in self._prop_names ), ) +def _pickle_NewTypeMeta(type_: NewTypeMeta): + base_type = type_.base_type + if base_type is type_: + return type_.__name__ + return base_type.subtype_params, (type_._type_parameters,) + + +copyreg.pickle(NewTypeMeta, _pickle_NewTypeMeta) + + +class Type(metaclass=NewTypeMeta): + pass + + DataType = str diff --git a/aesara/issubtype.py b/aesara/issubtype.py new file mode 100644 index 0000000000..e52c02ffa0 --- /dev/null +++ b/aesara/issubtype.py @@ -0,0 +1,13 @@ +def issubtype(x, typ): + if not isinstance(typ, tuple): + typ = (typ,) + + for t in typ: + if isinstance(x, type): + if issubclass(x, t): + return True + else: + if isinstance(x, typ): + return True + + return False diff --git a/aesara/link/c/op.py b/aesara/link/c/op.py index 9c35968e4d..07a959e028 100644 --- a/aesara/link/c/op.py +++ b/aesara/link/c/op.py @@ -25,6 +25,7 @@ from aesara.graph.op import ComputeMapType, Op, StorageMapType, ThunkType from aesara.graph.type import HasDataType from aesara.graph.utils import MethodNotDefined +from aesara.issubtype import issubtype from aesara.link.c.interface import CLinkerOp from aesara.link.c.params_type import ParamsType from aesara.utils import hash_from_code @@ -432,7 +433,7 @@ def __get_op_params(self) -> List[Tuple[str, Any]]: """ params: List[Tuple[str, Any]] = [] - if isinstance(self.params_type, ParamsType): + if issubtype(self.params_type, ParamsType): wrapper = self.params_type params.append(("PARAMS_TYPE", wrapper.name)) for i in range(wrapper.length): diff --git a/aesara/link/c/params_type.py b/aesara/link/c/params_type.py index 2b85873192..7f028f5bd3 100644 --- a/aesara/link/c/params_type.py +++ b/aesara/link/c/params_type.py @@ -116,15 +116,19 @@ def __init__(value_attr1, value_attr2): import hashlib import re +from typing import Any +from aesara.graph.type import Props from aesara.graph.utils import MethodNotDefined -from aesara.link.c.type import CType, EnumType +from aesara.issubtype import issubtype +from aesara.link.c.type import CType, CTypeMeta, EnumType # Set of C and C++ keywords as defined (at March 2nd, 2017) in the pages below: # - http://fr.cppreference.com/w/c/keyword # - http://fr.cppreference.com/w/cpp/keyword # Added `NULL` and `_Pragma` keywords. + c_cpp_keywords = { "_Alignas", "_Alignof", @@ -252,7 +256,7 @@ class Params(dict): """ def __init__(self, params_type, **kwargs): - if not isinstance(params_type, ParamsType): + if not issubtype(params_type, ParamsType): raise TypeError("Params: 1st constructor argument should be a ParamsType.") for field in params_type.fields: if field not in kwargs: @@ -316,7 +320,7 @@ def __ne__(self, other): return not self.__eq__(other) -class ParamsType(CType): +class ParamsTypeMeta(CTypeMeta): """ This class can create a struct of Aesara types (like `TensorType`, etc.) to be used as a convenience `Op` parameter wrapping many data. @@ -343,6 +347,9 @@ class ParamsType(CType): """ + fields: Props[tuple[Any, ...]] = tuple() + types: Props[tuple[Any, ...]] = tuple() + @classmethod def type_parameters(cls, **kwargs): params = dict() @@ -362,7 +369,7 @@ def type_parameters(cls, **kwargs): ) type_instance = kwargs[attribute_name] type_name = type_instance.__class__.__name__ - if not isinstance(type_instance, CType): + if not issubtype(type_instance, CType): raise TypeError( 'ParamsType: attribute "%s" should inherit from Aesara CType, got "%s".' % (attribute_name, type_name) @@ -375,7 +382,7 @@ def type_parameters(cls, **kwargs): params["_const_to_enum"] = {} params["_alias_to_enum"] = {} - enum_types = [t for t in params["types"] if isinstance(t, EnumType)] + enum_types = [t for t in params["types"] if issubtype(t, EnumType)] if enum_types: # We don't want same enum names in different enum types. if sum(len(t) for t in enum_types) != len( @@ -410,7 +417,6 @@ def type_parameters(cls, **kwargs): for enum_type in enum_types for alias in enum_type.aliases } - return params def __setstate__(self, state): @@ -429,10 +435,9 @@ def __setstate__(self, state): def __getattr__(self, key): # Now we can access value of each enum defined inside enum types wrapped into the current ParamsType. # const_to_enum = super().__getattribute__("_const_to_enum") - if not key.startswith("__"): - const_to_enum = self._const_to_enum - if key in const_to_enum: - return const_to_enum[key][key] + if key != "_const_to_enum" and hasattr(self, "_const_to_enum"): + if key in self._const_to_enum: + return self._const_to_enum[key][key] raise AttributeError(f"'{self}' object has no attribute '{key}'") def __repr__(self): @@ -443,16 +448,6 @@ def __repr__(self): ] ) - def __eq__(self, other): - return ( - type(self) == type(other) - and self.fields == other.fields - and self.types == other.types - ) - - def __hash__(self): - return hash((type(self),) + self.fields + self.types) - @staticmethod def generate_struct_name(params): # This method tries to generate a unique name for the current instance. @@ -897,3 +892,7 @@ def c_sync(self, name, sub): # `Type` cannot be (compiled) graph _outputs_, because that's when # `CType.c_sync` is used. raise NotImplementedError("Variables of this type cannot be graph outputs") + + +class ParamsType(CType, metaclass=ParamsTypeMeta): + pass diff --git a/aesara/link/c/type.py b/aesara/link/c/type.py index faa8e01895..63491606e6 100644 --- a/aesara/link/c/type.py +++ b/aesara/link/c/type.py @@ -1,11 +1,10 @@ import ctypes import platform import re -from collections.abc import Mapping -from typing import TypeVar +from typing import Any, ItemsView, KeysView, TypeVar, ValuesView from aesara.graph.basic import Constant -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Props, Type from aesara.link.c.interface import CLinkerType from aesara.utils import Singleton @@ -14,7 +13,7 @@ T = TypeVar("T", bound=Type) -class CType(Type[D], CLinkerType): +class CTypeMeta(NewTypeMeta, CLinkerType): """Convenience wrapper combining `Type` and `CLinkerType`. Aesara comes with several subclasses of such as: @@ -57,7 +56,11 @@ class CType(Type[D], CLinkerType): """ -class Generic(CType, Singleton): +class CType(Type, metaclass=CTypeMeta): + pass + + +class GenericMeta(CTypeMeta, Singleton): r"""A type for a generic Python object exposed directly in C. This class implements the `CType` and `CLinkerType` interfaces @@ -116,6 +119,10 @@ def __str__(self): return self.__class__.__name__ +class Generic(CType, metaclass=GenericMeta): + pass + + generic = Generic.subtype() _cdata_type = None @@ -126,7 +133,7 @@ def __str__(self): ).value -class CDataType(CType[D]): +class CDataTypeMeta(CTypeMeta): """ Represents opaque C data to be passed around. The intent is to ease passing arbitrary data between ops C code. @@ -146,19 +153,17 @@ class CDataType(CType[D]): The version to use in Aesara cache system. """ - __props__ = ( - "ctype", - "freefunc", - "headers", - "header_dirs", - "libraries", - "lib_dirs", - "extra_support_code", - "compile_args", - "version", - ) - - def __init__( + ctype: Props[Any] = None + freefunc: Props[Any] = None + headers: Props[Any] = None + header_dirs: Props[Any] = None + libraries: Props[Any] = None + lib_dirs: Props[Any] = None + extra_support_code: Props[Any] = None + compile_args: Props[Any] = None + version: Props[Any] = None + + def type_parameters( self, ctype, freefunc=None, @@ -170,19 +175,22 @@ def __init__( extra_support_code="", version=None, ): + params = dict() assert isinstance(ctype, str) - self.ctype = ctype + params["ctype"] = ctype if freefunc is not None: assert isinstance(freefunc, str) - self.freefunc = freefunc - self.headers = tuple(headers) - self.header_dirs = tuple(header_dirs) - self.libraries = tuple(libraries) - self.lib_dirs = tuple(lib_dirs) - self.compile_args = tuple(compile_args) - self.extra_support_code = extra_support_code - self._fn = None - self.version = version + params["freefunc"] = freefunc + params["headers"] = tuple(headers) + params["header_dirs"] = tuple(header_dirs) + params["libraries"] = tuple(libraries) + params["lib_dirs"] = tuple(lib_dirs) + params["compile_args"] = tuple(compile_args) + params["extra_support_code"] = extra_support_code + params["_fn"] = None + params["version"] = version + + return params def filter(self, data, strict=False, allow_downcast=None): # We ignore this type-check (_cdata_type is None) in PyPy @@ -292,6 +300,10 @@ def __setstate__(self, dct): self.version = None +class CDataType(CType, metaclass=CDataTypeMeta): + pass + + class CDataTypeConstant(Constant[T]): def merge_signature(self): # We don't want to merge constants that don't point to the @@ -329,7 +341,7 @@ def __ne__(self, other): return not self == other -class EnumType(Mapping, CType): +class EnumTypeMeta(CTypeMeta): """ Main subclasses: - :class:`EnumList` @@ -426,7 +438,10 @@ class EnumType(Mapping, CType): """ - __props__ = ("constants", "aliases", "ctype", "cname") + constants: Props[FrozenMap] = FrozenMap() + aliases: Props[FrozenMap] = FrozenMap() + ctype: Props[Any] = None + cname: Props[Any] = None @classmethod def __init_ctype(cls, ctype): @@ -516,21 +531,21 @@ def get_aliases(self): """ return tuple(sorted(self.aliases.keys())) - def __repr__(self): - names_to_aliases = {constant_name: "" for constant_name in self} - for alias in self.aliases: - names_to_aliases[self.aliases[alias]] = f"({alias})" - return "{}<{}>({})".format( - type(self).__name__, - self.ctype, - ", ".join( - "{}{}:{}".format(k, names_to_aliases[k], self[k]) - for k in sorted(self.keys()) - ), - ) + # def __repr__(self): + # names_to_aliases = {constant_name: "" for constant_name in self} + # for alias in self.aliases: + # names_to_aliases[self.aliases[alias]] = f"({alias})" + # return "{}<{}>({})".format( + # type(self).__name__, + # self.ctype, + # ", ".join( + # "{}{}:{}".format(k, names_to_aliases[k], self[k]) + # for k in sorted(self.keys()) + # ), + # ) def __getattr__(self, key): - if key in self.constants: + if key != "constants" and key in self.constants: return self[key] else: raise AttributeError( @@ -538,10 +553,10 @@ def __getattr__(self, key): ) def __setattr__(self, key, value): - if key in self.__props__: - CType.__setattr__(self, key, value) - else: + if hasattr(self, "constants") and key in self.constants: raise TypeError("constant values are immutable.") + else: + super().__setattr__(key, value) def __iter__(self): return self.constants.__iter__() @@ -558,28 +573,34 @@ def __setitem__(self, key, value): def __delitem__(self, key): raise TypeError("constant values are immutable.") - def __hash__(self): - # All values are Python basic types, then easy to hash. - return hash( - (type(self), self.ctype) - + tuple((k, self[k]) for k in sorted(self.keys())) - + tuple((a, self.aliases[a]) for a in sorted(self.aliases.keys())) - ) + # Copied from abc.collections.Mapping without `__eq__` since mixin would make this class + # unhashable + def get(self, key, default=None): + "D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None." + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + self[key] + except KeyError: + return False + else: + return True - def __eq__(self, other): - return ( - type(self) == type(other) - and self.ctype == other.ctype - and len(self) == len(other) - and len(self.aliases) == len(other.aliases) - and all(k in other for k in self) - and all(a in other.aliases for a in self.aliases) - and all(self[k] == other[k] for k in self) - and all(self.aliases[a] == other.aliases[a] for a in self.aliases) - ) + def keys(self): + "D.keys() -> a set-like object providing a view on D's keys" + return KeysView(self) - def __ne__(self, other): - return not self == other + def items(self): + "D.items() -> a set-like object providing a view on D's items" + return ItemsView(self) + + def values(self): + "D.values() -> an object providing a view on D's values" + return ValuesView(self) # EnumType should be used to create constants available in both Python and C code. # However, for convenience, we make sure EnumType can have a value, like other common types, @@ -703,7 +724,11 @@ def c_sync(self, name, sub): raise NotImplementedError("Variables of this type cannot be graph outputs") -class EnumList(EnumType): +class EnumType(CType, metaclass=EnumTypeMeta): + pass + + +class EnumListMeta(EnumTypeMeta): """ **Inherit from**: - :class:`EnumType` @@ -776,7 +801,14 @@ def type_parameters(cls, *args, **kwargs): return super().type_parameters(**kwargs) -class CEnumType(EnumList): +class EnumList(EnumType, metaclass=EnumListMeta): + pass + + +print({EnumList: 1}) + + +class CEnumTypeMeta(EnumListMeta): """ **Inherit from**: - :class:`EnumList` @@ -836,3 +868,7 @@ def c_extract(self, name, sub, check_input=True, **kwargs): def c_code_cache_version(self): return (1, super().c_code_cache_version()) + + +class CEnumType(EnumList, metaclass=CEnumTypeMeta): + pass diff --git a/aesara/link/jax/linker.py b/aesara/link/jax/linker.py index 49ef83b293..615a20e166 100644 --- a/aesara/link/jax/linker.py +++ b/aesara/link/jax/linker.py @@ -2,6 +2,7 @@ from numpy.random import Generator, RandomState +from aesara import issubtype from aesara.compile.sharedvalue import SharedVariable, shared from aesara.graph.basic import Constant from aesara.link.basic import JITLinker @@ -17,7 +18,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): shared_rng_inputs = [ inp for inp in fgraph.inputs - if (isinstance(inp, SharedVariable) and isinstance(inp.type, RandomType)) + if (isinstance(inp, SharedVariable) and issubtype(inp.type, RandomType)) ] # Replace any shared RNG inputs so that their values can be updated in place diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index 3c579637b5..b82278e923 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -21,6 +21,7 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.type import Type from aesara.ifelse import IfElse +from aesara.issubtype import issubtype from aesara.link.utils import ( compile_function_src, fgraph_to_python, @@ -79,7 +80,7 @@ def get_numba_type( Return Numba scalars for zero dimensional :class:`TensorType`\s. """ - if isinstance(aesara_type, TensorType): + if issubtype(aesara_type, TensorType): dtype = aesara_type.numpy_dtype numba_dtype = numba.from_dtype(dtype) if force_scalar or ( @@ -87,7 +88,7 @@ def get_numba_type( ): return numba_dtype return numba.types.Array(numba_dtype, aesara_type.ndim, layout) - elif isinstance(aesara_type, ScalarType): + elif issubtype(aesara_type, ScalarType): dtype = np.dtype(aesara_type.dtype) numba_dtype = numba.from_dtype(dtype) return numba_dtype @@ -391,7 +392,7 @@ def create_index_func(node, objmode=False): """Create a Python function that assembles and uses an index on an array.""" def convert_indices(indices, entry): - if indices and isinstance(entry, Type): + if indices and issubtype(entry, Type): rval = indices.pop(0) return rval.auto_name elif isinstance(entry, slice): diff --git a/aesara/link/numba/linker.py b/aesara/link/numba/linker.py index bb390b0523..eef671aa13 100644 --- a/aesara/link/numba/linker.py +++ b/aesara/link/numba/linker.py @@ -3,6 +3,7 @@ import numpy as np import aesara +from aesara import issubtype from aesara.link.basic import JITLinker @@ -14,7 +15,7 @@ class NumbaLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" def output_filter(self, var: "Variable", out: Any) -> Any: - if not isinstance(var, np.ndarray) and isinstance( + if not isinstance(var, np.ndarray) and issubtype( var.type, aesara.tensor.TensorType ): return np.asarray(out, dtype=var.type.dtype) diff --git a/aesara/raise_op.py b/aesara/raise_op.py index 2d851aaa65..f062c7c274 100644 --- a/aesara/raise_op.py +++ b/aesara/raise_op.py @@ -7,19 +7,20 @@ from aesara.gradient import DisconnectedType from aesara.graph.basic import Apply, Variable +from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType -from aesara.link.c.type import Generic +from aesara.link.c.type import Generic, GenericMeta from aesara.scalar.basic import ScalarType from aesara.tensor.type import DenseTensorType -class ExceptionType(Generic): - def __eq__(self, other): - return type(self) == type(other) +class ExceptionTypeMeta(GenericMeta): + pass - def __hash__(self): - return hash(type(self)) + +class ExceptionType(Generic, metaclass=ExceptionTypeMeta): + pass exception_type = ExceptionType.subtype() @@ -106,7 +107,7 @@ def connection_pattern(self, node): return [[1]] + [[0]] * (len(node.inputs) - 1) def c_code(self, node, name, inames, onames, props): - if not isinstance(node.inputs[0].type, (DenseTensorType, ScalarType)): + if not issubtype(node.inputs[0].type, (DenseTensorType, ScalarType)): raise NotImplementedError( f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}" ) @@ -118,7 +119,7 @@ def c_code(self, node, name, inames, onames, props): msg = self.msg.replace('"', '\\"').replace("\n", "\\n") for idx, cond_name in enumerate(cond_names): - if isinstance(node.inputs[0].type, DenseTensorType): + if issubtype(node.inputs[0].type, DenseTensorType): check.append( f""" if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{ @@ -145,7 +146,7 @@ def c_code(self, node, name, inames, onames, props): check = "\n".join(check) - if isinstance(node.inputs[0].type, DenseTensorType): + if issubtype(node.inputs[0].type, DenseTensorType): res = f""" {check} Py_XDECREF({out_name}); diff --git a/aesara/sandbox/rng_mrg.py b/aesara/sandbox/rng_mrg.py index 6bafe4c0e8..5bf26d746a 100644 --- a/aesara/sandbox/rng_mrg.py +++ b/aesara/sandbox/rng_mrg.py @@ -17,7 +17,7 @@ import numpy as np -from aesara import function, gradient +from aesara import function, gradient, issubtype from aesara import scalar as aes from aesara import shared from aesara import tensor as at @@ -536,7 +536,7 @@ def c_support_code(self, **kwargs): def c_code(self, node, name, inp, out, sub): # If we try to use the C code here with something else than a # TensorType, something is wrong. - assert isinstance(node.inputs[0].type, TensorType) + assert issubtype(node.inputs[0].type, TensorType) if self.output_type.dtype == "float16": # C code is not tested, fall back to Python raise NotImplementedError() diff --git a/aesara/scalar/basic.py b/aesara/scalar/basic.py index 5cb9ac20ef..66aaf38e28 100644 --- a/aesara/scalar/basic.py +++ b/aesara/scalar/basic.py @@ -28,10 +28,11 @@ from aesara.graph.basic import Apply, Constant, Variable, clone, list_of_nodes from aesara.graph.fg import FunctionGraph from aesara.graph.rewriting.basic import MergeOptimizer -from aesara.graph.type import DataType +from aesara.graph.type import DataType, Props from aesara.graph.utils import MetaObject, MethodNotDefined +from aesara.issubtype import issubtype from aesara.link.c.op import COp -from aesara.link.c.type import CType +from aesara.link.c.type import CType, CTypeMeta from aesara.misc.safe_asarray import _asarray from aesara.printing import pprint from aesara.utils import ( @@ -268,7 +269,7 @@ def convert(x, dtype=None): return x_ -class ScalarType(CType): +class ScalarTypeMeta(CTypeMeta): """ Internal class, should not be used by clients. @@ -281,7 +282,7 @@ class ScalarType(CType): """ - __props__ = ("dtype",) + dtype: Props[Any] = None ndim = 0 shape = () dtype: DataType @@ -298,7 +299,7 @@ def type_parameters(cls, dtype): def clone(self, dtype=None, **kwargs): if dtype is None: dtype = self.dtype - return type(self).subtype(dtype) + return self.subtype(dtype) @staticmethod def may_share_memory(a, b): @@ -413,7 +414,11 @@ def dtype_specs(self): ) def upcast(self, *others): - return upcast(*[x.dtype for x in [self] + list(others)]) + types = list(others) + if self.dtype is not None: + # this is None, if this method has been called for ScalarType + types = [self] + types + return upcast(*[x.dtype for x in types]) def make_variable(self, name=None): return ScalarVariable(self, None, name=name) @@ -671,6 +676,10 @@ def get_size(self, shape_info): return shape_info +class ScalarType(CType, metaclass=ScalarTypeMeta): + pass + + def get_scalar_type(dtype, cache: Dict[str, ScalarType] = {}) -> ScalarType: """ Return a ScalarType(dtype) object. @@ -875,7 +884,7 @@ def as_scalar(x: Any, name: Optional[str] = None) -> ScalarVariable: if isinstance(x, Variable): if isinstance(x, ScalarVariable): return x - elif isinstance(x.type, TensorType) and x.type.ndim == 0: + elif issubtype(x.type, TensorType) and x.type.ndim == 0: return scalar_from_tensor(x) else: raise TypeError(f"Cannot convert {x} to a scalar type") @@ -890,8 +899,8 @@ def as_scalar(x: Any, name: Optional[str] = None) -> ScalarVariable: complexs128 = apply_across_args(complex128) -def upcast_out(*types): - dtype = ScalarType.upcast(*types) +def upcast_out(typ, *types): + dtype = typ.upcast(*types) return (get_scalar_type(dtype),) @@ -1116,7 +1125,7 @@ def output_types(self, types): if hasattr(self, "output_types_preference"): variables = self.output_types_preference(*types) if not isinstance(variables, (list, tuple)) or any( - not isinstance(x, CType) for x in variables + not issubtype(x, CType) for x in variables ): raise TypeError( "output_types_preference should return a list or a tuple of types", @@ -2441,7 +2450,7 @@ def grad(self, inputs, gout): # CASTING OPERATIONS class Cast(UnaryScalarOp): def __init__(self, o_type, name=None): - if not isinstance(o_type, ScalarType): + if not issubtype(o_type, ScalarType): raise TypeError(o_type) super().__init__(specific_out(o_type), name=name) self.o_type = o_type diff --git a/aesara/scan/basic.py b/aesara/scan/basic.py index 81c42cdc1f..ba87155693 100644 --- a/aesara/scan/basic.py +++ b/aesara/scan/basic.py @@ -9,6 +9,7 @@ from aesara.graph.basic import Constant, Variable, clone_replace, graph_inputs from aesara.graph.op import get_test_value from aesara.graph.utils import MissingInputError, TestValueError +from aesara.issubtype import issubtype from aesara.scan.op import Scan, ScanInfo from aesara.scan.utils import expand_empty, safe_new, until from aesara.tensor.basic import get_scalar_constant_value @@ -880,7 +881,7 @@ def wrap_into_list(x): # then, if we return the output as given by the innner function # this will represent only a slice and it will have one # dimension less. - if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1: + if issubtype(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1: outputs[pos] = unbroadcast(shape_padleft(inner_out), 0) if not return_list and len(outputs) == 1: @@ -1006,7 +1007,7 @@ def wrap_into_list(x): inner_replacements[input.variable] = new_var - if isinstance(new_var.type, TensorType): + if issubtype(new_var.type, TensorType): sit_sot_inner_inputs.append(new_var) sit_sot_scan_inputs.append( expand_empty( diff --git a/aesara/scan/op.py b/aesara/scan/op.py index d13435ca44..21166fe1af 100644 --- a/aesara/scan/op.py +++ b/aesara/scan/op.py @@ -78,6 +78,7 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import HasInnerGraph, Op from aesara.graph.utils import InconsistencyError, MissingInputError +from aesara.issubtype import issubtype from aesara.link.c.basic import CLinker from aesara.link.c.exceptions import MissingGXX from aesara.link.utils import raise_with_op @@ -1234,7 +1235,7 @@ def make_node(self, *inputs): # strange NumPy behavior: vector_ndarray[int] return a NumPy # scalar and not a NumPy ndarray of 0 dimensions. def is_cpu_vector(s): - return isinstance(s.type, TensorType) and s.ndim == 1 + return issubtype(s.type, TensorType) and s.ndim == 1 self.vector_seqs = [ is_cpu_vector(seq) for seq in new_inputs[1 : 1 + self.info.n_seqs] @@ -1246,7 +1247,7 @@ def is_cpu_vector(s): ] ] self.vector_outs += [ - isinstance(t.type, TensorType) and t.ndim == 0 + issubtype(t.type, TensorType) and t.ndim == 0 for t in self.outer_nitsot_outs(self.inner_outputs) ] @@ -2574,7 +2575,7 @@ def compute_all_gradients(known_grads): info.n_seqs + pos ] - if not isinstance(dC_douts[outer_oidx].type, DisconnectedType): + if not issubtype(dC_douts[outer_oidx].type, DisconnectedType): dtypes.append(dC_douts[outer_oidx].dtype) if dtypes: new_dtype = aesara.scalar.upcast(*dtypes) @@ -2582,7 +2583,7 @@ def compute_all_gradients(known_grads): new_dtype = config.floatX dC_dXt = safe_new(Xt, dtype=new_dtype) else: - if isinstance(dC_douts[idx].type, DisconnectedType): + if issubtype(dC_douts[idx].type, DisconnectedType): continue dC_dXt = safe_new(dC_douts[idx][0]) dC_dXts.append(dC_dXt) @@ -2597,7 +2598,7 @@ def compute_all_gradients(known_grads): known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] dc_dxts_idx += 1 else: - if isinstance(dC_douts[i].type, DisconnectedType): + if issubtype(dC_douts[i].type, DisconnectedType): continue else: if diff_outputs[i] in known_grads: @@ -2645,10 +2646,10 @@ def compute_all_gradients(known_grads): dC_dXtm1s.append(safe_new(x)) for dx, dC_dXtm1 in enumerate(dC_dXtm1s): - if isinstance(dC_dinps_t[dx + info.n_seqs].type, NullType): + if issubtype(dC_dinps_t[dx + info.n_seqs].type, NullType): # The accumulated gradient is undefined pass - elif isinstance(dC_dXtm1.type, NullType): + elif issubtype(dC_dXtm1.type, NullType): # The new gradient is undefined, this makes the accumulated # gradient undefined as weell dC_dinps_t[dx + info.n_seqs] = dC_dXtm1 @@ -2677,7 +2678,7 @@ def compute_all_gradients(known_grads): outer_inp_seqs.append(nw_seq) outer_inp_seqs += [x[:-1][::-1] for x in self.outer_sitsot_outs(outs)] for x in self.outer_nitsot_outs(dC_douts): - if not isinstance(x.type, DisconnectedType): + if not issubtype(x.type, DisconnectedType): if info.as_while: # equivalent to x[:n_steps][::-1] outer_inp_seqs.append(x[n_steps - 1 :: -1]) @@ -2737,7 +2738,7 @@ def compute_all_gradients(known_grads): n_mitmot_inps = 0 for idx, taps in enumerate(info.mit_mot_in_slices): - if isinstance(dC_douts[idx].type, DisconnectedType): + if issubtype(dC_douts[idx].type, DisconnectedType): out = outs[idx] outer_inp_mitmot.append(at.zeros_like(out)) else: @@ -2762,7 +2763,7 @@ def compute_all_gradients(known_grads): if tap not in mitmot_inp_taps[idx]: inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs]) - if isinstance(dC_dinps_t[ins_pos].type, NullType): + if issubtype(dC_dinps_t[ins_pos].type, NullType): # We cannot use Null in the inner graph, so we # use a zero tensor of the appropriate shape instead. inner_out_mitmot.append( @@ -2814,7 +2815,7 @@ def compute_all_gradients(known_grads): offset = info.n_mit_mot for idx, taps in enumerate(info.mit_sot_in_slices): - if isinstance(dC_douts[idx + offset].type, DisconnectedType): + if issubtype(dC_douts[idx + offset].type, DisconnectedType): outer_inp_mitmot.append(outs[idx + offset].zeros_like()) else: outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) @@ -2831,7 +2832,7 @@ def compute_all_gradients(known_grads): tap = -tap inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs]) - if isinstance(dC_dinps_t[ins_pos].type, NullType): + if issubtype(dC_dinps_t[ins_pos].type, NullType): # We cannot use Null in the inner graph, so we # use a zero tensor of the appropriate shape instead. inner_out_mitmot.append( @@ -2869,10 +2870,10 @@ def compute_all_gradients(known_grads): mitmot_inp_taps.append([0, 1]) mitmot_out_taps.append([1]) through_shared = False - if not isinstance(dC_douts[idx + offset].type, DisconnectedType): + if not issubtype(dC_douts[idx + offset].type, DisconnectedType): outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) else: - if isinstance(dC_dinps_t[ins_pos].type, NullType): + if issubtype(dC_dinps_t[ins_pos].type, NullType): # Cannot use dC_dinps_t[ins_pos].dtype, so we use # floatX instead, as it is a dummy value that will not # be used anyway. @@ -2886,7 +2887,7 @@ def compute_all_gradients(known_grads): ) ) - if isinstance(dC_dinps_t[ins_pos].type, NullType): + if issubtype(dC_dinps_t[ins_pos].type, NullType): # We cannot use Null in the inner graph, so we # use a zero tensor of the appropriate shape instead. inner_out_mitmot.append( @@ -2900,7 +2901,7 @@ def compute_all_gradients(known_grads): for _sh in self.inner_shared(self_inputs) ) - if isinstance(dC_dinps_t[ins_pos].type, NullType): + if issubtype(dC_dinps_t[ins_pos].type, NullType): type_outs.append(dC_dinps_t[ins_pos].type.why_null) elif through_shared: type_outs.append("through_shared") @@ -2926,7 +2927,7 @@ def compute_all_gradients(known_grads): for _sh in self.inner_shared(self_inputs): if _sh in graph_inputs([vl]): through_shared = True - if isinstance(vl.type, NullType): + if issubtype(vl.type, NullType): type_outs.append(vl.type.why_null) # Replace the inner output with a zero tensor of # the right shape @@ -2945,7 +2946,7 @@ def compute_all_gradients(known_grads): for _sh in self.inner_shared(self_inputs): if _sh in graph_inputs([vl]): through_shared = True - if isinstance(vl.type, NullType): + if issubtype(vl.type, NullType): type_outs.append(vl.type.why_null) # Replace the inner output with a zero tensor of # the right shape @@ -2964,7 +2965,7 @@ def compute_all_gradients(known_grads): outer_inp_sitsot = [] for _idx, y in enumerate(inner_inp_sitsot): x = self.outer_non_seqs(inputs)[_idx] - if isinstance(y.type, NullType): + if issubtype(y.type, NullType): # Cannot use dC_dXtm1s.dtype, so we use floatX instead. outer_inp_sitsot.append( at.zeros( @@ -3105,7 +3106,7 @@ def compute_all_gradients(known_grads): disconnected = True connected_flags = self.connection_pattern(node)[idx + start] for dC_dout, connected in zip(dC_douts, connected_flags): - if not isinstance(dC_dout.type, DisconnectedType) and connected: + if not issubtype(dC_dout.type, DisconnectedType) and connected: disconnected = False if disconnected: gradients.append(DisconnectedType.subtype()()) @@ -3148,7 +3149,7 @@ def compute_all_gradients(known_grads): for idx in range(len(gradients)): disconnected = True for kdx in range(len(node.outputs)): - if connection_pattern[idx][kdx] and not isinstance( + if connection_pattern[idx][kdx] and not issubtype( dC_douts[kdx].type, DisconnectedType ): disconnected = False diff --git a/aesara/sparse/basic.py b/aesara/sparse/basic.py index 8a872efa1e..cc4407f1e2 100644 --- a/aesara/sparse/basic.py +++ b/aesara/sparse/basic.py @@ -19,6 +19,7 @@ from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.op import Op +from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.type import generic from aesara.misc.safe_asarray import _asarray @@ -85,7 +86,7 @@ def _is_sparse_variable(x): "or TensorType, for instance), not ", x, ) - return isinstance(x.type, SparseTensorType) + return issubtype(x.type, SparseTensorType) def _is_dense_variable(x): @@ -105,7 +106,7 @@ def _is_dense_variable(x): "TensorType, for instance), not ", x, ) - return isinstance(x.type, TensorType) + return issubtype(x.type, TensorType) def _is_dense(x): @@ -160,7 +161,7 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs): else: x = x.outputs[0] if isinstance(x, Variable): - if not isinstance(x.type, SparseTensorType): + if not issubtype(x.type, SparseTensorType): raise TypeError( "Variable type field must be a SparseTensorType.", x, x.type ) @@ -264,7 +265,7 @@ def to_dense(self, *args, **kwargs): self = self.toarray() new_args = [ arg.toarray() - if hasattr(arg, "type") and isinstance(arg.type, SparseTensorType) + if hasattr(arg, "type") and issubtype(arg.type, SparseTensorType) else arg for arg in args ] @@ -618,7 +619,7 @@ def grad(self, inputs, g): # g[1:] is connected, or this grad method wouldn't have been # called, so we should report zeros (csm,) = inputs - if isinstance(g[0].type, DisconnectedType): + if issubtype(g[0].type, DisconnectedType): return [csm.zeros_like()] data, indices, indptr, _shape = csm_properties(csm) @@ -980,7 +981,7 @@ def __str__(self): return f"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}" def __call__(self, x): - if not isinstance(x.type, SparseTensorType): + if not issubtype(x.type, SparseTensorType): return x return super().__call__(x) @@ -1052,7 +1053,7 @@ def __str__(self): return f"{self.__class__.__name__}{{{self.format}}}" def __call__(self, x): - if isinstance(x.type, SparseTensorType): + if issubtype(x.type, SparseTensorType): return x return super().__call__(x) @@ -3498,7 +3499,7 @@ def perform(self, node, inputs, outputs): ) variable = a * b - if isinstance(node.outputs[0].type, SparseTensorType): + if issubtype(node.outputs[0].type, SparseTensorType): assert _is_sparse(variable) out[0] = variable return diff --git a/aesara/sparse/type.py b/aesara/sparse/type.py index 584493da1b..9e03a14e88 100644 --- a/aesara/sparse/type.py +++ b/aesara/sparse/type.py @@ -7,7 +7,9 @@ import aesara from aesara import scalar as aes from aesara.graph.basic import Variable -from aesara.tensor.type import DenseTensorType, TensorType +from aesara.graph.type import Props +from aesara.issubtype import issubtype +from aesara.tensor.type import DenseTensorType, TensorType, TensorTypeMeta SparsityTypes = Literal["csr", "csc", "bsr"] @@ -32,7 +34,7 @@ def _is_sparse(x): return isinstance(x, scipy.sparse.spmatrix) -class SparseTensorType(TensorType): +class SparseTensorTypeMeta(TensorTypeMeta): """A `Type` for sparse tensors. Notes @@ -41,7 +43,8 @@ class SparseTensorType(TensorType): """ - __props__ = ("format", "dtype", "shape") + format: Props[SparsityTypes] = None + format_cls = { "csr": scipy.sparse.csr_matrix, "csc": scipy.sparse.csc_matrix, @@ -99,7 +102,7 @@ def clone( dtype = self.dtype if shape is None: shape = self.shape - return type(self).subtype(format, dtype, shape=shape, **kwargs) + return self.subtype(format, dtype, shape=shape, **kwargs) def filter(self, value, strict=False, allow_downcast=None): if isinstance(value, Variable): @@ -162,7 +165,7 @@ def convert_variable(self, var): return res if not isinstance(res.type, type(self)): - if isinstance(res.type, DenseTensorType): + if issubtype(res.type, DenseTensorType): if self.format == "csr": from aesara.sparse.basic import csr_from_dense @@ -180,9 +183,6 @@ def convert_variable(self, var): return res - def __hash__(self): - return super().__hash__() ^ hash(self.format) - def __repr__(self): return f"Sparse({self.dtype}, {self.shape}, {self.format})" @@ -235,14 +235,6 @@ def value_zeros(self, shape): return matrix_constructor(shape, dtype=self.dtype) - def __eq__(self, other): - res = super().__eq__(other) - - if isinstance(res, bool): - return res and other.format == self.format - - return res - def is_super(self, otype): if not super().is_super(otype): return False @@ -253,6 +245,10 @@ def is_super(self, otype): return False +class SparseTensorType(TensorType, metaclass=SparseTensorTypeMeta): + pass + + aesara.compile.register_view_op_c_code( SparseTensorType, """ diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index 50c3c5b104..d9df5f79a9 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -29,6 +29,7 @@ from aesara.graph.op import Op from aesara.graph.rewriting.utils import rewrite_graph from aesara.graph.type import Type +from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.misc.safe_asarray import _asarray @@ -96,7 +97,7 @@ def _as_tensor_Scalar(x, name, ndim, **kwargs): @_as_tensor_variable.register(Variable) def _as_tensor_Variable(x, name, ndim, **kwargs): - if not isinstance(x.type, TensorType): + if not issubtype(x.type, TensorType): raise TypeError( f"Tensor type field must be a TensorType; found {type(x.type)}." ) @@ -315,7 +316,7 @@ def get_scalar_constant_value( from aesara.sparse.type import SparseTensorType - if isinstance(v.type, SparseTensorType): + if issubtype(v.type, SparseTensorType): raise NotScalarConstantError() return data @@ -434,7 +435,7 @@ def get_scalar_constant_value( var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:] ): idx = v.owner.op.idx_list[0] - if isinstance(idx, Type): + if issubtype(idx, Type): idx = get_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) @@ -469,7 +470,7 @@ def get_scalar_constant_value( ): idx = v.owner.op.idx_list[0] - if isinstance(idx, Type): + if issubtype(idx, Type): idx = get_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) @@ -491,7 +492,7 @@ def get_scalar_constant_value( op = owner.op idx_list = op.idx_list idx = idx_list[0] - if isinstance(idx, Type): + if issubtype(idx, Type): idx = get_scalar_constant_value( owner.inputs[1], max_recur=max_recur ) @@ -536,7 +537,7 @@ class TensorFromScalar(COp): __props__ = () def make_node(self, s): - if not isinstance(s.type, aes.ScalarType): + if not issubtype(s.type, aes.ScalarType): raise TypeError("Input must be a `ScalarType` `Type`") return Apply(self, [s], [tensor(dtype=s.type.dtype, shape=())]) @@ -596,7 +597,7 @@ def __call__(self, *args, **kwargs) -> ScalarVariable: return type_cast(ScalarVariable, super().__call__(*args, **kwargs)) def make_node(self, t): - if not isinstance(t.type, TensorType) or t.type.ndim > 0: + if not issubtype(t.type, TensorType) or t.type.ndim > 0: raise TypeError("Input must be a scalar `TensorType`") return Apply( @@ -1951,7 +1952,7 @@ def grad(self, inputs, g_outputs): x, axis, n = inputs outputs = self(*inputs, return_list=True) # If all the output gradients are disconnected, then so are the inputs - if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs): + if builtins.all(issubtype(g.type, DisconnectedType) for g in g_outputs): return [ DisconnectedType.subtype()(), grad_undefined(self, 1, axis), @@ -1960,7 +1961,7 @@ def grad(self, inputs, g_outputs): # Else, we have to make them zeros before joining them new_g_outputs = [] for o, g in zip(outputs, g_outputs): - if isinstance(g.type, DisconnectedType): + if issubtype(g.type, DisconnectedType): new_g_outputs.append(o.zeros_like()) else: new_g_outputs.append(g) @@ -2611,7 +2612,7 @@ def stack(*tensors, **kwargs): if all( # In case there are explicit ints in tensors isinstance(t, (np.number, float, int, builtins.complex)) - or (isinstance(t, Variable) and isinstance(t.type, TensorType) and t.ndim == 0) + or (isinstance(t, Variable) and issubtype(t.type, TensorType) and t.ndim == 0) for t in tensors ): # in case there is direct int @@ -3377,7 +3378,7 @@ def make_node(self, x): return Apply( self, [x], - [x.type.__class__.subtype(dtype=x.dtype, shape=[False] * (x.ndim - 1))()], + [x.type.subtype(dtype=x.dtype, shape=[False] * (x.ndim - 1))()], ) def perform(self, node, inputs, outputs): diff --git a/aesara/tensor/blas.py b/aesara/tensor/blas.py index ffbf236496..3717ddb5b2 100644 --- a/aesara/tensor/blas.py +++ b/aesara/tensor/blas.py @@ -131,6 +131,8 @@ import numpy as np +from aesara.issubtype import issubtype + try: import numpy.__config__ # noqa @@ -270,7 +272,7 @@ def make_node(self, y, alpha, A, x, beta): inputs = [y, alpha, A, x, beta] - if any(not isinstance(i.type, DenseTensorType) for i in inputs): + if any(not issubtype(i.type, DenseTensorType) for i in inputs): raise NotImplementedError("Only dense tensor types are supported") return Apply(self, inputs, [y.type()]) @@ -372,7 +374,7 @@ def make_node(self, A, alpha, x, y): raise TypeError("only float and complex types supported", x.dtype) inputs = [A, alpha, x, y] - if any(not isinstance(i.type, DenseTensorType) for i in inputs): + if any(not issubtype(i.type, DenseTensorType) for i in inputs): raise NotImplementedError("Only dense tensor types are supported") return Apply(self, inputs, [A.type()]) @@ -933,7 +935,7 @@ def __getstate__(self): def make_node(self, *inputs): inputs = list(map(at.as_tensor_variable, inputs)) - if any(not isinstance(i.type, DenseTensorType) for i in inputs): + if any(not issubtype(i.type, DenseTensorType) for i in inputs): raise NotImplementedError("Only dense tensor types are supported") if len(inputs) != 5: @@ -1671,7 +1673,7 @@ def make_node(self, x, y): x = at.as_tensor_variable(x) y = at.as_tensor_variable(y) - if any(not isinstance(i.type, DenseTensorType) for i in (x, y)): + if any(not issubtype(i.type, DenseTensorType) for i in (x, y)): raise NotImplementedError("Only dense tensor types are supported") dtypes = ("float16", "float32", "float64", "complex64", "complex128") @@ -1760,7 +1762,7 @@ def local_dot_to_dot22(fgraph, node): if not isinstance(node.op, Dot): return - if any(not isinstance(i.type, DenseTensorType) for i in node.inputs): + if any(not issubtype(i.type, DenseTensorType) for i in node.inputs): return False x, y = node.inputs @@ -1968,7 +1970,7 @@ class Dot22Scalar(GemmRelated): def make_node(self, x, y, a): - if any(not isinstance(i.type, DenseTensorType) for i in (x, y, a)): + if any(not issubtype(i.type, DenseTensorType) for i in (x, y, a)): raise NotImplementedError("Only dense tensor types are supported") if a.ndim != 0: @@ -2192,7 +2194,7 @@ class BatchedDot(COp): def make_node(self, *inputs): inputs = list(map(at.as_tensor_variable, inputs)) - if any(not isinstance(i.type, DenseTensorType) for i in inputs): + if any(not issubtype(i.type, DenseTensorType) for i in inputs): raise NotImplementedError("Only dense tensor types are supported") if len(inputs) != 2: diff --git a/aesara/tensor/elemwise.py b/aesara/tensor/elemwise.py index ea89561b53..910681bbea 100644 --- a/aesara/tensor/elemwise.py +++ b/aesara/tensor/elemwise.py @@ -9,6 +9,7 @@ from aesara.graph.basic import Apply from aesara.graph.null_type import NullType from aesara.graph.utils import MethodNotDefined +from aesara.issubtype import issubtype from aesara.link.c.basic import failure_code from aesara.link.c.op import COp, ExternalCOp, OpenMPOp from aesara.link.c.params_type import ParamsType @@ -520,9 +521,7 @@ def R_op(self, inputs, eval_points): # the right thing to do .. have to talk to Ian and James # about it - if bgrads[jdx] is None or isinstance( - bgrads[jdx].type, DisconnectedType - ): + if bgrads[jdx] is None or issubtype(bgrads[jdx].type, DisconnectedType): pass elif eval_point is not None: if rop_out is None: @@ -558,7 +557,7 @@ def L_op(self, inputs, outs, ograds): # this op did the right thing. new_rval = [] for elem, ipt in zip(rval, inputs): - if isinstance(elem.type, (NullType, DisconnectedType)): + if issubtype(elem.type, (NullType, DisconnectedType)): new_rval.append(elem) else: elem = ipt.zeros_like() @@ -570,7 +569,7 @@ def L_op(self, inputs, outs, ograds): # sum out the broadcasted dimensions for i, ipt in enumerate(inputs): - if isinstance(rval[i].type, (NullType, DisconnectedType)): + if issubtype(rval[i].type, (NullType, DisconnectedType)): continue # List of all the dimensions that are broadcastable for input[i] so @@ -594,7 +593,7 @@ def _bgrad(self, inputs, outputs, ograds): with config.change_flags(compute_test_value="off"): def as_scalar(t): - if isinstance(t.type, (NullType, DisconnectedType)): + if issubtype(t.type, (NullType, DisconnectedType)): return t return get_scalar_type(t.type.dtype)() @@ -618,7 +617,7 @@ def as_scalar(t): def transform(r): # From a graph of ScalarOps, make a graph of Broadcast ops. - if isinstance(r.type, (NullType, DisconnectedType)): + if issubtype(r.type, (NullType, DisconnectedType)): return r if r in scalar_inputs: return inputs[scalar_inputs.index(r)] diff --git a/aesara/tensor/math.py b/aesara/tensor/math.py index f82250965e..20dde349ec 100644 --- a/aesara/tensor/math.py +++ b/aesara/tensor/math.py @@ -9,6 +9,7 @@ from aesara.gradient import DisconnectedType from aesara.graph.basic import Apply, Variable from aesara.graph.op import Op +from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.link.c.type import Generic @@ -302,8 +303,8 @@ def grad(self, inp, grads): axis = as_tensor_variable(self.axis) g_max, g_max_idx = grads - g_max_disconnected = isinstance(g_max.type, DisconnectedType) - g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType) + g_max_disconnected = issubtype(g_max.type, DisconnectedType) + g_max_idx_disconnected = issubtype(g_max_idx.type, DisconnectedType) # if the op is totally disconnected, so are its inputs if g_max_disconnected and g_max_idx_disconnected: @@ -2089,9 +2090,7 @@ def dense_dot(a, b): """ a, b = as_tensor_variable(a), as_tensor_variable(b) - if not isinstance(a.type, DenseTensorType) or not isinstance( - b.type, DenseTensorType - ): + if not issubtype(a.type, DenseTensorType) or not issubtype(b.type, DenseTensorType): raise TypeError("The dense dot product is only supported for dense types") if a.ndim == 0 or b.ndim == 0: diff --git a/aesara/tensor/nlinalg.py b/aesara/tensor/nlinalg.py index 7a7085cbcf..f4fbcbb809 100644 --- a/aesara/tensor/nlinalg.py +++ b/aesara/tensor/nlinalg.py @@ -7,6 +7,7 @@ from aesara.gradient import DisconnectedType from aesara.graph.basic import Apply from aesara.graph.op import Op +from aesara.issubtype import issubtype from aesara.tensor import basic as at from aesara.tensor import math as tm from aesara.tensor.basic import as_tensor_variable, extract_diag @@ -325,7 +326,7 @@ def grad(self, inputs, g_outputs): def _zero_disconnected(outputs, grads): l = [] for o, g in zip(outputs, grads): - if isinstance(g.type, DisconnectedType): + if issubtype(g.type, DisconnectedType): l.append(o.zeros_like()) else: l.append(g) diff --git a/aesara/tensor/nnet/basic.py b/aesara/tensor/nnet/basic.py index 096f57eae5..06fb2a1643 100644 --- a/aesara/tensor/nnet/basic.py +++ b/aesara/tensor/nnet/basic.py @@ -19,6 +19,7 @@ from aesara.graph.basic import Apply from aesara.graph.op import Op from aesara.graph.rewriting.basic import copy_stack_trace, graph_rewriter, node_rewriter +from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.raise_op import Assert from aesara.scalar import UnaryScalarOp @@ -126,7 +127,7 @@ def L_op(self, inp, outputs, grads): x, b = inp (g_sm,) = grads - if isinstance(g_sm.type, DisconnectedType): + if issubtype(g_sm.type, DisconnectedType): return [DisconnectedType.subtype()(), DisconnectedType.subtype()()] dx = softmax_grad_legacy(g_sm, outputs[0]) @@ -1421,19 +1422,19 @@ def grad(self, inp, grads): db_terms = [] d_idx_terms = [] - if not isinstance(g_nll.type, DisconnectedType): + if not issubtype(g_nll.type, DisconnectedType): nll, sm = crossentropy_softmax_1hot_with_bias(x, b, y_idx) dx = crossentropy_softmax_1hot_with_bias_dx(g_nll, sm, y_idx) db = at_sum(dx, axis=[0]) dx_terms.append(dx) db_terms.append(db) - if not isinstance(g_sm.type, DisconnectedType): + if not issubtype(g_sm.type, DisconnectedType): dx, db = softmax_with_bias.L_op((x, b), [softmax_with_bias(x, b)], (g_sm,)) dx_terms.append(dx) db_terms.append(db) - if not isinstance(g_am.type, DisconnectedType): + if not issubtype(g_am.type, DisconnectedType): dx_terms.append(x.zeros_like()) db_terms.append(b.zeros_like()) d_idx_terms.append(y_idx.zeros_like()) diff --git a/aesara/tensor/nnet/batchnorm.py b/aesara/tensor/nnet/batchnorm.py index 0693fc7dc6..f2a5b1e577 100644 --- a/aesara/tensor/nnet/batchnorm.py +++ b/aesara/tensor/nnet/batchnorm.py @@ -5,6 +5,7 @@ from aesara.graph.basic import Apply from aesara.graph.op import Op from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter +from aesara.issubtype import issubtype from aesara.scalar import Composite, add, as_common_dtype, mul, sub, true_div from aesara.tensor import basic as at from aesara.tensor.basic import as_tensor_variable @@ -690,7 +691,7 @@ def grad(self, inp, grads): g_wrt_x_mean = 0 g_wrt_x_invstd = 0 - if not isinstance(ddinputs.type, aesara.gradient.DisconnectedType): + if not issubtype(ddinputs.type, aesara.gradient.DisconnectedType): ccc = scale * (ddinputs - mean(ddinputs, axis=self.axes, keepdims=True)) ddd = (x_invstd**3) * ( ccc * mean(dy * x_diff, axis=self.axes, keepdims=True) @@ -721,7 +722,7 @@ def grad(self, inp, grads): keepdims=True, ) - if not isinstance(ddscale.type, aesara.gradient.DisconnectedType): + if not issubtype(ddscale.type, aesara.gradient.DisconnectedType): g_wrt_x = g_wrt_x + (x_invstd * ddscale * dy) g_wrt_dy = g_wrt_dy + (x_invstd * ddscale * x_diff) g_wrt_x_mean = g_wrt_x_mean - ( @@ -731,7 +732,7 @@ def grad(self, inp, grads): ddscale * at_sum(dy * x_diff, axis=self.axes, keepdims=True) ) - if not isinstance(ddbias.type, aesara.gradient.DisconnectedType): + if not issubtype(ddbias.type, aesara.gradient.DisconnectedType): g_wrt_dy = g_wrt_dy + at.fill(dy, ddbias) # depending on which output gradients are given, @@ -795,17 +796,17 @@ def local_abstract_batch_norm_train(fgraph, node): if min(axes) < 0 or max(axes) > x.ndim: return None if ( - not isinstance(x.type, TensorType) - or not isinstance(scale.type, TensorType) - or not isinstance(bias.type, TensorType) - or not isinstance(epsilon.type, TensorType) - or not isinstance(running_average_factor.type, TensorType) + not issubtype(x.type, TensorType) + or not issubtype(scale.type, TensorType) + or not issubtype(bias.type, TensorType) + or not issubtype(epsilon.type, TensorType) + or not issubtype(running_average_factor.type, TensorType) ): return None # optional running_mean and running_var - if len(node.inputs) > 5 and not isinstance(node.inputs[5].type, TensorType): + if len(node.inputs) > 5 and not issubtype(node.inputs[5].type, TensorType): return None - if len(node.inputs) > 6 and not isinstance(node.inputs[6].type, TensorType): + if len(node.inputs) > 6 and not issubtype(node.inputs[6].type, TensorType): return None mean = x.mean(axes, keepdims=True) @@ -849,12 +850,12 @@ def local_abstract_batch_norm_train_grad(fgraph, node): if min(axes) < 0 or max(axes) > x.ndim: return None if ( - not isinstance(x.type, TensorType) - or not isinstance(dy.type, TensorType) - or not isinstance(scale.type, TensorType) - or not isinstance(x_mean.type, TensorType) - or not isinstance(x_invstd.type, TensorType) - or not isinstance(epsilon.type, TensorType) + not issubtype(x.type, TensorType) + or not issubtype(dy.type, TensorType) + or not issubtype(scale.type, TensorType) + or not issubtype(x_mean.type, TensorType) + or not issubtype(x_invstd.type, TensorType) + or not issubtype(epsilon.type, TensorType) ): return None @@ -881,12 +882,12 @@ def local_abstract_batch_norm_inference(fgraph, node): x, scale, bias, estimated_mean, estimated_variance, epsilon = node.inputs if ( - not isinstance(x.type, TensorType) - or not isinstance(scale.type, TensorType) - or not isinstance(bias.type, TensorType) - or not isinstance(estimated_mean.type, TensorType) - or not isinstance(estimated_variance.type, TensorType) - or not isinstance(epsilon.type, TensorType) + not issubtype(x.type, TensorType) + or not issubtype(scale.type, TensorType) + or not issubtype(bias.type, TensorType) + or not issubtype(estimated_mean.type, TensorType) + or not issubtype(estimated_variance.type, TensorType) + or not issubtype(epsilon.type, TensorType) ): return None diff --git a/aesara/tensor/nnet/rewriting.py b/aesara/tensor/nnet/rewriting.py index 3a32e557c7..682c68ed4f 100644 --- a/aesara/tensor/nnet/rewriting.py +++ b/aesara/tensor/nnet/rewriting.py @@ -13,6 +13,7 @@ in2out, node_rewriter, ) +from aesara.issubtype import issubtype from aesara.tensor.nnet.abstract_conv import ( AbstractConv2d, AbstractConv2d_gradInputs, @@ -95,7 +96,7 @@ def local_abstractconv_gemm(fgraph, node): if not isinstance(node.op, AbstractConv2d): return None img, kern = node.inputs - if not isinstance(img.type, TensorType) or not isinstance(kern.type, TensorType): + if not issubtype(img.type, TensorType) or not issubtype(kern.type, TensorType): return None # need to flip the kernel if necessary @@ -123,7 +124,7 @@ def local_abstractconv3d_gemm(fgraph, node): if not isinstance(node.op, AbstractConv3d): return None img, kern = node.inputs - if not isinstance(img.type, TensorType) or not isinstance(kern.type, TensorType): + if not issubtype(img.type, TensorType) or not issubtype(kern.type, TensorType): return None # need to flip the kernel if necessary @@ -149,7 +150,7 @@ def local_abstractconv_gradweight_gemm(fgraph, node): if not isinstance(node.op, AbstractConv2d_gradWeights): return None img, topgrad, shape = node.inputs - if not isinstance(img.type, TensorType) or not isinstance(topgrad.type, TensorType): + if not issubtype(img.type, TensorType) or not issubtype(topgrad.type, TensorType): return None rval = CorrMM_gradWeights( @@ -179,7 +180,7 @@ def local_abstractconv3d_gradweight_gemm(fgraph, node): if not isinstance(node.op, AbstractConv3d_gradWeights): return None img, topgrad, shape = node.inputs - if not isinstance(img.type, TensorType) or not isinstance(topgrad.type, TensorType): + if not issubtype(img.type, TensorType) or not issubtype(topgrad.type, TensorType): return None rval = Corr3dMMGradWeights( @@ -207,9 +208,7 @@ def local_abstractconv_gradinputs_gemm(fgraph, node): if not isinstance(node.op, AbstractConv2d_gradInputs): return None kern, topgrad, shape = node.inputs - if not isinstance(kern.type, TensorType) or not isinstance( - topgrad.type, TensorType - ): + if not issubtype(kern.type, TensorType) or not issubtype(topgrad.type, TensorType): return None # need to flip the kernel if necessary @@ -237,9 +236,7 @@ def local_abstractconv3d_gradinputs_gemm(fgraph, node): if not isinstance(node.op, AbstractConv3d_gradInputs): return None kern, topgrad, shape = node.inputs - if not isinstance(kern.type, TensorType) or not isinstance( - topgrad.type, TensorType - ): + if not issubtype(kern.type, TensorType) or not issubtype(topgrad.type, TensorType): return None # need to flip the kernel if necessary @@ -263,7 +260,7 @@ def local_conv2d_cpu(fgraph, node): return None img, kern = node.inputs - if not isinstance(img.type, TensorType) or not isinstance(kern.type, TensorType): + if not issubtype(img.type, TensorType) or not issubtype(kern.type, TensorType): return None if node.op.border_mode not in ("full", "valid"): return None @@ -298,7 +295,7 @@ def local_conv2d_gradweight_cpu(fgraph, node): img, topgrad, shape = node.inputs - if not isinstance(img.type, TensorType) or not isinstance(topgrad.type, TensorType): + if not issubtype(img.type, TensorType) or not issubtype(topgrad.type, TensorType): return None if node.op.border_mode not in ("full", "valid"): return None @@ -407,9 +404,7 @@ def local_conv2d_gradinputs_cpu(fgraph, node): kern, topgrad, shape = node.inputs - if not isinstance(kern.type, TensorType) or not isinstance( - topgrad.type, TensorType - ): + if not issubtype(kern.type, TensorType) or not issubtype(topgrad.type, TensorType): return None if node.op.border_mode not in ("full", "valid"): return None diff --git a/aesara/tensor/random/basic.py b/aesara/tensor/random/basic.py index 6b79151451..11824ab940 100644 --- a/aesara/tensor/random/basic.py +++ b/aesara/tensor/random/basic.py @@ -5,6 +5,7 @@ import scipy.stats as stats import aesara +from aesara import issubtype from aesara.tensor.basic import as_tensor_variable from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params from aesara.tensor.random.type import RandomGeneratorType, RandomStateType @@ -1686,7 +1687,7 @@ def __call__(self, low, high=None, size=None, **kwargs): return super().__call__(low, high, size=size, **kwargs) def make_node(self, rng, *args, **kwargs): - if not isinstance( + if not issubtype( getattr(rng, "type", None), (RandomStateType, RandomStateSharedVariable) ): raise TypeError("`randint` is only available for `RandomStateType`s") @@ -1731,7 +1732,7 @@ def __call__(self, low, high=None, size=None, **kwargs): return super().__call__(low, high, size=size, **kwargs) def make_node(self, rng, *args, **kwargs): - if not isinstance( + if not issubtype( getattr(rng, "type", None), (RandomGeneratorType, RandomGeneratorSharedVariable), ): @@ -1761,7 +1762,7 @@ def _supp_shape_from_params(self, *args, **kwargs): def _infer_shape(self, size, dist_params, param_shapes=None): (a, p, _) = dist_params - if isinstance(p.type, aesara.tensor.type_other.NoneTypeT): + if issubtype(p.type, aesara.tensor.type_other.NoneTypeT): shape = super()._infer_shape(size, (a,), param_shapes) else: shape = super()._infer_shape(size, (a, p), param_shapes) diff --git a/aesara/tensor/random/op.py b/aesara/tensor/random/op.py index 8a6e36afde..a07c18e68c 100644 --- a/aesara/tensor/random/op.py +++ b/aesara/tensor/random/op.py @@ -7,6 +7,7 @@ from aesara.configdefaults import config from aesara.graph.basic import Apply, Variable from aesara.graph.op import Op +from aesara.issubtype import issubtype from aesara.misc.safe_asarray import _asarray from aesara.scalar import ScalarVariable from aesara.tensor.basic import ( @@ -316,7 +317,7 @@ def make_node(self, rng, size, dtype, *dist_params): if rng is None: rng = aesara.shared(np.random.default_rng()) - elif not isinstance(rng.type, RandomType): + elif not issubtype(rng.type, RandomType): raise TypeError( "The type of rng should be an instance of either RandomGeneratorType or RandomStateType" ) diff --git a/aesara/tensor/random/type.py b/aesara/tensor/random/type.py index 456606fc91..0edea88759 100644 --- a/aesara/tensor/random/type.py +++ b/aesara/tensor/random/type.py @@ -3,7 +3,7 @@ import numpy as np import aesara -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type T = TypeVar("T", np.random.RandomState, np.random.Generator) @@ -23,15 +23,19 @@ numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"} -class RandomType(Type[T]): +class RandomTypeMeta(NewTypeMeta): r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`.""" @staticmethod - def may_share_memory(a: T, b: T): + def may_share_memory(a, b): return a._bit_generator is b._bit_generator # type: ignore[attr-defined] -class RandomStateType(RandomType[np.random.RandomState]): +class RandomType(Type, metaclass=RandomTypeMeta): + pass + + +class RandomStateTypeMeta(RandomTypeMeta): r"""A Type wrapper for `numpy.random.RandomState`. The reason this exists (and `Generic` doesn't suffice) is that @@ -101,11 +105,9 @@ def _eq(sa, sb): return _eq(sa, sb) - def __eq__(self, other): - return type(self) == type(other) - def __hash__(self): - return hash(type(self)) +class RandomStateType(RandomType, metaclass=RandomStateTypeMeta): + pass # Register `RandomStateType`'s C code for `ViewOp`. @@ -122,7 +124,7 @@ def __hash__(self): random_state_type = RandomStateType.subtype() -class RandomGeneratorType(RandomType[np.random.Generator]): +class RandomGeneratorTypeMeta(RandomTypeMeta): r"""A Type wrapper for `numpy.random.Generator`. The reason this exists (and `Generic` doesn't suffice) is that @@ -197,11 +199,9 @@ def _eq(sa, sb): return _eq(sa, sb) - def __eq__(self, other): - return type(self) == type(other) - def __hash__(self): - return hash(type(self)) +class RandomGeneratorType(RandomType, metaclass=RandomGeneratorTypeMeta): + pass # Register `RandomGeneratorType`'s C code for `ViewOp`. diff --git a/aesara/tensor/rewriting/basic.py b/aesara/tensor/rewriting/basic.py index fa6066b72f..009735b7e0 100644 --- a/aesara/tensor/rewriting/basic.py +++ b/aesara/tensor/rewriting/basic.py @@ -18,6 +18,7 @@ node_rewriter, ) from aesara.graph.rewriting.db import RewriteDatabase +from aesara.issubtype import issubtype from aesara.raise_op import Assert, CheckAndRaise, assert_op from aesara.tensor.basic import ( Alloc, @@ -1152,7 +1153,7 @@ def constant_folding(fgraph, node): # TODO: `Type` itself should provide an interface for constructing # instances appropriate for a given constant. # TODO: Add handling for sparse types. - if isinstance(output.type, DenseTensorType): + if issubtype(output.type, DenseTensorType): output_type = TensorType.subtype( output.type.dtype, tuple(s == 1 for s in data.shape), diff --git a/aesara/tensor/rewriting/shape.py b/aesara/tensor/rewriting/shape.py index a3b30177f0..6f1df9d3f1 100644 --- a/aesara/tensor/rewriting/shape.py +++ b/aesara/tensor/rewriting/shape.py @@ -18,6 +18,7 @@ node_rewriter, ) from aesara.graph.utils import InconsistencyError, get_variable_trace_string +from aesara.issubtype import issubtype from aesara.tensor.basic import ( MakeVector, as_tensor_variable, @@ -1031,7 +1032,7 @@ def local_Shape_i_of_broadcastable(fgraph, node): shape_arg = node.inputs[0] - if not isinstance(shape_arg.type, TensorType): + if not issubtype(shape_arg.type, TensorType): return False if shape_arg.broadcastable[node.op.i]: diff --git a/aesara/tensor/rewriting/subtensor.py b/aesara/tensor/rewriting/subtensor.py index c25b77f8ee..63060b409d 100644 --- a/aesara/tensor/rewriting/subtensor.py +++ b/aesara/tensor/rewriting/subtensor.py @@ -13,6 +13,7 @@ in2out, node_rewriter, ) +from aesara.issubtype import issubtype from aesara.raise_op import Assert from aesara.tensor.basic import ( Alloc, @@ -165,10 +166,10 @@ def is_full_slice(x): or (isinstance(x, SliceConstant) and x.value == slice(None)) or ( not isinstance(x, SliceConstant) - and isinstance(getattr(x, "type", None), SliceType) + and issubtype(getattr(x, "type", None), SliceType) and x.owner is not None and all( - isinstance(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs + issubtype(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs ) ) ): @@ -560,7 +561,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): remove_dim = [] node_inputs_idx = 1 for dim, elem in enumerate(idx): - if isinstance(elem, (aes.ScalarType)): + if issubtype(elem, aes.ScalarType): # The idx is a ScalarType, ie a Type. This means the actual index # is contained in node.inputs[1] dim_index = node.inputs[node_inputs_idx] @@ -734,7 +735,7 @@ def local_subtensor_make_vector(fgraph, node): if isinstance(node.op, Subtensor): (idx,) = node.op.idx_list - if isinstance(idx, (aes.ScalarType, TensorType)): + if issubtype(idx, (aes.ScalarType, TensorType)): old_idx, idx = idx, node.inputs[1] assert idx.type.is_super(old_idx) elif isinstance(node.op, AdvancedSubtensor1): @@ -1602,7 +1603,7 @@ def local_subtensor_shape_constant(fgraph, node): assert idx_val != np.newaxis - if not isinstance(shape_arg.type, TensorType): + if not issubtype(shape_arg.type, TensorType): return False shape_parts = shape_arg.type.broadcastable[idx_val] @@ -1636,7 +1637,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node): indices = get_idx_list(node.inputs, node.op.idx_list) if any( - isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType) + isinstance(index, slice) or issubtype(getattr(index, "type", None), SliceType) for index in indices ): return False diff --git a/aesara/tensor/shape.py b/aesara/tensor/shape.py index cf56730c81..f2e3768799 100644 --- a/aesara/tensor/shape.py +++ b/aesara/tensor/shape.py @@ -8,6 +8,7 @@ import aesara from aesara.gradient import DisconnectedType from aesara.graph.basic import Apply, Variable +from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.misc.safe_asarray import _asarray @@ -64,7 +65,7 @@ def make_node(self, x): if not isinstance(x, Variable): x = at.as_tensor_variable(x) - if isinstance(x.type, TensorType): + if issubtype(x.type, TensorType): out_var = TensorType.subtype("int64", (x.type.ndim,))() else: out_var = aesara.tensor.type.lvector() @@ -103,7 +104,7 @@ def c_code(self, node, name, inames, onames, sub): (oname,) = onames fail = sub["fail"] - itype = node.inputs[0].type.__class__ + itype = node.inputs[0].type.base_type if itype in self.c_code_and_version: code, version = self.c_code_and_version[itype] return code % locals() @@ -144,7 +145,7 @@ def shape(x: Union[np.ndarray, Number, Variable]) -> Variable: x_type = x.type - if isinstance(x_type, TensorType) and all(s is not None for s in x_type.shape): + if issubtype(x_type, TensorType) and all(s is not None for s in x_type.shape): res = at.as_tensor_variable(x_type.shape, ndim=1, dtype=np.int64) else: res = _shape(x) @@ -263,7 +264,7 @@ def c_code(self, node, name, inames, onames, sub): # i is then 'params->i', not just 'params'. i = sub["params"] + "->i" - itype = node.inputs[0].type.__class__ + itype = node.inputs[0].type.base_type if itype in self.c_code_and_version: code, check_input, version = self.c_code_and_version[itype] return (check_input + code) % locals() @@ -473,7 +474,7 @@ def R_op(self, inputs, eval_points): return self.make_node(eval_points[0], *inputs[1:]).outputs def c_code(self, node, name, i_names, o_names, sub): - if not isinstance(node.inputs[0].type, DenseTensorType): + if not issubtype(node.inputs[0].type, DenseTensorType): raise NotImplementedError( f"Specify_shape c_code not implemented for input type {node.inputs[0].type}" ) diff --git a/aesara/tensor/subtensor.py b/aesara/tensor/subtensor.py index 8717f75619..86434fcda5 100644 --- a/aesara/tensor/subtensor.py +++ b/aesara/tensor/subtensor.py @@ -14,6 +14,7 @@ from aesara.graph.op import Op from aesara.graph.type import Type from aesara.graph.utils import MethodNotDefined +from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.misc.safe_asarray import _asarray @@ -108,7 +109,7 @@ def indices_from_subtensor( def convert_indices(indices, entry): """Reconstruct ``*Subtensor*`` index input parameter entries.""" - if indices and isinstance(entry, Type): + if indices and issubtype(entry, Type): rval = indices.pop(0) return rval elif isinstance(entry, slice): @@ -163,13 +164,13 @@ def as_index_literal( ------ NotScalarConstantError """ - if idx == np.newaxis or isinstance(getattr(idx, "type", None), NoneTypeT): + if idx == np.newaxis or issubtype(getattr(idx, "type", None), NoneTypeT): return np.newaxis if isinstance(idx, Constant): return idx.data.item() if isinstance(idx, np.ndarray) else idx.data - if isinstance(getattr(idx, "type", None), SliceType): + if issubtype(getattr(idx, "type", None), SliceType): idx = slice(*idx.owner.inputs) if isinstance(idx, slice): @@ -398,7 +399,7 @@ def is_basic_idx(idx): integer can indicate advanced indexing. """ - return isinstance(idx, (slice, type(None))) or isinstance( + return isinstance(idx, (slice, type(None))) or issubtype( getattr(idx, "type", None), (SliceType, NoneTypeT) ) @@ -421,7 +422,7 @@ def basic_shape(shape, indices): for idx, n in zip(indices, shape): if isinstance(idx, slice): res_shape += (slice_len(idx, n),) - elif isinstance(getattr(idx, "type", None), SliceType): + elif issubtype(getattr(idx, "type", None), SliceType): if idx.owner: idx_inputs = idx.owner.inputs else: @@ -429,7 +430,7 @@ def basic_shape(shape, indices): res_shape += (slice_len(slice(*idx_inputs), n),) elif idx is None: res_shape += (aes.ScalarConstant(aes.int64, 1),) - elif isinstance(getattr(idx, "type", None), NoneTypeT): + elif issubtype(getattr(idx, "type", None), NoneTypeT): res_shape += (aes.ScalarConstant(aes.int64, 1),) else: raise ValueError(f"Invalid index type: {idx}") @@ -453,7 +454,7 @@ def group_indices(indices): for idx in grp_indices: # We "zip" the dimension number to each index, which means we can't # count indices that add new axes - if (idx is not None) and not isinstance( + if (idx is not None) and not issubtype( getattr(idx, "type", None), NoneTypeT ): dim_num += 1 @@ -572,7 +573,7 @@ def index_vars_to_types(entry, slice_ok=True): if isinstance(entry, Variable) and entry.type in scal_types: return entry.type - elif isinstance(entry, Type) and entry in scal_types: + elif issubtype(entry, Type) and entry in scal_types: return entry if ( @@ -581,7 +582,7 @@ def index_vars_to_types(entry, slice_ok=True): and all(entry.type.broadcastable) ): return aes.get_scalar_type(entry.type.dtype) - elif isinstance(entry, Type) and entry in tensor_types and all(entry.broadcastable): + elif issubtype(entry, Type) and entry in tensor_types and all(entry.broadcastable): return aes.get_scalar_type(entry.dtype) elif slice_ok and isinstance(entry, slice): a = entry.start @@ -673,7 +674,7 @@ def as_nontensor_scalar(a: Variable) -> aes.ScalarVariable: # Since aes.as_scalar does not know about tensor types (it would # create a circular import) , this method converts either a # TensorVariable or a ScalarVariable to a scalar. - if isinstance(a, Variable) and isinstance(a.type, TensorType): + if isinstance(a, Variable) and issubtype(a.type, TensorType): return aesara.tensor.scalar_from_tensor(a) else: return aes.as_scalar(a) @@ -708,9 +709,7 @@ def make_node(self, x, *inputs): if len(idx_list) > x.type.ndim: raise IndexError("too many indices for array") - input_types = get_slice_elements( - idx_list, lambda entry: isinstance(entry, Type) - ) + input_types = get_slice_elements(idx_list, lambda entry: issubtype(entry, Type)) assert len(inputs) == len(input_types) @@ -924,7 +923,7 @@ def init_entry(entry, depth=0): inc_spec_pos(1) if depth == 0: is_slice.append(0) - elif isinstance(entry, Type): + elif issubtype(entry, Type): init_cmds.append( "subtensor_spec[%i] = %s;" % (spec_pos(), inputs[input_pos()]) ) @@ -1123,7 +1122,7 @@ def helper_c_code_cache_version(): return (9,) def c_code(self, node, name, inputs, outputs, sub): # DEBUG - if not isinstance(node.inputs[0].type, TensorType): + if not issubtype(node.inputs[0].type, TensorType): raise NotImplementedError() x = inputs[0] @@ -1209,7 +1208,7 @@ def _process(self, idxs, op_inputs, pstate): sidxs = [] getattr(pstate, "precedence", None) for entry in idxs: - if isinstance(entry, aes.ScalarType): + if issubtype(entry, aes.ScalarType): with set_precedence(pstate): sidxs.append(pstate.pprinter.process(inputs.pop())) elif isinstance(entry, slice): @@ -1535,9 +1534,7 @@ def make_node(self, x, y, *inputs): if len(idx_list) > x.type.ndim: raise IndexError("too many indices for array") - input_types = get_slice_elements( - idx_list, lambda entry: isinstance(entry, Type) - ) + input_types = get_slice_elements(idx_list, lambda entry: issubtype(entry, Type)) if len(inputs) != len(input_types): raise IndexError( "Not enough inputs to fill in the Subtensor template.", inputs, idx_list @@ -1559,7 +1556,7 @@ def perform(self, node, inputs, out_): indices = list(reversed(inputs[2:])) def _convert(entry): - if isinstance(entry, Type): + if issubtype(entry, Type): return indices.pop() elif isinstance(entry, slice): return slice( @@ -1709,7 +1706,7 @@ def do_type_checking(self, node): """ - if not isinstance(node.inputs[0].type, TensorType): + if not issubtype(node.inputs[0].type, TensorType): raise NotImplementedError() def c_code_cache_version(self): @@ -2507,9 +2504,9 @@ def as_index_variable(idx): return NoneConst.clone() if isinstance(idx, slice): return make_slice(idx) - if isinstance(idx, Variable) and isinstance(idx.type, SliceType): + if isinstance(idx, Variable) and issubtype(idx.type, SliceType): return idx - if isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT): + if isinstance(idx, Variable) and issubtype(idx.type, NoneTypeT): return idx idx = as_tensor_variable(idx) if idx.type.dtype not in discrete_dtypes: @@ -2602,7 +2599,7 @@ def infer_shape(self, fgraph, node, ishapes): ) # The `ishapes` entries for `SliceType`s will be None, and # we need to give `indexed_result_shape` the actual slices. - if isinstance(getattr(idx, "type", None), SliceType): + if issubtype(getattr(idx, "type", None), SliceType): index_shapes[i] = idx res_shape = indexed_result_shape( @@ -2619,7 +2616,7 @@ def perform(self, node, inputs, out_): # indexing, so __getitem__ will not return a copy. # Since no view_map is set, we need to copy the returned value if not any( - isinstance(v.type, TensorType) and v.ndim > 0 for v in node.inputs[1:] + issubtype(v.type, TensorType) and v.ndim > 0 for v in node.inputs[1:] ): rval = rval.copy() out[0] = rval diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 89e4522cb4..bad04d3050 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -1,6 +1,6 @@ import logging import warnings -from typing import Iterable, Optional, Tuple, Union +from typing import Iterable, Optional, Union import numpy as np @@ -8,8 +8,9 @@ from aesara import scalar as aes from aesara.configdefaults import config from aesara.graph.basic import Variable -from aesara.graph.type import DataType, NewTypeMeta, ShapeType -from aesara.link.c.type import CType +from aesara.graph.type import DataType, Props, ShapeType +from aesara.issubtype import issubtype +from aesara.link.c.type import CType, CTypeMeta from aesara.misc.safe_asarray import _asarray from aesara.utils import apply_across_args @@ -47,18 +48,18 @@ } -class TensorType(CType[np.ndarray]): +class TensorTypeMeta(CTypeMeta): r"""Symbolic `Type` representing `numpy.ndarray`\s.""" - __props__: Tuple[str, ...] = ("dtype", "shape") + shape: Props[DataType] = None + dtype: Props[ShapeType] = None ndim: int - shape: ShapeType - dtype: DataType dtype_specs_map = dtype_specs_map context_name = "cpu" filter_checks_isfinite = False + name = None """ When this is ``True``, strict filtering rejects data containing ``numpy.nan`` or ``numpy.inf`` entries. (Used in `DebugMode`) @@ -131,7 +132,7 @@ def clone( dtype = self.dtype if shape is None: shape = self.shape - return type(self).subtype(dtype, shape, name=self.name) + return TensorType.subtype(dtype, shape, name=self.name) def filter(self, data, strict=False, allow_downcast=None): """Convert `data` to something which can be associated to a `TensorVariable`. @@ -313,7 +314,7 @@ def in_same_class(self, otype): """ if ( - isinstance(otype, TensorType) + issubtype(otype, TensorType) and otype.dtype == self.dtype and otype.broadcastable == self.broadcastable ): @@ -377,15 +378,6 @@ def values_eq_approx( ): return values_eq_approx(a, b, allow_remove_inf, allow_remove_nan, rtol, atol) - def __eq__(self, other): - if type(self) != type(other): - return NotImplemented - - return other.dtype == self.dtype and other.shape == self.shape - - def __hash__(self): - return hash((type(self), self.dtype, self.shape)) - @property def broadcastable(self): """A boolean tuple indicating which dimensions have a shape equal to one.""" @@ -628,19 +620,27 @@ def c_code_cache_version(self): return () -class DenseTypeMeta(NewTypeMeta): - def __instancecheck__(self, o): - if type(o) == TensorType or isinstance(o, DenseTypeMeta): - return True - return False +class TensorType(CType, metaclass=TensorTypeMeta): + pass -class DenseTensorType(TensorType, metaclass=DenseTypeMeta): +class DenseTypeMeta(TensorTypeMeta): r"""A `Type` for dense tensors. Instances of this class and `TensorType`\s are considered dense `Type`\s. """ + def __subclasscheck__(self, subclass): + if getattr(subclass, "base_type", None) == TensorType or issubclass( + subclass, DenseTypeMeta + ): + return True + return False + + +class DenseTensorType(TensorType, metaclass=DenseTypeMeta): + pass + def values_eq_approx( a, b, allow_remove_inf=False, allow_remove_nan=False, rtol=None, atol=None diff --git a/aesara/tensor/type_other.py b/aesara/tensor/type_other.py index cd9eae18a8..5405eeb77e 100644 --- a/aesara/tensor/type_other.py +++ b/aesara/tensor/type_other.py @@ -9,7 +9,8 @@ from aesara.gradient import DisconnectedType from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.op import Op -from aesara.link.c.type import Generic, Type +from aesara.graph.type import NewTypeMeta +from aesara.link.c.type import Generic, GenericMeta, Type from aesara.tensor.type import integer_dtypes @@ -51,9 +52,9 @@ def grad(self, inputs, grads): make_slice = MakeSlice() -class SliceType(Type[slice]): +class SliceTypeMeta(NewTypeMeta): def clone(self, **kwargs): - return type(self).subtype() + return self.subtype() def filter(self, x, strict=False, allow_downcast=None): if isinstance(x, slice): @@ -64,18 +65,16 @@ def filter(self, x, strict=False, allow_downcast=None): def __str__(self): return "slice" - def __eq__(self, other): - return type(self) == type(other) - - def __hash__(self): - return hash(type(self)) - @staticmethod def may_share_memory(a, b): # Slices never shared memory between object return isinstance(a, slice) and a is b +class SliceType(Type, metaclass=SliceTypeMeta): + pass + + slicetype = SliceType.subtype() @@ -121,7 +120,7 @@ def as_symbolic_slice(x, **kwargs): return SliceConstant(slicetype, x) -class NoneTypeT(Generic): +class NoneTypeTMeta(GenericMeta): """ Inherit from Generic to have c code working. @@ -140,6 +139,10 @@ def may_share_memory(a, b): return False +class NoneTypeT(Generic, metaclass=NoneTypeTMeta): + pass + + none_type_t = NoneTypeT.subtype() NoneConst = Constant(none_type_t, None, name="NoneConst") diff --git a/aesara/tensor/var.py b/aesara/tensor/var.py index 0fcb8a52ed..c75a9e67ed 100644 --- a/aesara/tensor/var.py +++ b/aesara/tensor/var.py @@ -10,11 +10,10 @@ from aesara import tensor as at from aesara.configdefaults import config from aesara.graph.basic import Constant, OptionalApplyType, Variable -from aesara.graph.type import NewTypeMeta from aesara.scalar import ComplexError, IntegerDivisionError from aesara.tensor import _get_vector_length, as_tensor_variable from aesara.tensor.exceptions import AdvancedIndexingError -from aesara.tensor.type import TensorType +from aesara.tensor.type import TensorType, TensorTypeMeta from aesara.tensor.type_other import NoneConst from aesara.tensor.utils import hash_from_ndarray @@ -1068,7 +1067,7 @@ def __deepcopy__(self, memo): TensorType.constant_type = TensorConstant -class DenseVariableMeta(NewTypeMeta): +class DenseVariableMeta(TensorTypeMeta): def __instancecheck__(self, o): if type(o) == TensorVariable or isinstance(o, DenseVariableMeta): return True @@ -1083,7 +1082,7 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta): """ -class DenseConstantMeta(NewTypeMeta): +class DenseConstantMeta(TensorTypeMeta): def __instancecheck__(self, o): if type(o) == TensorConstant or isinstance(o, DenseConstantMeta): return True diff --git a/aesara/typed_list/basic.py b/aesara/typed_list/basic.py index 470d3ff242..7aba28b445 100644 --- a/aesara/typed_list/basic.py +++ b/aesara/typed_list/basic.py @@ -1,6 +1,7 @@ import numpy as np import aesara.tensor as at +from aesara import issubtype from aesara.compile.debugmode import _lessbroken_deepcopy from aesara.configdefaults import config from aesara.graph.basic import Apply, Constant, Variable @@ -72,7 +73,7 @@ class GetItem(COp): __props__ = () def make_node(self, x, index): - assert isinstance(x.type, TypedListType) + assert issubtype(x.type, TypedListType) if not isinstance(index, Variable): if isinstance(index, slice): index = Constant(SliceType.subtype(), index) @@ -80,7 +81,7 @@ def make_node(self, x, index): else: index = at.constant(index, ndim=0, dtype="int64") return Apply(self, [x, index], [x.ttype()]) - if isinstance(index.type, SliceType): + if issubtype(index.type, SliceType): return Apply(self, [x, index], [x.type()]) elif isinstance(index, TensorVariable) and index.ndim == 0: assert index.dtype == "int64" @@ -148,7 +149,7 @@ def __init__(self, inplace=False): self.view_map = {0: [0]} def make_node(self, x, toAppend): - assert isinstance(x.type, TypedListType) + assert issubtype(x.type, TypedListType) assert x.ttype == toAppend.type, (x.ttype, toAppend.type) return Apply(self, [x, toAppend], [x.type()]) @@ -231,7 +232,7 @@ def __init__(self, inplace=False): self.view_map = {0: [0]} def make_node(self, x, toAppend): - assert isinstance(x.type, TypedListType) + assert issubtype(x.type, TypedListType) assert toAppend.type.is_super(x.type) return Apply(self, [x, toAppend], [x.type()]) @@ -320,7 +321,7 @@ def __init__(self, inplace=False): self.view_map = {0: [0]} def make_node(self, x, index, toInsert): - assert isinstance(x.type, TypedListType) + assert issubtype(x.type, TypedListType) assert x.ttype == toInsert.type if not isinstance(index, Variable): index = at.constant(index, ndim=0, dtype="int64") @@ -405,7 +406,7 @@ def __init__(self, inplace=False): self.view_map = {0: [0]} def make_node(self, x, toRemove): - assert isinstance(x.type, TypedListType) + assert issubtype(x.type, TypedListType) assert x.ttype == toRemove.type return Apply(self, [x, toRemove], [x.type()]) @@ -462,7 +463,7 @@ def __init__(self, inplace=False): self.view_map = {0: [0]} def make_node(self, x): - assert isinstance(x.type, TypedListType) + assert issubtype(x.type, TypedListType) return Apply(self, [x], [x.type()]) def perform(self, node, inp, outputs): @@ -526,7 +527,7 @@ class Index(Op): __props__ = () def make_node(self, x, elem): - assert isinstance(x.type, TypedListType) + assert issubtype(x.type, TypedListType) assert x.ttype == elem.type return Apply(self, [x, elem], [scalar()]) @@ -555,7 +556,7 @@ class Count(Op): __props__ = () def make_node(self, x, elem): - assert isinstance(x.type, TypedListType) + assert issubtype(x.type, TypedListType) assert x.ttype == elem.type return Apply(self, [x, elem], [scalar()]) @@ -602,7 +603,7 @@ class Length(COp): __props__ = () def make_node(self, x): - assert isinstance(x.type, TypedListType) + assert issubtype(x.type, TypedListType) return Apply(self, [x], [scalar(dtype="int64")]) def perform(self, node, x, outputs): diff --git a/aesara/typed_list/type.py b/aesara/typed_list/type.py index 059c64105c..8bae344920 100644 --- a/aesara/typed_list/type.py +++ b/aesara/typed_list/type.py @@ -1,7 +1,11 @@ -from aesara.link.c.type import CType, Type +from typing import Any +from aesara import issubtype +from aesara.graph.type import Props +from aesara.link.c.type import CType, CTypeMeta, Type -class TypedListType(CType): + +class TypedListTypeMeta(CTypeMeta): """ Parameters @@ -14,14 +18,14 @@ class TypedListType(CType): """ - __props__ = ("ttype",) + ttype: Props[Any] = None @classmethod def type_parameters(cls, ttype, depth=0): if depth < 0: raise ValueError("Please specify a depth superior or" "equal to 0") - if not isinstance(ttype, Type): + if not issubtype(ttype, Type): raise TypeError("Expected an Aesara Type") if depth > 0: @@ -62,7 +66,7 @@ def get_depth(self): Utilitary function to get the 0 based level of the list. """ - if isinstance(self.ttype, TypedListType): + if issubtype(self.ttype, TypedListType): return self.ttype.get_depth() + 1 else: return 0 @@ -139,3 +143,7 @@ def c_code_cache_version(self): dtype = property(lambda self: self.ttype) ndim = property(lambda self: self.ttype.ndim + 1) + + +class TypedListType(CType, metaclass=TypedListTypeMeta): + pass diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index e103e92b9b..988f388782 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -20,6 +20,7 @@ from aesara.tensor.math import sum as at_sum from aesara.tensor.math import tanh from aesara.tensor.type import ( + TensorType, dmatrix, dscalar, dscalars, @@ -984,6 +985,15 @@ def test_deepcopy_shared_container(self): assert f[a] == 1 assert fc[ac] == 2 + def test_pickle_simple(self): + tt = pickle.loads(pickle.dumps(TensorType)) + assert tt == TensorType + + subt = TensorType.subtype(shape=(1, 2), dtype="floatX") + dumps = pickle.dumps(subt) + pst = pickle.loads(dumps) + assert pst == subt + def test_pickle(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") @@ -1000,7 +1010,8 @@ def test_pickle(self): try: # Note that here we also test protocol 0 on purpose, since it # should work (even though one should not use it). - g = pickle.loads(pickle.dumps(f, protocol=0)) + dump = pickle.dumps(f, protocol=0) + g = pickle.loads(dump) g = pickle.loads(pickle.dumps(f, protocol=-1)) except NotImplementedError as e: if e[0].startswith("DebugMode is not picklable"): @@ -1138,7 +1149,8 @@ def test_multiple_functions(self): for i in range(4): assert nl[i] != ol[i] assert nl[i].type == ol[i].type - assert nl[i].type is not ol[i].type + # TODO: is this a strict requirement? It doesn't make sense in the context of the new type system + # assert nl[i].type is not ol[i].type # see if the implicit input got stored assert ol[3].owner.inputs[1] is s diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 223180d006..ab4ac7035a 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -14,6 +14,7 @@ from aesara.graph.null_type import NullType from aesara.graph.rewriting.utils import rewrite_graph from aesara.graph.utils import MissingInputError +from aesara.issubtype import issubtype from aesara.printing import debugprint from aesara.tensor.basic import as_tensor from aesara.tensor.math import dot, exp @@ -244,10 +245,10 @@ def go2(inps, gs): disconnected_inputs="ignore", null_gradients="return", ) - assert isinstance(dx2.type, TensorType) + assert issubtype(dx2.type, TensorType) assert dx2.ndim == 1 - assert isinstance(dw2.type, NullType) - assert isinstance(db2.type, DisconnectedType) + assert issubtype(dw2.type, NullType) + assert issubtype(db2.type, DisconnectedType) @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py index 3b9a43c955..cab22a8a88 100644 --- a/tests/compile/test_debugmode.py +++ b/tests/compile/test_debugmode.py @@ -715,9 +715,8 @@ def make_node(self, v): if not isinstance(v, Variable): v = at.as_tensor_variable(v) assert v.type.ndim == 1 - type_class = type(v.type) - out_r_type = type_class.subtype(dtype=v.dtype, shape=(True, False)) - out_c_type = type_class.subtype(dtype=v.dtype, shape=(False, True)) + out_r_type = v.type.subtype(dtype=v.dtype, shape=(True, False)) + out_c_type = v.type.subtype(dtype=v.dtype, shape=(False, True)) return Apply(self, [v], [out_r_type(), out_c_type()]) def perform(self, node, inp, out): diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index c98f1317b2..8ea6991589 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -74,6 +74,8 @@ def test_cons(): tt1 = TensorType.subtype("float32", [True, False]) + # TODO new types: This doesn't fit with types being classes, since `TensorType(...)` should construct an object of + # type `TensorType` rather than a subtype like `t1` assert car(tt1) == TensorType assert cdr(tt1) == ("float32", (1, None)) @@ -255,7 +257,10 @@ def test_unify_Type(): assert s == {} # `Type`, `ExpressionTuple` - s = unify(t1, etuple(TensorType, "float64", (1, None))) + # TODO new types: This doesn't fit with types being classes, since `TensorType(...)` should construct an object of + # type `TensorType` rather than a subtype like `t1` + et = etuple(TensorType, "float64", (1, None)) + s = unify(t1, et) assert s == {} from aesara.scalar.basic import ScalarType diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index fb0118cdce..1fb318c691 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -1,5 +1,6 @@ import pickle from itertools import count +from typing import Any import numpy as np import pytest @@ -28,7 +29,8 @@ walk, ) from aesara.graph.op import Op -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Props, Type +from aesara.issubtype import issubtype from aesara.tensor.math import max_and_argmax from aesara.tensor.type import ( TensorType, @@ -45,17 +47,14 @@ from tests.graph.utils import MyInnerGraphOp -class MyType(Type): - __props__ = ("thingy",) - - def __init__(self, thingy): - self.thingy = thingy +class MyTypeMeta(NewTypeMeta): + thingy: Props[Any] = None def filter(self, *args, **kwargs): raise NotImplementedError() def __eq__(self, other): - return isinstance(other, MyType) and other.thingy == self.thingy + return isinstance(other, MyTypeMeta) and other.thingy == self.thingy def __hash__(self): return hash((type(self), self.thingy)) @@ -67,6 +66,10 @@ def __repr__(self): return f"R{self.thingy}" +class MyType(Type, metaclass=MyTypeMeta): + pass + + def MyVariable(thingy): return Variable(MyType.subtype(thingy), None, None) @@ -78,7 +81,7 @@ class MyOp(Op): def make_node(self, *inputs): for input in inputs: assert isinstance(input, Variable) - assert isinstance(input.type, MyType) + assert issubtype(input.type, MyType) outputs = [MyVariable(sum(input.type.thingy for input in inputs))] return Apply(self, list(inputs), outputs) diff --git a/tests/graph/test_compute_test_value.py b/tests/graph/test_compute_test_value.py index 3fc3bed7b6..b18c46f041 100644 --- a/tests/graph/test_compute_test_value.py +++ b/tests/graph/test_compute_test_value.py @@ -8,7 +8,7 @@ from aesara.graph import utils from aesara.graph.basic import Apply from aesara.graph.op import Op -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type from aesara.link.c.op import COp from aesara.tensor.math import _allclose, dot from aesara.tensor.type import fmatrix, iscalar, matrix, vector @@ -46,10 +46,13 @@ def perform(self, *args, **kwargs): class TestComputeTestValue: def test_destroy_map(self): - class SomeType(Type): + class SomeTypeMeta(NewTypeMeta): def filter(self, data, strict=False, allow_downcast=None): return data + class SomeType(Type, metaclass=SomeTypeMeta): + pass + class InplaceOp(Op): __props__ = () diff --git a/tests/graph/test_destroyhandler.py b/tests/graph/test_destroyhandler.py index 5dddfe75a1..005da2e26f 100644 --- a/tests/graph/test_destroyhandler.py +++ b/tests/graph/test_destroyhandler.py @@ -15,8 +15,9 @@ SubstitutionNodeRewriter, WalkingGraphRewriter, ) -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type from aesara.graph.utils import InconsistencyError +from aesara.issubtype import issubtype from tests.unittest_tools import assertFailure_fast @@ -37,12 +38,19 @@ def as_variable(x): return x -class MyType(Type): +class MyTypeMeta(NewTypeMeta): def filter(self, data): return data def __eq__(self, other): - return isinstance(other, MyType) + return isinstance(other, MyTypeMeta) + + def __hash__(self): + return hash(MyTypeMeta) + + +class MyType(Type, metaclass=MyTypeMeta): + pass def MyVariable(name): @@ -85,7 +93,7 @@ def make_node(self, *inputs): assert len(inputs) == self.nin inputs = list(map(as_variable, inputs)) for input in inputs: - if not isinstance(input.type, MyType): + if not issubtype(input.type, MyType): raise Exception("Error 1") outputs = [MyVariable(self.name + "_R") for i in range(self.nout)] return Apply(self, inputs, outputs) diff --git a/tests/graph/test_features.py b/tests/graph/test_features.py index 06a48d1e10..2c2a1da641 100644 --- a/tests/graph/test_features.py +++ b/tests/graph/test_features.py @@ -4,17 +4,15 @@ from aesara.graph.features import Feature, NodeFinder, ReplaceValidate from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Props, Type +from aesara.issubtype import issubtype from tests.graph.utils import MyVariable, op1 class TestNodeFinder: def test_straightforward(self): - class MyType(Type): - __props__ = ("name",) - - def __init__(self, name): - self.name = name + class MyTypeMeta(NewTypeMeta): + name: Props[str] = None def filter(self, *args, **kwargs): raise NotImplementedError() @@ -26,7 +24,13 @@ def __repr__(self): return self.name def __eq__(self, other): - return isinstance(other, MyType) + return isinstance(other, MyTypeMeta) + + def __hash__(self): + return hash(MyTypeMeta) + + class MyType(Type, metaclass=MyTypeMeta): + pass class MyOp(Op): @@ -44,7 +48,7 @@ def as_variable(x): assert len(inputs) == self.nin inputs = list(map(as_variable, inputs)) for input in inputs: - if not isinstance(input.type, MyType): + if not issubtype(input.type, MyType): raise Exception("Error 1") outputs = [MyType.subtype(self.name + "_R")()] return Apply(self, inputs, outputs) diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index 3b755b142d..aa683be252 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -1,3 +1,5 @@ +from typing import Any + import numpy as np import pytest @@ -8,8 +10,9 @@ from aesara.configdefaults import config from aesara.graph.basic import Apply, Variable from aesara.graph.op import Op -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Props, Type from aesara.graph.utils import TestValueError +from aesara.issubtype import issubtype from aesara.link.c.type import Generic from aesara.tensor.math import log from aesara.tensor.type import dmatrix, dscalar, dvector, vector @@ -20,15 +23,15 @@ def as_variable(x): return x -class MyType(Type): - __props__ = ("thingy",) - - def __init__(self, thingy): - self.thingy = thingy +class MyTypeMeta(NewTypeMeta): + thingy: Props[Any] = None def __eq__(self, other): return type(other) == type(self) and other.thingy == self.thingy + def __hash__(self): + return hash((MyTypeMeta, self.thingy)) + def __str__(self): return str(self.thingy) @@ -53,6 +56,10 @@ def may_share_memory(a, b): return False +class MyType(Type, metaclass=MyTypeMeta): + pass + + class MyOp(Op): __props__ = () @@ -60,7 +67,7 @@ class MyOp(Op): def make_node(self, *inputs): inputs = list(map(as_variable, inputs)) for input in inputs: - if not isinstance(input.type, MyType): + if not issubtype(input.type, MyType): raise Exception("Error 1") outputs = [MyType.subtype(sum(input.type.thingy for input in inputs))()] return Apply(self, inputs, outputs) @@ -99,9 +106,9 @@ def test_sanity_0(self): # validate def test_validate(self): try: - MyOp( - Generic.subtype()(), MyType.subtype(1)() - ) # MyOp requires MyType instances + gen = Generic.subtype()() + myt = MyType.subtype(1)() + MyOp(gen, myt) # MyOp requires MyType instances raise Exception("Expected an exception") except Exception as e: if str(e) != "Error 1": diff --git a/tests/graph/test_types.py b/tests/graph/test_types.py index 6c37e3ecc2..f0b8a1496e 100644 --- a/tests/graph/test_types.py +++ b/tests/graph/test_types.py @@ -1,20 +1,23 @@ +from typing import Any + import pytest from aesara.graph.basic import Variable -from aesara.graph.type import Type - +from aesara.graph.type import NewTypeMeta, Props, Type +from aesara.issubtype import issubtype -class MyType(Type): - __props__ = ("thingy",) - def __init__(self, thingy): - self.thingy = thingy +class MyTypeMeta(NewTypeMeta): + thingy: Props[Any] = None def filter(self, *args, **kwargs): raise NotImplementedError() def __eq__(self, other): - return isinstance(other, MyType) and other.thingy == self.thingy + return isinstance(other, MyTypeMeta) and other.thingy == self.thingy + + def __hash__(self): + return hash((MyTypeMeta, self.thingy)) def __str__(self): return f"R{self.thingy}" @@ -23,12 +26,20 @@ def __repr__(self): return f"R{self.thingy}" -class MyType2(MyType): +class MyType(Type, metaclass=MyTypeMeta): + pass + + +class MyTypeMeta2(MyTypeMeta): def is_super(self, other): if self.thingy <= other.thingy: return True +class MyType2(Type, metaclass=MyTypeMeta2): + pass + + def test_is_super(): t1 = MyType.subtype(1) t2 = MyType.subtype(2) @@ -64,4 +75,4 @@ def test_convert_variable(): def test_default_clone(): mt = MyType.subtype(1) - assert isinstance(mt.clone(1), MyType) + assert issubtype(mt.clone(1), MyType) diff --git a/tests/graph/utils.py b/tests/graph/utils.py index 6122d5ceae..aeb6831a73 100644 --- a/tests/graph/utils.py +++ b/tests/graph/utils.py @@ -3,7 +3,8 @@ from aesara.graph.basic import Apply, Constant, NominalVariable, Variable, clone_replace from aesara.graph.fg import FunctionGraph from aesara.graph.op import HasInnerGraph, Op -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type +from aesara.issubtype import issubtype def is_variable(x): @@ -12,34 +13,42 @@ def is_variable(x): return x -class MyType(Type): +class MyTypeMeta(NewTypeMeta): def filter(self, data): return data def __eq__(self, other): - return isinstance(other, MyType) + return isinstance(other, MyTypeMeta) def __hash__(self): - return hash(MyType) + return hash(MyTypeMeta) def __repr__(self): return "MyType()" -class MyType2(Type): +class MyType(Type, metaclass=MyTypeMeta): + pass + + +class MyTypeMeta2(NewTypeMeta): def filter(self, data): return data def __eq__(self, other): - return isinstance(other, MyType2) + return isinstance(other, MyTypeMeta2) def __hash__(self): - return hash(MyType) + return hash(MyTypeMeta) def __repr__(self): return "MyType2()" +class MyType2(Type, metaclass=MyTypeMeta2): + pass + + def MyVariable(name): return Variable(MyType.subtype(), None, None, name=name) @@ -64,7 +73,7 @@ def __init__(self, name, dmap=None, x=None, n_outs=1): def make_node(self, *inputs): inputs = list(map(is_variable, inputs)) for input in inputs: - if not isinstance(input.type, MyType): + if not issubtype(input.type, MyType): raise Exception("Error 1") outputs = [MyType.subtype()() for i in range(self.n_outs)] return Apply(self, inputs, outputs) @@ -98,7 +107,7 @@ class MyOpCastType2(MyOp): def make_node(self, *inputs): inputs = list(map(is_variable, inputs)) for input in inputs: - if not isinstance(input.type, MyType): + if not issubtype(input.type, MyType): raise Exception("Error 1") outputs = [MyType2.subtype()()] diff --git a/tests/link/c/test_basic.py b/tests/link/c/test_basic.py index a3e61de4af..37f7debcd1 100644 --- a/tests/link/c/test_basic.py +++ b/tests/link/c/test_basic.py @@ -10,7 +10,7 @@ from aesara.link.basic import PerformLinker from aesara.link.c.basic import CLinker, DualLinker, OpWiseCLinker from aesara.link.c.op import COp -from aesara.link.c.type import CType +from aesara.link.c.type import CType, CTypeMeta from aesara.tensor.type import iscalar, matrix, vector from tests.link.test_link import make_function @@ -20,7 +20,7 @@ def as_variable(x): return x -class TDouble(CType): +class TDoubleMeta(CTypeMeta): def filter(self, data, strict=False, allow_downcast=False): return float(data) @@ -72,11 +72,9 @@ def c_cleanup(self, name, sub): def c_code_cache_version(self): return (1,) - def __eq__(self, other): - return type(self) == type(other) - def __hash__(self): - return hash(type(self)) +class TDouble(CType, metaclass=TDoubleMeta): + pass tdouble = TDouble.subtype() diff --git a/tests/link/c/test_cmodule.py b/tests/link/c/test_cmodule.py index 6fef007537..63521fadbc 100644 --- a/tests/link/c/test_cmodule.py +++ b/tests/link/c/test_cmodule.py @@ -38,7 +38,7 @@ def c_code(self, node, name, inames, onames, sub): (iname,) = inames (oname,) = onames fail = sub["fail"] - itype = node.inputs[0].type.__class__ + itype = node.inputs[0].type.base_type if itype in self.c_code_and_version: code, version = self.c_code_and_version[itype] rand = np.random.random() diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index d48dd56194..3ed493e978 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -19,7 +19,7 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op, get_test_value from aesara.graph.rewriting.db import RewriteDatabaseQuery -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type from aesara.ifelse import ifelse from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.dispatch import numba_typify @@ -31,15 +31,19 @@ from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape -class MyType(Type): +class MyTypeMeta(NewTypeMeta): def filter(self, data): return data def __eq__(self, other): - return isinstance(other, MyType) + return isinstance(other, MyTypeMeta) def __hash__(self): - return hash(MyType) + return hash(MyTypeMeta) + + +class MyType(Type, metaclass=MyTypeMeta): + pass class MyOp(Op): diff --git a/tests/link/test_link.py b/tests/link/test_link.py index 51ce905ca0..77363bff66 100644 --- a/tests/link/test_link.py +++ b/tests/link/test_link.py @@ -8,7 +8,7 @@ from aesara.graph import fg from aesara.graph.basic import Apply, Constant, Variable, clone from aesara.graph.op import Op -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type from aesara.link.basic import Container, Linker, PerformLinker, WrapLinker from aesara.link.c.basic import OpWiseCLinker from aesara.tensor.type import matrix, scalar @@ -62,11 +62,15 @@ def as_variable(x): return x -class TDouble(Type): +class TDoubleMeta(NewTypeMeta): def filter(self, data): return float(data) +class TDouble(Type, metaclass=TDoubleMeta): + pass + + tdouble = TDouble.subtype() diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 04df4a5705..e52bff6ac0 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -5,6 +5,7 @@ import tests.unittest_tools as utt from aesara.compile.mode import Mode from aesara.graph.fg import FunctionGraph +from aesara.issubtype import issubtype from aesara.link.c.basic import DualLinker from aesara.scalar.basic import ( ComplexError, @@ -492,13 +493,13 @@ def test_mean(mode): def test_shape(): a = float32("a") - assert isinstance(a.type, ScalarType) + assert issubtype(a.type, ScalarType) assert a.shape.type.ndim == 1 assert a.shape.type.shape == (0,) assert a.shape.type.dtype == "int64" b = constant(2, name="b") - assert isinstance(b.type, ScalarType) + assert issubtype(b.type, ScalarType) assert b.shape.type.ndim == 1 assert b.shape.type.shape == (0,) assert b.shape.type.dtype == "int64" diff --git a/tests/sparse/test_var.py b/tests/sparse/test_var.py index 75936d70e2..a49073b645 100644 --- a/tests/sparse/test_var.py +++ b/tests/sparse/test_var.py @@ -7,6 +7,7 @@ import aesara import aesara.sparse as sparse import aesara.tensor as at +from aesara.issubtype import issubtype from aesara.sparse.type import SparseTensorType from aesara.tensor.type import DenseTensorType @@ -99,7 +100,7 @@ def test_unary(self, method, exp_type, cm, x): else: z_outs = z - assert all(isinstance(out.type, exp_type) for out in z_outs) + assert all(issubtype(out.type, exp_type) for out in z_outs) f = aesara.function([x], z, on_unused_input="ignore", allow_input_downcast=True) @@ -155,7 +156,7 @@ def test_binary(self, method, exp_type): else: z_outs = z - assert all(isinstance(out.type, exp_type) for out in z_outs) + assert all(issubtype(out.type, exp_type) for out in z_outs) f = aesara.function([x, y], z) res = f( @@ -177,7 +178,7 @@ def test_reshape(self): with pytest.warns(UserWarning, match=".*converted to dense.*"): z = x.reshape((3, 2)) - assert isinstance(z.type, DenseTensorType) + assert issubtype(z.type, DenseTensorType) f = aesara.function([x], z) exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) @@ -190,7 +191,7 @@ def test_dimshuffle(self): with pytest.warns(UserWarning, match=".*converted to dense.*"): z = x.dimshuffle((1, 0)) - assert isinstance(z.type, DenseTensorType) + assert issubtype(z.type, DenseTensorType) f = aesara.function([x], z) exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) @@ -201,7 +202,7 @@ def test_getitem(self): x = sparse.csr_from_dense(x) z = x[:, :2] - assert isinstance(z.type, SparseTensorType) + assert issubtype(z.type, SparseTensorType) f = aesara.function([x], z) exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) @@ -214,7 +215,7 @@ def test_dot(self): y = sparse.csr_from_dense(y) z = x.__dot__(y) - assert isinstance(z.type, SparseTensorType) + assert issubtype(z.type, SparseTensorType) f = aesara.function([x, y], z) exp_res = f( @@ -230,7 +231,7 @@ def test_repeat(self): with pytest.warns(UserWarning, match=".*converted to dense.*"): z = x.repeat(2, axis=1) - assert isinstance(z.type, DenseTensorType) + assert issubtype(z.type, DenseTensorType) f = aesara.function([x], z) exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index e4fc4da146..7157a07197 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -14,7 +14,7 @@ from aesara.graph.op import Op from aesara.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in from aesara.graph.rewriting.utils import rewrite_graph -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type from aesara.tensor.basic import as_tensor_variable from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.math import add, exp, maximum @@ -504,14 +504,14 @@ def test_local_Shape_i_of_broadcastable(): assert fgraph.outputs[0].data == 1 # A test for a non-`TensorType` - class MyType(Type): + class MyTypeMeta(NewTypeMeta): ndim = 1 def filter(self, *args, **kwargs): raise NotImplementedError() - def __eq__(self, other): - return isinstance(other, MyType) and other.thingy == self.thingy + class MyType(Type, metaclass=MyTypeMeta): + pass class MyVariable(Variable): pass diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index bdfa28c6b3..b73a989103 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -13,7 +13,7 @@ from aesara.graph.rewriting.basic import check_stack_trace from aesara.graph.rewriting.db import RewriteDatabaseQuery from aesara.graph.rewriting.utils import rewrite_graph -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type from aesara.raise_op import Assert from aesara.tensor import inplace from aesara.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector @@ -1932,12 +1932,12 @@ def test_local_subtensor_shape_constant(): assert np.array_equal(res.data, [1, 1]) # A test for a non-`TensorType` - class MyType(Type): + class MyTypeMeta(NewTypeMeta): def filter(self, *args, **kwargs): raise NotImplementedError() - def __eq__(self, other): - return isinstance(other, MyType) and other.thingy == self.thingy + class MyType(Type, metaclass=MyTypeMeta): + pass x = shape(Variable(MyType.subtype(), None, None))[0] diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 858831b932..7322181283 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -13,6 +13,7 @@ from aesara.configdefaults import config from aesara.graph.basic import Apply, Variable from aesara.graph.fg import FunctionGraph +from aesara.issubtype import issubtype from aesara.link.basic import PerformLinker from aesara.link.c.basic import CLinker, OpWiseCLinker from aesara.tensor import as_tensor_variable @@ -861,7 +862,7 @@ def test_shape_types(self): (out_shape,) = z.owner.op.infer_shape(None, z.owner, [(lscalar(), 1), (50, 10)]) - assert all(isinstance(v.type, TensorType) for v in out_shape) + assert all(issubtype(v.type, TensorType) for v in out_shape) def test_static_shape_unary(self): x = tensor("float64", shape=(None, 0, 1, 5)) diff --git a/tests/tensor/test_merge.py b/tests/tensor/test_merge.py index dd45bf3e0b..3ba424f45c 100644 --- a/tests/tensor/test_merge.py +++ b/tests/tensor/test_merge.py @@ -5,7 +5,8 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op from aesara.graph.rewriting.basic import MergeOptimizer -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type +from aesara.issubtype import issubtype def is_variable(x): @@ -14,12 +15,16 @@ def is_variable(x): return x -class MyType(Type): +class MyTypeMeta(NewTypeMeta): def filter(self, data): return data def __eq__(self, other): - return isinstance(other, MyType) + return isinstance(other, MyTypeMeta) + + +class MyType(Type, metaclass=MyTypeMeta): + pass class MyOp(Op): @@ -33,7 +38,7 @@ def __init__(self, name, dmap=None, x=None): def make_node(self, *inputs): inputs = list(map(is_variable, inputs)) for input in inputs: - if not isinstance(input.type, MyType): + if not issubtype(input.type, MyType): raise Exception("Error 1") outputs = [MyType.subtype()()] return Apply(self, inputs, outputs) diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 15bcb95c17..0bd89197cc 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -2,12 +2,12 @@ import pytest import aesara -from aesara import Mode, function, grad +from aesara import Mode, function, grad, issubtype from aesara.compile.ops import DeepCopyOp from aesara.configdefaults import config from aesara.graph.basic import Variable from aesara.graph.fg import FunctionGraph -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type from aesara.misc.safe_asarray import _asarray from aesara.tensor import as_tensor_variable, get_vector_length, row from aesara.tensor.basic import MakeVector, constant @@ -60,12 +60,12 @@ def test_shape_basic(): s = shape(lscalar()) assert s.type.broadcastable == (False,) - class MyType(Type): + class MyTypeMeta(NewTypeMeta): def filter(self, *args, **kwargs): raise NotImplementedError() - def __eq__(self, other): - return isinstance(other, MyType) and other.thingy == self.thingy + class MyType(Type, metaclass=MyTypeMeta): + pass s = shape(Variable(MyType.subtype(), None)) assert s.type.broadcastable == (False,) @@ -445,7 +445,7 @@ def test_bad_shape(self): with pytest.raises(AssertionError, match="SpecifyShape:.*"): f(xval) - assert isinstance( + assert issubtype( [n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0] .inputs[0] .type, @@ -455,7 +455,7 @@ def test_bad_shape(self): x = matrix() xval = np.random.random((2, 3)).astype(config.floatX) f = aesara.function([x], specify_shape(x, 2, 3), mode=self.mode) - assert isinstance( + assert issubtype( [n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0] .inputs[0] .type, diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 98b3984310..72700fae9b 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -9,6 +9,7 @@ import aesara import aesara.scalar as scal import aesara.tensor.basic as at +from aesara import issubtype from aesara.compile import DeepCopyOp, shared from aesara.compile.io import In from aesara.configdefaults import config @@ -2606,9 +2607,9 @@ def test_index_vars_to_types(): index_vars_to_types(1) res = index_vars_to_types(iscalar) - assert isinstance(res, scal.ScalarType) + assert issubtype(res, scal.ScalarType) x = scal.constant(1, dtype=np.uint8) - assert isinstance(x.type, scal.ScalarType) + assert issubtype(x.type, scal.ScalarType) res = index_vars_to_types(x) assert res == x.type diff --git a/tests/tensor/test_type.py b/tests/tensor/test_type.py index 0c82337370..557ce15b55 100644 --- a/tests/tensor/test_type.py +++ b/tests/tensor/test_type.py @@ -243,7 +243,8 @@ def test_fixed_shape_basic(): assert t1.broadcastable == (True,) t2 = t1.clone() - assert t1 is not t2 + # TODO: is this requirement necessary? It doesn't make sense with the new types + # assert t1 is not t2 assert t1 == t2 t2 = t1.clone(dtype="float32", shape=(2, 4)) diff --git a/tests/test_raise_op.py b/tests/test_raise_op.py index 8236ba28f1..93b213b5cf 100644 --- a/tests/test_raise_op.py +++ b/tests/test_raise_op.py @@ -6,6 +6,7 @@ import aesara.tensor as at from aesara.compile.mode import OPT_FAST_RUN, Mode from aesara.graph.basic import Constant, equal_computations +from aesara.issubtype import issubtype from aesara.raise_op import Assert, CheckAndRaise, assert_op from aesara.scalar.basic import ScalarType, float64 from aesara.sparse import as_sparse_variable @@ -116,8 +117,8 @@ def test_perform_CheckAndRaise_scalar(linker): conds = (val > 0, val > 3) y = check_and_raise(val, *conds) - assert all(isinstance(i.type, ScalarType) for i in y.owner.inputs) - assert isinstance(y.type, ScalarType) + assert all(issubtype(i.type, ScalarType) for i in y.owner.inputs) + assert issubtype(y.type, ScalarType) mode = Mode(linker=linker) y_fn = aesara.function([val], y, mode=mode) From 0813e3475f6d55732ce234eeeb5715dd8e054df6 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 27 Sep 2022 14:21:43 -0500 Subject: [PATCH 08/21] Make compilation modes configurable in compare_numba_and_py --- tests/link/numba/test_basic.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 3ed493e978..801d6aab24 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -102,7 +102,7 @@ def compare_shape_dtype(x, y): return x.shape == y.shape and x.dtype == y.dtype -def eval_python_only(fn_inputs, fgraph, inputs): +def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode): """Evaluate the Numba implementation in pure Python for coverage purposes.""" def py_tuple_setitem(t, i, v): @@ -168,13 +168,15 @@ def inner_vec(*args): aesara_numba_fn = function( fn_inputs, fgraph.outputs, - mode=numba_mode, + mode=mode, accept_inplace=True, ) _ = aesara_numba_fn(*inputs) -def compare_numba_and_py(fgraph, inputs, assert_fn=None): +def compare_numba_and_py( + fgraph, inputs, assert_fn=None, numba_mode=numba_mode, py_mode=py_mode +): """Function to compare python graph output and Numba compiled output for testing equality In the tests below computational graphs are defined in Aesara. These graphs are then passed to @@ -215,7 +217,7 @@ def assert_fn(x, y): numba_res = aesara_numba_fn(*inputs) # Get some coverage - eval_python_only(fn_inputs, fgraph, inputs) + eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode) if len(fgraph.outputs) > 1: for j, p in zip(numba_res, py_res): From 21c46ac10625586bf1fc883b1e15dc7424605574 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 21 Sep 2022 13:20:33 -0500 Subject: [PATCH 09/21] Enable Numba bounds checking during testing --- conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/conftest.py b/conftest.py index bd4f8bab67..048c0662c9 100644 --- a/conftest.py +++ b/conftest.py @@ -10,6 +10,7 @@ def pytest_sessionstart(session): "warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,cmodule__warn_no_version=True", ] ) + os.environ["NUMBA_BOUNDSCHECK"] = "1" def pytest_addoption(parser): From 47ecc5f825bbd3a8e8c8f9aa1e1b915fdf44dae3 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 27 Sep 2022 17:04:45 -0500 Subject: [PATCH 10/21] Use to_scalar in numba_funcify_ScalarFromTensor This should prevent errors when the input is already a Numba scalar, and it will use Numba's type information to selectively apply the scalar conversion. --- aesara/link/numba/dispatch/tensor_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aesara/link/numba/dispatch/tensor_basic.py b/aesara/link/numba/dispatch/tensor_basic.py index 9a9578fcc1..942e9cb709 100644 --- a/aesara/link/numba/dispatch/tensor_basic.py +++ b/aesara/link/numba/dispatch/tensor_basic.py @@ -230,6 +230,6 @@ def tensor_from_scalar(x): def numba_funcify_ScalarFromTensor(op, **kwargs): @numba_basic.numba_njit(inline="always") def scalar_from_tensor(x): - return x.item() + return numba_basic.to_scalar(x) return scalar_from_tensor From ff1ad55aa96ec01f2ed58d9f08ea0963a9d5f227 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Fri, 23 Sep 2022 15:38:27 -0500 Subject: [PATCH 11/21] Allow FunctionGraph arguments in create_numba_signature --- aesara/link/numba/dispatch/basic.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index b82278e923..c20636ef03 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from functools import singledispatch from textwrap import dedent +from typing import Union import numba import numba.np.unsafe.ndarray as numba_ndarray @@ -97,11 +98,13 @@ def get_numba_type( def create_numba_signature( - node: Apply, force_scalar: bool = False, reduce_to_scalar: bool = False + node_or_fgraph: Union[FunctionGraph, Apply], + force_scalar: bool = False, + reduce_to_scalar: bool = False, ) -> numba.types.Type: - """Create a Numba type for the signature of an ``Apply`` node.""" + """Create a Numba type for the signature of an `Apply` node or `FunctionGraph`.""" input_types = [] - for inp in node.inputs: + for inp in node_or_fgraph.inputs: input_types.append( get_numba_type( inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar @@ -109,7 +112,7 @@ def create_numba_signature( ) output_types = [] - for out in node.outputs: + for out in node_or_fgraph.outputs: output_types.append( get_numba_type( out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar From 26bcc9b55eacb6c6ba129172b517e4504e8b15d3 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 27 Sep 2022 18:06:14 -0500 Subject: [PATCH 12/21] Stop using auto_name for transpilation Using the `auto_name` values will result in cache misses when caching is based on the generated source code, so we're not going to use it. --- aesara/link/numba/dispatch/basic.py | 10 +++++----- aesara/link/numba/dispatch/scan.py | 2 +- aesara/link/utils.py | 15 +++++++++------ tests/link/jax/test_basic.py | 17 ----------------- tests/link/test_link.py | 1 - tests/link/test_utils.py | 16 ++++++++++------ 6 files changed, 25 insertions(+), 36 deletions(-) diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index c20636ef03..8febee82a8 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -394,10 +394,14 @@ def numba_funcify_FunctionGraph( def create_index_func(node, objmode=False): """Create a Python function that assembles and uses an index on an array.""" + unique_names = unique_name_generator( + ["subtensor", "incsubtensor", "z"], suffix_sep="_" + ) + def convert_indices(indices, entry): if indices and issubtype(entry, Type): rval = indices.pop(0) - return rval.auto_name + return unique_names(rval) elif isinstance(entry, slice): return ( f"slice({convert_indices(indices, entry.start)}, " @@ -414,10 +418,6 @@ def convert_indices(indices, entry): ) index_start_idx = 1 + int(set_or_inc) - unique_names = unique_name_generator( - ["subtensor", "incsubtensor", "z"], suffix_sep="_" - ) - input_names = [unique_names(v, force_unique=True) for v in node.inputs] op_indices = list(node.inputs[index_start_idx:]) idx_list = getattr(node.op, "idx_list", None) diff --git a/aesara/link/numba/dispatch/scan.py b/aesara/link/numba/dispatch/scan.py index 2a1a6c2735..c817baca0f 100644 --- a/aesara/link/numba/dispatch/scan.py +++ b/aesara/link/numba/dispatch/scan.py @@ -54,7 +54,7 @@ def numba_funcify_Scan(op, node, **kwargs): p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot - input_names = [f"{n.auto_name}_{i}" for i, n in enumerate(node.inputs[1:])] + input_names = [f"outer_in_{i}" for i, n in enumerate(node.inputs[1:])] outer_in_seqs_names = input_names[:n_seqs] outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot] outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot] diff --git a/aesara/link/utils.py b/aesara/link/utils.py index ba118ef170..8bc82bafda 100644 --- a/aesara/link/utils.py +++ b/aesara/link/utils.py @@ -613,8 +613,8 @@ def compile_function_src( def get_name_for_object(x: Any) -> str: """Get the name for an arbitrary object.""" - if isinstance(x, Variable): - name = re.sub("[^0-9a-zA-Z]+", "_", x.name) if x.name else "" + if isinstance(x, Variable) and x.name: + name = re.sub("[^0-9a-zA-Z]+", "_", x.name) name = ( name if ( @@ -622,19 +622,22 @@ def get_name_for_object(x: Any) -> str: and not iskeyword(name) and name not in dir(builtins) ) - else x.auto_name + else "" ) else: - name = getattr(x, "__name__", "") + name = re.sub(r"(? Callable: """Create a function that generates unique names.""" diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index 56dfd18d35..00df1cf52c 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -89,23 +89,6 @@ def compare_jax_and_py( return jax_res -def test_jax_FunctionGraph_names(): - import inspect - - from aesara.link.jax.dispatch import jax_funcify - - x = scalar("1x") - y = scalar("_") - z = scalar() - q = scalar("def") - - out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False) - out_jx = jax_funcify(out_fg) - sig = inspect.signature(out_jx) - assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys()) - assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4) - - def test_jax_FunctionGraph_once(): """Make sure that an output is only computed once when it's referenced multiple times.""" from aesara.link.jax.dispatch import jax_funcify diff --git a/tests/link/test_link.py b/tests/link/test_link.py index 77363bff66..30fa604d76 100644 --- a/tests/link/test_link.py +++ b/tests/link/test_link.py @@ -224,7 +224,6 @@ def wrap(fgraph, i, node, th): def test_sort_schedule_fn(): - import aesara from aesara.graph.sched import make_depends, sort_schedule_fn x = matrix("x") diff --git a/tests/link/test_utils.py b/tests/link/test_utils.py index efd8e62ab7..f32ab7e298 100644 --- a/tests/link/test_utils.py +++ b/tests/link/test_utils.py @@ -12,7 +12,7 @@ get_name_for_object, unique_name_generator, ) -from aesara.scalar.basic import Add +from aesara.scalar.basic import Add, float64 from aesara.tensor.elemwise import Elemwise from aesara.tensor.type import scalar, vector from aesara.tensor.type_other import NoneConst @@ -42,7 +42,7 @@ def test_fgraph_to_python_names(): x = scalar("1x") y = scalar("_") - z = scalar() + z = float64() q = scalar("def") r = NoneConst @@ -50,9 +50,13 @@ def test_fgraph_to_python_names(): out_jx = fgraph_to_python(out_fg, to_python) sig = inspect.signature(out_jx) - assert (x.auto_name, "_", z.auto_name, q.auto_name, r.name) == tuple( - sig.parameters.keys() - ) + assert ( + "tensor_variable", + "_", + "scalar_variable", + "tensor_variable_1", + r.name, + ) == tuple(sig.parameters.keys()) assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5) obj = object() @@ -191,7 +195,7 @@ def test_unique_name_generator(): q_name_1 = unique_names(q) q_name_2 = unique_names(q) - assert q_name_1 == q_name_2 == q.auto_name + assert q_name_1 == q_name_2 == "tensor_variable" unique_names = unique_name_generator() From 8e67e043b30c8b8d2cc314b69b25e6ba93641e69 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 20 Sep 2022 23:02:40 -0500 Subject: [PATCH 13/21] Do not overwrite arguments in Numba's Scan implementation --- aesara/link/numba/dispatch/scan.py | 56 ++++++++++++++++-------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/aesara/link/numba/dispatch/scan.py b/aesara/link/numba/dispatch/scan.py index c817baca0f..555b38a587 100644 --- a/aesara/link/numba/dispatch/scan.py +++ b/aesara/link/numba/dispatch/scan.py @@ -1,3 +1,5 @@ +from textwrap import dedent, indent + import numpy as np from numba import types from numba.extending import overload @@ -72,11 +74,10 @@ def numba_funcify_Scan(op, node, **kwargs): allocate_mem_to_nit_sot = "" for _name in outer_in_seqs_names: - # A sequence with multiple taps is provided as multiple modified - # input sequences to the Scan Op sliced appropriately - # to keep following the logic of a normal sequence. - index = "[i]" - inner_in_indexed.append(_name + index) + # A sequence with multiple taps is provided as multiple modified input + # sequences--all sliced so as to keep following the logic of a normal + # sequence. + inner_in_indexed.append(f"{_name}[i]") name_to_input_map = dict(zip(input_names, node.inputs[1:])) mit_sot_name_to_taps = dict(zip(outer_in_mit_sot_names, mit_sot_in_taps)) @@ -88,31 +89,34 @@ def numba_funcify_Scan(op, node, **kwargs): for _tap in curr_taps: index = idx_to_str(_tap - min_tap) - inner_in_indexed.append(_name + index) + inner_in_indexed.append(f"{_name}{index}") inner_out_name_to_index[_name] = -min_tap if _name in outer_in_sit_sot_names: - # Note that the outputs with single taps which are not - # -1 are (for instance taps = [-2]) are classified - # as mit-sot so the code for handling sit-sots remains - # constant as follows - index = "[i]" - inner_in_indexed.append(_name + index) + # Note that the outputs with single, non-`-1` taps are (e.g. `taps + # = [-2]`) are classified as mit-sot, so the code for handling + # sit-sots remains constant as follows + inner_in_indexed.append(f"{_name}[i]") inner_out_name_to_index[_name] = 1 if _name in outer_in_nit_sot_names: - inner_out_name_to_index[_name] = 0 - # In case of nit-sots we are provided shape of the array - # instead of actual arrays like other cases, hence we - # allocate space for the results accordingly. + output_name = f"{_name}_nitsot_storage" + inner_out_name_to_index[output_name] = 0 + # In case of nit-sots we are provided the shape of the array + # instead of actual arrays (like other cases), hence we allocate + # space for the results accordingly. curr_nit_sot_position = input_names.index(_name) - n_seqs curr_nit_sot = inner_fg.outputs[curr_nit_sot_position] mem_shape = ["1"] * curr_nit_sot.ndim curr_dtype = curr_nit_sot.type.numpy_dtype.name - allocate_mem_to_nit_sot += f""" - {_name} = [np.zeros(({create_arg_string(mem_shape)}), dtype=np.{curr_dtype})]*{_name}.item() -""" + allocate_mem_to_nit_sot += dedent( + f""" + {output_name} = [ + np.empty(({create_arg_string(mem_shape)},), dtype=np.{curr_dtype}) for i in range({_name}.item()) + ]""" + ) + # The non_seqs are passed to inner function as-is inner_in_indexed += outer_in_non_seqs_names inner_out_indexed = [ @@ -121,7 +125,7 @@ def numba_funcify_Scan(op, node, **kwargs): while_logic = "" if op.info.as_while: - # The inner function will be returning a boolean as last argument + # The inner function will return a boolean as the last value inner_out_indexed.append("while_flag") while_logic += """ if while_flag: @@ -137,18 +141,18 @@ def numba_funcify_Scan(op, node, **kwargs): global_env = locals() global_env["np"] = np + output_names = outer_in_mit_sot_names + outer_in_sit_sot_names + output_names += [f"{n}_nitsot_storage" for n in outer_in_nit_sot_names] + scan_op_src = f""" def scan(n_steps, {", ".join(input_names)}): -{allocate_mem_to_nit_sot} +{indent(allocate_mem_to_nit_sot, " " * 4)} + for i in range(n_steps): inner_args = {create_tuple_string(inner_in_indexed)} {create_tuple_string(inner_out_indexed)} = numba_at_inner_func(*inner_args) {while_logic} - return {create_arg_string( - outer_in_mit_sot_names + - outer_in_sit_sot_names + - outer_in_nit_sot_names - )} + return {create_arg_string(output_names)} """ scalar_op_fn = compile_function_src( scan_op_src, "scan", {**globals(), **global_env} From e7b8e9bc28aba1dcf2c3bb5d5a90c897e25ccd8b Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 21 Sep 2022 14:59:20 -0500 Subject: [PATCH 14/21] Fix storage handling in numba_funcify_Scan --- aesara/link/numba/dispatch/scan.py | 366 ++++++++++++++++++++--------- tests/link/numba/test_scan.py | 118 +++++++++- 2 files changed, 368 insertions(+), 116 deletions(-) diff --git a/aesara/link/numba/dispatch/scan.py b/aesara/link/numba/dispatch/scan.py index 555b38a587..564615b1e3 100644 --- a/aesara/link/numba/dispatch/scan.py +++ b/aesara/link/numba/dispatch/scan.py @@ -1,10 +1,11 @@ +from itertools import groupby from textwrap import dedent, indent +from typing import Dict, List, Optional, Tuple import numpy as np from numba import types from numba.extending import overload -from aesara.graph.fg import FunctionGraph from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.dispatch.basic import ( create_arg_string, @@ -15,13 +16,23 @@ from aesara.scan.op import Scan -def idx_to_str(idx): - res = "[i" - if idx < 0: - res += str(idx) - elif idx > 0: - res += "+" + str(idx) - return res + "]" +def idx_to_str( + array_name: str, offset: int, size: Optional[str] = None, idx_symbol: str = "i" +) -> str: + if offset < 0: + indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}" + elif offset > 0: + indices = f"{idx_symbol} + {offset}" + else: + indices = idx_symbol + + if size: + # TODO FIXME: The `Scan` `Op` should tell us which outputs are computed + # in this way. We shouldn't have to waste run-time efforts in order to + # compensate for this poor `Op`/rewrite design and implementation. + indices = f"({indices}) % {size}" + + return f"{array_name}[{indices}]" @overload(range) @@ -36,124 +47,267 @@ def range_arr(x): @numba_funcify.register(Scan) def numba_funcify_Scan(op, node, **kwargs): - inner_fg = FunctionGraph(op.inner_inputs, op.inner_outputs) - numba_at_inner_func = numba_basic.numba_njit(numba_funcify(inner_fg, **kwargs)) + scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) n_seqs = op.info.n_seqs - n_mit_mot = op.info.n_mit_mot - n_mit_sot = op.info.n_mit_sot - n_nit_sot = op.info.n_nit_sot - n_sit_sot = op.info.n_sit_sot - tap_array = op.info.tap_array - n_shared_outs = op.info.n_shared_outs - mit_mot_in_taps = tuple(tap_array[:n_mit_mot]) - mit_sot_in_taps = tuple(tap_array[n_mit_mot : n_mit_mot + n_mit_sot]) - - p_in_mit_mot = n_seqs - p_in_mit_sot = p_in_mit_mot + n_mit_mot - p_in_sit_sot = p_in_mit_sot + n_mit_sot - p_outer_in_shared = p_in_sit_sot + n_sit_sot - p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs - p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot - - input_names = [f"outer_in_{i}" for i, n in enumerate(node.inputs[1:])] - outer_in_seqs_names = input_names[:n_seqs] - outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot] - outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot] - outer_in_sit_sot_names = input_names[p_in_sit_sot : p_in_sit_sot + n_sit_sot] - outer_in_shared_names = input_names[ - p_outer_in_shared : p_outer_in_shared + n_shared_outs - ] - outer_in_nit_sot_names = input_names[ - p_outer_in_nit_sot : p_outer_in_nit_sot + n_nit_sot - ] - outer_in_feedback_names = input_names[n_seqs:p_outer_in_non_seqs] - outer_in_non_seqs_names = input_names[p_outer_in_non_seqs:] - inner_in_indexed = [] - allocate_mem_to_nit_sot = "" + outer_in_names_to_vars = { + (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs) + } + outer_in_names = list(outer_in_names_to_vars.keys()) + outer_in_seqs_names = op.outer_seqs(outer_in_names) + outer_in_mit_mot_names = op.outer_mitmot(outer_in_names) + outer_in_mit_sot_names = op.outer_mitsot(outer_in_names) + outer_in_sit_sot_names = op.outer_sitsot(outer_in_names) + outer_in_nit_sot_names = op.outer_nitsot(outer_in_names) + outer_in_outtap_names = ( + outer_in_mit_mot_names + + outer_in_mit_sot_names + + outer_in_sit_sot_names + + outer_in_nit_sot_names + ) + outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names) + + inner_in_to_index_offset: List[Tuple[str, Optional[int], Optional[int]]] = [] + allocate_taps_storage: List[str] = [] - for _name in outer_in_seqs_names: + for outer_in_name in outer_in_seqs_names: # A sequence with multiple taps is provided as multiple modified input # sequences--all sliced so as to keep following the logic of a normal # sequence. - inner_in_indexed.append(f"{_name}[i]") - - name_to_input_map = dict(zip(input_names, node.inputs[1:])) - mit_sot_name_to_taps = dict(zip(outer_in_mit_sot_names, mit_sot_in_taps)) - inner_out_name_to_index = {} - for _name in outer_in_feedback_names: - if _name in outer_in_mit_sot_names: - curr_taps = mit_sot_name_to_taps[_name] - min_tap = min(curr_taps) - - for _tap in curr_taps: - index = idx_to_str(_tap - min_tap) - inner_in_indexed.append(f"{_name}{index}") - - inner_out_name_to_index[_name] = -min_tap - - if _name in outer_in_sit_sot_names: - # Note that the outputs with single, non-`-1` taps are (e.g. `taps - # = [-2]`) are classified as mit-sot, so the code for handling - # sit-sots remains constant as follows - inner_in_indexed.append(f"{_name}[i]") - inner_out_name_to_index[_name] = 1 - - if _name in outer_in_nit_sot_names: - output_name = f"{_name}_nitsot_storage" - inner_out_name_to_index[output_name] = 0 - # In case of nit-sots we are provided the shape of the array - # instead of actual arrays (like other cases), hence we allocate - # space for the results accordingly. - curr_nit_sot_position = input_names.index(_name) - n_seqs - curr_nit_sot = inner_fg.outputs[curr_nit_sot_position] - mem_shape = ["1"] * curr_nit_sot.ndim - curr_dtype = curr_nit_sot.type.numpy_dtype.name - allocate_mem_to_nit_sot += dedent( + inner_in_to_index_offset.append((outer_in_name, 0, None)) + + inner_in_names_to_input_taps: Dict[str, Tuple[int]] = dict( + zip( + outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names, + op.info.mit_mot_in_slices + + op.info.mit_sot_in_slices + + op.info.sit_sot_in_slices, + ) + ) + inner_in_names_to_output_taps: Dict[str, Optional[Tuple[int, ...]]] = dict( + zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices) + ) + + inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))] + + # Maps storage array names to their tap values (i.e. maximum absolute tap + # value) and storage sizes + inner_out_name_to_taps_storage: List[Tuple[str, int, Optional[str]]] = [] + outer_in_to_storage_name: Dict[str, str] = {} + outer_in_sot_names = set( + outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names + ) + inner_out_post_processing_stmts: List[str] = [] + for outer_in_name in outer_in_outtap_names: + outer_in_var = outer_in_names_to_vars[outer_in_name] + + if outer_in_name in outer_in_sot_names: + if outer_in_name in outer_in_mit_mot_names: + storage_name = f"{outer_in_name}_mitmot_storage" + elif outer_in_name in outer_in_mit_sot_names: + storage_name = f"{outer_in_name}_mitsot_storage" + else: + # Note that the outputs with single, non-`-1` taps are (e.g. `taps + # = [-2]`) are classified as mit-sot, so the code for handling + # sit-sots remains constant as follows + storage_name = f"{outer_in_name}_sitsot_storage" + + output_idx = len(outer_in_to_storage_name) + outer_in_to_storage_name[outer_in_name] = storage_name + + input_taps = inner_in_names_to_input_taps[outer_in_name] + tap_storage_size = -min(input_taps) + assert tap_storage_size >= 0 + + storage_size_name = f"{outer_in_name}_len" + + for in_tap in input_taps: + tap_offset = in_tap + tap_storage_size + assert tap_offset >= 0 + # In truncated storage situations (i.e. created by + # `save_mem_new_scan`), the taps and output storage overlap, + # instead of the standard situation in which the output storage + # is large enough to contain both the initial taps values and + # the output storage. + inner_in_to_index_offset.append( + (outer_in_name, tap_offset, storage_size_name) + ) + + output_taps = inner_in_names_to_output_taps.get( + outer_in_name, [tap_storage_size] + ) + for out_tap in output_taps: + inner_out_name_to_taps_storage.append( + (storage_name, out_tap, storage_size_name) + ) + + if output_idx in node.op.destroy_map: + storage_alloc_stmt = f"{storage_name} = {outer_in_name}" + else: + storage_alloc_stmt = f"{storage_name} = np.copy({outer_in_name})" + + storage_alloc_stmt = dedent( f""" - {output_name} = [ - np.empty(({create_arg_string(mem_shape)},), dtype=np.{curr_dtype}) for i in range({_name}.item()) - ]""" + # {outer_in_var.type} + {storage_size_name} = {outer_in_name}.shape[0] + {storage_alloc_stmt} + """ + ).strip() + + allocate_taps_storage.append(storage_alloc_stmt) + + elif outer_in_name in outer_in_nit_sot_names: + # This is a special case in which there are no outer-inputs used + # for outer-output storage, so we need to create our own storage + # from scratch. + + storage_name = f"{outer_in_name}_nitsot_storage" + outer_in_to_storage_name[outer_in_name] = storage_name + + storage_size_name = f"{outer_in_name}_len" + inner_out_name_to_taps_storage.append((storage_name, 0, storage_size_name)) + + # In case of nit-sots we are provided the length of the array in + # the iteration dimension instead of actual arrays, hence we + # allocate space for the results accordingly. + curr_nit_sot_position = outer_in_names[1:].index(outer_in_name) - n_seqs + curr_nit_sot = op.inner_outputs[curr_nit_sot_position] + needs_alloc = curr_nit_sot.ndim > 0 + + storage_shape = create_tuple_string( + [storage_size_name] + ["0"] * curr_nit_sot.ndim + ) + storage_dtype = curr_nit_sot.type.numpy_dtype.name + + allocate_taps_storage.append( + dedent( + f""" + # {curr_nit_sot.type} + {storage_size_name} = to_numba_scalar({outer_in_name}) + {storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype}) + """ + ).strip() ) - # The non_seqs are passed to inner function as-is - inner_in_indexed += outer_in_non_seqs_names - inner_out_indexed = [ - _name + idx_to_str(idx) for _name, idx in inner_out_name_to_index.items() + if needs_alloc: + allocate_taps_storage.append(f"{outer_in_name}_ready = False") + + # In this case, we don't know the shape of the output storage + # array until we get some output from the inner-function. + # With the following we add delayed output storage initialization: + inner_out_name = inner_output_names[curr_nit_sot_position] + inner_out_post_processing_stmts.append( + dedent( + f""" + if not {outer_in_name}_ready: + {storage_name} = np.empty(({storage_size_name},) + {inner_out_name}.shape, dtype=np.{storage_dtype}) + {outer_in_name}_ready = True + """ + ).strip() + ) + + # The non_seqs are passed to the inner function as-is + for name in outer_in_non_seqs_names: + inner_in_to_index_offset.append((name, None, None)) + + inner_out_storage_indexed = [ + name if taps is None else idx_to_str(name, taps, size=size) + for (name, taps, size) in inner_out_name_to_taps_storage ] - while_logic = "" + output_storage_post_processing_stmts: List[str] = [] + + for outer_in_name, grp_vals in groupby( + inner_out_name_to_taps_storage, lambda x: x[0] + ): + + _, tap_sizes, storage_sizes = zip(*grp_vals) + + tap_size = max(tap_sizes) + storage_size = storage_sizes[0] + + if op.info.as_while: + # While loops need to truncate the output storage to a length given + # by the number of iterations performed. + output_storage_post_processing_stmts.append( + dedent( + f""" + if i + {tap_size} < {storage_size}: + {storage_size} = i + {tap_size} + {outer_in_name} = {outer_in_name}[:{storage_size}] + """ + ).strip() + ) + + # Rotate the storage so that the last computed value is at the end of + # the storage array. + # This is needed when the output storage array does not have a length + # equal to the number of taps plus `n_steps`. + output_storage_post_processing_stmts.append( + dedent( + f""" + {outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) + if {outer_in_name}_shift > 0: + {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] + {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] + {outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left)) + """ + ).strip() + ) + if op.info.as_while: # The inner function will return a boolean as the last value - inner_out_indexed.append("while_flag") - while_logic += """ - if while_flag: - """ - for _name, idx in inner_out_name_to_index.items(): - while_logic += f""" - {_name} = {_name}[:i+{idx+1}] - """ - while_logic += """ - break - """ - - global_env = locals() - global_env["np"] = np + inner_out_storage_indexed.append("cond") - output_names = outer_in_mit_sot_names + outer_in_sit_sot_names - output_names += [f"{n}_nitsot_storage" for n in outer_in_nit_sot_names] + output_names = [outer_in_to_storage_name[n] for n in outer_in_outtap_names] + + # Construct the inner-input expressions + inner_inputs: List[str] = [] + for outer_in_name, tap_offset, size in inner_in_to_index_offset: + storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name) + indexed_inner_in_str = ( + idx_to_str(storage_name, tap_offset, size=size) + if tap_offset is not None + else storage_name + ) + # if outer_in_names_to_vars[outer_in_name].type.ndim - 1 <= 0: + # # Convert scalar inner-inputs to Numba scalars + # indexed_inner_in_str = f"to_numba_scalar({indexed_inner_in_str})" + inner_inputs.append(indexed_inner_in_str) + + inner_inputs = create_arg_string(inner_inputs) + inner_outputs = create_tuple_string(inner_output_names) + input_storage_block = "\n".join(allocate_taps_storage) + output_storage_post_processing_block = "\n".join( + output_storage_post_processing_stmts + ) + inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) scan_op_src = f""" -def scan(n_steps, {", ".join(input_names)}): -{indent(allocate_mem_to_nit_sot, " " * 4)} +def scan({", ".join(outer_in_names)}): + +{indent(input_storage_block, " " * 4)} + + i = 0 + cond = False + while i < n_steps and not cond: + {inner_outputs} = scan_inner_func({inner_inputs}) +{indent(inner_out_post_processing_block, " " * 8)} + {create_tuple_string(inner_out_storage_indexed)} = {inner_outputs} + i += 1 + +{indent(output_storage_post_processing_block, " " * 4)} - for i in range(n_steps): - inner_args = {create_tuple_string(inner_in_indexed)} - {create_tuple_string(inner_out_indexed)} = numba_at_inner_func(*inner_args) -{while_logic} return {create_arg_string(output_names)} """ + + global_env = { + "scan_inner_func": scan_inner_func, + "to_numba_scalar": numba_basic.to_scalar, + } + global_env["np"] = np + scalar_op_fn = compile_function_src( scan_op_src, "scan", {**globals(), **global_env} ) diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 6ef2257caa..99444ca787 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -1,27 +1,29 @@ import numpy as np +import pytest import aesara.tensor as at -from aesara import config +from aesara import config, grad +from aesara.compile.mode import Mode, get_mode from aesara.graph.fg import FunctionGraph from aesara.scan.basic import scan from aesara.scan.utils import until +from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py -rng = np.random.default_rng(42849) - - def test_scan_multiple_output(): """Test a scan implementation of a SEIR model. SEIR model definition: - S[t+1] = S[t] - B[t] - E[t+1] = E[t] +B[t] - C[t] - I[t+1] = I[t+1] + C[t] - D[t] - B[t] ~ Binom(S[t], beta) - C[t] ~ Binom(E[t], gamma) - D[t] ~ Binom(I[t], delta) + S[t+1] = S[t] - B[t] + E[t+1] = E[t] + B[t] - C[t] + I[t+1] = I[t+1] + C[t] - D[t] + + B[t] ~ Binom(S[t], beta) + C[t] ~ Binom(E[t], gamma) + D[t] ~ Binom(I[t], delta) + """ def binomln(n, k): @@ -198,3 +200,99 @@ def power_step(prior_result, x): test_input_vals = (np.array([1.0, 2.0]),) compare_numba_and_py(out_fg, test_input_vals) + + +def test_scan_save_mem_basic(): + """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite.""" + k = at.iscalar("k") + A = at.dvector("A") + + result, _ = scan( + fn=lambda prior_result, A: prior_result * A, + outputs_info=at.ones_like(A), + non_sequences=A, + n_steps=k, + ) + + numba_mode = get_mode("NUMBA") # .including("scan_save_mem") + py_mode = Mode("py").including("scan_save_mem") + + out_fg = FunctionGraph([A, k], [result]) + test_input_vals = (np.arange(10, dtype=np.int32), 2) + compare_numba_and_py( + out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode + ) + test_input_vals = (np.arange(10, dtype=np.int32), 4) + compare_numba_and_py( + out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode + ) + + +@pytest.mark.parametrize("n_steps_val", [1, 5]) +def test_scan_save_mem_2(n_steps_val): + def f_pow2(x_tm2, x_tm1): + return 2 * x_tm1 + x_tm2 + + init_x = at.dvector("init_x") + n_steps = at.iscalar("n_steps") + output, _ = scan( + f_pow2, + sequences=[], + outputs_info=[{"initial": init_x, "taps": [-2, -1]}], + non_sequences=[], + n_steps=n_steps, + ) + + state_val = np.array([1.0, 2.0]) + + numba_mode = get_mode("NUMBA") # .including("scan_save_mem") + py_mode = Mode("py").including("scan_save_mem") + + out_fg = FunctionGraph([init_x, n_steps], [output]) + test_input_vals = (state_val, n_steps_val) + + compare_numba_and_py( + out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode + ) + + +def test_grad_sitsot(): + def get_sum_of_grad(inp): + scan_outputs, updates = scan( + fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA" + ) + return grad(scan_outputs.sum(), inp).sum() + + floatX = config.floatX + inputs_test_values = [ + np.random.default_rng(utt.fetch_seed()).random(3).astype(floatX) + ] + utt.verify_grad(get_sum_of_grad, inputs_test_values, mode="NUMBA") + + +def test_mitmots_basic(): + + init_x = at.dvector() + seq = at.dvector() + + def inner_fct(seq, state_old, state_current): + return state_old * 2 + state_current + seq + + out, _ = scan( + inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]} + ) + + g_outs = grad(out.sum(), [seq, init_x]) + + numba_mode = get_mode("NUMBA").including("scan_save_mem") + py_mode = Mode("py").including("scan_save_mem") + + out_fg = FunctionGraph([seq, init_x], g_outs) + + seq_val = np.arange(3) + init_x_val = np.r_[-2, -1] + test_input_vals = (seq_val, init_x_val) + + compare_numba_and_py( + out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode + ) From 9d07ce9bcaf82f8e953df7709f21c574e2e3191b Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 4 Oct 2022 18:02:08 -0500 Subject: [PATCH 15/21] Make fgraph_to_python process constant FunctionGraph outputs correctly --- aesara/link/utils.py | 25 ++++++++++++++++++++----- tests/link/test_utils.py | 13 +++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/aesara/link/utils.py b/aesara/link/utils.py index 8bc82bafda..4fac5e9d0b 100644 --- a/aesara/link/utils.py +++ b/aesara/link/utils.py @@ -8,7 +8,7 @@ from keyword import iskeyword from operator import itemgetter from tempfile import NamedTemporaryFile -from textwrap import indent +from textwrap import dedent, indent from typing import ( TYPE_CHECKING, Any, @@ -767,6 +767,19 @@ def fgraph_to_python( assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})" body_assigns.append(f"{assign_comment_str}\n{assign_str}") + # Handle `Constant`-only outputs (these don't have associated `Apply` + # nodes, so the above isn't applicable) + for out in fgraph.outputs: + if isinstance(out, Constant): + local_input_name = unique_name(out) + if local_input_name not in global_env: + global_env[local_input_name] = type_conversion_fn( + storage_map[out][0], + variable=out, + storage=storage_map[out], + **kwargs, + ) + fgraph_input_names = [unique_name(v) for v in fgraph.inputs] fgraph_output_names = [unique_name(v) for v in fgraph.outputs] joined_body_assigns = indent("\n".join(body_assigns), " ") @@ -778,11 +791,13 @@ def fgraph_to_python( else: fgraph_return_src = ", ".join(fgraph_output_names) - fgraph_def_src = f""" -def {fgraph_name}({", ".join(fgraph_input_names)}): -{joined_body_assigns} - return {fgraph_return_src} + fgraph_def_src = dedent( + f""" + def {fgraph_name}({", ".join(fgraph_input_names)}): + {indent(joined_body_assigns, " " * 4)} + return {fgraph_return_src} """ + ).strip() if local_env is None: local_env = locals() diff --git a/tests/link/test_utils.py b/tests/link/test_utils.py index f32ab7e298..407d399552 100644 --- a/tests/link/test_utils.py +++ b/tests/link/test_utils.py @@ -13,6 +13,7 @@ unique_name_generator, ) from aesara.scalar.basic import Add, float64 +from aesara.tensor import constant from aesara.tensor.elemwise import Elemwise from aesara.tensor.type import scalar, vector from aesara.tensor.type_other import NoneConst @@ -163,6 +164,18 @@ def func(*args, op=op): ) +def test_fgraph_to_python_constant_outputs(): + """Make sure that constant outputs are handled properly.""" + + y = constant(1) + + out_fg = FunctionGraph([], [y], clone=False) + + out_py = fgraph_to_python(out_fg, to_python) + + assert out_py()[0] is y.data + + def test_unique_name_generator(): unique_names = unique_name_generator(["blah"], suffix_sep="_") From a25afa345478ec1c74d30e78e1cfffc4242db6a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 31 Aug 2022 14:16:18 -0600 Subject: [PATCH 16/21] Add gufunc signature to `RandomVariable`\'s docstrings Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- aesara/tensor/random/basic.py | 189 +++++++++++++++++++++++++++- doc/library/tensor/random/basic.rst | 2 +- 2 files changed, 189 insertions(+), 2 deletions(-) diff --git a/aesara/tensor/random/basic.py b/aesara/tensor/random/basic.py index 11824ab940..1fa6c9a00b 100644 --- a/aesara/tensor/random/basic.py +++ b/aesara/tensor/random/basic.py @@ -95,6 +95,11 @@ def __call__(self, low=0.0, high=1.0, size=None, **kwargs): The results are undefined when `high < low`. + Signature + --------- + + `(), () -> ()` + Parameters ---------- low @@ -142,6 +147,11 @@ class TriangularRV(RandomVariable): def __call__(self, left, mode, right, size=None, **kwargs): r"""Draw samples from a triangular distribution. + Signature + --------- + + `(), (), () -> ()` + Parameters ---------- left @@ -192,6 +202,11 @@ class BetaRV(RandomVariable): def __call__(self, alpha, beta, size=None, **kwargs): r"""Draw samples from a beta distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- alpha @@ -233,6 +248,11 @@ class NormalRV(RandomVariable): def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): r"""Draw samples from a normal distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- loc @@ -266,6 +286,11 @@ class StandardNormalRV(NormalRV): def __call__(self, size=None, **kwargs): """Draw samples from a standard normal distribution. + Signature + --------- + + `nil -> ()` + Parameters ---------- size @@ -303,6 +328,11 @@ class HalfNormalRV(ScipyRandomVariable): def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): r"""Draw samples from a half-normal distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- loc @@ -363,6 +393,11 @@ class LogNormalRV(RandomVariable): def __call__(self, mean=0.0, sigma=1.0, size=None, **kwargs): r"""Draw sample from a lognormal distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- mean @@ -409,6 +444,11 @@ class GammaRV(ScipyRandomVariable): def __call__(self, shape, rate, size=None, **kwargs): r"""Draw samples from a gamma distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- shape @@ -462,6 +502,11 @@ class ChiSquareRV(RandomVariable): def __call__(self, df, size=None, **kwargs): r"""Draw samples from a chisquare distribution. + Signature + --------- + + `() -> ()` + Parameters ---------- df @@ -501,6 +546,11 @@ class ParetoRV(ScipyRandomVariable): def __call__(self, b, scale=1.0, size=None, **kwargs): r"""Draw samples from a pareto distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- b @@ -552,6 +602,11 @@ def __call__( ) -> RandomVariable: r"""Draw samples from a gumbel distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- loc @@ -602,6 +657,11 @@ class ExponentialRV(RandomVariable): def __call__(self, scale=1.0, size=None, **kwargs): r"""Draw samples from an exponential distribution. + Signature + --------- + + `() -> ()` + Parameters ---------- scale @@ -640,6 +700,11 @@ class WeibullRV(RandomVariable): def __call__(self, shape, size=None, **kwargs): r"""Draw samples from a weibull distribution. + Signature + --------- + + `() -> ()` + Parameters ---------- shape @@ -679,6 +744,12 @@ class LogisticRV(RandomVariable): def __call__(self, loc=0, scale=1, size=None, **kwargs): r"""Draw samples from a logistic distribution. + Signature + --------- + + `(), () -> ()` + + Parameters ---------- loc @@ -721,6 +792,11 @@ class VonMisesRV(RandomVariable): def __call__(self, mu, kappa, size=None, **kwargs): r"""Draw samples from a von Mises distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- mu @@ -783,6 +859,11 @@ class MvNormalRV(RandomVariable): def __call__(self, mean=None, cov=None, size=None, **kwargs): r""" "Draw samples from a multivariate normal distribution. + Signature + --------- + + `(n), (n,n) -> (n)` + Parameters ---------- mean @@ -857,6 +938,11 @@ class DirichletRV(RandomVariable): def __call__(self, alphas, size=None, **kwargs): r"""Draw samples from a dirichlet distribution. + Signature + --------- + + `(k) -> (k)` + Parameters ---------- alphas @@ -917,6 +1003,11 @@ class PoissonRV(RandomVariable): def __call__(self, lam=1.0, size=None, **kwargs): r"""Draw samples from a poisson distribution. + Signature + --------- + + `() -> ()` + Parameters ---------- lam @@ -957,6 +1048,11 @@ class GeometricRV(RandomVariable): def __call__(self, p, size=None, **kwargs): r"""Draw samples from a geometric distribution. + Signature + --------- + + `() -> ()` + Parameters ---------- p @@ -994,6 +1090,11 @@ class HyperGeometricRV(RandomVariable): def __call__(self, ngood, nbad, nsample, size=None, **kwargs): r"""Draw samples from a geometric distribution. + Signature + --------- + + `(), (), () -> ()` + Parameters ---------- ngood @@ -1037,6 +1138,11 @@ class CauchyRV(ScipyRandomVariable): def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): r"""Draw samples from a Cauchy distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- loc @@ -1082,6 +1188,11 @@ class HalfCauchyRV(ScipyRandomVariable): def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): r"""Draw samples from a half-Cauchy distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- loc @@ -1131,6 +1242,11 @@ class InvGammaRV(ScipyRandomVariable): def __call__(self, shape, scale, size=None, **kwargs): r"""Draw samples from an inverse-gamma distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- shape @@ -1176,6 +1292,11 @@ class WaldRV(RandomVariable): def __call__(self, mean=1.0, scale=1.0, size=None, **kwargs): r"""Draw samples from a Wald distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- mean @@ -1218,6 +1339,11 @@ class TruncExponentialRV(ScipyRandomVariable): def __call__(self, b, loc=0.0, scale=1.0, size=None, **kwargs): r"""Draw samples from a truncated exponential distribution. + Signature + --------- + + `(), (), () -> ()` + Parameters ---------- b @@ -1273,6 +1399,11 @@ class BernoulliRV(ScipyRandomVariable): def __call__(self, p, size=None, **kwargs): r"""Draw samples from a Bernoulli distribution. + Signature + --------- + + `() -> ()` + Parameters ---------- p @@ -1315,6 +1446,12 @@ class LaplaceRV(RandomVariable): def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): r"""Draw samples from a Laplace distribution. + Signature + --------- + + `(), () -> ()` + + Parameters ---------- loc @@ -1355,6 +1492,11 @@ class BinomialRV(RandomVariable): def __call__(self, n, p, size=None, **kwargs): r"""Draw samples from a binomial distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- n @@ -1394,6 +1536,11 @@ class NegBinomialRV(ScipyRandomVariable): def __call__(self, n, p, size=None, **kwargs): r"""Draw samples from a negative binomial distribution. + Signature + --------- + + `(), () -> ()` + Parameters ---------- n @@ -1444,6 +1591,11 @@ class BetaBinomialRV(ScipyRandomVariable): def __call__(self, n, a, b, size=None, **kwargs): r"""Draw samples from a beta-binomial distribution. + Signature + --------- + + `(), (), () -> ()` + Parameters ---------- n @@ -1490,6 +1642,11 @@ class GenGammaRV(ScipyRandomVariable): def __call__(self, alpha=1.0, p=1.0, lambd=1.0, size=None, **kwargs): r"""Draw samples from a generalized gamma distribution. + Signature + --------- + + `(), (), () -> ()` + Parameters ---------- alpha @@ -1547,6 +1704,11 @@ class MultinomialRV(RandomVariable): def __call__(self, n, p, size=None, **kwargs): r"""Draw samples from a discrete multinomial distribution. + Signature + --------- + + `(), (n) -> (n)` + Parameters ---------- n @@ -1614,6 +1776,11 @@ class CategoricalRV(RandomVariable): def __call__(self, p, size=None, **kwargs): r"""Draw samples from a discrete categorical distribution. + Signature + --------- + + `(j) -> ()` + Parameters ---------- p @@ -1665,6 +1832,11 @@ class RandIntRV(RandomVariable): def __call__(self, low, high=None, size=None, **kwargs): r"""Draw samples from a discrete uniform distribution. + Signature + --------- + + `() -> ()` + Parameters ---------- low @@ -1712,6 +1884,11 @@ class IntegersRV(RandomVariable): def __call__(self, low, high=None, size=None, **kwargs): r"""Draw samples from a discrete uniform distribution. + Signature + --------- + + `() -> ()` + Parameters ---------- low @@ -1772,6 +1949,11 @@ def _infer_shape(self, size, dist_params, param_shapes=None): def __call__(self, a, size=None, replace=True, p=None, **kwargs): r"""Generate a random sample from an array. + Signature + --------- + + `(x) -> ()` + Parameters ---------- a @@ -1782,7 +1964,7 @@ def __call__(self, a, size=None, replace=True, p=None, **kwargs): k` independent samples are returned. Default is `None`, in which case a single sample is returned. replace - When ``True``, sampling is performed with replacement. + When `True`, sampling is performed with replacement. p The probabilities associated with each entry in `a`. If not given, all elements have equal probability. @@ -1832,6 +2014,11 @@ def _infer_shape(self, size, dist_params, param_shapes=None): def __call__(self, x, **kwargs): r"""Randomly permute a sequence or a range of values. + Signature + --------- + + `(x) -> (x)` + Parameters ---------- x diff --git a/doc/library/tensor/random/basic.rst b/doc/library/tensor/random/basic.rst index bc88aaf900..9dbf310193 100644 --- a/doc/library/tensor/random/basic.rst +++ b/doc/library/tensor/random/basic.rst @@ -47,7 +47,7 @@ Reference Distributions ============== -Aesara can produce :class:`RandomVariable`\s that draw samples from many different statistical distributions, using the following :class:`Op`\s. +Aesara can produce :class:`RandomVariable`\s that draw samples from many different statistical distributions, using the following :class:`Op`\s. The :class:`RandomVariable`\s behave similarly to NumPy's *Generalized Universal Functions* (or `gunfunc`): it supports "core" random variable :class:`Op`\s that map distinctly shaped inputs to potentially non-scalar outputs. We document this behavior in the following with `gufunc`-like signatures. .. autoclass:: aesara.tensor.random.basic.UniformRV :members: __call__ From 6302849bce6249739700021d4ce438a9e2ed3295 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Fri, 7 Oct 2022 14:02:59 -0500 Subject: [PATCH 17/21] Update pre-commit hooks --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a16f40178..5ad1173961 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: )$ - id: check-merge-conflict - repo: https://github.com/psf/black - rev: 22.8.0 + rev: 22.10.0 hooks: - id: black language_version: python3 @@ -47,7 +47,7 @@ repos: )$ args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable'] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.971 + rev: v0.982 hooks: - id: mypy additional_dependencies: From 6198d7ac2bd56e66e0c5b2284dddca3084f3b410 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 5 Oct 2022 23:53:50 -0500 Subject: [PATCH 18/21] Fix make_numba_random_fn RandomStateType check --- aesara/link/numba/dispatch/random.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aesara/link/numba/dispatch/random.py b/aesara/link/numba/dispatch/random.py index b41f83e221..bb968f44c8 100644 --- a/aesara/link/numba/dispatch/random.py +++ b/aesara/link/numba/dispatch/random.py @@ -20,7 +20,6 @@ ) from aesara.tensor.basic import get_vector_length from aesara.tensor.random.type import RandomStateType -from aesara.tensor.random.var import RandomStateSharedVariable class RandomStateNumbaType(types.Type): @@ -96,7 +95,7 @@ def make_numba_random_fn(node, np_random_func): The functions generated here add parameter broadcasting and the ``size`` argument to the Numba-supported scalar ``np.random`` functions. """ - if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)): + if not isinstance(node.inputs[0].type, RandomStateType): raise TypeError("Numba does not support NumPy `Generator`s") tuple_size = int(get_vector_length(node.inputs[1])) From 2799bad862981a1adc0f317ec9bf59a3ee3647bc Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Fri, 7 Oct 2022 01:06:43 -0500 Subject: [PATCH 19/21] Support updates and more input types in compare_numba_and_py --- tests/link/numba/test_basic.py | 47 ++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 801d6aab24..9b8fc75255 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -1,5 +1,6 @@ import contextlib import inspect +from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple, Union from unittest import mock import numba @@ -30,6 +31,10 @@ from aesara.tensor.elemwise import Elemwise from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +if TYPE_CHECKING: + from aesara.graph.basic import Variable + from aesara.tensor import TensorLike + class MyTypeMeta(NewTypeMeta): def filter(self, data): @@ -102,7 +107,7 @@ def compare_shape_dtype(x, y): return x.shape == y.shape and x.dtype == y.dtype -def eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode): +def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode): """Evaluate the Numba implementation in pure Python for coverage purposes.""" def py_tuple_setitem(t, i, v): @@ -167,7 +172,7 @@ def inner_vec(*args): aesara_numba_fn = function( fn_inputs, - fgraph.outputs, + fn_outputs, mode=mode, accept_inplace=True, ) @@ -175,7 +180,12 @@ def inner_vec(*args): def compare_numba_and_py( - fgraph, inputs, assert_fn=None, numba_mode=numba_mode, py_mode=py_mode + fgraph: Union[FunctionGraph, Tuple[Sequence["Variable"], Sequence["Variable"]]], + inputs: Sequence["TensorLike"], + assert_fn: Optional[Callable] = None, + numba_mode=numba_mode, + py_mode=py_mode, + updates=None, ): """Function to compare python graph output and Numba compiled output for testing equality @@ -185,13 +195,15 @@ def compare_numba_and_py( Parameters ---------- - fgraph: FunctionGraph - Aesara function Graph object - inputs: iter - Inputs for function graph - assert_fn: func, opt + fgraph + `FunctionGraph` or inputs to compare. + inputs + Numeric inputs to be passed to the compiled graphs. + assert_fn Assert function used to check for equality between python and Numba. If not - provided uses np.testing.assert_allclose + provided uses `np.testing.assert_allclose`. + updates + Updates to be passed to `aesara.function`. """ if assert_fn is None: @@ -201,25 +213,32 @@ def assert_fn(x, y): x, y ) - fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] + if isinstance(fgraph, tuple): + fn_inputs, fn_outputs = fgraph + else: + fn_inputs = fgraph.inputs + fn_outputs = fgraph.outputs + + fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)] aesara_py_fn = function( - fn_inputs, fgraph.outputs, mode=py_mode, accept_inplace=True + fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates ) py_res = aesara_py_fn(*inputs) aesara_numba_fn = function( fn_inputs, - fgraph.outputs, + fn_outputs, mode=numba_mode, accept_inplace=True, + updates=updates, ) numba_res = aesara_numba_fn(*inputs) # Get some coverage - eval_python_only(fn_inputs, fgraph, inputs, mode=numba_mode) + eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) - if len(fgraph.outputs) > 1: + if len(fn_outputs) > 1: for j, p in zip(numba_res, py_res): assert_fn(j, p) else: From 41eeeb654026470057d53a9225b845bdb4560328 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 6 Oct 2022 00:33:25 -0500 Subject: [PATCH 20/21] Add support for shared inputs in numba_funcify_Scan --- aesara/link/numba/dispatch/scan.py | 318 ++++++++++++++++------------- tests/link/numba/test_random.py | 2 +- tests/link/numba/test_scan.py | 177 +++++++++++++--- 3 files changed, 329 insertions(+), 168 deletions(-) diff --git a/aesara/link/numba/dispatch/scan.py b/aesara/link/numba/dispatch/scan.py index 564615b1e3..89af6c15e4 100644 --- a/aesara/link/numba/dispatch/scan.py +++ b/aesara/link/numba/dispatch/scan.py @@ -1,4 +1,3 @@ -from itertools import groupby from textwrap import dedent, indent from typing import Dict, List, Optional, Tuple @@ -14,6 +13,7 @@ ) from aesara.link.utils import compile_function_src from aesara.scan.op import Scan +from aesara.tensor.type import TensorType def idx_to_str( @@ -49,8 +49,6 @@ def range_arr(x): def numba_funcify_Scan(op, node, **kwargs): scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) - n_seqs = op.info.n_seqs - outer_in_names_to_vars = { (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs) } @@ -60,22 +58,63 @@ def numba_funcify_Scan(op, node, **kwargs): outer_in_mit_sot_names = op.outer_mitsot(outer_in_names) outer_in_sit_sot_names = op.outer_sitsot(outer_in_names) outer_in_nit_sot_names = op.outer_nitsot(outer_in_names) + outer_in_shared_names = op.outer_shared(outer_in_names) + outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names) + + # These are all the outer-input names that have produce outputs/have output + # taps (i.e. they have inner-outputs and corresponding outer-outputs). + # Outer-outputs are ordered as follows: + # mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs outer_in_outtap_names = ( outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names + outer_in_nit_sot_names + + outer_in_shared_names ) - outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names) - inner_in_to_index_offset: List[Tuple[str, Optional[int], Optional[int]]] = [] - allocate_taps_storage: List[str] = [] + # We create distinct variables for/references to the storage arrays for + # each output. + outer_in_to_storage_name: Dict[str, str] = {} + for outer_in_name in outer_in_mit_mot_names: + outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_mitmot_storage" + + for outer_in_name in outer_in_mit_sot_names: + outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_mitsot_storage" + + for outer_in_name in outer_in_sit_sot_names: + outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_sitsot_storage" + + for outer_in_name in outer_in_nit_sot_names: + outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage" + + for outer_in_name in outer_in_shared_names: + outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_shared_storage" + + outer_output_names = list(outer_in_to_storage_name.values()) + assert len(outer_output_names) == len(node.outputs) + + # Construct the inner-input expressions (e.g. indexed storage expressions) + # Inner-inputs are ordered as follows: + # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + + # shared-inputs + non-sequences. + inner_in_exprs: List[str] = [] + + def add_inner_in_expr( + outer_in_name: str, tap_offset: Optional[int], storage_size_var: Optional[str] + ): + """Construct an inner-input expression.""" + storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name) + indexed_inner_in_str = ( + storage_name + if tap_offset is None + else idx_to_str(storage_name, tap_offset, size=storage_size_var) + ) + inner_in_exprs.append(indexed_inner_in_str) for outer_in_name in outer_in_seqs_names: - # A sequence with multiple taps is provided as multiple modified input - # sequences--all sliced so as to keep following the logic of a normal - # sequence. - inner_in_to_index_offset.append((outer_in_name, 0, None)) + # These outer-inputs are indexed without offsets or storage wrap-around + add_inner_in_expr(outer_in_name, 0, None) inner_in_names_to_input_taps: Dict[str, Tuple[int]] = dict( zip( @@ -89,201 +128,202 @@ def numba_funcify_Scan(op, node, **kwargs): zip(outer_in_mit_mot_names, op.info.mit_mot_out_slices) ) + # Inner-outputs consist of: + # mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + + # shared-outputs [+ while-condition] inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))] - # Maps storage array names to their tap values (i.e. maximum absolute tap - # value) and storage sizes - inner_out_name_to_taps_storage: List[Tuple[str, int, Optional[str]]] = [] - outer_in_to_storage_name: Dict[str, str] = {} - outer_in_sot_names = set( - outer_in_mit_mot_names + outer_in_mit_sot_names + outer_in_sit_sot_names - ) + # inner_out_shared_names = op.inner_shared_outs(inner_output_names) + + # The assignment statements that copy inner-outputs into the outer-outputs + # storage + inner_out_to_outer_in_stmts: List[str] = [] + + # Special statements that perform storage truncation for `while`-loops and + # rotation for initially truncated storage. + output_storage_post_proc_stmts: List[str] = [] + + # In truncated storage situations (e.g. created by `save_mem_new_scan`), + # the taps and output storage overlap, instead of the standard situation in + # which the output storage is large enough to contain both the initial taps + # values and the output storage. In this truncated case, we use the + # storage array like a circular buffer, and that's why we need to track the + # storage size along with the taps length/indexing offset. + def add_output_storage_post_proc_stmt( + outer_in_name: str, tap_sizes: Tuple[int], storage_size: str + ): + + tap_size = max(tap_sizes) + + if op.info.as_while: + # While loops need to truncate the output storage to a length given + # by the number of iterations performed. + output_storage_post_proc_stmts.append( + dedent( + f""" + if i + {tap_size} < {storage_size}: + {storage_size} = i + {tap_size} + {outer_in_name} = {outer_in_name}[:{storage_size}] + """ + ).strip() + ) + + # Rotate the storage so that the last computed value is at the end of + # the storage array. + # This is needed when the output storage array does not have a length + # equal to the number of taps plus `n_steps`. + output_storage_post_proc_stmts.append( + dedent( + f""" + {outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) + if {outer_in_name}_shift > 0: + {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] + {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] + {outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left)) + """ + ).strip() + ) + + # Special in-loop statements that create (nit-sot) storage arrays after a + # single iteration is performed. This is necessary because we don't know + # the exact shapes of the storage arrays that need to be allocated until + # after an iteration is performed. inner_out_post_processing_stmts: List[str] = [] + + # Storage allocation statements + # For output storage allocated/provided by the inputs, these statements + # will either construct aliases between the input names and the entries in + # `outer_in_to_storage_name` or assign the latter to expressions that + # create copies of those storage inputs. + # In the nit-sot case, empty dummy arrays are assigned to the storage + # variables and updated later by the statements in + # `inner_out_post_processing_stmts`. + storage_alloc_stmts: List[str] = [] + for outer_in_name in outer_in_outtap_names: outer_in_var = outer_in_names_to_vars[outer_in_name] - if outer_in_name in outer_in_sot_names: - if outer_in_name in outer_in_mit_mot_names: - storage_name = f"{outer_in_name}_mitmot_storage" - elif outer_in_name in outer_in_mit_sot_names: - storage_name = f"{outer_in_name}_mitsot_storage" - else: - # Note that the outputs with single, non-`-1` taps are (e.g. `taps - # = [-2]`) are classified as mit-sot, so the code for handling - # sit-sots remains constant as follows - storage_name = f"{outer_in_name}_sitsot_storage" + if outer_in_name not in outer_in_nit_sot_names: - output_idx = len(outer_in_to_storage_name) - outer_in_to_storage_name[outer_in_name] = storage_name + storage_name = outer_in_to_storage_name[outer_in_name] - input_taps = inner_in_names_to_input_taps[outer_in_name] - tap_storage_size = -min(input_taps) - assert tap_storage_size >= 0 + is_tensor_type = isinstance(outer_in_var.type, TensorType) + if is_tensor_type: + storage_size_name = f"{outer_in_name}_len" + storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]" + input_taps = inner_in_names_to_input_taps[outer_in_name] + tap_storage_size = -min(input_taps) + assert tap_storage_size >= 0 - storage_size_name = f"{outer_in_name}_len" + for in_tap in input_taps: + tap_offset = in_tap + tap_storage_size + assert tap_offset >= 0 + add_inner_in_expr(outer_in_name, tap_offset, storage_size_name) - for in_tap in input_taps: - tap_offset = in_tap + tap_storage_size - assert tap_offset >= 0 - # In truncated storage situations (i.e. created by - # `save_mem_new_scan`), the taps and output storage overlap, - # instead of the standard situation in which the output storage - # is large enough to contain both the initial taps values and - # the output storage. - inner_in_to_index_offset.append( - (outer_in_name, tap_offset, storage_size_name) + output_taps = inner_in_names_to_output_taps.get( + outer_in_name, [tap_storage_size] ) + for out_tap in output_taps: + inner_out_to_outer_in_stmts.append( + idx_to_str(storage_name, out_tap, size=storage_size_name) + ) - output_taps = inner_in_names_to_output_taps.get( - outer_in_name, [tap_storage_size] - ) - for out_tap in output_taps: - inner_out_name_to_taps_storage.append( - (storage_name, out_tap, storage_size_name) + add_output_storage_post_proc_stmt( + storage_name, output_taps, storage_size_name ) - if output_idx in node.op.destroy_map: + else: + storage_size_stmt = "" + add_inner_in_expr(outer_in_name, None, None) + inner_out_to_outer_in_stmts.append(storage_name) + + output_idx = outer_output_names.index(storage_name) + if output_idx in node.op.destroy_map or not is_tensor_type: storage_alloc_stmt = f"{storage_name} = {outer_in_name}" else: storage_alloc_stmt = f"{storage_name} = np.copy({outer_in_name})" storage_alloc_stmt = dedent( f""" - # {outer_in_var.type} - {storage_size_name} = {outer_in_name}.shape[0] + {storage_size_stmt} {storage_alloc_stmt} """ ).strip() - allocate_taps_storage.append(storage_alloc_stmt) + storage_alloc_stmts.append(storage_alloc_stmt) + + else: + assert outer_in_name in outer_in_nit_sot_names - elif outer_in_name in outer_in_nit_sot_names: # This is a special case in which there are no outer-inputs used # for outer-output storage, so we need to create our own storage # from scratch. - - storage_name = f"{outer_in_name}_nitsot_storage" - outer_in_to_storage_name[outer_in_name] = storage_name - + storage_name = outer_in_to_storage_name[outer_in_name] storage_size_name = f"{outer_in_name}_len" - inner_out_name_to_taps_storage.append((storage_name, 0, storage_size_name)) + + inner_out_to_outer_in_stmts.append( + idx_to_str(storage_name, 0, size=storage_size_name) + ) + add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name) # In case of nit-sots we are provided the length of the array in # the iteration dimension instead of actual arrays, hence we # allocate space for the results accordingly. - curr_nit_sot_position = outer_in_names[1:].index(outer_in_name) - n_seqs - curr_nit_sot = op.inner_outputs[curr_nit_sot_position] - needs_alloc = curr_nit_sot.ndim > 0 + curr_nit_sot_position = outer_in_nit_sot_names.index(outer_in_name) + curr_nit_sot = op.inner_nitsot_outs(op.inner_outputs)[curr_nit_sot_position] storage_shape = create_tuple_string( [storage_size_name] + ["0"] * curr_nit_sot.ndim ) storage_dtype = curr_nit_sot.type.numpy_dtype.name - allocate_taps_storage.append( + storage_alloc_stmts.append( dedent( f""" - # {curr_nit_sot.type} {storage_size_name} = to_numba_scalar({outer_in_name}) {storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype}) """ ).strip() ) - if needs_alloc: - allocate_taps_storage.append(f"{outer_in_name}_ready = False") + if curr_nit_sot.type.ndim > 0: + storage_alloc_stmts.append(f"{outer_in_name}_ready = False") # In this case, we don't know the shape of the output storage # array until we get some output from the inner-function. # With the following we add delayed output storage initialization: - inner_out_name = inner_output_names[curr_nit_sot_position] + inner_out_name = op.inner_nitsot_outs(inner_output_names)[ + curr_nit_sot_position + ] inner_out_post_processing_stmts.append( dedent( f""" if not {outer_in_name}_ready: - {storage_name} = np.empty(({storage_size_name},) + {inner_out_name}.shape, dtype=np.{storage_dtype}) + {storage_name} = np.empty(({storage_size_name},) + np.shape({inner_out_name}), dtype=np.{storage_dtype}) {outer_in_name}_ready = True """ ).strip() ) - # The non_seqs are passed to the inner function as-is for name in outer_in_non_seqs_names: - inner_in_to_index_offset.append((name, None, None)) - - inner_out_storage_indexed = [ - name if taps is None else idx_to_str(name, taps, size=size) - for (name, taps, size) in inner_out_name_to_taps_storage - ] - - output_storage_post_processing_stmts: List[str] = [] - - for outer_in_name, grp_vals in groupby( - inner_out_name_to_taps_storage, lambda x: x[0] - ): - - _, tap_sizes, storage_sizes = zip(*grp_vals) - - tap_size = max(tap_sizes) - storage_size = storage_sizes[0] - - if op.info.as_while: - # While loops need to truncate the output storage to a length given - # by the number of iterations performed. - output_storage_post_processing_stmts.append( - dedent( - f""" - if i + {tap_size} < {storage_size}: - {storage_size} = i + {tap_size} - {outer_in_name} = {outer_in_name}[:{storage_size}] - """ - ).strip() - ) - - # Rotate the storage so that the last computed value is at the end of - # the storage array. - # This is needed when the output storage array does not have a length - # equal to the number of taps plus `n_steps`. - output_storage_post_processing_stmts.append( - dedent( - f""" - {outer_in_name}_shift = (i + {tap_size}) % ({storage_size}) - if {outer_in_name}_shift > 0: - {outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift] - {outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:] - {outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left)) - """ - ).strip() - ) + add_inner_in_expr(name, None, None) if op.info.as_while: # The inner function will return a boolean as the last value - inner_out_storage_indexed.append("cond") + inner_out_to_outer_in_stmts.append("cond") - output_names = [outer_in_to_storage_name[n] for n in outer_in_outtap_names] + assert len(inner_in_exprs) == len(op.fgraph.inputs) - # Construct the inner-input expressions - inner_inputs: List[str] = [] - for outer_in_name, tap_offset, size in inner_in_to_index_offset: - storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name) - indexed_inner_in_str = ( - idx_to_str(storage_name, tap_offset, size=size) - if tap_offset is not None - else storage_name - ) - # if outer_in_names_to_vars[outer_in_name].type.ndim - 1 <= 0: - # # Convert scalar inner-inputs to Numba scalars - # indexed_inner_in_str = f"to_numba_scalar({indexed_inner_in_str})" - inner_inputs.append(indexed_inner_in_str) - - inner_inputs = create_arg_string(inner_inputs) + inner_in_args = create_arg_string(inner_in_exprs) inner_outputs = create_tuple_string(inner_output_names) - input_storage_block = "\n".join(allocate_taps_storage) - output_storage_post_processing_block = "\n".join( - output_storage_post_processing_stmts - ) + input_storage_block = "\n".join(storage_alloc_stmts) + output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts) inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) + inner_out_to_outer_out_stmts = "\n".join( + [f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)] + ) + scan_op_src = f""" def scan({", ".join(outer_in_names)}): @@ -292,14 +332,14 @@ def scan({", ".join(outer_in_names)}): i = 0 cond = False while i < n_steps and not cond: - {inner_outputs} = scan_inner_func({inner_inputs}) + {inner_outputs} = scan_inner_func({inner_in_args}) {indent(inner_out_post_processing_block, " " * 8)} - {create_tuple_string(inner_out_storage_indexed)} = {inner_outputs} +{indent(inner_out_to_outer_out_stmts, " " * 8)} i += 1 {indent(output_storage_post_processing_block, " " * 4)} - return {create_arg_string(output_names)} + return {create_arg_string(outer_output_names)} """ global_env = { diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index 25d77a5a66..b859919829 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -554,7 +554,7 @@ def test_DirichletRV(a, size, cm): a_val = a.tag.test_value # For coverage purposes only... - eval_python_only([a], FunctionGraph(outputs=[g], clone=False), [a_val]) + eval_python_only([a], [g], [a_val]) all_samples = [] for i in range(1000): diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 99444ca787..2095587bfc 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -2,15 +2,160 @@ import pytest import aesara.tensor as at -from aesara import config, grad +from aesara import config, function, grad from aesara.compile.mode import Mode, get_mode from aesara.graph.fg import FunctionGraph from aesara.scan.basic import scan +from aesara.scan.op import Scan from aesara.scan.utils import until +from aesara.tensor.random.utils import RandomStream from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py +@pytest.mark.parametrize( + "fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check", + [ + # sequences + ( + lambda a_t: 2 * a_t, + [at.dvector("a")], + [{}], + [], + None, + [np.arange(10)], + None, + lambda op: op.info.n_seqs > 0, + ), + # nit-sot + ( + lambda: at.as_tensor(2.0), + [], + [{}], + [], + 3, + [], + None, + lambda op: op.info.n_nit_sot > 0, + ), + # nit-sot, non_seq + ( + lambda c: at.as_tensor(2.0) * c, + [], + [{}], + [at.dscalar("c")], + 3, + [1.0], + None, + lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0, + ), + # sit-sot + ( + lambda a_tm1: 2 * a_tm1, + [], + [{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}], + [], + 3, + [], + None, + lambda op: op.info.n_sit_sot > 0, + ), + # sit-sot, while + ( + lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)), + [], + [{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}], + [], + 3, + [], + None, + lambda op: op.info.n_sit_sot > 0, + ), + # nit-sot, shared input/output + ( + lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal( + 0, 1, name="a" + ), + [], + [{}], + [], + 3, + [], + [np.array([-1.63408257, 0.18046406, 2.43265803])], + lambda op: op.info.n_shared_outs > 0, + ), + # mit-sot (that's also a type of sit-sot) + ( + lambda a_tm1: 2 * a_tm1, + [], + [{"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}], + [], + 6, + [], + None, + lambda op: op.info.n_mit_sot > 0, + ), + # mit-sot + ( + lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1), + [], + [ + {"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}, + {"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}, + ], + [], + 10, + [], + None, + lambda op: op.info.n_mit_sot > 0, + ), + ], +) +def test_xit_xot_types( + fn, + sequences, + outputs_info, + non_sequences, + n_steps, + input_vals, + output_vals, + op_check, +): + """Test basic xit-xot configurations.""" + res, updates = scan( + fn, + sequences=sequences, + outputs_info=outputs_info, + non_sequences=non_sequences, + n_steps=n_steps, + strict=True, + mode=Mode(linker="py", optimizer=None), + ) + + if not isinstance(res, list): + res = [res] + + # Get rid of any `Subtensor` indexing on the `Scan` outputs + res = [r.owner.inputs[0] if not isinstance(r.owner.op, Scan) else r for r in res] + + scan_op = res[0].owner.op + assert isinstance(scan_op, Scan) + + _ = op_check(scan_op) + + if output_vals is None: + compare_numba_and_py( + (sequences + non_sequences, res), input_vals, updates=updates + ) + else: + numba_mode = get_mode("NUMBA") + numba_fn = function( + sequences + non_sequences, res, mode=numba_mode, updates=updates + ) + res_val = numba_fn(*input_vals) + assert np.allclose(res_val, output_vals) + + def test_scan_multiple_output(): """Test a scan implementation of a SEIR model. @@ -202,34 +347,10 @@ def power_step(prior_result, x): compare_numba_and_py(out_fg, test_input_vals) -def test_scan_save_mem_basic(): +@pytest.mark.parametrize("n_steps_val", [1, 5]) +def test_scan_save_mem_basic(n_steps_val): """Make sure we can handle storage changes caused by the `scan_save_mem` rewrite.""" - k = at.iscalar("k") - A = at.dvector("A") - - result, _ = scan( - fn=lambda prior_result, A: prior_result * A, - outputs_info=at.ones_like(A), - non_sequences=A, - n_steps=k, - ) - - numba_mode = get_mode("NUMBA") # .including("scan_save_mem") - py_mode = Mode("py").including("scan_save_mem") - - out_fg = FunctionGraph([A, k], [result]) - test_input_vals = (np.arange(10, dtype=np.int32), 2) - compare_numba_and_py( - out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode - ) - test_input_vals = (np.arange(10, dtype=np.int32), 4) - compare_numba_and_py( - out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode - ) - -@pytest.mark.parametrize("n_steps_val", [1, 5]) -def test_scan_save_mem_2(n_steps_val): def f_pow2(x_tm2, x_tm1): return 2 * x_tm1 + x_tm2 @@ -245,7 +366,7 @@ def f_pow2(x_tm2, x_tm1): state_val = np.array([1.0, 2.0]) - numba_mode = get_mode("NUMBA") # .including("scan_save_mem") + numba_mode = get_mode("NUMBA").including("scan_save_mem") py_mode = Mode("py").including("scan_save_mem") out_fg = FunctionGraph([init_x, n_steps], [output]) From 4ef874cd78e6230f3e330baa16d93d6a7cd866c1 Mon Sep 17 00:00:00 2001 From: Markus Schmaus Date: Mon, 10 Oct 2022 19:06:20 +0200 Subject: [PATCH 21/21] change `issubtype` to `isinstance` on metaclass --- aesara/__init__.py | 4 +- aesara/compile/builders.py | 25 ++++++------ aesara/compile/compiledir.py | 7 +--- aesara/compile/debugmode.py | 14 +++---- aesara/compile/nanguardmode.py | 5 +-- aesara/compile/ops.py | 7 ++-- aesara/gradient.py | 41 ++++++++++--------- aesara/graph/basic.py | 26 ++++++------ aesara/graph/fg.py | 6 +-- aesara/graph/op.py | 5 +-- aesara/graph/type.py | 10 ++--- aesara/issubtype.py | 13 ------ aesara/link/c/op.py | 5 +-- aesara/link/c/params_type.py | 9 ++--- aesara/link/c/type.py | 2 +- aesara/link/jax/linker.py | 7 ++-- aesara/link/numba/dispatch/basic.py | 13 +++--- aesara/link/numba/linker.py | 5 +-- aesara/raise_op.py | 11 +++-- aesara/sandbox/rng_mrg.py | 6 +-- aesara/scalar/basic.py | 12 +++--- aesara/scan/basic.py | 10 +++-- aesara/scan/op.py | 59 +++++++++++++++------------ aesara/sparse/basic.py | 29 ++++++++------ aesara/sparse/type.py | 7 ++-- aesara/tensor/basic.py | 37 ++++++++++------- aesara/tensor/blas.py | 18 ++++----- aesara/tensor/elemwise.py | 17 ++++---- aesara/tensor/math.py | 13 +++--- aesara/tensor/nlinalg.py | 5 +-- aesara/tensor/nnet/basic.py | 11 +++-- aesara/tensor/nnet/batchnorm.py | 47 +++++++++++----------- aesara/tensor/nnet/rewriting.py | 39 +++++++++++++----- aesara/tensor/random/basic.py | 13 +++--- aesara/tensor/random/op.py | 9 +++-- aesara/tensor/rewriting/basic.py | 5 +-- aesara/tensor/rewriting/shape.py | 5 +-- aesara/tensor/rewriting/subtensor.py | 19 ++++----- aesara/tensor/shape.py | 15 ++++--- aesara/tensor/subtensor.py | 60 ++++++++++++++++------------ aesara/tensor/type.py | 24 +++++++---- aesara/typed_list/basic.py | 25 ++++++------ aesara/typed_list/type.py | 9 ++--- tests/compile/test_builders.py | 19 +++++---- tests/graph/test_basic.py | 3 +- tests/graph/test_destroyhandler.py | 3 +- tests/graph/test_features.py | 3 +- tests/graph/test_op.py | 3 +- tests/graph/test_types.py | 3 +- tests/graph/utils.py | 5 +-- tests/scalar/test_basic.py | 6 +-- tests/sparse/test_var.py | 21 +++++----- tests/tensor/test_elemwise.py | 4 +- tests/tensor/test_merge.py | 3 +- tests/tensor/test_shape.py | 9 +++-- tests/tensor/test_subtensor.py | 5 +-- tests/test_raise_op.py | 7 ++-- 57 files changed, 418 insertions(+), 385 deletions(-) delete mode 100644 aesara/issubtype.py diff --git a/aesara/__init__.py b/aesara/__init__.py index 658a273e86..4974f303c3 100644 --- a/aesara/__init__.py +++ b/aesara/__init__.py @@ -29,8 +29,6 @@ from functools import singledispatch from typing import Any, NoReturn, Optional -from aesara.issubtype import issubtype - aesara_logger = logging.getLogger("aesara") logging_default_handler = logging.StreamHandler() @@ -153,7 +151,7 @@ def get_scalar_constant_value(v): """ # Is it necessary to test for presence of aesara.sparse at runtime? sparse = globals().get("sparse") - if sparse and issubtype(v.type, sparse.SparseTensorType): + if sparse and isinstance(v.type, sparse.SparseTensorTypeMeta): if v.owner is not None and isinstance(v.owner.op, sparse.CSM): data = v.owner.inputs[0] return tensor.get_scalar_constant_value(data) diff --git a/aesara/compile/builders.py b/aesara/compile/builders.py index 84914231b7..5f7db2135e 100644 --- a/aesara/compile/builders.py +++ b/aesara/compile/builders.py @@ -10,7 +10,7 @@ from aesara.compile.mode import optdb from aesara.compile.sharedvalue import SharedVariable from aesara.configdefaults import config -from aesara.gradient import DisconnectedType, Rop, grad +from aesara.gradient import DisconnectedTypeMeta, Rop, grad from aesara.graph.basic import ( Apply, Constant, @@ -22,11 +22,10 @@ replace_nominals_with_dummies, ) from aesara.graph.fg import FunctionGraph -from aesara.graph.null_type import NullType +from aesara.graph.null_type import NullTypeMeta from aesara.graph.op import HasInnerGraph, Op from aesara.graph.rewriting.basic import in2out, node_rewriter from aesara.graph.utils import MissingInputError -from aesara.issubtype import issubtype from aesara.tensor.rewriting.shape import ShapeFeature @@ -211,7 +210,7 @@ def _filter_grad_var(grad, inp): # # For now, this converts NullType or DisconnectedType into zeros_like. # other types are unmodified: overrider_var -> None - if issubtype(grad.type, (NullType, DisconnectedType)): + if isinstance(grad.type, (NullTypeMeta, DisconnectedTypeMeta)): if hasattr(inp, "zeros_like"): return inp.zeros_like(), grad else: @@ -222,9 +221,9 @@ def _filter_grad_var(grad, inp): @staticmethod def _filter_rop_var(inpJ, out): # mostly similar to _filter_grad_var - if issubtype(inpJ.type, NullType): + if isinstance(inpJ.type, NullTypeMeta): return out.zeros_like(), inpJ - if issubtype(inpJ.type, DisconnectedType): + if isinstance(inpJ.type, DisconnectedTypeMeta): # since R_op does not have DisconnectedType yet, we will just # make them zeros. return out.zeros_like(), None @@ -503,7 +502,7 @@ def lop_op(inps, grads): all_grads_l = list(all_grads_l) all_grads_ov_l = list(all_grads_ov_l) elif isinstance(lop_op, Variable): - if issubtype(lop_op.type, (DisconnectedType, NullType)): + if isinstance(lop_op.type, (NullTypeMeta, DisconnectedTypeMeta)): all_grads_l = [inp.zeros_like() for inp in local_inputs] all_grads_ov_l = [lop_op.type() for _ in range(inp_len)] else: @@ -530,7 +529,7 @@ def lop_op(inps, grads): all_grads_l.append(gnext) all_grads_ov_l.append(gnext_ov) elif isinstance(fn_gov, Variable): - if issubtype(fn_gov.type, (DisconnectedType, NullType)): + if isinstance(fn_gov.type, (NullTypeMeta, DisconnectedTypeMeta)): all_grads_l.append(inp.zeros_like()) all_grads_ov_l.append(fn_gov.type()) else: @@ -615,10 +614,10 @@ def _recompute_rop_op(self): all_rops_l = list(all_rops_l) all_rops_ov_l = list(all_rops_ov_l) elif isinstance(rop_op, Variable): - if issubtype(rop_op.type, NullType): + if isinstance(rop_op.type, NullTypeMeta): all_rops_l = [inp.zeros_like() for inp in local_inputs] all_rops_ov_l = [rop_op.type() for _ in range(out_len)] - elif issubtype(rop_op.type, DisconnectedType): + elif isinstance(rop_op.type, DisconnectedTypeMeta): all_rops_l = [inp.zeros_like() for inp in local_inputs] all_rops_ov_l = [None] * out_len else: @@ -645,10 +644,10 @@ def _recompute_rop_op(self): all_rops_l.append(rnext) all_rops_ov_l.append(rnext_ov) elif isinstance(fn_rov, Variable): - if issubtype(fn_rov.type, NullType): + if isinstance(fn_rov.type, NullTypeMeta): all_rops_l.append(out.zeros_like()) all_rops_ov_l.append(fn_rov.type()) - if issubtype(fn_rov.type, DisconnectedType): + if isinstance(fn_rov.type, DisconnectedTypeMeta): all_rops_l.append(out.zeros_like()) all_rops_ov_l.append(None) else: @@ -858,7 +857,7 @@ def connection_pattern(self, node): # cpmat_self &= out_is_disconnected for i, t in enumerate(self._lop_op_stypes_l): if t is not None: - if issubtype(t.type, DisconnectedType): + if isinstance(t.type, DisconnectedTypeMeta): for o in range(out_len): cpmat_self[i][o] = False for o in range(out_len): diff --git a/aesara/compile/compiledir.py b/aesara/compile/compiledir.py index 790490464c..fdcda747db 100644 --- a/aesara/compile/compiledir.py +++ b/aesara/compile/compiledir.py @@ -1,6 +1,3 @@ -from aesara import issubtype - - """ This module contains housekeeping functions for cleaning/purging the "compiledir". It is used by the "aesara-cache" CLI tool, located in the /bin folder of the repository. @@ -14,7 +11,7 @@ from aesara.configdefaults import config from aesara.graph.op import Op -from aesara.link.c.type import CType +from aesara.link.c.type import CType, CTypeMeta from aesara.utils import flatten @@ -134,7 +131,7 @@ def print_compiledir_content(): zeros_op += 1 else: types = list( - {x for x in flatten(keydata.keys) if issubtype(x, CType)} + {x for x in flatten(keydata.keys) if isinstance(x, CTypeMeta)} ) compile_start = compile_end = float("nan") for fn in os.listdir(os.path.join(compiledir, dir)): diff --git a/aesara/compile/debugmode.py b/aesara/compile/debugmode.py index 54f715a0d4..697fd14062 100644 --- a/aesara/compile/debugmode.py +++ b/aesara/compile/debugmode.py @@ -20,7 +20,6 @@ import numpy as np import aesara -from aesara import issubtype from aesara.compile.function.types import ( Function, FunctionMaker, @@ -40,6 +39,7 @@ from aesara.link.utils import map_storage, raise_with_op from aesara.printing import _debugprint from aesara.tensor import TensorType +from aesara.tensor.type import TensorTypeMeta from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function @@ -793,7 +793,7 @@ def _get_preallocated_maps( for r in considered_outputs: # There is no risk to overwrite inputs, since r does not work # inplace. - if issubtype(r.type, TensorType): + if isinstance(r.type, TensorTypeMeta): reuse_outputs[r][...] = np.asarray(def_val).astype(r.type.dtype) if reuse_outputs: @@ -806,7 +806,7 @@ def _get_preallocated_maps( if "c_contiguous" in prealloc_modes or "ALL" in prealloc_modes: c_cont_outputs = {} for r in considered_outputs: - if issubtype(r.type, TensorType): + if isinstance(r.type, TensorTypeMeta): # Build a C-contiguous buffer new_buf = r.type.value_zeros(r_vals[r].shape) assert new_buf.flags["C_CONTIGUOUS"] @@ -823,7 +823,7 @@ def _get_preallocated_maps( if "f_contiguous" in prealloc_modes or "ALL" in prealloc_modes: f_cont_outputs = {} for r in considered_outputs: - if issubtype(r.type, TensorType): + if isinstance(r.type, TensorTypeMeta): new_buf = np.zeros( shape=r_vals[r].shape, dtype=r_vals[r].dtype, order="F" ) @@ -851,7 +851,7 @@ def _get_preallocated_maps( max_ndim = 0 rev_out_broadcastable = [] for r in considered_outputs: - if issubtype(r.type, TensorType): + if isinstance(r.type, TensorTypeMeta): if max_ndim < r.ndim: rev_out_broadcastable += [True] * (r.ndim - max_ndim) max_ndim = r.ndim @@ -866,7 +866,7 @@ def _get_preallocated_maps( # Initial allocation init_strided = {} for r in considered_outputs: - if issubtype(r.type, TensorType): + if isinstance(r.type, TensorTypeMeta): # Create a buffer twice as large in every dimension, # except if broadcastable, or for dimensions above # config.DebugMode__check_preallocated_output_ndim @@ -945,7 +945,7 @@ def _get_preallocated_maps( name = f"wrong_size{tuple(shape_diff)}" for r in considered_outputs: - if issubtype(r.type, TensorType): + if isinstance(r.type, TensorTypeMeta): r_shape_diff = shape_diff[: r.ndim] out_shape = [ max((s + sd), 0) diff --git a/aesara/compile/nanguardmode.py b/aesara/compile/nanguardmode.py index de165a926a..a665d3ea50 100644 --- a/aesara/compile/nanguardmode.py +++ b/aesara/compile/nanguardmode.py @@ -5,9 +5,9 @@ import numpy as np import aesara -from aesara import issubtype from aesara.compile.mode import Mode from aesara.configdefaults import config +from aesara.tensor.random.type import RandomTypeMeta from aesara.tensor.type import discrete_dtypes @@ -31,13 +31,12 @@ def _is_numeric_value(arr, var): """ from aesara.link.c.type import _cdata_type - from aesara.tensor.random.type import RandomType if isinstance(arr, _cdata_type): return False elif isinstance(arr, (np.random.mtrand.RandomState, np.random.Generator)): return False - elif var and issubtype(var.type, RandomType): + elif var and isinstance(var.type, RandomTypeMeta): return False elif isinstance(arr, slice): return False diff --git a/aesara/compile/ops.py b/aesara/compile/ops.py index 1f3897f0d5..75a77a15a9 100644 --- a/aesara/compile/ops.py +++ b/aesara/compile/ops.py @@ -12,9 +12,8 @@ from aesara.graph.basic import Apply from aesara.graph.op import Op -from aesara.issubtype import issubtype from aesara.link.c.op import COp -from aesara.link.c.type import CType +from aesara.link.c.type import CTypeMeta def register_view_op_c_code(type, code, version=()): @@ -312,11 +311,11 @@ def numpy_dot(a, b): """ if not isinstance(itypes, (list, tuple)): itypes = [itypes] - if not all(issubtype(t, CType) for t in itypes): + if not all(isinstance(t, CTypeMeta) for t in itypes): raise TypeError("itypes has to be a list of Aesara types") if not isinstance(otypes, (list, tuple)): otypes = [otypes] - if not all(issubtype(t, CType) for t in otypes): + if not all(isinstance(t, CTypeMeta) for t in otypes): raise TypeError("otypes has to be a list of Aesara types") # make sure they are lists and not tuples diff --git a/aesara/gradient.py b/aesara/gradient.py index 50b4dc80ac..2a8aec96de 100644 --- a/aesara/gradient.py +++ b/aesara/gradient.py @@ -25,10 +25,9 @@ from aesara.configdefaults import config from aesara.graph import utils from aesara.graph.basic import Apply, NominalVariable, Variable -from aesara.graph.null_type import NullType, null_type +from aesara.graph.null_type import NullType, NullTypeMeta, null_type from aesara.graph.op import get_test_values from aesara.graph.type import NewTypeMeta, Type -from aesara.issubtype import issubtype if TYPE_CHECKING: @@ -503,7 +502,7 @@ def grad( if known_grads is None: raise ValueError("cost and known_grads can't both be None.") - if cost is not None and issubtype(cost.type, NullType): + if cost is not None and isinstance(cost.type, NullTypeMeta): raise ValueError( "Can't differentiate a NaN cost. " f"Cost is NaN because {cost.type.why_null}" @@ -567,8 +566,8 @@ def grad( " or sparse aesara variable" ) - if not issubtype( - g_var.type, (NullType, DisconnectedType) + if not isinstance( + g_var.type, (NullTypeMeta, DisconnectedTypeMeta) ) and "float" not in str(g_var.type.dtype): raise TypeError( "Gradients must always be NullType, " @@ -632,14 +631,14 @@ def handle_disconnected(var): rval: MutableSequence[Optional[Variable]] = list(_rval) for i in range(len(_rval)): - if issubtype(_rval[i].type, NullType): + if isinstance(_rval[i].type, NullTypeMeta): if null_gradients == "raise": raise NullTypeGradError( f"`grad` encountered a NaN. {_rval[i].type.why_null}" ) else: assert null_gradients == "return" - if issubtype(_rval[i].type, DisconnectedType): + if isinstance(_rval[i].type, DisconnectedTypeMeta): handle_disconnected(_rval[i]) if return_disconnected == "zero": rval[i] = _float_zeros_like(_wrt[i]) @@ -1064,7 +1063,7 @@ def access_term_cache(node): # list of bools indicating if each output is connected to the cost outputs_connected = [ - not issubtype(g.type, DisconnectedType) for g in output_grads + not isinstance(g.type, DisconnectedTypeMeta) for g in output_grads ] connection_pattern = _node_to_pattern(node) @@ -1091,7 +1090,9 @@ def access_term_cache(node): ] # List of bools indicating if each output is NullType - ograd_is_nan = [issubtype(output.type, NullType) for output in output_grads] + ograd_is_nan = [ + isinstance(output.type, NullTypeMeta) for output in output_grads + ] # List of bools indicating if each input only has NullType outputs only_connected_to_nan = [ @@ -1200,7 +1201,7 @@ def try_to_copy_if_needed(var): orig_output, new_output_grad = packed if not hasattr(orig_output, "shape"): continue - if issubtype(new_output_grad.type, DisconnectedType): + if isinstance(new_output_grad.type, DisconnectedTypeMeta): continue for orig_output_v, new_output_grad_v in get_test_values(*packed): o_shape = orig_output_v.shape @@ -1228,7 +1229,7 @@ def try_to_copy_if_needed(var): # return the sparse grad for optimization reason. # for ig, i in zip(input_grads, inputs): - # if (not issubtype(ig.type, (DisconnectedType, NullType)) and + # if (not isinstance(ig.type, (NullTypeMeta, DisconnectedTypeMeta)) and # type(ig.type) != type(i.type)): # raise ValueError( # "%s returned the wrong type for gradient terms." @@ -1249,7 +1250,9 @@ def try_to_copy_if_needed(var): if ( ograd_is_nan[out_idx] and connection_pattern[inp_idx][out_idx] - and not issubtype(input_grads[inp_idx].type, DisconnectedType) + and not isinstance( + input_grads[inp_idx].type, DisconnectedTypeMeta + ) ): input_grads[inp_idx] = output_grads[out_idx] @@ -1303,7 +1306,7 @@ def try_to_copy_if_needed(var): f"of shape {i_shape}" ) - if not issubtype(term.type, (NullType, DisconnectedType)): + if not isinstance(term.type, (NullTypeMeta, DisconnectedTypeMeta)): if term.type.dtype not in aesara.tensor.type.float_dtypes: raise TypeError( str(node.op) + ".grad illegally " @@ -1312,7 +1315,7 @@ def try_to_copy_if_needed(var): ) if only_connected_to_nan[i]: - assert issubtype(term.type, NullType) + assert isinstance(term.type, NullTypeMeta) if only_connected_to_int[i]: # This term has only integer outputs and we know @@ -1348,7 +1351,7 @@ def try_to_copy_if_needed(var): for i, (ipt, ig, connected) in enumerate( zip(inputs, input_grads, inputs_connected) ): - actually_connected = not issubtype(ig.type, DisconnectedType) + actually_connected = not isinstance(ig.type, DisconnectedTypeMeta) if actually_connected and not connected: msg = ( @@ -1395,12 +1398,12 @@ def access_grad_cache(var): " Variable instance." ) - if issubtype(term.type, NullType): + if isinstance(term.type, NullTypeMeta): null_terms.append(term) continue # Don't try to sum up DisconnectedType placeholders - if issubtype(term.type, DisconnectedType): + if isinstance(term.type, DisconnectedTypeMeta): continue if hasattr(var, "ndim") and term.ndim != var.ndim: @@ -2102,9 +2105,9 @@ def _is_zero(x): """ if not hasattr(x, "type"): return np.all(x == 0.0) - if issubtype(x.type, NullType): + if isinstance(x.type, NullTypeMeta): return "no" - if issubtype(x.type, DisconnectedType): + if isinstance(x.type, DisconnectedTypeMeta): return "yes" no_constant_value = True diff --git a/aesara/graph/basic.py b/aesara/graph/basic.py index ae97742432..b314201778 100644 --- a/aesara/graph/basic.py +++ b/aesara/graph/basic.py @@ -44,12 +44,12 @@ if TYPE_CHECKING: from aesara.graph.op import Op - from aesara.graph.type import Type + from aesara.graph.type import Type, NewTypeMeta OpType = TypeVar("OpType", bound="Op") OptionalApplyType = TypeVar("OptionalApplyType", None, "Apply", covariant=True) -_TypeType = TypeVar("_TypeType", bound="Type") +_TypeMeta = TypeVar("_TypeMeta", bound="NewTypeMeta") _IdType = TypeVar("_IdType", bound=Hashable) T = TypeVar("T", bound="Node") @@ -315,7 +315,7 @@ def params_type(self): return self.op.params_type -class Variable(Node, Generic[_TypeType, OptionalApplyType]): +class Variable(Node, Generic[_TypeMeta, OptionalApplyType]): r""" A :term:`Variable` is a node in an expression graph that represents a variable. @@ -433,7 +433,7 @@ def index(self, value): def __init__( self, - type: _TypeType, + type: _TypeMeta, owner: OptionalApplyType, index: Optional[int] = None, name: Optional[str] = None, @@ -618,10 +618,10 @@ def __getstate__(self): return d -class AtomicVariable(Variable[_TypeType, None]): +class AtomicVariable(Variable[_TypeMeta, None]): """A node type that has no ancestors and should never be considered an input to a graph.""" - def __init__(self, type: _TypeType, **kwargs): + def __init__(self, type: _TypeMeta, **kwargs): super().__init__(type, None, None, **kwargs) @abc.abstractmethod @@ -657,12 +657,12 @@ def index(self, value): raise ValueError("AtomicVariable instances cannot have an index.") -class NominalVariable(AtomicVariable[_TypeType]): +class NominalVariable(AtomicVariable[_TypeMeta]): """A variable that enables alpha-equivalent comparisons.""" - __instances__: Dict[Tuple["Type", Hashable], "NominalVariable"] = {} + __instances__: Dict[Tuple["NewTypeMeta", Hashable], "NominalVariable"] = {} - def __new__(cls, id: _IdType, typ: _TypeType, **kwargs): + def __new__(cls, id: _IdType, typ: _TypeMeta, **kwargs): if (typ, id) not in cls.__instances__: var_type = typ.variable_type type_name = f"Nominal{var_type.__name__}" @@ -682,7 +682,7 @@ def _str(self): return cls.__instances__[(typ, id)] - def __init__(self, id: _IdType, typ: _TypeType, **kwargs): + def __init__(self, id: _IdType, typ: _TypeMeta, **kwargs): self.id = id super().__init__(typ, **kwargs) @@ -705,11 +705,11 @@ def __hash__(self): def __repr__(self): return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})" - def signature(self) -> Tuple[_TypeType, _IdType]: + def signature(self) -> Tuple[_TypeMeta, _IdType]: return (self.type, self.id) -class Constant(AtomicVariable[_TypeType]): +class Constant(AtomicVariable[_TypeMeta]): """A `Variable` with a fixed `data` field. `Constant` nodes make numerous optimizations possible (e.g. constant @@ -724,7 +724,7 @@ class Constant(AtomicVariable[_TypeType]): # __slots__ = ['data'] - def __init__(self, type: _TypeType, data: Any, name: Optional[str] = None): + def __init__(self, type: _TypeMeta, data: Any, name: Optional[str] = None): super().__init__(type, name=name) self.data = type.filter(data) add_tag_trace(self) diff --git a/aesara/graph/fg.py b/aesara/graph/fg.py index dc538e3158..1bfdbea53d 100644 --- a/aesara/graph/fg.py +++ b/aesara/graph/fg.py @@ -23,8 +23,8 @@ from aesara.graph.basic import as_string as graph_as_string from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate +from aesara.graph.null_type import NullTypeMeta from aesara.graph.utils import MetaObject, MissingInputError, TestValueError -from aesara.issubtype import issubtype from aesara.misc.ordered_set import OrderedSet @@ -308,9 +308,9 @@ def import_var( and not isinstance(var, AtomicVariable) and var not in self.inputs ): - from aesara.graph.null_type import NullType + pass - if issubtype(var.type, NullType): + if isinstance(var.type, NullTypeMeta): raise TypeError( f"Computation graph contains a NaN. {var.type.why_null}" ) diff --git a/aesara/graph/op.py b/aesara/graph/op.py index 60dee3cc56..72a6623b5d 100644 --- a/aesara/graph/op.py +++ b/aesara/graph/op.py @@ -29,8 +29,7 @@ add_tag_trace, get_variable_trace_string, ) -from aesara.issubtype import issubtype -from aesara.link.c.params_type import Params, ParamsType +from aesara.link.c.params_type import Params, ParamsType, ParamsTypeMeta if TYPE_CHECKING: @@ -478,7 +477,7 @@ def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool: def get_params(self, node: Apply) -> Params: """Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`.""" - if issubtype(self.params_type, ParamsType): + if isinstance(self.params_type, ParamsTypeMeta): wrapper = self.params_type if not all(hasattr(self, field) for field in wrapper.fields): # Let's print missing attributes for debugging. diff --git a/aesara/graph/type.py b/aesara/graph/type.py index 8b6699e1b7..04c6870130 100644 --- a/aesara/graph/type.py +++ b/aesara/graph/type.py @@ -50,7 +50,7 @@ class NewTypeMeta(ABCMeta): """ _prop_names: tuple[str, ...] = tuple() - _subclass_cache = dict() + _subclass_cache: dict[Any, "NewTypeMeta"] = dict() _base_type: Optional["NewTypeMeta"] = None _type_parameters: dict[str, Any] = dict() @@ -112,7 +112,7 @@ def subtype_params(cls, params): NewTypeMeta._subclass_cache[key] = res return res - def __call__(self, name: Optional[Text] = None) -> Any: + def __call__(self, name: Optional[Text] = None) -> Any: # type: ignore """Return a new `Variable` instance of Type `self`. Parameters @@ -132,7 +132,7 @@ def type_parameters(cls, *args, **kwargs): def create(cls, **kwargs): MetaType(f"{cls.__name__}[{kwargs}]", (cls,), kwargs) - def in_same_class(self, otype: "Type") -> Optional[bool]: + def in_same_class(self, otype: "NewTypeMeta") -> Optional[bool]: """Determine if another `Type` represents a subset from the same "class" of types represented by `self`. A "class" of types could be something like "float64 tensors with four @@ -152,7 +152,7 @@ def in_same_class(self, otype: "Type") -> Optional[bool]: return False - def is_super(self, otype: "Type") -> Optional[bool]: + def is_super(self, otype: "NewTypeMeta") -> Optional[bool]: """Determine if `self` is a supertype of `otype`. This method effectively implements the type relation ``>``. @@ -309,7 +309,7 @@ def make_constant(self, value: D, name: Optional[Text] = None) -> constant_type: """ return self.constant_type(type=self, data=value, name=name) - def clone(self, *args, **kwargs) -> "Type": + def clone(self, *args, **kwargs) -> "NewTypeMeta": """Clone a copy of this type with the given arguments/keyword values, if any.""" return self.subtype(*args, **kwargs) diff --git a/aesara/issubtype.py b/aesara/issubtype.py deleted file mode 100644 index e52c02ffa0..0000000000 --- a/aesara/issubtype.py +++ /dev/null @@ -1,13 +0,0 @@ -def issubtype(x, typ): - if not isinstance(typ, tuple): - typ = (typ,) - - for t in typ: - if isinstance(x, type): - if issubclass(x, t): - return True - else: - if isinstance(x, typ): - return True - - return False diff --git a/aesara/link/c/op.py b/aesara/link/c/op.py index 07a959e028..2a3b05af26 100644 --- a/aesara/link/c/op.py +++ b/aesara/link/c/op.py @@ -25,9 +25,8 @@ from aesara.graph.op import ComputeMapType, Op, StorageMapType, ThunkType from aesara.graph.type import HasDataType from aesara.graph.utils import MethodNotDefined -from aesara.issubtype import issubtype from aesara.link.c.interface import CLinkerOp -from aesara.link.c.params_type import ParamsType +from aesara.link.c.params_type import ParamsTypeMeta from aesara.utils import hash_from_code @@ -433,7 +432,7 @@ def __get_op_params(self) -> List[Tuple[str, Any]]: """ params: List[Tuple[str, Any]] = [] - if issubtype(self.params_type, ParamsType): + if isinstance(self.params_type, ParamsTypeMeta): wrapper = self.params_type params.append(("PARAMS_TYPE", wrapper.name)) for i in range(wrapper.length): diff --git a/aesara/link/c/params_type.py b/aesara/link/c/params_type.py index 7f028f5bd3..edc55b80b1 100644 --- a/aesara/link/c/params_type.py +++ b/aesara/link/c/params_type.py @@ -120,8 +120,7 @@ def __init__(value_attr1, value_attr2): from aesara.graph.type import Props from aesara.graph.utils import MethodNotDefined -from aesara.issubtype import issubtype -from aesara.link.c.type import CType, CTypeMeta, EnumType +from aesara.link.c.type import CType, CTypeMeta, EnumTypeMeta # Set of C and C++ keywords as defined (at March 2nd, 2017) in the pages below: @@ -256,7 +255,7 @@ class Params(dict): """ def __init__(self, params_type, **kwargs): - if not issubtype(params_type, ParamsType): + if not isinstance(params_type, ParamsTypeMeta): raise TypeError("Params: 1st constructor argument should be a ParamsType.") for field in params_type.fields: if field not in kwargs: @@ -369,7 +368,7 @@ def type_parameters(cls, **kwargs): ) type_instance = kwargs[attribute_name] type_name = type_instance.__class__.__name__ - if not issubtype(type_instance, CType): + if not isinstance(type_instance, CTypeMeta): raise TypeError( 'ParamsType: attribute "%s" should inherit from Aesara CType, got "%s".' % (attribute_name, type_name) @@ -382,7 +381,7 @@ def type_parameters(cls, **kwargs): params["_const_to_enum"] = {} params["_alias_to_enum"] = {} - enum_types = [t for t in params["types"] if issubtype(t, EnumType)] + enum_types = [t for t in params["types"] if isinstance(t, EnumTypeMeta)] if enum_types: # We don't want same enum names in different enum types. if sum(len(t) for t in enum_types) != len( diff --git a/aesara/link/c/type.py b/aesara/link/c/type.py index 63491606e6..820f63f26a 100644 --- a/aesara/link/c/type.py +++ b/aesara/link/c/type.py @@ -10,7 +10,7 @@ D = TypeVar("D") -T = TypeVar("T", bound=Type) +T = TypeVar("T", bound=NewTypeMeta) class CTypeMeta(NewTypeMeta, CLinkerType): diff --git a/aesara/link/jax/linker.py b/aesara/link/jax/linker.py index 615a20e166..0f63b8c2b9 100644 --- a/aesara/link/jax/linker.py +++ b/aesara/link/jax/linker.py @@ -2,7 +2,6 @@ from numpy.random import Generator, RandomState -from aesara import issubtype from aesara.compile.sharedvalue import SharedVariable, shared from aesara.graph.basic import Constant from aesara.link.basic import JITLinker @@ -13,12 +12,14 @@ class JAXLinker(JITLinker): def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from aesara.link.jax.dispatch import jax_funcify - from aesara.tensor.random.type import RandomType + from aesara.tensor.random.type import RandomTypeMeta shared_rng_inputs = [ inp for inp in fgraph.inputs - if (isinstance(inp, SharedVariable) and issubtype(inp.type, RandomType)) + if ( + isinstance(inp, SharedVariable) and isinstance(inp.type, RandomTypeMeta) + ) ] # Replace any shared RNG inputs so that their values can be updated in place diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index 8febee82a8..b24dc3b8c3 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -20,15 +20,14 @@ from aesara.compile.ops import DeepCopyOp from aesara.graph.basic import Apply, NoParams from aesara.graph.fg import FunctionGraph -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type from aesara.ifelse import IfElse -from aesara.issubtype import issubtype from aesara.link.utils import ( compile_function_src, fgraph_to_python, unique_name_generator, ) -from aesara.scalar.basic import ScalarType +from aesara.scalar.basic import ScalarTypeMeta from aesara.scalar.math import Softplus from aesara.tensor.blas import BatchedDot from aesara.tensor.math import Dot @@ -42,7 +41,7 @@ IncSubtensor, Subtensor, ) -from aesara.tensor.type import TensorType +from aesara.tensor.type import TensorTypeMeta from aesara.tensor.type_other import MakeSlice, NoneConst @@ -81,7 +80,7 @@ def get_numba_type( Return Numba scalars for zero dimensional :class:`TensorType`\s. """ - if issubtype(aesara_type, TensorType): + if isinstance(aesara_type, TensorTypeMeta): dtype = aesara_type.numpy_dtype numba_dtype = numba.from_dtype(dtype) if force_scalar or ( @@ -89,7 +88,7 @@ def get_numba_type( ): return numba_dtype return numba.types.Array(numba_dtype, aesara_type.ndim, layout) - elif issubtype(aesara_type, ScalarType): + elif isinstance(aesara_type, ScalarTypeMeta): dtype = np.dtype(aesara_type.dtype) numba_dtype = numba.from_dtype(dtype) return numba_dtype @@ -399,7 +398,7 @@ def create_index_func(node, objmode=False): ) def convert_indices(indices, entry): - if indices and issubtype(entry, Type): + if indices and isinstance(entry, NewTypeMeta): rval = indices.pop(0) return unique_names(rval) elif isinstance(entry, slice): diff --git a/aesara/link/numba/linker.py b/aesara/link/numba/linker.py index eef671aa13..c9abfc72bf 100644 --- a/aesara/link/numba/linker.py +++ b/aesara/link/numba/linker.py @@ -3,7 +3,6 @@ import numpy as np import aesara -from aesara import issubtype from aesara.link.basic import JITLinker @@ -15,8 +14,8 @@ class NumbaLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" def output_filter(self, var: "Variable", out: Any) -> Any: - if not isinstance(var, np.ndarray) and issubtype( - var.type, aesara.tensor.TensorType + if not isinstance(var, np.ndarray) and isinstance( + var.type, aesara.tensor.TensorTypeMeta ): return np.asarray(out, dtype=var.type.dtype) diff --git a/aesara/raise_op.py b/aesara/raise_op.py index f062c7c274..81b5d39609 100644 --- a/aesara/raise_op.py +++ b/aesara/raise_op.py @@ -7,12 +7,11 @@ from aesara.gradient import DisconnectedType from aesara.graph.basic import Apply, Variable -from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.link.c.type import Generic, GenericMeta -from aesara.scalar.basic import ScalarType -from aesara.tensor.type import DenseTensorType +from aesara.scalar.basic import ScalarTypeMeta +from aesara.tensor.type import DenseTensorTypeMeta class ExceptionTypeMeta(GenericMeta): @@ -107,7 +106,7 @@ def connection_pattern(self, node): return [[1]] + [[0]] * (len(node.inputs) - 1) def c_code(self, node, name, inames, onames, props): - if not issubtype(node.inputs[0].type, (DenseTensorType, ScalarType)): + if not isinstance(node.inputs[0].type, (DenseTensorTypeMeta, ScalarTypeMeta)): raise NotImplementedError( f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}" ) @@ -119,7 +118,7 @@ def c_code(self, node, name, inames, onames, props): msg = self.msg.replace('"', '\\"').replace("\n", "\\n") for idx, cond_name in enumerate(cond_names): - if issubtype(node.inputs[0].type, DenseTensorType): + if isinstance(node.inputs[0].type, DenseTensorTypeMeta): check.append( f""" if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{ @@ -146,7 +145,7 @@ def c_code(self, node, name, inames, onames, props): check = "\n".join(check) - if issubtype(node.inputs[0].type, DenseTensorType): + if isinstance(node.inputs[0].type, DenseTensorTypeMeta): res = f""" {check} Py_XDECREF({out_name}); diff --git a/aesara/sandbox/rng_mrg.py b/aesara/sandbox/rng_mrg.py index 5bf26d746a..e62c0101b0 100644 --- a/aesara/sandbox/rng_mrg.py +++ b/aesara/sandbox/rng_mrg.py @@ -17,7 +17,7 @@ import numpy as np -from aesara import function, gradient, issubtype +from aesara import function, gradient from aesara import scalar as aes from aesara import shared from aesara import tensor as at @@ -34,7 +34,7 @@ from aesara.tensor import as_tensor_variable, cast, get_vector_length from aesara.tensor.math import cos, log, prod, sin, sqrt from aesara.tensor.shape import reshape -from aesara.tensor.type import TensorType, iscalar, ivector, lmatrix +from aesara.tensor.type import TensorType, TensorTypeMeta, iscalar, ivector, lmatrix warnings.warn( @@ -536,7 +536,7 @@ def c_support_code(self, **kwargs): def c_code(self, node, name, inp, out, sub): # If we try to use the C code here with something else than a # TensorType, something is wrong. - assert issubtype(node.inputs[0].type, TensorType) + assert isinstance(node.inputs[0].type, TensorTypeMeta) if self.output_type.dtype == "float16": # C code is not tested, fall back to Python raise NotImplementedError() diff --git a/aesara/scalar/basic.py b/aesara/scalar/basic.py index 66aaf38e28..96710ac507 100644 --- a/aesara/scalar/basic.py +++ b/aesara/scalar/basic.py @@ -30,7 +30,6 @@ from aesara.graph.rewriting.basic import MergeOptimizer from aesara.graph.type import DataType, Props from aesara.graph.utils import MetaObject, MethodNotDefined -from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.type import CType, CTypeMeta from aesara.misc.safe_asarray import _asarray @@ -282,10 +281,9 @@ class ScalarTypeMeta(CTypeMeta): """ - dtype: Props[Any] = None + dtype: Props[DataType] = None ndim = 0 shape = () - dtype: DataType @classmethod def type_parameters(cls, dtype): @@ -870,7 +868,7 @@ def constant(x, name=None, dtype=None) -> ScalarConstant: def as_scalar(x: Any, name: Optional[str] = None) -> ScalarVariable: from aesara.tensor.basic import scalar_from_tensor - from aesara.tensor.type import TensorType + from aesara.tensor.type import TensorTypeMeta if isinstance(x, Apply): if len(x.outputs) != 1: @@ -884,7 +882,7 @@ def as_scalar(x: Any, name: Optional[str] = None) -> ScalarVariable: if isinstance(x, Variable): if isinstance(x, ScalarVariable): return x - elif issubtype(x.type, TensorType) and x.type.ndim == 0: + elif isinstance(x.type, TensorTypeMeta) and x.type.ndim == 0: return scalar_from_tensor(x) else: raise TypeError(f"Cannot convert {x} to a scalar type") @@ -1125,7 +1123,7 @@ def output_types(self, types): if hasattr(self, "output_types_preference"): variables = self.output_types_preference(*types) if not isinstance(variables, (list, tuple)) or any( - not issubtype(x, CType) for x in variables + not isinstance(x, CTypeMeta) for x in variables ): raise TypeError( "output_types_preference should return a list or a tuple of types", @@ -2450,7 +2448,7 @@ def grad(self, inputs, gout): # CASTING OPERATIONS class Cast(UnaryScalarOp): def __init__(self, o_type, name=None): - if not issubtype(o_type, ScalarType): + if not isinstance(o_type, ScalarTypeMeta): raise TypeError(o_type) super().__init__(specific_out(o_type), name=name) self.o_type = o_type diff --git a/aesara/scan/basic.py b/aesara/scan/basic.py index ba87155693..9e98c1b31b 100644 --- a/aesara/scan/basic.py +++ b/aesara/scan/basic.py @@ -9,14 +9,13 @@ from aesara.graph.basic import Constant, Variable, clone_replace, graph_inputs from aesara.graph.op import get_test_value from aesara.graph.utils import MissingInputError, TestValueError -from aesara.issubtype import issubtype from aesara.scan.op import Scan, ScanInfo from aesara.scan.utils import expand_empty, safe_new, until from aesara.tensor.basic import get_scalar_constant_value from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.math import minimum from aesara.tensor.shape import shape_padleft, unbroadcast -from aesara.tensor.type import TensorType, integer_dtypes +from aesara.tensor.type import TensorTypeMeta, integer_dtypes from aesara.updates import OrderedUpdates @@ -881,7 +880,10 @@ def wrap_into_list(x): # then, if we return the output as given by the innner function # this will represent only a slice and it will have one # dimension less. - if issubtype(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1: + if ( + isinstance(inner_out.type, TensorTypeMeta) + and return_steps.get(pos, 0) != 1 + ): outputs[pos] = unbroadcast(shape_padleft(inner_out), 0) if not return_list and len(outputs) == 1: @@ -1007,7 +1009,7 @@ def wrap_into_list(x): inner_replacements[input.variable] = new_var - if issubtype(new_var.type, TensorType): + if isinstance(new_var.type, TensorTypeMeta): sit_sot_inner_inputs.append(new_var) sit_sot_scan_inputs.append( expand_empty( diff --git a/aesara/scan/op.py b/aesara/scan/op.py index 21166fe1af..aa2c575101 100644 --- a/aesara/scan/op.py +++ b/aesara/scan/op.py @@ -62,7 +62,14 @@ from aesara.compile.mode import Mode, get_default_mode, get_mode from aesara.compile.profiling import register_profiler_printer from aesara.configdefaults import config -from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined +from aesara.gradient import ( + DisconnectedType, + DisconnectedTypeMeta, + NullType, + Rop, + grad, + grad_undefined, +) from aesara.graph.basic import ( Apply, Constant, @@ -76,9 +83,9 @@ ) from aesara.graph.features import NoOutputFromInplace from aesara.graph.fg import FunctionGraph +from aesara.graph.null_type import NullTypeMeta from aesara.graph.op import HasInnerGraph, Op from aesara.graph.utils import InconsistencyError, MissingInputError -from aesara.issubtype import issubtype from aesara.link.c.basic import CLinker from aesara.link.c.exceptions import MissingGXX from aesara.link.utils import raise_with_op @@ -87,7 +94,7 @@ from aesara.tensor.basic import as_tensor_variable from aesara.tensor.math import minimum from aesara.tensor.shape import Shape_i -from aesara.tensor.type import TensorType, integer_dtypes +from aesara.tensor.type import TensorType, TensorTypeMeta, integer_dtypes from aesara.tensor.var import TensorVariable @@ -1235,7 +1242,7 @@ def make_node(self, *inputs): # strange NumPy behavior: vector_ndarray[int] return a NumPy # scalar and not a NumPy ndarray of 0 dimensions. def is_cpu_vector(s): - return issubtype(s.type, TensorType) and s.ndim == 1 + return isinstance(s.type, TensorTypeMeta) and s.ndim == 1 self.vector_seqs = [ is_cpu_vector(seq) for seq in new_inputs[1 : 1 + self.info.n_seqs] @@ -1247,7 +1254,7 @@ def is_cpu_vector(s): ] ] self.vector_outs += [ - issubtype(t.type, TensorType) and t.ndim == 0 + isinstance(t.type, TensorTypeMeta) and t.ndim == 0 for t in self.outer_nitsot_outs(self.inner_outputs) ] @@ -2575,7 +2582,9 @@ def compute_all_gradients(known_grads): info.n_seqs + pos ] - if not issubtype(dC_douts[outer_oidx].type, DisconnectedType): + if not isinstance( + dC_douts[outer_oidx].type, DisconnectedTypeMeta + ): dtypes.append(dC_douts[outer_oidx].dtype) if dtypes: new_dtype = aesara.scalar.upcast(*dtypes) @@ -2583,7 +2592,7 @@ def compute_all_gradients(known_grads): new_dtype = config.floatX dC_dXt = safe_new(Xt, dtype=new_dtype) else: - if issubtype(dC_douts[idx].type, DisconnectedType): + if isinstance(dC_douts[idx].type, DisconnectedTypeMeta): continue dC_dXt = safe_new(dC_douts[idx][0]) dC_dXts.append(dC_dXt) @@ -2598,7 +2607,7 @@ def compute_all_gradients(known_grads): known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] dc_dxts_idx += 1 else: - if issubtype(dC_douts[i].type, DisconnectedType): + if isinstance(dC_douts[i].type, DisconnectedTypeMeta): continue else: if diff_outputs[i] in known_grads: @@ -2646,10 +2655,10 @@ def compute_all_gradients(known_grads): dC_dXtm1s.append(safe_new(x)) for dx, dC_dXtm1 in enumerate(dC_dXtm1s): - if issubtype(dC_dinps_t[dx + info.n_seqs].type, NullType): + if isinstance(dC_dinps_t[dx + info.n_seqs].type, NullTypeMeta): # The accumulated gradient is undefined pass - elif issubtype(dC_dXtm1.type, NullType): + elif isinstance(dC_dXtm1.type, NullTypeMeta): # The new gradient is undefined, this makes the accumulated # gradient undefined as weell dC_dinps_t[dx + info.n_seqs] = dC_dXtm1 @@ -2678,7 +2687,7 @@ def compute_all_gradients(known_grads): outer_inp_seqs.append(nw_seq) outer_inp_seqs += [x[:-1][::-1] for x in self.outer_sitsot_outs(outs)] for x in self.outer_nitsot_outs(dC_douts): - if not issubtype(x.type, DisconnectedType): + if not isinstance(x.type, DisconnectedTypeMeta): if info.as_while: # equivalent to x[:n_steps][::-1] outer_inp_seqs.append(x[n_steps - 1 :: -1]) @@ -2738,7 +2747,7 @@ def compute_all_gradients(known_grads): n_mitmot_inps = 0 for idx, taps in enumerate(info.mit_mot_in_slices): - if issubtype(dC_douts[idx].type, DisconnectedType): + if isinstance(dC_douts[idx].type, DisconnectedTypeMeta): out = outs[idx] outer_inp_mitmot.append(at.zeros_like(out)) else: @@ -2763,7 +2772,7 @@ def compute_all_gradients(known_grads): if tap not in mitmot_inp_taps[idx]: inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs]) - if issubtype(dC_dinps_t[ins_pos].type, NullType): + if isinstance(dC_dinps_t[ins_pos].type, NullTypeMeta): # We cannot use Null in the inner graph, so we # use a zero tensor of the appropriate shape instead. inner_out_mitmot.append( @@ -2815,7 +2824,7 @@ def compute_all_gradients(known_grads): offset = info.n_mit_mot for idx, taps in enumerate(info.mit_sot_in_slices): - if issubtype(dC_douts[idx + offset].type, DisconnectedType): + if isinstance(dC_douts[idx + offset].type, DisconnectedTypeMeta): outer_inp_mitmot.append(outs[idx + offset].zeros_like()) else: outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) @@ -2832,7 +2841,7 @@ def compute_all_gradients(known_grads): tap = -tap inner_inp_mitmot.append(dC_dXtm1s[ins_pos - info.n_seqs]) - if issubtype(dC_dinps_t[ins_pos].type, NullType): + if isinstance(dC_dinps_t[ins_pos].type, NullTypeMeta): # We cannot use Null in the inner graph, so we # use a zero tensor of the appropriate shape instead. inner_out_mitmot.append( @@ -2870,10 +2879,10 @@ def compute_all_gradients(known_grads): mitmot_inp_taps.append([0, 1]) mitmot_out_taps.append([1]) through_shared = False - if not issubtype(dC_douts[idx + offset].type, DisconnectedType): + if not isinstance(dC_douts[idx + offset].type, DisconnectedTypeMeta): outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) else: - if issubtype(dC_dinps_t[ins_pos].type, NullType): + if isinstance(dC_dinps_t[ins_pos].type, NullTypeMeta): # Cannot use dC_dinps_t[ins_pos].dtype, so we use # floatX instead, as it is a dummy value that will not # be used anyway. @@ -2887,7 +2896,7 @@ def compute_all_gradients(known_grads): ) ) - if issubtype(dC_dinps_t[ins_pos].type, NullType): + if isinstance(dC_dinps_t[ins_pos].type, NullTypeMeta): # We cannot use Null in the inner graph, so we # use a zero tensor of the appropriate shape instead. inner_out_mitmot.append( @@ -2901,7 +2910,7 @@ def compute_all_gradients(known_grads): for _sh in self.inner_shared(self_inputs) ) - if issubtype(dC_dinps_t[ins_pos].type, NullType): + if isinstance(dC_dinps_t[ins_pos].type, NullTypeMeta): type_outs.append(dC_dinps_t[ins_pos].type.why_null) elif through_shared: type_outs.append("through_shared") @@ -2927,7 +2936,7 @@ def compute_all_gradients(known_grads): for _sh in self.inner_shared(self_inputs): if _sh in graph_inputs([vl]): through_shared = True - if issubtype(vl.type, NullType): + if isinstance(vl.type, NullTypeMeta): type_outs.append(vl.type.why_null) # Replace the inner output with a zero tensor of # the right shape @@ -2946,7 +2955,7 @@ def compute_all_gradients(known_grads): for _sh in self.inner_shared(self_inputs): if _sh in graph_inputs([vl]): through_shared = True - if issubtype(vl.type, NullType): + if isinstance(vl.type, NullTypeMeta): type_outs.append(vl.type.why_null) # Replace the inner output with a zero tensor of # the right shape @@ -2965,7 +2974,7 @@ def compute_all_gradients(known_grads): outer_inp_sitsot = [] for _idx, y in enumerate(inner_inp_sitsot): x = self.outer_non_seqs(inputs)[_idx] - if issubtype(y.type, NullType): + if isinstance(y.type, NullTypeMeta): # Cannot use dC_dXtm1s.dtype, so we use floatX instead. outer_inp_sitsot.append( at.zeros( @@ -3106,7 +3115,7 @@ def compute_all_gradients(known_grads): disconnected = True connected_flags = self.connection_pattern(node)[idx + start] for dC_dout, connected in zip(dC_douts, connected_flags): - if not issubtype(dC_dout.type, DisconnectedType) and connected: + if not isinstance(dC_dout.type, DisconnectedTypeMeta) and connected: disconnected = False if disconnected: gradients.append(DisconnectedType.subtype()()) @@ -3149,8 +3158,8 @@ def compute_all_gradients(known_grads): for idx in range(len(gradients)): disconnected = True for kdx in range(len(node.outputs)): - if connection_pattern[idx][kdx] and not issubtype( - dC_douts[kdx].type, DisconnectedType + if connection_pattern[idx][kdx] and not isinstance( + dC_douts[kdx].type, DisconnectedTypeMeta ): disconnected = False if disconnected: diff --git a/aesara/sparse/basic.py b/aesara/sparse/basic.py index cc4407f1e2..e83bc00449 100644 --- a/aesara/sparse/basic.py +++ b/aesara/sparse/basic.py @@ -16,14 +16,18 @@ import aesara from aesara import scalar as aes from aesara.configdefaults import config -from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined +from aesara.gradient import ( + DisconnectedType, + DisconnectedTypeMeta, + grad_not_implemented, + grad_undefined, +) from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.op import Op -from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.type import generic from aesara.misc.safe_asarray import _asarray -from aesara.sparse.type import SparseTensorType, _is_sparse +from aesara.sparse.type import SparseTensorType, SparseTensorTypeMeta, _is_sparse from aesara.sparse.utils import hash_from_sparse from aesara.tensor import basic as at from aesara.tensor.basic import Split @@ -47,7 +51,7 @@ trunc, ) from aesara.tensor.shape import shape, specify_broadcastable -from aesara.tensor.type import TensorType +from aesara.tensor.type import TensorType, TensorTypeMeta from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes from aesara.tensor.type import iscalar, ivector, scalar, tensor, vector @@ -86,7 +90,7 @@ def _is_sparse_variable(x): "or TensorType, for instance), not ", x, ) - return issubtype(x.type, SparseTensorType) + return isinstance(x.type, SparseTensorTypeMeta) def _is_dense_variable(x): @@ -106,7 +110,7 @@ def _is_dense_variable(x): "TensorType, for instance), not ", x, ) - return issubtype(x.type, TensorType) + return isinstance(x.type, TensorTypeMeta) def _is_dense(x): @@ -161,7 +165,7 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs): else: x = x.outputs[0] if isinstance(x, Variable): - if not issubtype(x.type, SparseTensorType): + if not isinstance(x.type, SparseTensorTypeMeta): raise TypeError( "Variable type field must be a SparseTensorType.", x, x.type ) @@ -265,7 +269,8 @@ def to_dense(self, *args, **kwargs): self = self.toarray() new_args = [ arg.toarray() - if hasattr(arg, "type") and issubtype(arg.type, SparseTensorType) + if hasattr(arg, "type") + and isinstance(arg.type, SparseTensorTypeMeta) else arg for arg in args ] @@ -619,7 +624,7 @@ def grad(self, inputs, g): # g[1:] is connected, or this grad method wouldn't have been # called, so we should report zeros (csm,) = inputs - if issubtype(g[0].type, DisconnectedType): + if isinstance(g[0].type, DisconnectedTypeMeta): return [csm.zeros_like()] data, indices, indptr, _shape = csm_properties(csm) @@ -981,7 +986,7 @@ def __str__(self): return f"{self.__class__.__name__}{{structured_grad={self.sparse_grad}}}" def __call__(self, x): - if not issubtype(x.type, SparseTensorType): + if not isinstance(x.type, SparseTensorTypeMeta): return x return super().__call__(x) @@ -1053,7 +1058,7 @@ def __str__(self): return f"{self.__class__.__name__}{{{self.format}}}" def __call__(self, x): - if issubtype(x.type, SparseTensorType): + if isinstance(x.type, SparseTensorTypeMeta): return x return super().__call__(x) @@ -3499,7 +3504,7 @@ def perform(self, node, inputs, outputs): ) variable = a * b - if issubtype(node.outputs[0].type, SparseTensorType): + if isinstance(node.outputs[0].type, SparseTensorTypeMeta): assert _is_sparse(variable) out[0] = variable return diff --git a/aesara/sparse/type.py b/aesara/sparse/type.py index 9e03a14e88..eed8706ace 100644 --- a/aesara/sparse/type.py +++ b/aesara/sparse/type.py @@ -8,8 +8,7 @@ from aesara import scalar as aes from aesara.graph.basic import Variable from aesara.graph.type import Props -from aesara.issubtype import issubtype -from aesara.tensor.type import DenseTensorType, TensorType, TensorTypeMeta +from aesara.tensor.type import DenseTensorTypeMeta, TensorType, TensorTypeMeta SparsityTypes = Literal["csr", "csc", "bsr"] @@ -74,7 +73,7 @@ def type_parameters( shape: Optional[Iterable[Optional[Union[bool, int]]]] = None, name: Optional[str] = None, broadcastable: Optional[Iterable[bool]] = None, - ): + ): #type: ignore[override] if shape is None and broadcastable is None: shape = (None, None) @@ -165,7 +164,7 @@ def convert_variable(self, var): return res if not isinstance(res.type, type(self)): - if issubtype(res.type, DenseTensorType): + if isinstance(res.type, DenseTensorTypeMeta): if self.format == "csr": from aesara.sparse.basic import csr_from_dense diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index d9df5f79a9..829edc32e1 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -23,13 +23,17 @@ import aesara.scalar.sharedvar from aesara import compile, config, printing from aesara import scalar as aes -from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined +from aesara.gradient import ( + DisconnectedType, + DisconnectedTypeMeta, + grad_not_implemented, + grad_undefined, +) from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op from aesara.graph.rewriting.utils import rewrite_graph -from aesara.graph.type import Type -from aesara.issubtype import issubtype +from aesara.graph.type import NewTypeMeta from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.misc.safe_asarray import _asarray @@ -58,6 +62,7 @@ ) from aesara.tensor.type import ( TensorType, + TensorTypeMeta, discrete_dtypes, float_dtypes, int_dtypes, @@ -97,7 +102,7 @@ def _as_tensor_Scalar(x, name, ndim, **kwargs): @_as_tensor_variable.register(Variable) def _as_tensor_Variable(x, name, ndim, **kwargs): - if not issubtype(x.type, TensorType): + if not isinstance(x.type, TensorTypeMeta): raise TypeError( f"Tensor type field must be a TensorType; found {type(x.type)}." ) @@ -314,9 +319,9 @@ def get_scalar_constant_value( except ValueError: raise NotScalarConstantError() - from aesara.sparse.type import SparseTensorType + from aesara.sparse.type import SparseTensorTypeMeta - if issubtype(v.type, SparseTensorType): + if isinstance(v.type, SparseTensorTypeMeta): raise NotScalarConstantError() return data @@ -435,7 +440,7 @@ def get_scalar_constant_value( var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:] ): idx = v.owner.op.idx_list[0] - if issubtype(idx, Type): + if isinstance(idx, NewTypeMeta): idx = get_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) @@ -470,7 +475,7 @@ def get_scalar_constant_value( ): idx = v.owner.op.idx_list[0] - if issubtype(idx, Type): + if isinstance(idx, NewTypeMeta): idx = get_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) @@ -492,7 +497,7 @@ def get_scalar_constant_value( op = owner.op idx_list = op.idx_list idx = idx_list[0] - if issubtype(idx, Type): + if isinstance(idx, NewTypeMeta): idx = get_scalar_constant_value( owner.inputs[1], max_recur=max_recur ) @@ -537,7 +542,7 @@ class TensorFromScalar(COp): __props__ = () def make_node(self, s): - if not issubtype(s.type, aes.ScalarType): + if not isinstance(s.type, aes.ScalarTypeMeta): raise TypeError("Input must be a `ScalarType` `Type`") return Apply(self, [s], [tensor(dtype=s.type.dtype, shape=())]) @@ -597,7 +602,7 @@ def __call__(self, *args, **kwargs) -> ScalarVariable: return type_cast(ScalarVariable, super().__call__(*args, **kwargs)) def make_node(self, t): - if not issubtype(t.type, TensorType) or t.type.ndim > 0: + if not isinstance(t.type, TensorTypeMeta) or t.type.ndim > 0: raise TypeError("Input must be a scalar `TensorType`") return Apply( @@ -1952,7 +1957,7 @@ def grad(self, inputs, g_outputs): x, axis, n = inputs outputs = self(*inputs, return_list=True) # If all the output gradients are disconnected, then so are the inputs - if builtins.all(issubtype(g.type, DisconnectedType) for g in g_outputs): + if builtins.all(isinstance(g.type, DisconnectedTypeMeta) for g in g_outputs): return [ DisconnectedType.subtype()(), grad_undefined(self, 1, axis), @@ -1961,7 +1966,7 @@ def grad(self, inputs, g_outputs): # Else, we have to make them zeros before joining them new_g_outputs = [] for o, g in zip(outputs, g_outputs): - if issubtype(g.type, DisconnectedType): + if isinstance(g.type, DisconnectedTypeMeta): new_g_outputs.append(o.zeros_like()) else: new_g_outputs.append(g) @@ -2612,7 +2617,11 @@ def stack(*tensors, **kwargs): if all( # In case there are explicit ints in tensors isinstance(t, (np.number, float, int, builtins.complex)) - or (isinstance(t, Variable) and issubtype(t.type, TensorType) and t.ndim == 0) + or ( + isinstance(t, Variable) + and isinstance(t.type, TensorTypeMeta) + and t.ndim == 0 + ) for t in tensors ): # in case there is direct int diff --git a/aesara/tensor/blas.py b/aesara/tensor/blas.py index 3717ddb5b2..5d6db16ad3 100644 --- a/aesara/tensor/blas.py +++ b/aesara/tensor/blas.py @@ -131,8 +131,6 @@ import numpy as np -from aesara.issubtype import issubtype - try: import numpy.__config__ # noqa @@ -168,7 +166,7 @@ from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift from aesara.tensor.shape import specify_broadcastable from aesara.tensor.type import ( - DenseTensorType, + DenseTensorTypeMeta, integer_dtypes, tensor, values_eq_approx_remove_inf_nan, @@ -272,7 +270,7 @@ def make_node(self, y, alpha, A, x, beta): inputs = [y, alpha, A, x, beta] - if any(not issubtype(i.type, DenseTensorType) for i in inputs): + if any(not isinstance(i.type, DenseTensorTypeMeta) for i in inputs): raise NotImplementedError("Only dense tensor types are supported") return Apply(self, inputs, [y.type()]) @@ -374,7 +372,7 @@ def make_node(self, A, alpha, x, y): raise TypeError("only float and complex types supported", x.dtype) inputs = [A, alpha, x, y] - if any(not issubtype(i.type, DenseTensorType) for i in inputs): + if any(not isinstance(i.type, DenseTensorTypeMeta) for i in inputs): raise NotImplementedError("Only dense tensor types are supported") return Apply(self, inputs, [A.type()]) @@ -935,7 +933,7 @@ def __getstate__(self): def make_node(self, *inputs): inputs = list(map(at.as_tensor_variable, inputs)) - if any(not issubtype(i.type, DenseTensorType) for i in inputs): + if any(not isinstance(i.type, DenseTensorTypeMeta) for i in inputs): raise NotImplementedError("Only dense tensor types are supported") if len(inputs) != 5: @@ -1673,7 +1671,7 @@ def make_node(self, x, y): x = at.as_tensor_variable(x) y = at.as_tensor_variable(y) - if any(not issubtype(i.type, DenseTensorType) for i in (x, y)): + if any(not isinstance(i.type, DenseTensorTypeMeta) for i in (x, y)): raise NotImplementedError("Only dense tensor types are supported") dtypes = ("float16", "float32", "float64", "complex64", "complex128") @@ -1762,7 +1760,7 @@ def local_dot_to_dot22(fgraph, node): if not isinstance(node.op, Dot): return - if any(not issubtype(i.type, DenseTensorType) for i in node.inputs): + if any(not isinstance(i.type, DenseTensorTypeMeta) for i in node.inputs): return False x, y = node.inputs @@ -1970,7 +1968,7 @@ class Dot22Scalar(GemmRelated): def make_node(self, x, y, a): - if any(not issubtype(i.type, DenseTensorType) for i in (x, y, a)): + if any(not isinstance(i.type, DenseTensorTypeMeta) for i in (x, y, a)): raise NotImplementedError("Only dense tensor types are supported") if a.ndim != 0: @@ -2194,7 +2192,7 @@ class BatchedDot(COp): def make_node(self, *inputs): inputs = list(map(at.as_tensor_variable, inputs)) - if any(not issubtype(i.type, DenseTensorType) for i in inputs): + if any(not isinstance(i.type, DenseTensorTypeMeta) for i in inputs): raise NotImplementedError("Only dense tensor types are supported") if len(inputs) != 2: diff --git a/aesara/tensor/elemwise.py b/aesara/tensor/elemwise.py index 910681bbea..7699b1d511 100644 --- a/aesara/tensor/elemwise.py +++ b/aesara/tensor/elemwise.py @@ -5,11 +5,10 @@ import aesara.tensor.basic from aesara.configdefaults import config -from aesara.gradient import DisconnectedType +from aesara.gradient import DisconnectedTypeMeta from aesara.graph.basic import Apply -from aesara.graph.null_type import NullType +from aesara.graph.null_type import NullTypeMeta from aesara.graph.utils import MethodNotDefined -from aesara.issubtype import issubtype from aesara.link.c.basic import failure_code from aesara.link.c.op import COp, ExternalCOp, OpenMPOp from aesara.link.c.params_type import ParamsType @@ -521,7 +520,9 @@ def R_op(self, inputs, eval_points): # the right thing to do .. have to talk to Ian and James # about it - if bgrads[jdx] is None or issubtype(bgrads[jdx].type, DisconnectedType): + if bgrads[jdx] is None or isinstance( + bgrads[jdx].type, DisconnectedTypeMeta + ): pass elif eval_point is not None: if rop_out is None: @@ -557,7 +558,7 @@ def L_op(self, inputs, outs, ograds): # this op did the right thing. new_rval = [] for elem, ipt in zip(rval, inputs): - if issubtype(elem.type, (NullType, DisconnectedType)): + if isinstance(elem.type, (NullTypeMeta, DisconnectedTypeMeta)): new_rval.append(elem) else: elem = ipt.zeros_like() @@ -569,7 +570,7 @@ def L_op(self, inputs, outs, ograds): # sum out the broadcasted dimensions for i, ipt in enumerate(inputs): - if issubtype(rval[i].type, (NullType, DisconnectedType)): + if isinstance(rval[i].type, (NullTypeMeta, DisconnectedTypeMeta)): continue # List of all the dimensions that are broadcastable for input[i] so @@ -593,7 +594,7 @@ def _bgrad(self, inputs, outputs, ograds): with config.change_flags(compute_test_value="off"): def as_scalar(t): - if issubtype(t.type, (NullType, DisconnectedType)): + if isinstance(t.type, (NullTypeMeta, DisconnectedTypeMeta)): return t return get_scalar_type(t.type.dtype)() @@ -617,7 +618,7 @@ def as_scalar(t): def transform(r): # From a graph of ScalarOps, make a graph of Broadcast ops. - if issubtype(r.type, (NullType, DisconnectedType)): + if isinstance(r.type, (NullTypeMeta, DisconnectedTypeMeta)): return r if r in scalar_inputs: return inputs[scalar_inputs.index(r)] diff --git a/aesara/tensor/math.py b/aesara/tensor/math.py index 20dde349ec..f6b9fae96d 100644 --- a/aesara/tensor/math.py +++ b/aesara/tensor/math.py @@ -6,10 +6,9 @@ from aesara import config, printing from aesara import scalar as aes -from aesara.gradient import DisconnectedType +from aesara.gradient import DisconnectedType, DisconnectedTypeMeta from aesara.graph.basic import Apply, Variable from aesara.graph.op import Op -from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.link.c.type import Generic @@ -35,7 +34,7 @@ ) from aesara.tensor.shape import shape, specify_broadcastable from aesara.tensor.type import ( - DenseTensorType, + DenseTensorTypeMeta, TensorType, complex_dtypes, continuous_dtypes, @@ -303,8 +302,8 @@ def grad(self, inp, grads): axis = as_tensor_variable(self.axis) g_max, g_max_idx = grads - g_max_disconnected = issubtype(g_max.type, DisconnectedType) - g_max_idx_disconnected = issubtype(g_max_idx.type, DisconnectedType) + g_max_disconnected = isinstance(g_max.type, DisconnectedTypeMeta) + g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedTypeMeta) # if the op is totally disconnected, so are its inputs if g_max_disconnected and g_max_idx_disconnected: @@ -2090,7 +2089,9 @@ def dense_dot(a, b): """ a, b = as_tensor_variable(a), as_tensor_variable(b) - if not issubtype(a.type, DenseTensorType) or not issubtype(b.type, DenseTensorType): + if not isinstance(a.type, DenseTensorTypeMeta) or not isinstance( + b.type, DenseTensorTypeMeta + ): raise TypeError("The dense dot product is only supported for dense types") if a.ndim == 0 or b.ndim == 0: diff --git a/aesara/tensor/nlinalg.py b/aesara/tensor/nlinalg.py index f4fbcbb809..354aee2c9b 100644 --- a/aesara/tensor/nlinalg.py +++ b/aesara/tensor/nlinalg.py @@ -4,10 +4,9 @@ import numpy as np from aesara import scalar as aes -from aesara.gradient import DisconnectedType +from aesara.gradient import DisconnectedTypeMeta from aesara.graph.basic import Apply from aesara.graph.op import Op -from aesara.issubtype import issubtype from aesara.tensor import basic as at from aesara.tensor import math as tm from aesara.tensor.basic import as_tensor_variable, extract_diag @@ -326,7 +325,7 @@ def grad(self, inputs, g_outputs): def _zero_disconnected(outputs, grads): l = [] for o, g in zip(outputs, grads): - if issubtype(g.type, DisconnectedType): + if isinstance(g.type, DisconnectedTypeMeta): l.append(o.zeros_like()) else: l.append(g) diff --git a/aesara/tensor/nnet/basic.py b/aesara/tensor/nnet/basic.py index 06fb2a1643..429acd834b 100644 --- a/aesara/tensor/nnet/basic.py +++ b/aesara/tensor/nnet/basic.py @@ -15,11 +15,10 @@ import aesara from aesara import scalar as aes from aesara.compile import optdb -from aesara.gradient import DisconnectedType, grad_not_implemented +from aesara.gradient import DisconnectedType, DisconnectedTypeMeta, grad_not_implemented from aesara.graph.basic import Apply from aesara.graph.op import Op from aesara.graph.rewriting.basic import copy_stack_trace, graph_rewriter, node_rewriter -from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.raise_op import Assert from aesara.scalar import UnaryScalarOp @@ -127,7 +126,7 @@ def L_op(self, inp, outputs, grads): x, b = inp (g_sm,) = grads - if issubtype(g_sm.type, DisconnectedType): + if isinstance(g_sm.type, DisconnectedTypeMeta): return [DisconnectedType.subtype()(), DisconnectedType.subtype()()] dx = softmax_grad_legacy(g_sm, outputs[0]) @@ -1422,19 +1421,19 @@ def grad(self, inp, grads): db_terms = [] d_idx_terms = [] - if not issubtype(g_nll.type, DisconnectedType): + if not isinstance(g_nll.type, DisconnectedTypeMeta): nll, sm = crossentropy_softmax_1hot_with_bias(x, b, y_idx) dx = crossentropy_softmax_1hot_with_bias_dx(g_nll, sm, y_idx) db = at_sum(dx, axis=[0]) dx_terms.append(dx) db_terms.append(db) - if not issubtype(g_sm.type, DisconnectedType): + if not isinstance(g_sm.type, DisconnectedTypeMeta): dx, db = softmax_with_bias.L_op((x, b), [softmax_with_bias(x, b)], (g_sm,)) dx_terms.append(dx) db_terms.append(db) - if not issubtype(g_am.type, DisconnectedType): + if not isinstance(g_am.type, DisconnectedTypeMeta): dx_terms.append(x.zeros_like()) db_terms.append(b.zeros_like()) d_idx_terms.append(y_idx.zeros_like()) diff --git a/aesara/tensor/nnet/batchnorm.py b/aesara/tensor/nnet/batchnorm.py index f2a5b1e577..3cc7af68cf 100644 --- a/aesara/tensor/nnet/batchnorm.py +++ b/aesara/tensor/nnet/batchnorm.py @@ -5,7 +5,6 @@ from aesara.graph.basic import Apply from aesara.graph.op import Op from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter -from aesara.issubtype import issubtype from aesara.scalar import Composite, add, as_common_dtype, mul, sub, true_div from aesara.tensor import basic as at from aesara.tensor.basic import as_tensor_variable @@ -14,7 +13,7 @@ from aesara.tensor.math import sum as at_sum from aesara.tensor.rewriting.basic import register_specialize_device from aesara.tensor.shape import specify_broadcastable -from aesara.tensor.type import TensorType +from aesara.tensor.type import TensorTypeMeta class BNComposite(Composite): @@ -691,7 +690,7 @@ def grad(self, inp, grads): g_wrt_x_mean = 0 g_wrt_x_invstd = 0 - if not issubtype(ddinputs.type, aesara.gradient.DisconnectedType): + if not isinstance(ddinputs.type, aesara.gradient.DisconnectedTypeMeta): ccc = scale * (ddinputs - mean(ddinputs, axis=self.axes, keepdims=True)) ddd = (x_invstd**3) * ( ccc * mean(dy * x_diff, axis=self.axes, keepdims=True) @@ -722,7 +721,7 @@ def grad(self, inp, grads): keepdims=True, ) - if not issubtype(ddscale.type, aesara.gradient.DisconnectedType): + if not isinstance(ddscale.type, aesara.gradient.DisconnectedTypeMeta): g_wrt_x = g_wrt_x + (x_invstd * ddscale * dy) g_wrt_dy = g_wrt_dy + (x_invstd * ddscale * x_diff) g_wrt_x_mean = g_wrt_x_mean - ( @@ -732,7 +731,7 @@ def grad(self, inp, grads): ddscale * at_sum(dy * x_diff, axis=self.axes, keepdims=True) ) - if not issubtype(ddbias.type, aesara.gradient.DisconnectedType): + if not isinstance(ddbias.type, aesara.gradient.DisconnectedTypeMeta): g_wrt_dy = g_wrt_dy + at.fill(dy, ddbias) # depending on which output gradients are given, @@ -796,17 +795,17 @@ def local_abstract_batch_norm_train(fgraph, node): if min(axes) < 0 or max(axes) > x.ndim: return None if ( - not issubtype(x.type, TensorType) - or not issubtype(scale.type, TensorType) - or not issubtype(bias.type, TensorType) - or not issubtype(epsilon.type, TensorType) - or not issubtype(running_average_factor.type, TensorType) + not isinstance(x.type, TensorTypeMeta) + or not isinstance(scale.type, TensorTypeMeta) + or not isinstance(bias.type, TensorTypeMeta) + or not isinstance(epsilon.type, TensorTypeMeta) + or not isinstance(running_average_factor.type, TensorTypeMeta) ): return None # optional running_mean and running_var - if len(node.inputs) > 5 and not issubtype(node.inputs[5].type, TensorType): + if len(node.inputs) > 5 and not isinstance(node.inputs[5].type, TensorTypeMeta): return None - if len(node.inputs) > 6 and not issubtype(node.inputs[6].type, TensorType): + if len(node.inputs) > 6 and not isinstance(node.inputs[6].type, TensorTypeMeta): return None mean = x.mean(axes, keepdims=True) @@ -850,12 +849,12 @@ def local_abstract_batch_norm_train_grad(fgraph, node): if min(axes) < 0 or max(axes) > x.ndim: return None if ( - not issubtype(x.type, TensorType) - or not issubtype(dy.type, TensorType) - or not issubtype(scale.type, TensorType) - or not issubtype(x_mean.type, TensorType) - or not issubtype(x_invstd.type, TensorType) - or not issubtype(epsilon.type, TensorType) + not isinstance(x.type, TensorTypeMeta) + or not isinstance(dy.type, TensorTypeMeta) + or not isinstance(scale.type, TensorTypeMeta) + or not isinstance(x_mean.type, TensorTypeMeta) + or not isinstance(x_invstd.type, TensorTypeMeta) + or not isinstance(epsilon.type, TensorTypeMeta) ): return None @@ -882,12 +881,12 @@ def local_abstract_batch_norm_inference(fgraph, node): x, scale, bias, estimated_mean, estimated_variance, epsilon = node.inputs if ( - not issubtype(x.type, TensorType) - or not issubtype(scale.type, TensorType) - or not issubtype(bias.type, TensorType) - or not issubtype(estimated_mean.type, TensorType) - or not issubtype(estimated_variance.type, TensorType) - or not issubtype(epsilon.type, TensorType) + not isinstance(x.type, TensorTypeMeta) + or not isinstance(scale.type, TensorTypeMeta) + or not isinstance(bias.type, TensorTypeMeta) + or not isinstance(estimated_mean.type, TensorTypeMeta) + or not isinstance(estimated_variance.type, TensorTypeMeta) + or not isinstance(epsilon.type, TensorTypeMeta) ): return None diff --git a/aesara/tensor/nnet/rewriting.py b/aesara/tensor/nnet/rewriting.py index 682c68ed4f..b549c5fccd 100644 --- a/aesara/tensor/nnet/rewriting.py +++ b/aesara/tensor/nnet/rewriting.py @@ -13,7 +13,6 @@ in2out, node_rewriter, ) -from aesara.issubtype import issubtype from aesara.tensor.nnet.abstract_conv import ( AbstractConv2d, AbstractConv2d_gradInputs, @@ -35,7 +34,7 @@ from aesara.tensor.nnet.corr import CorrMM, CorrMM_gradInputs, CorrMM_gradWeights from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGradWeights from aesara.tensor.rewriting.basic import register_specialize_device -from aesara.tensor.type import TensorType +from aesara.tensor.type import TensorTypeMeta @node_rewriter([SparseBlockGemv], inplace=True) @@ -96,7 +95,9 @@ def local_abstractconv_gemm(fgraph, node): if not isinstance(node.op, AbstractConv2d): return None img, kern = node.inputs - if not issubtype(img.type, TensorType) or not issubtype(kern.type, TensorType): + if not isinstance(img.type, TensorTypeMeta) or not isinstance( + kern.type, TensorTypeMeta + ): return None # need to flip the kernel if necessary @@ -124,7 +125,9 @@ def local_abstractconv3d_gemm(fgraph, node): if not isinstance(node.op, AbstractConv3d): return None img, kern = node.inputs - if not issubtype(img.type, TensorType) or not issubtype(kern.type, TensorType): + if not isinstance(img.type, TensorTypeMeta) or not isinstance( + kern.type, TensorTypeMeta + ): return None # need to flip the kernel if necessary @@ -150,7 +153,9 @@ def local_abstractconv_gradweight_gemm(fgraph, node): if not isinstance(node.op, AbstractConv2d_gradWeights): return None img, topgrad, shape = node.inputs - if not issubtype(img.type, TensorType) or not issubtype(topgrad.type, TensorType): + if not isinstance(img.type, TensorTypeMeta) or not isinstance( + topgrad.type, TensorTypeMeta + ): return None rval = CorrMM_gradWeights( @@ -180,7 +185,9 @@ def local_abstractconv3d_gradweight_gemm(fgraph, node): if not isinstance(node.op, AbstractConv3d_gradWeights): return None img, topgrad, shape = node.inputs - if not issubtype(img.type, TensorType) or not issubtype(topgrad.type, TensorType): + if not isinstance(img.type, TensorTypeMeta) or not isinstance( + topgrad.type, TensorTypeMeta + ): return None rval = Corr3dMMGradWeights( @@ -208,7 +215,9 @@ def local_abstractconv_gradinputs_gemm(fgraph, node): if not isinstance(node.op, AbstractConv2d_gradInputs): return None kern, topgrad, shape = node.inputs - if not issubtype(kern.type, TensorType) or not issubtype(topgrad.type, TensorType): + if not isinstance(kern.type, TensorTypeMeta) or not isinstance( + topgrad.type, TensorTypeMeta + ): return None # need to flip the kernel if necessary @@ -236,7 +245,9 @@ def local_abstractconv3d_gradinputs_gemm(fgraph, node): if not isinstance(node.op, AbstractConv3d_gradInputs): return None kern, topgrad, shape = node.inputs - if not issubtype(kern.type, TensorType) or not issubtype(topgrad.type, TensorType): + if not isinstance(kern.type, TensorTypeMeta) or not isinstance( + topgrad.type, TensorTypeMeta + ): return None # need to flip the kernel if necessary @@ -260,7 +271,9 @@ def local_conv2d_cpu(fgraph, node): return None img, kern = node.inputs - if not issubtype(img.type, TensorType) or not issubtype(kern.type, TensorType): + if not isinstance(img.type, TensorTypeMeta) or not isinstance( + kern.type, TensorTypeMeta + ): return None if node.op.border_mode not in ("full", "valid"): return None @@ -295,7 +308,9 @@ def local_conv2d_gradweight_cpu(fgraph, node): img, topgrad, shape = node.inputs - if not issubtype(img.type, TensorType) or not issubtype(topgrad.type, TensorType): + if not isinstance(img.type, TensorTypeMeta) or not isinstance( + topgrad.type, TensorTypeMeta + ): return None if node.op.border_mode not in ("full", "valid"): return None @@ -404,7 +419,9 @@ def local_conv2d_gradinputs_cpu(fgraph, node): kern, topgrad, shape = node.inputs - if not issubtype(kern.type, TensorType) or not issubtype(topgrad.type, TensorType): + if not isinstance(kern.type, TensorTypeMeta) or not isinstance( + topgrad.type, TensorTypeMeta + ): return None if node.op.border_mode not in ("full", "valid"): return None diff --git a/aesara/tensor/random/basic.py b/aesara/tensor/random/basic.py index 1fa6c9a00b..317c2c9ec4 100644 --- a/aesara/tensor/random/basic.py +++ b/aesara/tensor/random/basic.py @@ -5,10 +5,9 @@ import scipy.stats as stats import aesara -from aesara import issubtype from aesara.tensor.basic import as_tensor_variable from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params -from aesara.tensor.random.type import RandomGeneratorType, RandomStateType +from aesara.tensor.random.type import RandomGeneratorTypeMeta, RandomStateTypeMeta from aesara.tensor.random.utils import broadcast_params from aesara.tensor.random.var import ( RandomGeneratorSharedVariable, @@ -1859,8 +1858,8 @@ def __call__(self, low, high=None, size=None, **kwargs): return super().__call__(low, high, size=size, **kwargs) def make_node(self, rng, *args, **kwargs): - if not issubtype( - getattr(rng, "type", None), (RandomStateType, RandomStateSharedVariable) + if not isinstance( + getattr(rng, "type", None), (RandomStateTypeMeta, RandomStateSharedVariable) ): raise TypeError("`randint` is only available for `RandomStateType`s") return super().make_node(rng, *args, **kwargs) @@ -1909,9 +1908,9 @@ def __call__(self, low, high=None, size=None, **kwargs): return super().__call__(low, high, size=size, **kwargs) def make_node(self, rng, *args, **kwargs): - if not issubtype( + if not isinstance( getattr(rng, "type", None), - (RandomGeneratorType, RandomGeneratorSharedVariable), + (RandomGeneratorTypeMeta, RandomGeneratorSharedVariable), ): raise TypeError("`integers` is only available for `RandomGeneratorType`s") return super().make_node(rng, *args, **kwargs) @@ -1939,7 +1938,7 @@ def _supp_shape_from_params(self, *args, **kwargs): def _infer_shape(self, size, dist_params, param_shapes=None): (a, p, _) = dist_params - if issubtype(p.type, aesara.tensor.type_other.NoneTypeT): + if isinstance(p.type, aesara.tensor.type_other.NoneTypeTMeta): shape = super()._infer_shape(size, (a,), param_shapes) else: shape = super()._infer_shape(size, (a, p), param_shapes) diff --git a/aesara/tensor/random/op.py b/aesara/tensor/random/op.py index a07c18e68c..84952142b6 100644 --- a/aesara/tensor/random/op.py +++ b/aesara/tensor/random/op.py @@ -7,7 +7,6 @@ from aesara.configdefaults import config from aesara.graph.basic import Apply, Variable from aesara.graph.op import Op -from aesara.issubtype import issubtype from aesara.misc.safe_asarray import _asarray from aesara.scalar import ScalarVariable from aesara.tensor.basic import ( @@ -17,7 +16,11 @@ get_vector_length, infer_broadcastable, ) -from aesara.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType +from aesara.tensor.random.type import ( + RandomGeneratorType, + RandomStateType, + RandomTypeMeta, +) from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes from aesara.tensor.shape import shape_tuple from aesara.tensor.type import TensorType, all_dtypes @@ -317,7 +320,7 @@ def make_node(self, rng, size, dtype, *dist_params): if rng is None: rng = aesara.shared(np.random.default_rng()) - elif not issubtype(rng.type, RandomType): + elif not isinstance(rng.type, RandomTypeMeta): raise TypeError( "The type of rng should be an instance of either RandomGeneratorType or RandomStateType" ) diff --git a/aesara/tensor/rewriting/basic.py b/aesara/tensor/rewriting/basic.py index 009735b7e0..81f1e2a039 100644 --- a/aesara/tensor/rewriting/basic.py +++ b/aesara/tensor/rewriting/basic.py @@ -18,7 +18,6 @@ node_rewriter, ) from aesara.graph.rewriting.db import RewriteDatabase -from aesara.issubtype import issubtype from aesara.raise_op import Assert, CheckAndRaise, assert_op from aesara.tensor.basic import ( Alloc, @@ -48,7 +47,7 @@ from aesara.tensor.math import eq from aesara.tensor.shape import Shape_i from aesara.tensor.sort import TopKOp -from aesara.tensor.type import DenseTensorType, TensorType +from aesara.tensor.type import DenseTensorTypeMeta, TensorType from aesara.tensor.var import TensorConstant from aesara.utils import NoDuplicateOptWarningFilter @@ -1153,7 +1152,7 @@ def constant_folding(fgraph, node): # TODO: `Type` itself should provide an interface for constructing # instances appropriate for a given constant. # TODO: Add handling for sparse types. - if issubtype(output.type, DenseTensorType): + if isinstance(output.type, DenseTensorTypeMeta): output_type = TensorType.subtype( output.type.dtype, tuple(s == 1 for s in data.shape), diff --git a/aesara/tensor/rewriting/shape.py b/aesara/tensor/rewriting/shape.py index 6f1df9d3f1..b511e53c33 100644 --- a/aesara/tensor/rewriting/shape.py +++ b/aesara/tensor/rewriting/shape.py @@ -18,7 +18,6 @@ node_rewriter, ) from aesara.graph.utils import InconsistencyError, get_variable_trace_string -from aesara.issubtype import issubtype from aesara.tensor.basic import ( MakeVector, as_tensor_variable, @@ -48,7 +47,7 @@ unbroadcast, ) from aesara.tensor.subtensor import Subtensor, get_idx_list -from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes +from aesara.tensor.type import TensorTypeMeta, discrete_dtypes, integer_dtypes from aesara.tensor.type_other import NoneConst @@ -1032,7 +1031,7 @@ def local_Shape_i_of_broadcastable(fgraph, node): shape_arg = node.inputs[0] - if not issubtype(shape_arg.type, TensorType): + if not isinstance(shape_arg.type, TensorTypeMeta): return False if shape_arg.broadcastable[node.op.i]: diff --git a/aesara/tensor/rewriting/subtensor.py b/aesara/tensor/rewriting/subtensor.py index 63060b409d..a1fdd40cca 100644 --- a/aesara/tensor/rewriting/subtensor.py +++ b/aesara/tensor/rewriting/subtensor.py @@ -13,7 +13,6 @@ in2out, node_rewriter, ) -from aesara.issubtype import issubtype from aesara.raise_op import Assert from aesara.tensor.basic import ( Alloc, @@ -80,8 +79,8 @@ inc_subtensor, indices_from_subtensor, ) -from aesara.tensor.type import TensorType -from aesara.tensor.type_other import NoneTypeT, SliceConstant, SliceType +from aesara.tensor.type import TensorTypeMeta +from aesara.tensor.type_other import NoneTypeTMeta, SliceConstant, SliceTypeMeta from aesara.tensor.var import TensorConstant, TensorVariable @@ -166,10 +165,11 @@ def is_full_slice(x): or (isinstance(x, SliceConstant) and x.value == slice(None)) or ( not isinstance(x, SliceConstant) - and issubtype(getattr(x, "type", None), SliceType) + and isinstance(getattr(x, "type", None), SliceTypeMeta) and x.owner is not None and all( - issubtype(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs + isinstance(getattr(i, "type", None), NoneTypeTMeta) + for i in x.owner.inputs ) ) ): @@ -561,7 +561,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): remove_dim = [] node_inputs_idx = 1 for dim, elem in enumerate(idx): - if issubtype(elem, aes.ScalarType): + if isinstance(elem, aes.ScalarTypeMeta): # The idx is a ScalarType, ie a Type. This means the actual index # is contained in node.inputs[1] dim_index = node.inputs[node_inputs_idx] @@ -735,7 +735,7 @@ def local_subtensor_make_vector(fgraph, node): if isinstance(node.op, Subtensor): (idx,) = node.op.idx_list - if issubtype(idx, (aes.ScalarType, TensorType)): + if isinstance(idx, (aes.ScalarTypeMeta, TensorTypeMeta)): old_idx, idx = idx, node.inputs[1] assert idx.type.is_super(old_idx) elif isinstance(node.op, AdvancedSubtensor1): @@ -1603,7 +1603,7 @@ def local_subtensor_shape_constant(fgraph, node): assert idx_val != np.newaxis - if not issubtype(shape_arg.type, TensorType): + if not isinstance(shape_arg.type, TensorTypeMeta): return False shape_parts = shape_arg.type.broadcastable[idx_val] @@ -1637,7 +1637,8 @@ def local_subtensor_SpecifyShape_lift(fgraph, node): indices = get_idx_list(node.inputs, node.op.idx_list) if any( - isinstance(index, slice) or issubtype(getattr(index, "type", None), SliceType) + isinstance(index, slice) + or isinstance(getattr(index, "type", None), SliceTypeMeta) for index in indices ): return False diff --git a/aesara/tensor/shape.py b/aesara/tensor/shape.py index f2e3768799..98b5907a85 100644 --- a/aesara/tensor/shape.py +++ b/aesara/tensor/shape.py @@ -8,7 +8,6 @@ import aesara from aesara.gradient import DisconnectedType from aesara.graph.basic import Apply, Variable -from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.misc.safe_asarray import _asarray @@ -17,7 +16,13 @@ from aesara.tensor import basic as at from aesara.tensor import get_vector_length from aesara.tensor.exceptions import NotScalarConstantError -from aesara.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor +from aesara.tensor.type import ( + DenseTensorTypeMeta, + TensorType, + TensorTypeMeta, + int_dtypes, + tensor, +) from aesara.tensor.type_other import NoneConst from aesara.tensor.var import TensorConstant, TensorVariable @@ -65,7 +70,7 @@ def make_node(self, x): if not isinstance(x, Variable): x = at.as_tensor_variable(x) - if issubtype(x.type, TensorType): + if isinstance(x.type, TensorTypeMeta): out_var = TensorType.subtype("int64", (x.type.ndim,))() else: out_var = aesara.tensor.type.lvector() @@ -145,7 +150,7 @@ def shape(x: Union[np.ndarray, Number, Variable]) -> Variable: x_type = x.type - if issubtype(x_type, TensorType) and all(s is not None for s in x_type.shape): + if isinstance(x_type, TensorTypeMeta) and all(s is not None for s in x_type.shape): res = at.as_tensor_variable(x_type.shape, ndim=1, dtype=np.int64) else: res = _shape(x) @@ -474,7 +479,7 @@ def R_op(self, inputs, eval_points): return self.make_node(eval_points[0], *inputs[1:]).outputs def c_code(self, node, name, i_names, o_names, sub): - if not issubtype(node.inputs[0].type, DenseTensorType): + if not isinstance(node.inputs[0].type, DenseTensorTypeMeta): raise NotImplementedError( f"Specify_shape c_code not implemented for input type {node.inputs[0].type}" ) diff --git a/aesara/tensor/subtensor.py b/aesara/tensor/subtensor.py index 86434fcda5..9a7b19fd62 100644 --- a/aesara/tensor/subtensor.py +++ b/aesara/tensor/subtensor.py @@ -12,9 +12,8 @@ from aesara.gradient import DisconnectedType from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.op import Op -from aesara.graph.type import Type +from aesara.graph.type import NewTypeMeta, Type from aesara.graph.utils import MethodNotDefined -from aesara.issubtype import issubtype from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.misc.safe_asarray import _asarray @@ -32,6 +31,7 @@ from aesara.tensor.shape import Reshape, specify_broadcastable from aesara.tensor.type import ( TensorType, + TensorTypeMeta, bscalar, complex_dtypes, cscalar, @@ -49,7 +49,7 @@ wscalar, zscalar, ) -from aesara.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice +from aesara.tensor.type_other import NoneConst, NoneTypeTMeta, SliceTypeMeta, make_slice _logger = logging.getLogger("aesara.tensor.subtensor") @@ -109,7 +109,7 @@ def indices_from_subtensor( def convert_indices(indices, entry): """Reconstruct ``*Subtensor*`` index input parameter entries.""" - if indices and issubtype(entry, Type): + if indices and isinstance(entry, NewTypeMeta): rval = indices.pop(0) return rval elif isinstance(entry, slice): @@ -164,13 +164,13 @@ def as_index_literal( ------ NotScalarConstantError """ - if idx == np.newaxis or issubtype(getattr(idx, "type", None), NoneTypeT): + if idx == np.newaxis or isinstance(getattr(idx, "type", None), NoneTypeTMeta): return np.newaxis if isinstance(idx, Constant): return idx.data.item() if isinstance(idx, np.ndarray) else idx.data - if issubtype(getattr(idx, "type", None), SliceType): + if isinstance(getattr(idx, "type", None), SliceTypeMeta): idx = slice(*idx.owner.inputs) if isinstance(idx, slice): @@ -399,8 +399,8 @@ def is_basic_idx(idx): integer can indicate advanced indexing. """ - return isinstance(idx, (slice, type(None))) or issubtype( - getattr(idx, "type", None), (SliceType, NoneTypeT) + return isinstance(idx, (slice, type(None))) or isinstance( + getattr(idx, "type", None), (SliceTypeMeta, NoneTypeTMeta) ) @@ -422,7 +422,7 @@ def basic_shape(shape, indices): for idx, n in zip(indices, shape): if isinstance(idx, slice): res_shape += (slice_len(idx, n),) - elif issubtype(getattr(idx, "type", None), SliceType): + elif isinstance(getattr(idx, "type", None), SliceTypeMeta): if idx.owner: idx_inputs = idx.owner.inputs else: @@ -430,7 +430,7 @@ def basic_shape(shape, indices): res_shape += (slice_len(slice(*idx_inputs), n),) elif idx is None: res_shape += (aes.ScalarConstant(aes.int64, 1),) - elif issubtype(getattr(idx, "type", None), NoneTypeT): + elif isinstance(getattr(idx, "type", None), NoneTypeTMeta): res_shape += (aes.ScalarConstant(aes.int64, 1),) else: raise ValueError(f"Invalid index type: {idx}") @@ -454,8 +454,8 @@ def group_indices(indices): for idx in grp_indices: # We "zip" the dimension number to each index, which means we can't # count indices that add new axes - if (idx is not None) and not issubtype( - getattr(idx, "type", None), NoneTypeT + if (idx is not None) and not isinstance( + getattr(idx, "type", None), NoneTypeTMeta ): dim_num += 1 @@ -573,7 +573,7 @@ def index_vars_to_types(entry, slice_ok=True): if isinstance(entry, Variable) and entry.type in scal_types: return entry.type - elif issubtype(entry, Type) and entry in scal_types: + elif isinstance(entry, NewTypeMeta) and entry in scal_types: return entry if ( @@ -582,7 +582,11 @@ def index_vars_to_types(entry, slice_ok=True): and all(entry.type.broadcastable) ): return aes.get_scalar_type(entry.type.dtype) - elif issubtype(entry, Type) and entry in tensor_types and all(entry.broadcastable): + elif ( + isinstance(entry, NewTypeMeta) + and entry in tensor_types + and all(entry.broadcastable) + ): return aes.get_scalar_type(entry.dtype) elif slice_ok and isinstance(entry, slice): a = entry.start @@ -674,7 +678,7 @@ def as_nontensor_scalar(a: Variable) -> aes.ScalarVariable: # Since aes.as_scalar does not know about tensor types (it would # create a circular import) , this method converts either a # TensorVariable or a ScalarVariable to a scalar. - if isinstance(a, Variable) and issubtype(a.type, TensorType): + if isinstance(a, Variable) and isinstance(a.type, TensorTypeMeta): return aesara.tensor.scalar_from_tensor(a) else: return aes.as_scalar(a) @@ -709,7 +713,9 @@ def make_node(self, x, *inputs): if len(idx_list) > x.type.ndim: raise IndexError("too many indices for array") - input_types = get_slice_elements(idx_list, lambda entry: issubtype(entry, Type)) + input_types = get_slice_elements( + idx_list, lambda entry: isinstance(entry, NewTypeMeta) + ) assert len(inputs) == len(input_types) @@ -923,7 +929,7 @@ def init_entry(entry, depth=0): inc_spec_pos(1) if depth == 0: is_slice.append(0) - elif issubtype(entry, Type): + elif isinstance(entry, NewTypeMeta): init_cmds.append( "subtensor_spec[%i] = %s;" % (spec_pos(), inputs[input_pos()]) ) @@ -1122,7 +1128,7 @@ def helper_c_code_cache_version(): return (9,) def c_code(self, node, name, inputs, outputs, sub): # DEBUG - if not issubtype(node.inputs[0].type, TensorType): + if not isinstance(node.inputs[0].type, TensorTypeMeta): raise NotImplementedError() x = inputs[0] @@ -1208,7 +1214,7 @@ def _process(self, idxs, op_inputs, pstate): sidxs = [] getattr(pstate, "precedence", None) for entry in idxs: - if issubtype(entry, aes.ScalarType): + if isinstance(entry, aes.ScalarTypeMeta): with set_precedence(pstate): sidxs.append(pstate.pprinter.process(inputs.pop())) elif isinstance(entry, slice): @@ -1534,7 +1540,9 @@ def make_node(self, x, y, *inputs): if len(idx_list) > x.type.ndim: raise IndexError("too many indices for array") - input_types = get_slice_elements(idx_list, lambda entry: issubtype(entry, Type)) + input_types = get_slice_elements( + idx_list, lambda entry: isinstance(entry, NewTypeMeta) + ) if len(inputs) != len(input_types): raise IndexError( "Not enough inputs to fill in the Subtensor template.", inputs, idx_list @@ -1556,7 +1564,7 @@ def perform(self, node, inputs, out_): indices = list(reversed(inputs[2:])) def _convert(entry): - if issubtype(entry, Type): + if isinstance(entry, NewTypeMeta): return indices.pop() elif isinstance(entry, slice): return slice( @@ -1706,7 +1714,7 @@ def do_type_checking(self, node): """ - if not issubtype(node.inputs[0].type, TensorType): + if not isinstance(node.inputs[0].type, TensorTypeMeta): raise NotImplementedError() def c_code_cache_version(self): @@ -2504,9 +2512,9 @@ def as_index_variable(idx): return NoneConst.clone() if isinstance(idx, slice): return make_slice(idx) - if isinstance(idx, Variable) and issubtype(idx.type, SliceType): + if isinstance(idx, Variable) and isinstance(idx.type, SliceTypeMeta): return idx - if isinstance(idx, Variable) and issubtype(idx.type, NoneTypeT): + if isinstance(idx, Variable) and isinstance(idx.type, NoneTypeTMeta): return idx idx = as_tensor_variable(idx) if idx.type.dtype not in discrete_dtypes: @@ -2599,7 +2607,7 @@ def infer_shape(self, fgraph, node, ishapes): ) # The `ishapes` entries for `SliceType`s will be None, and # we need to give `indexed_result_shape` the actual slices. - if issubtype(getattr(idx, "type", None), SliceType): + if isinstance(getattr(idx, "type", None), SliceTypeMeta): index_shapes[i] = idx res_shape = indexed_result_shape( @@ -2616,7 +2624,7 @@ def perform(self, node, inputs, out_): # indexing, so __getitem__ will not return a copy. # Since no view_map is set, we need to copy the returned value if not any( - issubtype(v.type, TensorType) and v.ndim > 0 for v in node.inputs[1:] + isinstance(v.type, TensorTypeMeta) and v.ndim > 0 for v in node.inputs[1:] ): rval = rval.copy() out[0] = rval diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index bad04d3050..91cfeec54d 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -9,7 +9,6 @@ from aesara.configdefaults import config from aesara.graph.basic import Variable from aesara.graph.type import DataType, Props, ShapeType -from aesara.issubtype import issubtype from aesara.link.c.type import CType, CTypeMeta from aesara.misc.safe_asarray import _asarray from aesara.utils import apply_across_args @@ -51,8 +50,9 @@ class TensorTypeMeta(CTypeMeta): r"""Symbolic `Type` representing `numpy.ndarray`\s.""" - shape: Props[DataType] = None - dtype: Props[ShapeType] = None + shape: Props[ShapeType] = None + dtype: Props[DataType] = None + numpy_dtype: Props[DataType] = None ndim: int @@ -314,7 +314,7 @@ def in_same_class(self, otype): """ if ( - issubtype(otype, TensorType) + isinstance(otype, TensorTypeMeta) and otype.dtype == self.dtype and otype.broadcastable == self.broadcastable ): @@ -624,7 +624,16 @@ class TensorType(CType, metaclass=TensorTypeMeta): pass -class DenseTypeMeta(TensorTypeMeta): +class DenseTensorTypeMetaMeta(type): + def __instancecheck__(self, instance): + if type(instance) == TensorTypeMeta or isinstance( + instance, DenseTensorTypeMetaMeta + ): + return True + return False + + +class DenseTensorTypeMeta(TensorTypeMeta, metaclass=DenseTensorTypeMetaMeta): r"""A `Type` for dense tensors. Instances of this class and `TensorType`\s are considered dense `Type`\s. @@ -632,13 +641,13 @@ class DenseTypeMeta(TensorTypeMeta): def __subclasscheck__(self, subclass): if getattr(subclass, "base_type", None) == TensorType or issubclass( - subclass, DenseTypeMeta + subclass, DenseTensorTypeMeta ): return True return False -class DenseTensorType(TensorType, metaclass=DenseTypeMeta): +class DenseTensorType(TensorType, metaclass=DenseTensorTypeMeta): pass @@ -1111,6 +1120,7 @@ def tensor7(name=None, dtype=None): __all__ = [ "TensorType", + "TensorTypeMeta", "bcol", "bmatrix", "brow", diff --git a/aesara/typed_list/basic.py b/aesara/typed_list/basic.py index 7aba28b445..d03d904d6a 100644 --- a/aesara/typed_list/basic.py +++ b/aesara/typed_list/basic.py @@ -1,16 +1,15 @@ import numpy as np import aesara.tensor as at -from aesara import issubtype from aesara.compile.debugmode import _lessbroken_deepcopy from aesara.configdefaults import config from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.op import Op from aesara.link.c.op import COp from aesara.tensor.type import scalar -from aesara.tensor.type_other import SliceType +from aesara.tensor.type_other import SliceType, SliceTypeMeta from aesara.tensor.var import TensorVariable -from aesara.typed_list.type import TypedListType +from aesara.typed_list.type import TypedListType, TypedListTypeMeta class _typed_list_py_operators: @@ -73,7 +72,7 @@ class GetItem(COp): __props__ = () def make_node(self, x, index): - assert issubtype(x.type, TypedListType) + assert isinstance(x.type, TypedListTypeMeta) if not isinstance(index, Variable): if isinstance(index, slice): index = Constant(SliceType.subtype(), index) @@ -81,7 +80,7 @@ def make_node(self, x, index): else: index = at.constant(index, ndim=0, dtype="int64") return Apply(self, [x, index], [x.ttype()]) - if issubtype(index.type, SliceType): + if isinstance(index.type, SliceTypeMeta): return Apply(self, [x, index], [x.type()]) elif isinstance(index, TensorVariable) and index.ndim == 0: assert index.dtype == "int64" @@ -149,7 +148,7 @@ def __init__(self, inplace=False): self.view_map = {0: [0]} def make_node(self, x, toAppend): - assert issubtype(x.type, TypedListType) + assert isinstance(x.type, TypedListTypeMeta) assert x.ttype == toAppend.type, (x.ttype, toAppend.type) return Apply(self, [x, toAppend], [x.type()]) @@ -232,7 +231,7 @@ def __init__(self, inplace=False): self.view_map = {0: [0]} def make_node(self, x, toAppend): - assert issubtype(x.type, TypedListType) + assert isinstance(x.type, TypedListTypeMeta) assert toAppend.type.is_super(x.type) return Apply(self, [x, toAppend], [x.type()]) @@ -321,7 +320,7 @@ def __init__(self, inplace=False): self.view_map = {0: [0]} def make_node(self, x, index, toInsert): - assert issubtype(x.type, TypedListType) + assert isinstance(x.type, TypedListTypeMeta) assert x.ttype == toInsert.type if not isinstance(index, Variable): index = at.constant(index, ndim=0, dtype="int64") @@ -406,7 +405,7 @@ def __init__(self, inplace=False): self.view_map = {0: [0]} def make_node(self, x, toRemove): - assert issubtype(x.type, TypedListType) + assert isinstance(x.type, TypedListTypeMeta) assert x.ttype == toRemove.type return Apply(self, [x, toRemove], [x.type()]) @@ -463,7 +462,7 @@ def __init__(self, inplace=False): self.view_map = {0: [0]} def make_node(self, x): - assert issubtype(x.type, TypedListType) + assert isinstance(x.type, TypedListTypeMeta) return Apply(self, [x], [x.type()]) def perform(self, node, inp, outputs): @@ -527,7 +526,7 @@ class Index(Op): __props__ = () def make_node(self, x, elem): - assert issubtype(x.type, TypedListType) + assert isinstance(x.type, TypedListTypeMeta) assert x.ttype == elem.type return Apply(self, [x, elem], [scalar()]) @@ -556,7 +555,7 @@ class Count(Op): __props__ = () def make_node(self, x, elem): - assert issubtype(x.type, TypedListType) + assert isinstance(x.type, TypedListTypeMeta) assert x.ttype == elem.type return Apply(self, [x, elem], [scalar()]) @@ -603,7 +602,7 @@ class Length(COp): __props__ = () def make_node(self, x): - assert issubtype(x.type, TypedListType) + assert isinstance(x.type, TypedListTypeMeta) return Apply(self, [x], [scalar(dtype="int64")]) def perform(self, node, x, outputs): diff --git a/aesara/typed_list/type.py b/aesara/typed_list/type.py index 8bae344920..b8b9ea20c4 100644 --- a/aesara/typed_list/type.py +++ b/aesara/typed_list/type.py @@ -1,8 +1,7 @@ from typing import Any -from aesara import issubtype -from aesara.graph.type import Props -from aesara.link.c.type import CType, CTypeMeta, Type +from aesara.graph.type import NewTypeMeta, Props +from aesara.link.c.type import CType, CTypeMeta class TypedListTypeMeta(CTypeMeta): @@ -25,7 +24,7 @@ def type_parameters(cls, ttype, depth=0): if depth < 0: raise ValueError("Please specify a depth superior or" "equal to 0") - if not issubtype(ttype, Type): + if not isinstance(ttype, NewTypeMeta): raise TypeError("Expected an Aesara Type") if depth > 0: @@ -66,7 +65,7 @@ def get_depth(self): Utilitary function to get the 0 based level of the list. """ - if issubtype(self.ttype, TypedListType): + if isinstance(self.ttype, TypedListTypeMeta): return self.ttype.get_depth() + 1 else: return 0 diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index ab4ac7035a..1ac9e32f32 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -8,13 +8,18 @@ from aesara.compile.builders import OpFromGraph from aesara.compile.function import function from aesara.configdefaults import config -from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad +from aesara.gradient import ( + DisconnectedType, + DisconnectedTypeMeta, + Rop, + disconnected_type, + grad, +) from aesara.graph.basic import equal_computations from aesara.graph.fg import FunctionGraph -from aesara.graph.null_type import NullType +from aesara.graph.null_type import NullType, NullTypeMeta from aesara.graph.rewriting.utils import rewrite_graph from aesara.graph.utils import MissingInputError -from aesara.issubtype import issubtype from aesara.printing import debugprint from aesara.tensor.basic import as_tensor from aesara.tensor.math import dot, exp @@ -24,7 +29,7 @@ from aesara.tensor.random.utils import RandomStream from aesara.tensor.rewriting.shape import ShapeOptimizer from aesara.tensor.shape import specify_shape -from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors +from aesara.tensor.type import TensorTypeMeta, matrices, matrix, scalar, vector, vectors from tests import unittest_tools from tests.graph.utils import MyVariable @@ -245,10 +250,10 @@ def go2(inps, gs): disconnected_inputs="ignore", null_gradients="return", ) - assert issubtype(dx2.type, TensorType) + assert isinstance(dx2.type, TensorTypeMeta) assert dx2.ndim == 1 - assert issubtype(dw2.type, NullType) - assert issubtype(db2.type, DisconnectedType) + assert isinstance(dw2.type, NullTypeMeta) + assert isinstance(db2.type, DisconnectedTypeMeta) @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 1fb318c691..c7fd5d5303 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -30,7 +30,6 @@ ) from aesara.graph.op import Op from aesara.graph.type import NewTypeMeta, Props, Type -from aesara.issubtype import issubtype from aesara.tensor.math import max_and_argmax from aesara.tensor.type import ( TensorType, @@ -81,7 +80,7 @@ class MyOp(Op): def make_node(self, *inputs): for input in inputs: assert isinstance(input, Variable) - assert issubtype(input.type, MyType) + assert issubclass(input.type, MyType) outputs = [MyVariable(sum(input.type.thingy for input in inputs))] return Apply(self, list(inputs), outputs) diff --git a/tests/graph/test_destroyhandler.py b/tests/graph/test_destroyhandler.py index 005da2e26f..a15579d4f4 100644 --- a/tests/graph/test_destroyhandler.py +++ b/tests/graph/test_destroyhandler.py @@ -17,7 +17,6 @@ ) from aesara.graph.type import NewTypeMeta, Type from aesara.graph.utils import InconsistencyError -from aesara.issubtype import issubtype from tests.unittest_tools import assertFailure_fast @@ -93,7 +92,7 @@ def make_node(self, *inputs): assert len(inputs) == self.nin inputs = list(map(as_variable, inputs)) for input in inputs: - if not issubtype(input.type, MyType): + if not issubclass(input.type, MyType): raise Exception("Error 1") outputs = [MyVariable(self.name + "_R") for i in range(self.nout)] return Apply(self, inputs, outputs) diff --git a/tests/graph/test_features.py b/tests/graph/test_features.py index 2c2a1da641..610c78d6bf 100644 --- a/tests/graph/test_features.py +++ b/tests/graph/test_features.py @@ -5,7 +5,6 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op from aesara.graph.type import NewTypeMeta, Props, Type -from aesara.issubtype import issubtype from tests.graph.utils import MyVariable, op1 @@ -48,7 +47,7 @@ def as_variable(x): assert len(inputs) == self.nin inputs = list(map(as_variable, inputs)) for input in inputs: - if not issubtype(input.type, MyType): + if not issubclass(input.type, MyType): raise Exception("Error 1") outputs = [MyType.subtype(self.name + "_R")()] return Apply(self, inputs, outputs) diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index aa683be252..a2bb0a2c01 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -12,7 +12,6 @@ from aesara.graph.op import Op from aesara.graph.type import NewTypeMeta, Props, Type from aesara.graph.utils import TestValueError -from aesara.issubtype import issubtype from aesara.link.c.type import Generic from aesara.tensor.math import log from aesara.tensor.type import dmatrix, dscalar, dvector, vector @@ -67,7 +66,7 @@ class MyOp(Op): def make_node(self, *inputs): inputs = list(map(as_variable, inputs)) for input in inputs: - if not issubtype(input.type, MyType): + if not issubclass(input.type, MyType): raise Exception("Error 1") outputs = [MyType.subtype(sum(input.type.thingy for input in inputs))()] return Apply(self, inputs, outputs) diff --git a/tests/graph/test_types.py b/tests/graph/test_types.py index f0b8a1496e..29b186248e 100644 --- a/tests/graph/test_types.py +++ b/tests/graph/test_types.py @@ -4,7 +4,6 @@ from aesara.graph.basic import Variable from aesara.graph.type import NewTypeMeta, Props, Type -from aesara.issubtype import issubtype class MyTypeMeta(NewTypeMeta): @@ -75,4 +74,4 @@ def test_convert_variable(): def test_default_clone(): mt = MyType.subtype(1) - assert issubtype(mt.clone(1), MyType) + assert issubclass(mt.clone(1), MyType) diff --git a/tests/graph/utils.py b/tests/graph/utils.py index aeb6831a73..587eb84bd2 100644 --- a/tests/graph/utils.py +++ b/tests/graph/utils.py @@ -4,7 +4,6 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import HasInnerGraph, Op from aesara.graph.type import NewTypeMeta, Type -from aesara.issubtype import issubtype def is_variable(x): @@ -73,7 +72,7 @@ def __init__(self, name, dmap=None, x=None, n_outs=1): def make_node(self, *inputs): inputs = list(map(is_variable, inputs)) for input in inputs: - if not issubtype(input.type, MyType): + if not issubclass(input.type, MyType): raise Exception("Error 1") outputs = [MyType.subtype()() for i in range(self.n_outs)] return Apply(self, inputs, outputs) @@ -107,7 +106,7 @@ class MyOpCastType2(MyOp): def make_node(self, *inputs): inputs = list(map(is_variable, inputs)) for input in inputs: - if not issubtype(input.type, MyType): + if not issubclass(input.type, MyType): raise Exception("Error 1") outputs = [MyType2.subtype()()] diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index e52bff6ac0..6e2abffddd 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -5,13 +5,13 @@ import tests.unittest_tools as utt from aesara.compile.mode import Mode from aesara.graph.fg import FunctionGraph -from aesara.issubtype import issubtype from aesara.link.c.basic import DualLinker from aesara.scalar.basic import ( ComplexError, Composite, InRange, ScalarType, + ScalarTypeMeta, add, and_, arccos, @@ -493,13 +493,13 @@ def test_mean(mode): def test_shape(): a = float32("a") - assert issubtype(a.type, ScalarType) + assert isinstance(a.type, ScalarTypeMeta) assert a.shape.type.ndim == 1 assert a.shape.type.shape == (0,) assert a.shape.type.dtype == "int64" b = constant(2, name="b") - assert issubtype(b.type, ScalarType) + assert isinstance(b.type, ScalarTypeMeta) assert b.shape.type.ndim == 1 assert b.shape.type.shape == (0,) assert b.shape.type.dtype == "int64" diff --git a/tests/sparse/test_var.py b/tests/sparse/test_var.py index a49073b645..b5c6cae266 100644 --- a/tests/sparse/test_var.py +++ b/tests/sparse/test_var.py @@ -7,9 +7,8 @@ import aesara import aesara.sparse as sparse import aesara.tensor as at -from aesara.issubtype import issubtype -from aesara.sparse.type import SparseTensorType -from aesara.tensor.type import DenseTensorType +from aesara.sparse.type import SparseTensorType, SparseTensorTypeMeta +from aesara.tensor.type import DenseTensorType, DenseTensorTypeMeta class TestSparseVariable: @@ -100,7 +99,8 @@ def test_unary(self, method, exp_type, cm, x): else: z_outs = z - assert all(issubtype(out.type, exp_type) for out in z_outs) + # TODO: Maybe exp_type should already by the Meta class + assert all(isinstance(out.type, type(exp_type)) for out in z_outs) f = aesara.function([x], z, on_unused_input="ignore", allow_input_downcast=True) @@ -156,7 +156,8 @@ def test_binary(self, method, exp_type): else: z_outs = z - assert all(issubtype(out.type, exp_type) for out in z_outs) + # TODO: maybe exp_type should already by the Meta class + assert all(isinstance(out.type, type(exp_type)) for out in z_outs) f = aesara.function([x, y], z) res = f( @@ -178,7 +179,7 @@ def test_reshape(self): with pytest.warns(UserWarning, match=".*converted to dense.*"): z = x.reshape((3, 2)) - assert issubtype(z.type, DenseTensorType) + assert isinstance(z.type, DenseTensorTypeMeta) f = aesara.function([x], z) exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) @@ -191,7 +192,7 @@ def test_dimshuffle(self): with pytest.warns(UserWarning, match=".*converted to dense.*"): z = x.dimshuffle((1, 0)) - assert issubtype(z.type, DenseTensorType) + assert isinstance(z.type, DenseTensorTypeMeta) f = aesara.function([x], z) exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) @@ -202,7 +203,7 @@ def test_getitem(self): x = sparse.csr_from_dense(x) z = x[:, :2] - assert issubtype(z.type, SparseTensorType) + assert isinstance(z.type, SparseTensorTypeMeta) f = aesara.function([x], z) exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) @@ -215,7 +216,7 @@ def test_dot(self): y = sparse.csr_from_dense(y) z = x.__dot__(y) - assert issubtype(z.type, SparseTensorType) + assert isinstance(z.type, SparseTensorTypeMeta) f = aesara.function([x, y], z) exp_res = f( @@ -231,7 +232,7 @@ def test_repeat(self): with pytest.warns(UserWarning, match=".*converted to dense.*"): z = x.repeat(2, axis=1) - assert issubtype(z.type, DenseTensorType) + assert isinstance(z.type, DenseTensorTypeMeta) f = aesara.function([x], z) exp_res = f([[1.1, 0.0, 2.0], [-1.0, 0.0, 0.0]]) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 7322181283..0cde1dda58 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -13,7 +13,6 @@ from aesara.configdefaults import config from aesara.graph.basic import Apply, Variable from aesara.graph.fg import FunctionGraph -from aesara.issubtype import issubtype from aesara.link.basic import PerformLinker from aesara.link.c.basic import CLinker, OpWiseCLinker from aesara.tensor import as_tensor_variable @@ -25,6 +24,7 @@ from aesara.tensor.math import exp from aesara.tensor.type import ( TensorType, + TensorTypeMeta, bmatrix, bscalar, discrete_dtypes, @@ -862,7 +862,7 @@ def test_shape_types(self): (out_shape,) = z.owner.op.infer_shape(None, z.owner, [(lscalar(), 1), (50, 10)]) - assert all(issubtype(v.type, TensorType) for v in out_shape) + assert all(isinstance(v.type, TensorTypeMeta) for v in out_shape) def test_static_shape_unary(self): x = tensor("float64", shape=(None, 0, 1, 5)) diff --git a/tests/tensor/test_merge.py b/tests/tensor/test_merge.py index 3ba424f45c..e331b40631 100644 --- a/tests/tensor/test_merge.py +++ b/tests/tensor/test_merge.py @@ -6,7 +6,6 @@ from aesara.graph.op import Op from aesara.graph.rewriting.basic import MergeOptimizer from aesara.graph.type import NewTypeMeta, Type -from aesara.issubtype import issubtype def is_variable(x): @@ -38,7 +37,7 @@ def __init__(self, name, dmap=None, x=None): def make_node(self, *inputs): inputs = list(map(is_variable, inputs)) for input in inputs: - if not issubtype(input.type, MyType): + if not issubclass(input.type, MyType): raise Exception("Error 1") outputs = [MyType.subtype()()] return Apply(self, inputs, outputs) diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 0bd89197cc..cf7e422dee 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -2,7 +2,7 @@ import pytest import aesara -from aesara import Mode, function, grad, issubtype +from aesara import Mode, function, grad from aesara.compile.ops import DeepCopyOp from aesara.configdefaults import config from aesara.graph.basic import Variable @@ -29,6 +29,7 @@ from aesara.tensor.subtensor import Subtensor from aesara.tensor.type import ( TensorType, + TensorTypeMeta, dmatrix, dtensor4, dvector, @@ -341,7 +342,7 @@ def test_shape_i_hash(): class TestSpecifyShape(utt.InferShapeTester): mode = None - input_type = TensorType + input_type = TensorTypeMeta def test_check_inputs(self): with pytest.raises(TypeError, match="must be integer types"): @@ -445,7 +446,7 @@ def test_bad_shape(self): with pytest.raises(AssertionError, match="SpecifyShape:.*"): f(xval) - assert issubtype( + assert isinstance( [n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0] .inputs[0] .type, @@ -455,7 +456,7 @@ def test_bad_shape(self): x = matrix() xval = np.random.random((2, 3)).astype(config.floatX) f = aesara.function([x], specify_shape(x, 2, 3), mode=self.mode) - assert issubtype( + assert isinstance( [n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0] .inputs[0] .type, diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 72700fae9b..05bc1e2a41 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -9,7 +9,6 @@ import aesara import aesara.scalar as scal import aesara.tensor.basic as at -from aesara import issubtype from aesara.compile import DeepCopyOp, shared from aesara.compile.io import In from aesara.configdefaults import config @@ -2607,9 +2606,9 @@ def test_index_vars_to_types(): index_vars_to_types(1) res = index_vars_to_types(iscalar) - assert issubtype(res, scal.ScalarType) + assert isinstance(res, scal.ScalarTypeMeta) x = scal.constant(1, dtype=np.uint8) - assert issubtype(x.type, scal.ScalarType) + assert isinstance(x.type, scal.ScalarTypeMeta) res = index_vars_to_types(x) assert res == x.type diff --git a/tests/test_raise_op.py b/tests/test_raise_op.py index 93b213b5cf..b43e437c47 100644 --- a/tests/test_raise_op.py +++ b/tests/test_raise_op.py @@ -6,9 +6,8 @@ import aesara.tensor as at from aesara.compile.mode import OPT_FAST_RUN, Mode from aesara.graph.basic import Constant, equal_computations -from aesara.issubtype import issubtype from aesara.raise_op import Assert, CheckAndRaise, assert_op -from aesara.scalar.basic import ScalarType, float64 +from aesara.scalar.basic import ScalarTypeMeta, float64 from aesara.sparse import as_sparse_variable from tests import unittest_tools as utt @@ -117,8 +116,8 @@ def test_perform_CheckAndRaise_scalar(linker): conds = (val > 0, val > 3) y = check_and_raise(val, *conds) - assert all(issubtype(i.type, ScalarType) for i in y.owner.inputs) - assert issubtype(y.type, ScalarType) + assert all(isinstance(i.type, ScalarTypeMeta) for i in y.owner.inputs) + assert isinstance(y.type, ScalarTypeMeta) mode = Mode(linker=linker) y_fn = aesara.function([val], y, mode=mode)