Skip to content

Commit ee37c24

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 35df4b2 commit ee37c24

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

pytensor/tensor/optimize.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,9 @@ def __init__(
819819
self.fgraph = FunctionGraph([variables, *args], [equations])
820820

821821
if jac:
822-
jac_wrt_x = jacobian(self.fgraph.outputs[0], self.fgraph.inputs[0], vectorize=True)
822+
jac_wrt_x = jacobian(
823+
self.fgraph.outputs[0], self.fgraph.inputs[0], vectorize=True
824+
)
823825
self.fgraph.add_output(atleast_2d(jac_wrt_x))
824826

825827
self.jac = jac
@@ -899,8 +901,14 @@ def L_op(
899901
inner_x, *inner_args = self.fgraph.inputs
900902
inner_fx = self.fgraph.outputs[0]
901903

902-
df_dx = jacobian(inner_fx, inner_x, vectorize=True) if not self.jac else self.fgraph.outputs[1]
903-
df_dtheta_columns = jacobian(inner_fx, inner_args, disconnected_inputs="ignore", vectorize=True)
904+
df_dx = (
905+
jacobian(inner_fx, inner_x, vectorize=True)
906+
if not self.jac
907+
else self.fgraph.outputs[1]
908+
)
909+
df_dtheta_columns = jacobian(
910+
inner_fx, inner_args, disconnected_inputs="ignore", vectorize=True
911+
)
904912

905913
grad_wrt_args = implict_optimization_grads(
906914
df_dx=df_dx,

0 commit comments

Comments
 (0)