Skip to content

Commit eb256a4

Browse files
committed
add reverse-mode autodiff for smooth dynamics
1 parent ada89af commit eb256a4

File tree

8 files changed

+1087
-91
lines changed

8 files changed

+1087
-91
lines changed

mujoco_warp/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@
4646
from mujoco_warp._src.forward import rungekutta4 as rungekutta4
4747
from mujoco_warp._src.forward import step1 as step1
4848
from mujoco_warp._src.forward import step2 as step2
49+
from mujoco_warp._src.grad import SMOOTH_GRAD_FIELDS as SMOOTH_GRAD_FIELDS
50+
from mujoco_warp._src.grad import diff_forward as diff_forward
51+
from mujoco_warp._src.grad import diff_step as diff_step
52+
from mujoco_warp._src.grad import disable_grad as disable_grad
53+
from mujoco_warp._src.grad import enable_grad as enable_grad
54+
from mujoco_warp._src.grad import make_diff_data as make_diff_data
4955
from mujoco_warp._src.inverse import inverse as inverse
5056
from mujoco_warp._src.io import create_render_context as create_render_context
5157
from mujoco_warp._src.io import get_data_into as get_data_into

mujoco_warp/_src/adjoint.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""custom adjoint definitions for MuJoCo Warp autodifferentiation.
2+
3+
This module centralizes all ``@wp.func_grad`` registrations. It must be
4+
imported before any tape recording so that custom gradients are registered
5+
with Warp's AD system.
6+
7+
Import this module via ``grad.py`` dont import it directly
8+
"""
9+
10+
import warp as wp
11+
12+
from mujoco_warp._src import math
13+
14+
15+
@wp.func_grad(math.quat_integrate)
16+
def _quat_integrate_grad(q: wp.quat, v: wp.vec3, dt: float, adj_ret: wp.quat):
17+
"""Custom adjoint avoiding gradient singularity at |v|=0."""
18+
EPS = float(1e-10)
19+
norm_v = wp.length(v)
20+
norm_v_sq = norm_v * norm_v
21+
half_angle = dt * norm_v * 0.5
22+
23+
# sinc-safe rotation quaternion construction
24+
if norm_v > EPS:
25+
s_over_nv = wp.sin(half_angle) / norm_v # sin(dt|v|/2) / |v|
26+
c = wp.cos(half_angle)
27+
# d(s_over_nv)/dv_j = ds_coeff * v_j
28+
ds_coeff = (c * dt * 0.5 - s_over_nv) / norm_v_sq
29+
else:
30+
s_over_nv = dt * 0.5
31+
c = 1.0
32+
# Taylor limit: (c*dt/2 - s_over_nv) / |v|^2 -> -dt^3/24
33+
ds_coeff = -dt * dt * dt / 24.0
34+
35+
q_rot = wp.quat(
36+
c,
37+
s_over_nv * v[0],
38+
s_over_nv * v[1],
39+
s_over_nv * v[2],
40+
)
41+
42+
# recompute forward intermediates
43+
q_len = wp.length(q)
44+
q_inv_len = 1.0 / wp.max(q_len, EPS)
45+
q_n = wp.quat(
46+
q[0] * q_inv_len,
47+
q[1] * q_inv_len,
48+
q[2] * q_inv_len,
49+
q[3] * q_inv_len,
50+
)
51+
52+
q_res = math.mul_quat(q_n, q_rot)
53+
res_len = wp.length(q_res)
54+
res_inv = 1.0 / wp.max(res_len, EPS)
55+
56+
# result = normalize(q_res)
57+
# adj_q_res_k = adj_ret_k / |q_res| - q_res_k * dot(adj_ret, q_res) / |q_res|^3
58+
dot_ar = adj_ret[0] * q_res[0] + adj_ret[1] * q_res[1] + adj_ret[2] * q_res[2] + adj_ret[3] * q_res[3]
59+
res_inv3 = res_inv * res_inv * res_inv
60+
adj_qr = wp.quat(
61+
adj_ret[0] * res_inv - q_res[0] * dot_ar * res_inv3,
62+
adj_ret[1] * res_inv - q_res[1] * dot_ar * res_inv3,
63+
adj_ret[2] * res_inv - q_res[2] * dot_ar * res_inv3,
64+
adj_ret[3] * res_inv - q_res[3] * dot_ar * res_inv3,
65+
)
66+
67+
# q_res = mul_quat(q_n, q_rot)
68+
# adj_q_n = mul_quat(adj_qr, conj(q_rot))
69+
# adj_q_rot = mul_quat(conj(q_n), adj_qr)
70+
q_rot_conj = wp.quat(q_rot[0], -q_rot[1], -q_rot[2], -q_rot[3])
71+
adj_qn = math.mul_quat(adj_qr, q_rot_conj)
72+
73+
q_n_conj = wp.quat(q_n[0], -q_n[1], -q_n[2], -q_n[3])
74+
adj_q_rot = math.mul_quat(q_n_conj, adj_qr)
75+
76+
# q_rot = (c, s_over_nv * v)
77+
# d(c)/dv_j = -s_over_nv * dt/2 * v_j
78+
# d(s_over_nv * v_i)/dv_j = ds_coeff * v_j * v_i + s_over_nv * delta_ij
79+
sv_dot = adj_q_rot[1] * v[0] + adj_q_rot[2] * v[1] + adj_q_rot[3] * v[2]
80+
common = -s_over_nv * dt * 0.5 * adj_q_rot[0] + ds_coeff * sv_dot
81+
adj_v_val = wp.vec3(
82+
common * v[0] + s_over_nv * adj_q_rot[1],
83+
common * v[1] + s_over_nv * adj_q_rot[2],
84+
common * v[2] + s_over_nv * adj_q_rot[3],
85+
)
86+
87+
# adj_dt from q_rot dependency on dt
88+
# d(c)/d(dt) = -sin(half_angle) * norm_v / 2
89+
# d(s_over_nv * v_i)/dt = (c / 2) * v_i
90+
adj_dt_val = adj_q_rot[0] * (-wp.sin(half_angle) * norm_v * 0.5)
91+
adj_dt_val += sv_dot * c * 0.5
92+
93+
# q_n = normalize(q)
94+
# adj_q_k = adj_qn_k / |q| - q_k * dot(adj_qn, q) / |q|^3
95+
dot_aqn = adj_qn[0] * q[0] + adj_qn[1] * q[1] + adj_qn[2] * q[2] + adj_qn[3] * q[3]
96+
q_inv_len3 = q_inv_len * q_inv_len * q_inv_len
97+
adj_q_val = wp.quat(
98+
adj_qn[0] * q_inv_len - q[0] * dot_aqn * q_inv_len3,
99+
adj_qn[1] * q_inv_len - q[1] * dot_aqn * q_inv_len3,
100+
adj_qn[2] * q_inv_len - q[2] * dot_aqn * q_inv_len3,
101+
adj_qn[3] * q_inv_len - q[3] * dot_aqn * q_inv_len3,
102+
)
103+
104+
# accumulate adjoints
105+
wp.adjoint[q] += adj_q_val
106+
wp.adjoint[v] += adj_v_val
107+
wp.adjoint[dt] += adj_dt_val

mujoco_warp/_src/derivative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from mujoco_warp._src.types import vec10f
2525
from mujoco_warp._src.warp_util import event_scope
2626

27-
wp.set_module_options({"enable_backward": False})
27+
wp.set_module_options({"enable_backward": True})
2828

2929

3030
@wp.kernel

mujoco_warp/_src/forward.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from mujoco_warp._src.warp_util import cache_kernel
4545
from mujoco_warp._src.warp_util import event_scope
4646

47-
wp.set_module_options({"enable_backward": False})
47+
wp.set_module_options({"enable_backward": True})
4848

4949

5050
@wp.kernel
@@ -235,6 +235,12 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
235235
"""Advance state and time given activation derivatives and acceleration."""
236236
# TODO(team): can we assume static timesteps?
237237

238+
# Clone arrays used as both input and output so that Warp's tape retains the
239+
# original values for correct reverse-mode AD.
240+
act_in = wp.clone(d.act)
241+
qvel_prev = wp.clone(d.qvel)
242+
qpos_prev = wp.clone(d.qpos)
243+
238244
# advance activations
239245
wp.launch(
240246
_next_activation,
@@ -247,7 +253,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
247253
m.actuator_actlimited,
248254
m.actuator_dynprm,
249255
m.actuator_actrange,
250-
d.act,
256+
act_in,
251257
d.act_dot,
252258
1.0,
253259
True,
@@ -258,7 +264,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
258264
wp.launch(
259265
_next_velocity,
260266
dim=(d.nworld, m.nv),
261-
inputs=[m.opt.timestep, d.qvel, qacc, 1.0],
267+
inputs=[m.opt.timestep, qvel_prev, qacc, 1.0],
262268
outputs=[d.qvel],
263269
)
264270

@@ -268,7 +274,7 @@ def _advance(m: Model, d: Data, qacc: wp.array, qvel: Optional[wp.array] = None)
268274
wp.launch(
269275
_next_position,
270276
dim=(d.nworld, m.njnt),
271-
inputs=[m.opt.timestep, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, d.qpos, qvel_in, 1.0],
277+
inputs=[m.opt.timestep, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, qpos_prev, qvel_in, 1.0],
272278
outputs=[d.qpos],
273279
)
274280

@@ -782,9 +788,9 @@ def _tendon_actuator_force_clamp(
782788
actfrcrange = tendon_actfrcrange[worldid % tendon_actfrcrange.shape[0], tenid]
783789

784790
if ten_actfrc < actfrcrange[0]:
785-
actuator_force_out[worldid, actid] *= actfrcrange[0] / ten_actfrc
791+
actuator_force_out[worldid, actid] = actuator_force_out[worldid, actid] * (actfrcrange[0] / ten_actfrc)
786792
elif ten_actfrc > actfrcrange[1]:
787-
actuator_force_out[worldid, actid] *= actfrcrange[1] / ten_actfrc
793+
actuator_force_out[worldid, actid] = actuator_force_out[worldid, actid] * (actfrcrange[1] / ten_actfrc)
788794

789795

790796
@wp.kernel
@@ -919,6 +925,8 @@ def fwd_actuation(m: Model, d: Data):
919925
],
920926
outputs=[d.qfrc_actuator],
921927
)
928+
# clone to break input/output aliasing for correct AD
929+
qfrc_actuator_in = wp.clone(d.qfrc_actuator)
922930
wp.launch(
923931
_qfrc_actuator_gravcomp_limits,
924932
dim=(d.nworld, m.nv),
@@ -929,7 +937,7 @@ def fwd_actuation(m: Model, d: Data):
929937
m.jnt_actfrcrange,
930938
m.dof_jntid,
931939
d.qfrc_gravcomp,
932-
d.qfrc_actuator,
940+
qfrc_actuator_in,
933941
],
934942
outputs=[d.qfrc_actuator],
935943
)

mujoco_warp/_src/grad.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""Autodifferentiation coordination for MuJoCo Warp.
2+
3+
This module provides utilities for enabling Warp's tape-based reverse-mode
4+
automatic differentiation through the MuJoCo Warp physics pipeline.
5+
6+
Usage::
7+
8+
import mujoco_warp as mjw
9+
10+
d = mjw.make_diff_data(mjm) # Data with gradient tracking
11+
tape = wp.Tape()
12+
with tape:
13+
mjw.step(m, d)
14+
wp.launch(loss_kernel, dim=1, inputs=[d.xpos, target, loss])
15+
tape.backward(loss=loss)
16+
grad_ctrl = d.ctrl.grad
17+
"""
18+
19+
from typing import Callable, Optional, Sequence
20+
21+
import warp as wp
22+
23+
from mujoco_warp._src import adjoint as _adjoint # noqa: F401 (register custom adjoints)
24+
from mujoco_warp._src import io
25+
from mujoco_warp._src.forward import forward
26+
from mujoco_warp._src.forward import step
27+
from mujoco_warp._src.types import Data
28+
from mujoco_warp._src.types import Model
29+
30+
SMOOTH_GRAD_FIELDS: tuple = (
31+
# primary state, user-controlled inputs
32+
"qpos",
33+
"qvel",
34+
"ctrl",
35+
"act",
36+
"mocap_pos",
37+
"mocap_quat",
38+
"xfrc_applied",
39+
"qfrc_applied",
40+
# position-dependent outputs
41+
"xpos",
42+
"xquat",
43+
"xmat",
44+
"xipos",
45+
"ximat",
46+
"xanchor",
47+
"xaxis",
48+
"geom_xpos",
49+
"geom_xmat",
50+
"site_xpos",
51+
"site_xmat",
52+
"subtree_com",
53+
"cinert",
54+
"crb",
55+
"cdof",
56+
# Velocity-dependent outputs
57+
"cdof_dot",
58+
"cvel",
59+
"subtree_linvel",
60+
"subtree_angmom",
61+
"actuator_velocity",
62+
"ten_velocity",
63+
# body-level intermediate quantities
64+
"cacc",
65+
"cfrc_int",
66+
"cfrc_ext",
67+
# force/acceleration outputs
68+
"qfrc_bias",
69+
"qfrc_spring",
70+
"qfrc_damper",
71+
"qfrc_gravcomp",
72+
"qfrc_fluid",
73+
"qfrc_passive",
74+
"qfrc_actuator",
75+
"qfrc_smooth",
76+
"qacc",
77+
"qacc_smooth",
78+
"actuator_force",
79+
"act_dot",
80+
# inertia matrix
81+
"qM",
82+
"qLD",
83+
"qLDiagInv",
84+
# Tendon
85+
"ten_J",
86+
"ten_length",
87+
# actuator
88+
"actuator_length",
89+
"actuator_moment",
90+
# sensor
91+
"sensordata",
92+
)
93+
94+
95+
def enable_grad(d: Data, fields: Optional[Sequence[str]] = None) -> None:
96+
"""Enables gradient tracking on Data arrays."""
97+
if fields is None:
98+
fields = SMOOTH_GRAD_FIELDS
99+
for name in fields:
100+
arr = getattr(d, name, None)
101+
if arr is not None and isinstance(arr, wp.array):
102+
arr.requires_grad = True
103+
104+
105+
def disable_grad(d: Data) -> None:
106+
"""Disables gradient tracking on all Data arrays."""
107+
for name in SMOOTH_GRAD_FIELDS:
108+
arr = getattr(d, name, None)
109+
if arr is not None and isinstance(arr, wp.array):
110+
arr.requires_grad = False
111+
112+
113+
def make_diff_data(
114+
mjm,
115+
nworld: int = 1,
116+
grad_fields: Optional[Sequence[str]] = None,
117+
**kwargs,
118+
) -> Data:
119+
"""Creates a Data object with gradient tracking enabled."""
120+
d = io.make_data(mjm, nworld=nworld, **kwargs)
121+
enable_grad(d, fields=grad_fields)
122+
return d
123+
124+
125+
def diff_step(
126+
m: Model,
127+
d: Data,
128+
loss_fn: Callable[[Model, Data], wp.array],
129+
) -> wp.Tape:
130+
"""Runs a differentiable physics step."""
131+
tape = wp.Tape()
132+
with tape:
133+
step(m, d)
134+
loss = loss_fn(m, d)
135+
tape.backward(loss=loss)
136+
return tape
137+
138+
139+
def diff_forward(
140+
m: Model,
141+
d: Data,
142+
loss_fn: Callable[[Model, Data], wp.array],
143+
) -> wp.Tape:
144+
"""Runs differentiable forward dynamics (no integration)."""
145+
tape = wp.Tape()
146+
with tape:
147+
forward(m, d)
148+
loss = loss_fn(m, d)
149+
tape.backward(loss=loss)
150+
return tape

0 commit comments

Comments
 (0)