-
Notifications
You must be signed in to change notification settings - Fork 140
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the issue:
Due to changes in #1582, as part of #1550, the following codeblock(s) in https://pytensor.readthedocs.io/en/latest/gallery/optimize/root.html fail:
Reproducable code example:
import pytensor
import pytensor.tensor as pt
import numpy as np
import matplotlib.pyplot as plt
from pytensor.tensor.special import betaln
n, a, b = pt.scalars('n a b'.split())
w_min, w_max = pt.scalars('w_min w_max'.split())
w_support = pt.linspace(w_min, w_max, n+1)
k = pt.floor(w_support)
ln_n_choose_k = -pt.log(n + 1) - betaln(n - k + 1, k + 1)
q_probs = pt.exp(ln_n_choose_k + betaln(k + a, n - k + b) - betaln(a, b))
dist_args = [n, a, b, w_min, w_max]
f = pytensor.function(dist_args, [w_support, q_probs])
dist_params = {'n':50, 'a':200, 'b':100, 'w_min':10, 'w_max':60}
c = pt.dscalar('c') # Unemployment benefit
β = pt.dscalar('β') # Discount rate
# initial value function guess
v0 = pt.dvector('v0')
# Fixed-point operator
T = pt.maximum(w_support / (1 - β), c + β * pt.dot(v0, q_probs))
v_star, success = pt.optimize.root(equations=T - v0,
variables=v0,
method='hybr')
fn = pytensor.function([v0, c, β, *dist_args],
[w_support, v_star, success])
c_value = 25
beta_value = 0.99
v0_value = np.zeros(dist_params['n'] + 1)
w_values, v_star_val, success_flag = fn(v0_value, c_value, beta_value, **dist_params)
w_bar = (1 - β) * (c + β * pt.dot(v_star, q_probs))
# We want to study the impact of change in unemployment and patience on the reserve wage
w_grads = pt.grad(w_bar, [c, β])
Error message:
TypeError Traceback (most recent call last)
Cell In[1], line 47
44 w_bar = (1 - β) * (c + β * pt.dot(v_star, q_probs))
46 # We want to study the impact of change in unemployment and patience on the reserve wage
---> 47 w_grads = pt.grad(w_bar, [c, β])
File ~/git/pytensor/pytensor/gradient.py:747, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
744 if hasattr(g.type, "dtype"):
745 assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 747 _rval: Sequence[Variable] = _populate_grad_dict(
748 var_to_app_to_idx, grad_dict, _wrt, cost_name
749 )
751 rval: MutableSequence[Variable | None] = list(_rval)
753 for i in range(len(_rval)):
File ~/git/pytensor/pytensor/gradient.py:1541, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
1538 # end if cache miss
1539 return grad_dict[var]
-> 1541 rval = [access_grad_cache(elem) for elem in wrt]
1543 return rval
File ~/git/pytensor/pytensor/gradient.py:1496, in _populate_grad_dict.<locals>.access_grad_cache(var)
1494 for node in node_to_idx:
1495 for idx in node_to_idx[node]:
-> 1496 term = access_term_cache(node)[idx]
1498 if not isinstance(term, Variable):
1499 raise TypeError(
1500 f"{node.op}.grad returned {type(term)}, expected"
1501 " Variable instance."
1502 )
File ~/git/pytensor/pytensor/gradient.py:1326, in _populate_grad_dict.<locals>.access_term_cache(node)
1318 if o_shape != g_shape:
1319 raise ValueError(
1320 "Got a gradient of shape "
1321 + str(o_shape)
1322 + " on an output of shape "
1323 + str(g_shape)
1324 )
-> 1326 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
1328 if input_grads is None:
1329 raise TypeError(
1330 f"{node.op}.grad returned NoneType, expected iterable."
1331 )
File ~/git/pytensor/pytensor/tensor/optimize.py:913, in RootOp.L_op(self, inputs, outputs, output_grads)
904 df_dx = (
905 jacobian(inner_fx, inner_x, vectorize=True)
906 if not self.jac
907 else self.fgraph.outputs[1]
908 )
909 df_dtheta_columns = jacobian(
910 inner_fx, inner_args, disconnected_inputs="ignore", vectorize=True
911 )
--> 913 grad_wrt_args = implict_optimization_grads(
914 df_dx=df_dx,
915 df_dtheta_columns=df_dtheta_columns,
916 args=args,
917 x_star=x_star,
918 output_grad=output_grad,
919 fgraph=self.fgraph,
920 )
922 return [zeros_like(x), *grad_wrt_args]
File ~/git/pytensor/pytensor/tensor/optimize.py:333, in implict_optimization_grads(df_dx, df_dtheta_columns, args, x_star, output_grad, fgraph)
290 r"""
291 Compute gradients of an optimization problem with respect to its parameters.
292
(...) 329 The function graph that contains the inputs and outputs of the optimization problem.
330 """
331 df_dx = cast(TensorVariable, df_dx)
--> 333 df_dtheta = concatenate(
334 [
335 atleast_2d(jac_col, left=False)
336 for jac_col in cast(list[TensorVariable], df_dtheta_columns)
337 ],
338 axis=-1,
339 )
341 replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
343 df_dx_star, df_dtheta_star = cast(
344 list[TensorVariable],
345 graph_replace([atleast_2d(df_dx), df_dtheta], replace=replace),
346 )
File ~/git/pytensor/pytensor/tensor/basic.py:2998, in concatenate(tensor_list, axis)
2991 if not isinstance(tensor_list, tuple | list):
2992 raise TypeError(
2993 "The 'tensors' argument must be either a tuple "
2994 "or a list, make sure you did not forget () or [] around "
2995 "arguments of concatenate.",
2996 tensor_list,
2997 )
-> 2998 return join(axis, *tensor_list)
File ~/git/pytensor/pytensor/tensor/basic.py:2812, in join(axis, *tensors_list)
2810 return tensors_list[0]
2811 else:
-> 2812 return _join(axis, *tensors_list)
File ~/git/pytensor/pytensor/graph/op.py:293, in Op.__call__(self, name, return_list, *inputs, **kwargs)
249 def __call__(
250 self, *inputs: Any, name=None, return_list=False, **kwargs
251 ) -> Variable | list[Variable]:
252 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
253
254 This method is just a wrapper around :meth:`Op.make_node`.
(...) 291
292 """
--> 293 node = self.make_node(*inputs, **kwargs)
294 if name is not None:
295 if len(node.outputs) == 1:
File ~/git/pytensor/pytensor/tensor/basic.py:2505, in Join.make_node(self, axis, *tensors)
2502 ndim = tensors[0].type.ndim
2504 if not builtins.all(x.ndim == ndim for x in tensors):
-> 2505 raise TypeError(
2506 "Only tensors with the same number of dimensions can be joined. "
2507 f"Input ndims were: {[x.ndim for x in tensors]}"
2508 )
2510 try:
2511 static_axis = int(get_scalar_constant_value(axis))
TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [3, 2, 2, 2]
PyTensor version information:
Pytensor branch: #1582
PyMC: Main branch (as of writing)
Context for the issue:
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working