|
| 1 | +"""custom adjoint definitions for MuJoCo Warp autodifferentiation. |
| 2 | +
|
| 3 | +This module centralizes all ``@wp.func_grad`` registrations. It must be |
| 4 | +imported before any tape recording so that custom gradients are registered |
| 5 | +with Warp's AD system. |
| 6 | +
|
| 7 | +Import this module via ``grad.py`` dont import it directly |
| 8 | +""" |
| 9 | + |
| 10 | +import warp as wp |
| 11 | + |
| 12 | +from mujoco_warp._src import math |
| 13 | + |
| 14 | + |
| 15 | +@wp.func_grad(math.quat_integrate) |
| 16 | +def _quat_integrate_grad(q: wp.quat, v: wp.vec3, dt: float, adj_ret: wp.quat): |
| 17 | + """Custom adjoint avoiding gradient singularity at |v|=0.""" |
| 18 | + EPS = float(1e-10) |
| 19 | + norm_v = wp.length(v) |
| 20 | + norm_v_sq = norm_v * norm_v |
| 21 | + half_angle = dt * norm_v * 0.5 |
| 22 | + |
| 23 | + # sinc-safe rotation quaternion construction |
| 24 | + if norm_v > EPS: |
| 25 | + s_over_nv = wp.sin(half_angle) / norm_v # sin(dt|v|/2) / |v| |
| 26 | + c = wp.cos(half_angle) |
| 27 | + # d(s_over_nv)/dv_j = ds_coeff * v_j |
| 28 | + ds_coeff = (c * dt * 0.5 - s_over_nv) / norm_v_sq |
| 29 | + else: |
| 30 | + s_over_nv = dt * 0.5 |
| 31 | + c = 1.0 |
| 32 | + # Taylor limit: (c*dt/2 - s_over_nv) / |v|^2 -> -dt^3/24 |
| 33 | + ds_coeff = -dt * dt * dt / 24.0 |
| 34 | + |
| 35 | + q_rot = wp.quat( |
| 36 | + c, |
| 37 | + s_over_nv * v[0], |
| 38 | + s_over_nv * v[1], |
| 39 | + s_over_nv * v[2], |
| 40 | + ) |
| 41 | + |
| 42 | + # recompute forward intermediates |
| 43 | + q_len = wp.length(q) |
| 44 | + q_inv_len = 1.0 / wp.max(q_len, EPS) |
| 45 | + q_n = wp.quat( |
| 46 | + q[0] * q_inv_len, |
| 47 | + q[1] * q_inv_len, |
| 48 | + q[2] * q_inv_len, |
| 49 | + q[3] * q_inv_len, |
| 50 | + ) |
| 51 | + |
| 52 | + q_res = math.mul_quat(q_n, q_rot) |
| 53 | + res_len = wp.length(q_res) |
| 54 | + res_inv = 1.0 / wp.max(res_len, EPS) |
| 55 | + |
| 56 | + # result = normalize(q_res) |
| 57 | + # adj_q_res_k = adj_ret_k / |q_res| - q_res_k * dot(adj_ret, q_res) / |q_res|^3 |
| 58 | + dot_ar = adj_ret[0] * q_res[0] + adj_ret[1] * q_res[1] + adj_ret[2] * q_res[2] + adj_ret[3] * q_res[3] |
| 59 | + res_inv3 = res_inv * res_inv * res_inv |
| 60 | + adj_qr = wp.quat( |
| 61 | + adj_ret[0] * res_inv - q_res[0] * dot_ar * res_inv3, |
| 62 | + adj_ret[1] * res_inv - q_res[1] * dot_ar * res_inv3, |
| 63 | + adj_ret[2] * res_inv - q_res[2] * dot_ar * res_inv3, |
| 64 | + adj_ret[3] * res_inv - q_res[3] * dot_ar * res_inv3, |
| 65 | + ) |
| 66 | + |
| 67 | + # q_res = mul_quat(q_n, q_rot) |
| 68 | + # adj_q_n = mul_quat(adj_qr, conj(q_rot)) |
| 69 | + # adj_q_rot = mul_quat(conj(q_n), adj_qr) |
| 70 | + q_rot_conj = wp.quat(q_rot[0], -q_rot[1], -q_rot[2], -q_rot[3]) |
| 71 | + adj_qn = math.mul_quat(adj_qr, q_rot_conj) |
| 72 | + |
| 73 | + q_n_conj = wp.quat(q_n[0], -q_n[1], -q_n[2], -q_n[3]) |
| 74 | + adj_q_rot = math.mul_quat(q_n_conj, adj_qr) |
| 75 | + |
| 76 | + # q_rot = (c, s_over_nv * v) |
| 77 | + # d(c)/dv_j = -s_over_nv * dt/2 * v_j |
| 78 | + # d(s_over_nv * v_i)/dv_j = ds_coeff * v_j * v_i + s_over_nv * delta_ij |
| 79 | + sv_dot = adj_q_rot[1] * v[0] + adj_q_rot[2] * v[1] + adj_q_rot[3] * v[2] |
| 80 | + common = -s_over_nv * dt * 0.5 * adj_q_rot[0] + ds_coeff * sv_dot |
| 81 | + adj_v_val = wp.vec3( |
| 82 | + common * v[0] + s_over_nv * adj_q_rot[1], |
| 83 | + common * v[1] + s_over_nv * adj_q_rot[2], |
| 84 | + common * v[2] + s_over_nv * adj_q_rot[3], |
| 85 | + ) |
| 86 | + |
| 87 | + # adj_dt from q_rot dependency on dt |
| 88 | + # d(c)/d(dt) = -sin(half_angle) * norm_v / 2 |
| 89 | + # d(s_over_nv * v_i)/dt = (c / 2) * v_i |
| 90 | + adj_dt_val = adj_q_rot[0] * (-wp.sin(half_angle) * norm_v * 0.5) |
| 91 | + adj_dt_val += sv_dot * c * 0.5 |
| 92 | + |
| 93 | + # q_n = normalize(q) |
| 94 | + # adj_q_k = adj_qn_k / |q| - q_k * dot(adj_qn, q) / |q|^3 |
| 95 | + dot_aqn = adj_qn[0] * q[0] + adj_qn[1] * q[1] + adj_qn[2] * q[2] + adj_qn[3] * q[3] |
| 96 | + q_inv_len3 = q_inv_len * q_inv_len * q_inv_len |
| 97 | + adj_q_val = wp.quat( |
| 98 | + adj_qn[0] * q_inv_len - q[0] * dot_aqn * q_inv_len3, |
| 99 | + adj_qn[1] * q_inv_len - q[1] * dot_aqn * q_inv_len3, |
| 100 | + adj_qn[2] * q_inv_len - q[2] * dot_aqn * q_inv_len3, |
| 101 | + adj_qn[3] * q_inv_len - q[3] * dot_aqn * q_inv_len3, |
| 102 | + ) |
| 103 | + |
| 104 | + # accumulate adjoints |
| 105 | + wp.adjoint[q] += adj_q_val |
| 106 | + wp.adjoint[v] += adj_v_val |
| 107 | + wp.adjoint[dt] += adj_dt_val |
0 commit comments