Skip to content

BUG: Example in notebook for optimize no longer works #1586

@Michal-Novomestsky

Description

@Michal-Novomestsky

Describe the issue:

Due to changes in #1582, as part of #1550, the following codeblock(s) in https://pytensor.readthedocs.io/en/latest/gallery/optimize/root.html fail:

Reproducable code example:

import pytensor
import pytensor.tensor as pt

import numpy as np
import matplotlib.pyplot as plt
from pytensor.tensor.special import betaln

n, a, b = pt.scalars('n a b'.split())
w_min, w_max = pt.scalars('w_min w_max'.split())

w_support = pt.linspace(w_min, w_max, n+1)

k = pt.floor(w_support)
ln_n_choose_k = -pt.log(n + 1) - betaln(n - k + 1, k + 1)
q_probs = pt.exp(ln_n_choose_k + betaln(k + a, n - k + b) - betaln(a, b))

dist_args = [n, a, b, w_min, w_max]
f = pytensor.function(dist_args, [w_support, q_probs])

dist_params = {'n':50, 'a':200, 'b':100, 'w_min':10, 'w_max':60}

c = pt.dscalar('c') # Unemployment benefit
β = pt.dscalar('β') # Discount rate

# initial value function guess
v0 = pt.dvector('v0') 

# Fixed-point operator
T = pt.maximum(w_support / (1 - β), c + β * pt.dot(v0, q_probs))

v_star, success = pt.optimize.root(equations=T - v0,
                                   variables=v0,
                                   method='hybr')

fn = pytensor.function([v0, c, β, *dist_args],
                       [w_support, v_star, success])

c_value = 25
beta_value = 0.99
v0_value = np.zeros(dist_params['n'] + 1)

w_values, v_star_val, success_flag = fn(v0_value, c_value, beta_value, **dist_params)

w_bar = (1 - β) * (c + β * pt.dot(v_star, q_probs))

# We want to study the impact of change in unemployment and patience on the reserve wage 
w_grads = pt.grad(w_bar, [c, β])

Error message:

TypeError                                 Traceback (most recent call last)
Cell In[1], line 47
     44 w_bar = (1 - β) * (c + β * pt.dot(v_star, q_probs))
     46 # We want to study the impact of change in unemployment and patience on the reserve wage 
---> 47 w_grads = pt.grad(w_bar, [c, β])

File ~/git/pytensor/pytensor/gradient.py:747, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    744     if hasattr(g.type, "dtype"):
    745         assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 747 _rval: Sequence[Variable] = _populate_grad_dict(
    748     var_to_app_to_idx, grad_dict, _wrt, cost_name
    749 )
    751 rval: MutableSequence[Variable | None] = list(_rval)
    753 for i in range(len(_rval)):

File ~/git/pytensor/pytensor/gradient.py:1541, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1538     # end if cache miss
   1539     return grad_dict[var]
-> 1541 rval = [access_grad_cache(elem) for elem in wrt]
   1543 return rval

File ~/git/pytensor/pytensor/gradient.py:1496, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1494 for node in node_to_idx:
   1495     for idx in node_to_idx[node]:
-> 1496         term = access_term_cache(node)[idx]
   1498         if not isinstance(term, Variable):
   1499             raise TypeError(
   1500                 f"{node.op}.grad returned {type(term)}, expected"
   1501                 " Variable instance."
   1502             )

File ~/git/pytensor/pytensor/gradient.py:1326, in _populate_grad_dict.<locals>.access_term_cache(node)
   1318         if o_shape != g_shape:
   1319             raise ValueError(
   1320                 "Got a gradient of shape "
   1321                 + str(o_shape)
   1322                 + " on an output of shape "
   1323                 + str(g_shape)
   1324             )
-> 1326 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1328 if input_grads is None:
   1329     raise TypeError(
   1330         f"{node.op}.grad returned NoneType, expected iterable."
   1331     )

File ~/git/pytensor/pytensor/tensor/optimize.py:913, in RootOp.L_op(self, inputs, outputs, output_grads)
    904 df_dx = (
    905     jacobian(inner_fx, inner_x, vectorize=True)
    906     if not self.jac
    907     else self.fgraph.outputs[1]
    908 )
    909 df_dtheta_columns = jacobian(
    910     inner_fx, inner_args, disconnected_inputs="ignore", vectorize=True
    911 )
--> 913 grad_wrt_args = implict_optimization_grads(
    914     df_dx=df_dx,
    915     df_dtheta_columns=df_dtheta_columns,
    916     args=args,
    917     x_star=x_star,
    918     output_grad=output_grad,
    919     fgraph=self.fgraph,
    920 )
    922 return [zeros_like(x), *grad_wrt_args]

File ~/git/pytensor/pytensor/tensor/optimize.py:333, in implict_optimization_grads(df_dx, df_dtheta_columns, args, x_star, output_grad, fgraph)
    290 r"""
    291 Compute gradients of an optimization problem with respect to its parameters.
    292 
   (...)    329     The function graph that contains the inputs and outputs of the optimization problem.
    330 """
    331 df_dx = cast(TensorVariable, df_dx)
--> 333 df_dtheta = concatenate(
    334     [
    335         atleast_2d(jac_col, left=False)
    336         for jac_col in cast(list[TensorVariable], df_dtheta_columns)
    337     ],
    338     axis=-1,
    339 )
    341 replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
    343 df_dx_star, df_dtheta_star = cast(
    344     list[TensorVariable],
    345     graph_replace([atleast_2d(df_dx), df_dtheta], replace=replace),
    346 )

File ~/git/pytensor/pytensor/tensor/basic.py:2998, in concatenate(tensor_list, axis)
   2991 if not isinstance(tensor_list, tuple | list):
   2992     raise TypeError(
   2993         "The 'tensors' argument must be either a tuple "
   2994         "or a list, make sure you did not forget () or [] around "
   2995         "arguments of concatenate.",
   2996         tensor_list,
   2997     )
-> 2998 return join(axis, *tensor_list)

File ~/git/pytensor/pytensor/tensor/basic.py:2812, in join(axis, *tensors_list)
   2810     return tensors_list[0]
   2811 else:
-> 2812     return _join(axis, *tensors_list)

File ~/git/pytensor/pytensor/graph/op.py:293, in Op.__call__(self, name, return_list, *inputs, **kwargs)
    249 def __call__(
    250     self, *inputs: Any, name=None, return_list=False, **kwargs
    251 ) -> Variable | list[Variable]:
    252     r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    253 
    254     This method is just a wrapper around :meth:`Op.make_node`.
   (...)    291 
    292     """
--> 293     node = self.make_node(*inputs, **kwargs)
    294     if name is not None:
    295         if len(node.outputs) == 1:

File ~/git/pytensor/pytensor/tensor/basic.py:2505, in Join.make_node(self, axis, *tensors)
   2502 ndim = tensors[0].type.ndim
   2504 if not builtins.all(x.ndim == ndim for x in tensors):
-> 2505     raise TypeError(
   2506         "Only tensors with the same number of dimensions can be joined. "
   2507         f"Input ndims were: {[x.ndim for x in tensors]}"
   2508     )
   2510 try:
   2511     static_axis = int(get_scalar_constant_value(axis))

TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [3, 2, 2, 2]

PyTensor version information:

Pytensor branch: #1582
PyMC: Main branch (as of writing)

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions