Skip to content

Commit f228149

Browse files
maresbricardoV94
authored andcommitted
Fix mypy errors on main
1 parent 7092f55 commit f228149

File tree

8 files changed

+22
-16
lines changed

8 files changed

+22
-16
lines changed

pytensor/link/jax/linker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
class JAXLinker(JITLinker):
1010
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
1111

12+
scalar_shape_inputs: tuple[int, ...]
13+
1214
def __init__(self, *args, **kwargs):
13-
self.scalar_shape_inputs: tuple[int] = () # type: ignore[annotation-unchecked]
15+
self.scalar_shape_inputs = ()
1416
super().__init__(*args, **kwargs)
1517

1618
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,17 +517,17 @@ def make_loop_call(
517517
output_slices = []
518518
for output, output_type, bc in zip(outputs, output_types, output_bc, strict=True):
519519
core_ndim = output_type.ndim - len(bc)
520-
size_type = output.shape.type.element # type: ignore
521-
output_shape = cgutils.unpack_tuple(builder, output.shape) # type: ignore
522-
output_strides = cgutils.unpack_tuple(builder, output.strides) # type: ignore
520+
size_type = output.shape.type.element # pyright: ignore[reportAttributeAccessIssue]
521+
output_shape = cgutils.unpack_tuple(builder, output.shape) # pyright: ignore[reportAttributeAccessIssue]
522+
output_strides = cgutils.unpack_tuple(builder, output.strides) # pyright: ignore[reportAttributeAccessIssue]
523523

524524
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [
525525
zero
526526
] * core_ndim
527527
ptr = cgutils.get_item_pointer2(
528528
context,
529529
builder,
530-
output.data, # type:ignore
530+
output.data,
531531
output_shape,
532532
output_strides,
533533
output_type.layout,

pytensor/npy_2_compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242

4343
if using_numpy_2:
44-
ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version()
44+
ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
4545
else:
4646
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
4747

pytensor/scan/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def safe_new(
109109
except TestValueError:
110110
pass
111111

112-
return nw_x
112+
return type_cast(Variable, nw_x)
113113

114114

115115
class until:

pytensor/tensor/einsum.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -597,10 +597,14 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
597597
# Numpy einsum_path requires arrays even though only the shapes matter
598598
# It's not trivial to duck-type our way around because of internal call to `asanyarray`
599599
*[np.empty(shape) for shape in shapes],
600-
einsum_call=True, # Not part of public API
600+
# einsum_call is not part of public API
601+
einsum_call=True, # type: ignore[arg-type]
601602
optimize="optimal",
602-
) # type: ignore
603-
np_path = tuple(contraction[0] for contraction in contraction_list)
603+
)
604+
np_path: PATH | tuple[tuple[int, ...]] = tuple(
605+
contraction[0] # type: ignore[misc]
606+
for contraction in contraction_list
607+
)
604608

605609
if len(np_path) == 1 and len(np_path[0]) > 2:
606610
# When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
@@ -610,7 +614,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
610614
subscripts, tensor_operands, path
611615
)
612616
else:
613-
path = np_path
617+
path = cast(PATH, np_path)
614618

615619
optimized = True
616620

pytensor/tensor/random/rewriting/numba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ def introduce_explicit_core_shape_rv(fgraph, node):
5353
# ← dirichlet_rv{"(a)->(a)"}.1 [id F]
5454
# └─ ···
5555
"""
56-
op: RandomVariable = node.op # type: ignore[annotation-unchecked]
56+
op: RandomVariable = node.op
5757

5858
next_rng, rv = node.outputs
59-
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
59+
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None)
6060
if shape_feature:
6161
core_shape = [
6262
shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp))

pytensor/tensor/rewriting/blockwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def local_blockwise_alloc(fgraph, node):
102102
This is critical to remove many unnecessary Blockwise, or to reduce the work done by it
103103
"""
104104

105-
op: Blockwise = node.op # type: ignore
105+
op: Blockwise = node.op
106106

107107
batch_ndim = op.batch_ndim(node)
108108
if not batch_ndim:

pytensor/tensor/rewriting/numba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ def introduce_explicit_core_shape_blockwise(fgraph, node):
6565
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6
6666
# └─ ···
6767
"""
68-
op: Blockwise = node.op # type: ignore[annotation-unchecked]
68+
op: Blockwise = node.op
6969
batch_ndim = op.batch_ndim(node)
7070

71-
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
71+
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None)
7272
if shape_feature:
7373
core_shapes = [
7474
[shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)]

0 commit comments

Comments
 (0)