Computational complexity of a hessian-vector product #32285
Replies: 2 comments 5 replies
-
Could you share a code snippet showing exactly what you're doing? |
Beta Was this translation helpful? Give feedback.
-
Thank you for the quick answer! Here is a minimal example that reproduces the behaviour I am looking at: import jax
import jax.numpy as jnp
def cost_fn(x1:jnp.ndarray, x2:jnp.ndarray, x3:jnp.ndarray)-> float:
# some tensor contractions
matrix = x1
d = 20
for _ in range(d):
matrix = matrix @ x1
return (x2 @ matrix @ x3).real
key = jax.random.PRNGKey(42)
key1, key2, key3, key4, key5, key6 = jax.random.split(key, 6)
# tensor dimensions
n1 = int(2*1e2)
n2 = int(1e2)
n3 = int(1e2)
# Complex matrix and vectors
x1 = jax.random.normal(key1, shape=(n1//2, n1//2)) + 1j * jax.random.normal(key2, shape=(n1//2, n1//2))
x2 = jax.random.normal(key3, shape=(n2,)) + 1j * jax.random.normal(key4, shape=(n2,))
x3 = jax.random.normal(key5, shape=(n3,)) + 1j * jax.random.normal(key6, shape=(n3,))
# Create tangent vectors for JVP calculations
v1 = jax.random.normal(key1, shape=(n1//2, n1//2)) + 1j * jax.random.normal(key2, shape=(n1//2, n1//2))
v2 = jax.random.normal(key3, shape=(n2,)) + 1j * jax.random.normal(key4, shape=(n2,))
v3 = jax.random.normal(key5, shape=(n3,)) + 1j * jax.random.normal(key6, shape=(n3,))
# Define individual functions
f_1 = lambda x: cost_fn(x, x2, x3)
grad_x1 = jax.grad(f_1)
f_2 = lambda x: cost_fn(x1, x, x3)
grad_x2 = jax.grad(f_2)
f_3 = lambda x: cost_fn(x1, x2, x)
grad_x3 = jax.grad(f_3) From trying to reproduce this behaviour, it seems like the difference in the times given by |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hey there!
I have a theoretical question:
I am currently computing the first and second order approximations of a scalar-valued cost function:
$f(x_1, x_2, x_3): \mathrm{C}^{d_1} \times \mathrm{C}^{d_2} \times \mathrm{C}^{d_3} \to \mathrm{R}$ .
Where$d_1 \gg d_2, d_3$
In particular, I am using$\partial^2 g_i(x_i) \cdot [v_i]$ , where $g_i$ is defined as fixing all the other $x_j$ such that $j \neq i$ .
jax.jvp(jax.grad(g_i), (x_i, ), (v_i, ))
to compute the hessian-vector productBased on the automatic differentiation theory (e.g. The Elements of Differentiable Programming. Blondel & Roulet), I would expect to have a complexity that depends linearly on the complexity of the cost function$C$ , as also mentioned in your autodiff cookbook
However, when computing the hessian-vector product for each$x_i$ , I also see a significant dependence on which $x_i$ I am differentiating, as their sizes vary considerably.
Is this in agreement with the theory of automatic differentiation or some detail in the implementation? If it is related to the theory, I would greatly appreciate if you know of a source where this is discussed and can share it.
Thank you in advance!
Beta Was this translation helpful? Give feedback.
All reactions