diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py
index 87ef6c92a..b5a94ee57 100644
--- a/mujoco_warp/__init__.py
+++ b/mujoco_warp/__init__.py
@@ -28,6 +28,7 @@
from ._src.collision_sdf import sdf_narrowphase as sdf_narrowphase
from ._src.constraint import make_constraint as make_constraint
from ._src.derivative import deriv_smooth_vel as deriv_smooth_vel
+from ._src.derivative import transition_fd as transition_fd
from ._src.forward import euler as euler
from ._src.forward import forward as forward
from ._src.forward import fwd_acceleration as fwd_acceleration
diff --git a/mujoco_warp/_src/derivative.py b/mujoco_warp/_src/derivative.py
index 87df3a8b7..9c7548bd9 100644
--- a/mujoco_warp/_src/derivative.py
+++ b/mujoco_warp/_src/derivative.py
@@ -15,11 +15,14 @@
import warp as wp
+from . import forward
+from . import math
from .types import BiasType
from .types import Data
from .types import DisableBit
from .types import DynType
from .types import GainType
+from .types import JointType
from .types import Model
from .types import TileSet
from .types import vec10f
@@ -290,3 +293,550 @@ def deriv_smooth_vel(m: Model, d: Data, out: wp.array2d(dtype=float)):
)
# TODO(team): rne derivative
+
+
+@wp.kernel
+def _get_state(
+ # Model:
+ nq: int,
+ nv: int,
+ na: int,
+ # Data in:
+ qpos_in: wp.array2d(dtype=float),
+ qvel_in: wp.array2d(dtype=float),
+ act_in: wp.array2d(dtype=float),
+ # Out:
+ state_out: wp.array2d(dtype=float),
+):
+ # get state = [qpos, qvel, act]
+ worldid = wp.tid()
+ for i in range(nq):
+ state_out[worldid, i] = qpos_in[worldid, i]
+ if i < nv:
+ state_out[worldid, nq + i] = qvel_in[worldid, i]
+ for i in range(na):
+ state_out[worldid, nq + nv + i] = act_in[worldid, i]
+
+
+@wp.kernel
+def _set_state(
+ # Model:
+ nq: int,
+ nv: int,
+ na: int,
+ # In:
+ state_in: wp.array2d(dtype=float),
+ # Data out:
+ qpos_out: wp.array2d(dtype=float),
+ qvel_out: wp.array2d(dtype=float),
+ act_out: wp.array2d(dtype=float),
+):
+ # set state = [qpos, qvel, act]
+ worldid = wp.tid()
+ for i in range(nq):
+ qpos_out[worldid, i] = state_in[worldid, i]
+ if i < nv:
+ qvel_out[worldid, i] = state_in[worldid, nq + i]
+ for i in range(na):
+ act_out[worldid, i] = state_in[worldid, nq + nv + i]
+
+
+@wp.kernel
+def _state_diff(
+ # Model:
+ nq: int,
+ nv: int,
+ na: int,
+ njnt: int,
+ jnt_type: wp.array(dtype=int),
+ jnt_qposadr: wp.array(dtype=int),
+ jnt_dofadr: wp.array(dtype=int),
+ # In:
+ state1_in: wp.array2d(dtype=float),
+ state2_in: wp.array2d(dtype=float),
+ inv_h: float,
+ # Out:
+ ds_out: wp.array2d(dtype=float),
+):
+ # finite difference two state vectors: ds = (s2 - s1) / h
+ worldid = wp.tid()
+
+ # position difference via joint type
+ for jntid in range(njnt):
+ jnttype = jnt_type[jntid]
+ qpos_adr = jnt_qposadr[jntid]
+ dof_adr = jnt_dofadr[jntid]
+
+ if jnttype == JointType.FREE:
+ # linear position difference
+ ds_out[worldid, dof_adr + 0] = (state2_in[worldid, qpos_adr + 0] - state1_in[worldid, qpos_adr + 0]) * inv_h
+ ds_out[worldid, dof_adr + 1] = (state2_in[worldid, qpos_adr + 1] - state1_in[worldid, qpos_adr + 1]) * inv_h
+ ds_out[worldid, dof_adr + 2] = (state2_in[worldid, qpos_adr + 2] - state1_in[worldid, qpos_adr + 2]) * inv_h
+ # quaternion difference
+ q1 = wp.quat(
+ state1_in[worldid, qpos_adr + 3],
+ state1_in[worldid, qpos_adr + 4],
+ state1_in[worldid, qpos_adr + 5],
+ state1_in[worldid, qpos_adr + 6],
+ )
+ q2 = wp.quat(
+ state2_in[worldid, qpos_adr + 3],
+ state2_in[worldid, qpos_adr + 4],
+ state2_in[worldid, qpos_adr + 5],
+ state2_in[worldid, qpos_adr + 6],
+ )
+ dq = math.quat_sub(q2, q1)
+ ds_out[worldid, dof_adr + 3] = dq[0] * inv_h
+ ds_out[worldid, dof_adr + 4] = dq[1] * inv_h
+ ds_out[worldid, dof_adr + 5] = dq[2] * inv_h
+ elif jnttype == JointType.BALL:
+ q1 = wp.quat(
+ state1_in[worldid, qpos_adr + 0],
+ state1_in[worldid, qpos_adr + 1],
+ state1_in[worldid, qpos_adr + 2],
+ state1_in[worldid, qpos_adr + 3],
+ )
+ q2 = wp.quat(
+ state2_in[worldid, qpos_adr + 0],
+ state2_in[worldid, qpos_adr + 1],
+ state2_in[worldid, qpos_adr + 2],
+ state2_in[worldid, qpos_adr + 3],
+ )
+ dq = math.quat_sub(q2, q1)
+ ds_out[worldid, dof_adr + 0] = dq[0] * inv_h
+ ds_out[worldid, dof_adr + 1] = dq[1] * inv_h
+ ds_out[worldid, dof_adr + 2] = dq[2] * inv_h
+ else: # SLIDE, HINGE
+ ds_out[worldid, dof_adr] = (state2_in[worldid, qpos_adr] - state1_in[worldid, qpos_adr]) * inv_h
+
+ # velocity and activation difference
+ for i in range(nv):
+ ds_out[worldid, nv + i] = (state2_in[worldid, nq + i] - state1_in[worldid, nq + i]) * inv_h
+ for i in range(na):
+ ds_out[worldid, 2 * nv + i] = (state2_in[worldid, nq + nv + i] - state1_in[worldid, nq + nv + i]) * inv_h
+
+
+@wp.kernel
+def _perturb_position(
+ # Model:
+ nq: int,
+ njnt: int,
+ jnt_type: wp.array(dtype=int),
+ jnt_qposadr: wp.array(dtype=int),
+ jnt_dofadr: wp.array(dtype=int),
+ # Data in:
+ qpos_in: wp.array2d(dtype=float),
+ # In:
+ dof_idx: int,
+ eps: float,
+ # Data out:
+ qpos_out: wp.array2d(dtype=float),
+):
+ worldid = wp.tid()
+
+ # copy qpos_in to qpos_out
+ for i in range(nq):
+ qpos_out[worldid, i] = qpos_in[worldid, i]
+
+ # find joint for this dof and perturb
+ for jntid in range(njnt):
+ jnttype = jnt_type[jntid]
+ qpos_adr = jnt_qposadr[jntid]
+ dof_adr = jnt_dofadr[jntid]
+
+ if jnttype == JointType.FREE:
+ if dof_idx >= dof_adr and dof_idx < dof_adr + 3:
+ qpos_out[worldid, qpos_adr + (dof_idx - dof_adr)] += eps
+ elif dof_idx >= dof_adr + 3 and dof_idx < dof_adr + 6:
+ q = wp.quat(
+ qpos_in[worldid, qpos_adr + 3],
+ qpos_in[worldid, qpos_adr + 4],
+ qpos_in[worldid, qpos_adr + 5],
+ qpos_in[worldid, qpos_adr + 6],
+ )
+ local_idx = dof_idx - dof_adr - 3
+ if local_idx == 0:
+ v = wp.vec3(1.0, 0.0, 0.0)
+ elif local_idx == 1:
+ v = wp.vec3(0.0, 1.0, 0.0)
+ else:
+ v = wp.vec3(0.0, 0.0, 1.0)
+ q_new = math.quat_integrate(q, v, eps)
+ qpos_out[worldid, qpos_adr + 3] = q_new[0]
+ qpos_out[worldid, qpos_adr + 4] = q_new[1]
+ qpos_out[worldid, qpos_adr + 5] = q_new[2]
+ qpos_out[worldid, qpos_adr + 6] = q_new[3]
+ elif jnttype == JointType.BALL:
+ if dof_idx >= dof_adr and dof_idx < dof_adr + 3:
+ q = wp.quat(
+ qpos_in[worldid, qpos_adr + 0],
+ qpos_in[worldid, qpos_adr + 1],
+ qpos_in[worldid, qpos_adr + 2],
+ qpos_in[worldid, qpos_adr + 3],
+ )
+ local_idx = dof_idx - dof_adr
+ if local_idx == 0:
+ v = wp.vec3(1.0, 0.0, 0.0)
+ elif local_idx == 1:
+ v = wp.vec3(0.0, 1.0, 0.0)
+ else:
+ v = wp.vec3(0.0, 0.0, 1.0)
+ q_new = math.quat_integrate(q, v, eps)
+ qpos_out[worldid, qpos_adr + 0] = q_new[0]
+ qpos_out[worldid, qpos_adr + 1] = q_new[1]
+ qpos_out[worldid, qpos_adr + 2] = q_new[2]
+ qpos_out[worldid, qpos_adr + 3] = q_new[3]
+ else: # SLIDE, HINGE
+ if dof_idx == dof_adr:
+ qpos_out[worldid, qpos_adr] += eps
+
+
+@wp.kernel
+def _perturb_array(
+ # In:
+ idx: int,
+ eps: float,
+ arr_in: wp.array2d(dtype=float),
+ # Out:
+ arr_out: wp.array2d(dtype=float),
+):
+ worldid = wp.tid()
+ for i in range(arr_in.shape[1]):
+ if i == idx:
+ arr_out[worldid, i] = arr_in[worldid, i] + eps
+ else:
+ arr_out[worldid, i] = arr_in[worldid, i]
+
+
+@wp.kernel
+def _diff_vectors(
+ # In:
+ x1_in: wp.array2d(dtype=float),
+ x2_in: wp.array2d(dtype=float),
+ inv_h: float,
+ n: int,
+ # Out:
+ dx_out: wp.array2d(dtype=float),
+):
+ # dx = (x2 - x1) / h
+ worldid = wp.tid()
+ for i in range(n):
+ dx_out[worldid, i] = (x2_in[worldid, i] - x1_in[worldid, i]) * inv_h
+
+
+@wp.kernel
+def _copy_to_jacobian_col(
+ # In:
+ col_in: wp.array2d(dtype=float),
+ col_idx: int,
+ nrow: int,
+ # Out:
+ jac_out: wp.array3d(dtype=float),
+):
+ worldid = wp.tid()
+ for i in range(nrow):
+ jac_out[worldid, i, col_idx] = col_in[worldid, i]
+
+
+@event_scope
+def transition_fd(
+ m: Model,
+ d: Data,
+ eps: float,
+ centered: bool = False,
+ A: wp.array3d(dtype=float) = None,
+ B: wp.array3d(dtype=float) = None,
+ C: wp.array3d(dtype=float) = None,
+ D: wp.array3d(dtype=float) = None,
+):
+ """Finite differenced transition matrices (control theory notation).
+
+ Computes: d(x_next) = A*dx + B*du, d(sensor) = C*dx + D*du
+ where x = [qvel_diff, qvel, act] is the state in tangent space.
+
+ Args:
+ m: model
+ d: data
+ eps: finite difference epsilon
+ centered: if True, use centered differences
+ A: output state transition matrix (nworld, ndx, ndx) where ndx = 2*nv+na
+ B: output control transition matrix (nworld, ndx, nu)
+ C: output state observation matrix (nworld, nsensordata, ndx)
+ D: output control observation matrix (nworld, nsensordata, nu)
+ """
+ # TODO(team): add option for scratch memory
+
+ nq, nv, na, nu = m.nq, m.nv, m.na, m.nu
+ ns = m.nsensordata
+ ndx = 2 * nv + na
+ nworld = d.nworld
+
+ # skip sensor computations if not requested
+ skip_sensor = C is None and D is None
+
+ # save current state
+ state_size = nq + nv + na
+ state0 = wp.zeros((nworld, state_size), dtype=float)
+ ctrl0 = wp.zeros((nworld, nu), dtype=float) if nu > 0 else None
+ wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[state0])
+ if nu > 0:
+ wp.copy(ctrl0, d.ctrl)
+
+ # baseline step
+ forward.step(m, d)
+
+ # save baseline next state and sensors
+ next_state = wp.zeros((nworld, state_size), dtype=float)
+ wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_state])
+ sensor0 = None
+ if not skip_sensor:
+ sensor0 = wp.zeros((nworld, ns), dtype=float)
+ wp.copy(sensor0, d.sensordata)
+
+ # restore state
+ wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act])
+ if nu > 0:
+ wp.copy(d.ctrl, ctrl0)
+
+ # allocate work arrays
+ next_plus = wp.zeros((nworld, state_size), dtype=float)
+ next_minus = wp.zeros((nworld, state_size), dtype=float) if centered else None
+ ds = wp.zeros((nworld, ndx), dtype=float)
+ sensor_plus = wp.zeros((nworld, ns), dtype=float) if not skip_sensor else None
+ sensor_minus = wp.zeros((nworld, ns), dtype=float) if not skip_sensor and centered else None
+ dsensor = wp.zeros((nworld, ns), dtype=float) if not skip_sensor else None
+
+ inv_eps = 1.0 / eps
+ inv_2eps = 1.0 / (2.0 * eps) if centered else inv_eps
+
+ # finite difference controls
+ if (B is not None or D is not None) and nu > 0:
+ ctrl_temp = wp.zeros((nworld, nu), dtype=float)
+ for i in range(nu):
+ # nudge forward
+ wp.launch(_perturb_array, dim=nworld, inputs=[i, eps, ctrl0], outputs=[ctrl_temp])
+ wp.copy(d.ctrl, ctrl_temp)
+ forward.step(m, d)
+ wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus])
+ if not skip_sensor:
+ wp.copy(sensor_plus, d.sensordata)
+
+ # restore and nudge backward if centered
+ wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act])
+ wp.copy(d.ctrl, ctrl0)
+
+ if centered:
+ wp.launch(_perturb_array, dim=nworld, inputs=[i, -eps, ctrl0], outputs=[ctrl_temp])
+ wp.copy(d.ctrl, ctrl_temp)
+ forward.step(m, d)
+ wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus])
+ if not skip_sensor:
+ wp.copy(sensor_minus, d.sensordata)
+ wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act])
+ wp.copy(d.ctrl, ctrl0)
+
+ # compute derivatives
+ if B is not None:
+ if centered:
+ wp.launch(
+ _state_diff,
+ dim=nworld,
+ inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps],
+ outputs=[ds],
+ )
+ else:
+ wp.launch(
+ _state_diff,
+ dim=nworld,
+ inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps],
+ outputs=[ds],
+ )
+ wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[ds, i, ndx], outputs=[B])
+
+ if D is not None:
+ if centered:
+ wp.launch(_diff_vectors, dim=nworld, inputs=[sensor_plus, sensor_minus, inv_2eps, ns], outputs=[dsensor])
+ else:
+ wp.launch(_diff_vectors, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns], outputs=[dsensor])
+ wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[dsensor, i, ns], outputs=[D])
+
+ # finite difference activations
+ if (A is not None or C is not None) and na > 0:
+ act0 = wp.zeros((nworld, na), dtype=float)
+ wp.copy(act0, d.act)
+ act_temp = wp.zeros((nworld, na), dtype=float)
+ for i in range(na):
+ # nudge forward
+ wp.launch(_perturb_array, dim=nworld, inputs=[i, eps, act0], outputs=[act_temp])
+ wp.copy(d.act, act_temp)
+ forward.step(m, d)
+ wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus])
+ if not skip_sensor:
+ wp.copy(sensor_plus, d.sensordata)
+
+ # restore and nudge backward if centered
+ wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act])
+ if nu > 0:
+ wp.copy(d.ctrl, ctrl0)
+
+ if centered:
+ wp.launch(_perturb_array, dim=nworld, inputs=[i, -eps, act0], outputs=[act_temp])
+ wp.copy(d.act, act_temp)
+ forward.step(m, d)
+ wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus])
+ if not skip_sensor:
+ wp.copy(sensor_minus, d.sensordata)
+ wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act])
+ if nu > 0:
+ wp.copy(d.ctrl, ctrl0)
+
+ # compute derivatives
+ col_idx = 2 * nv + i
+ if A is not None:
+ if centered:
+ wp.launch(
+ _state_diff,
+ dim=nworld,
+ inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps],
+ outputs=[ds],
+ )
+ else:
+ wp.launch(
+ _state_diff,
+ dim=nworld,
+ inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps],
+ outputs=[ds],
+ )
+ wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[ds, col_idx, ndx], outputs=[A])
+
+ if C is not None:
+ if centered:
+ wp.launch(_diff_vectors, dim=nworld, inputs=[sensor_minus, sensor_plus, inv_2eps, ns], outputs=[dsensor])
+ else:
+ wp.launch(_diff_vectors, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns], outputs=[dsensor])
+ wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[dsensor, col_idx, ns], outputs=[C])
+
+ # finite difference velocities
+ if A is not None or C is not None:
+ qvel0 = wp.zeros((nworld, nv), dtype=float)
+ wp.copy(qvel0, d.qvel)
+ qvel_temp = wp.zeros((nworld, nv), dtype=float)
+ for i in range(nv):
+ # nudge forward
+ wp.launch(_perturb_array, dim=nworld, inputs=[i, eps, qvel0], outputs=[qvel_temp])
+ wp.copy(d.qvel, qvel_temp)
+ forward.step(m, d)
+ wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus])
+ if not skip_sensor:
+ wp.copy(sensor_plus, d.sensordata)
+
+ # restore and nudge backward if centered
+ wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act])
+ if nu > 0:
+ wp.copy(d.ctrl, ctrl0)
+
+ if centered:
+ wp.launch(_perturb_array, dim=nworld, inputs=[i, -eps, qvel0], outputs=[qvel_temp])
+ wp.copy(d.qvel, qvel_temp)
+ forward.step(m, d)
+ wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus])
+ if not skip_sensor:
+ wp.copy(sensor_minus, d.sensordata)
+ wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act])
+ if nu > 0:
+ wp.copy(d.ctrl, ctrl0)
+
+ # compute derivatives
+ col_idx = nv + i
+ if A is not None:
+ if centered:
+ wp.launch(
+ _state_diff,
+ dim=nworld,
+ inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps],
+ outputs=[ds],
+ )
+ else:
+ wp.launch(
+ _state_diff,
+ dim=nworld,
+ inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps],
+ outputs=[ds],
+ )
+ wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[ds, col_idx, ndx], outputs=[A])
+
+ if C is not None:
+ if centered:
+ wp.launch(_diff_vectors, dim=nworld, inputs=[sensor_minus, sensor_plus, inv_2eps, ns], outputs=[dsensor])
+ else:
+ wp.launch(_diff_vectors, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns], outputs=[dsensor])
+ wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[dsensor, col_idx, ns], outputs=[C])
+
+ # finite difference positions
+ if A is not None or C is not None:
+ qpos_perturbed = wp.zeros((nworld, nq), dtype=float)
+ for i in range(nv):
+ # nudge position forward
+ wp.launch(
+ _perturb_position,
+ dim=nworld,
+ inputs=[nq, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, d.qpos, i, eps],
+ outputs=[qpos_perturbed],
+ )
+ wp.copy(d.qpos, qpos_perturbed)
+ forward.step(m, d)
+ wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus])
+ if not skip_sensor:
+ wp.copy(sensor_plus, d.sensordata)
+
+ # restore and nudge backward if centered
+ wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act])
+ if nu > 0:
+ wp.copy(d.ctrl, ctrl0)
+
+ if centered:
+ wp.launch(
+ _perturb_position,
+ dim=nworld,
+ inputs=[nq, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, d.qpos, i, -eps],
+ outputs=[qpos_perturbed],
+ )
+ wp.copy(d.qpos, qpos_perturbed)
+ forward.step(m, d)
+ wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus])
+ if not skip_sensor:
+ wp.copy(sensor_minus, d.sensordata)
+ wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act])
+ if nu > 0:
+ wp.copy(d.ctrl, ctrl0)
+
+ # compute derivatives
+ col_idx = i
+ if A is not None:
+ if centered:
+ wp.launch(
+ _state_diff,
+ dim=nworld,
+ inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps],
+ outputs=[ds],
+ )
+ else:
+ wp.launch(
+ _state_diff,
+ dim=nworld,
+ inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps],
+ outputs=[ds],
+ )
+ wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[ds, col_idx, ndx], outputs=[A])
+
+ if C is not None:
+ if centered:
+ wp.launch(_diff_vectors, dim=nworld, inputs=[sensor_minus, sensor_plus, inv_2eps, ns], outputs=[dsensor])
+ else:
+ wp.launch(_diff_vectors, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns], outputs=[dsensor])
+ wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[dsensor, col_idx, ns], outputs=[C])
+
+ # restore final state
+ wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act])
+ if nu > 0:
+ wp.copy(d.ctrl, ctrl0)
diff --git a/mujoco_warp/_src/derivative_test.py b/mujoco_warp/_src/derivative_test.py
index 6ce0e5c7b..5c390498d 100644
--- a/mujoco_warp/_src/derivative_test.py
+++ b/mujoco_warp/_src/derivative_test.py
@@ -24,13 +24,9 @@
import mujoco_warp as mjw
from mujoco_warp import test_data
-# tolerance for difference between MuJoCo and mjwarp smooth calculations - mostly
-# due to float precision
-_TOLERANCE = 5e-5
-
def _assert_eq(a, b, name):
- tol = _TOLERANCE * 10 # avoid test noise
+ tol = 1e-4
err_msg = f"mismatch: {name}"
np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol)
@@ -118,6 +114,179 @@ def test_smooth_vel(self, jacobian):
_assert_eq(mjw_out, mj_out, "qM - dt * qDeriv")
+ @parameterized.parameters(False, True)
+ def test_transition_fd_linear_system(self, centered):
+ """Tests A and B matrices match MuJoCo mjd_transitionFD."""
+ # simple linear system with 3 slide joints
+ mjm, mjd, m, d = test_data.fixture(
+ xml="""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ keyframe=0,
+ )
+
+ # larger eps needed for float32 precision
+ eps = 1e-3
+ ndx = 2 * mjm.nv + mjm.na
+
+ # mujoco reference
+ A_mj = np.zeros((ndx, ndx))
+ B_mj = np.zeros((ndx, mjm.nu))
+ mujoco.mjd_transitionFD(mjm, mjd, eps, centered, A_mj, B_mj, None, None)
+
+ # mujoco warp
+ A_mjw = wp.zeros((1, ndx, ndx), dtype=float)
+ B_mjw = wp.zeros((1, ndx, mjm.nu), dtype=float)
+ mjw.transition_fd(m, d, eps, centered, A_mjw, B_mjw, None, None)
+
+ _assert_eq(A_mjw.numpy()[0], A_mj, "A")
+ _assert_eq(B_mjw.numpy()[0], B_mj, "B")
+
+ @parameterized.parameters(False, True)
+ def test_transition_fd_sensor_derivatives(self, centered):
+ """Tests C and D matrices against MuJoCo mjd_transitionFD."""
+ mjm, mjd, m, d = test_data.fixture(
+ xml="""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ )
+
+ # larger eps needed for float32 precision
+ eps = 1e-3
+ nv = mjm.nv
+ nu = mjm.nu
+ ns = mjm.nsensordata
+ ndx = 2 * nv + mjm.na
+
+ # mujoco reference
+ C_mj = np.zeros((ns, ndx))
+ D_mj = np.zeros((ns, nu))
+ mujoco.mjd_transitionFD(mjm, mjd, eps, centered, None, None, C_mj, D_mj)
+
+ # mujoco warp
+ C_mjw = wp.zeros((1, ns, ndx), dtype=float)
+ D_mjw = wp.zeros((1, ns, nu), dtype=float)
+ mjw.transition_fd(m, d, eps, centered, None, None, C_mjw, D_mjw)
+
+ _assert_eq(C_mjw.numpy()[0], C_mj, "C")
+ _assert_eq(D_mjw.numpy()[0], D_mj, "D")
+
+ @parameterized.parameters(False, True)
+ def test_transition_fd_clamped_ctrl(self, centered):
+ """Tests that B matrix is zero when ctrl is at or beyond limits."""
+ mjm, mjd, m, d = test_data.fixture(
+ xml="""
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ )
+
+ eps = 1e-3
+ nv = mjm.nv
+ nu = mjm.nu
+ ndx = 2 * nv + mjm.na
+
+ # set ctrl beyond limits
+ mjd.ctrl[0] = 2.0
+ d.ctrl.fill_(2.0)
+
+ # mujoco reference - B should be zero
+ B_mj = np.zeros((ndx, nu))
+ mujoco.mjd_transitionFD(mjm, mjd, eps, centered, None, B_mj, None, None)
+
+ # mujoco warp
+ B_mjw = wp.zeros((1, ndx, nu), dtype=float)
+ mjw.transition_fd(m, d, eps, centered, None, B_mjw, None, None)
+
+ # expect B to be zero since ctrl is beyond limits
+ _assert_eq(B_mjw.numpy()[0], B_mj, "B clamped")
+ np.testing.assert_allclose(B_mj, 0.0, atol=1e-10)
+
+ def test_transition_fd_no_state_mutation(self):
+ """Tests that transition_fd does not mutate state."""
+ mjm, mjd, m, d = test_data.fixture(
+ xml="""
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """,
+ keyframe=0,
+ )
+
+ # save state before
+ qpos_before = d.qpos.numpy().copy()
+ qvel_before = d.qvel.numpy().copy()
+ ctrl_before = d.ctrl.numpy().copy()
+
+ # call transition_fd
+ eps = 1e-3
+ ndx = 2 * m.nv + m.na
+ A = wp.zeros((1, ndx, ndx), dtype=float)
+ B = wp.zeros((1, ndx, m.nu), dtype=float)
+ mjw.transition_fd(m, d, eps, False, A, B, None, None)
+
+ # check state unchanged
+ _assert_eq(d.qpos.numpy(), qpos_before, "qpos")
+ _assert_eq(d.qvel.numpy(), qvel_before, "qvel")
+ _assert_eq(d.ctrl.numpy(), ctrl_before, "ctrl")
+
if __name__ == "__main__":
wp.init()