diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py index 87ef6c92a..b5a94ee57 100644 --- a/mujoco_warp/__init__.py +++ b/mujoco_warp/__init__.py @@ -28,6 +28,7 @@ from ._src.collision_sdf import sdf_narrowphase as sdf_narrowphase from ._src.constraint import make_constraint as make_constraint from ._src.derivative import deriv_smooth_vel as deriv_smooth_vel +from ._src.derivative import transition_fd as transition_fd from ._src.forward import euler as euler from ._src.forward import forward as forward from ._src.forward import fwd_acceleration as fwd_acceleration diff --git a/mujoco_warp/_src/derivative.py b/mujoco_warp/_src/derivative.py index 87df3a8b7..9c7548bd9 100644 --- a/mujoco_warp/_src/derivative.py +++ b/mujoco_warp/_src/derivative.py @@ -15,11 +15,14 @@ import warp as wp +from . import forward +from . import math from .types import BiasType from .types import Data from .types import DisableBit from .types import DynType from .types import GainType +from .types import JointType from .types import Model from .types import TileSet from .types import vec10f @@ -290,3 +293,550 @@ def deriv_smooth_vel(m: Model, d: Data, out: wp.array2d(dtype=float)): ) # TODO(team): rne derivative + + +@wp.kernel +def _get_state( + # Model: + nq: int, + nv: int, + na: int, + # Data in: + qpos_in: wp.array2d(dtype=float), + qvel_in: wp.array2d(dtype=float), + act_in: wp.array2d(dtype=float), + # Out: + state_out: wp.array2d(dtype=float), +): + # get state = [qpos, qvel, act] + worldid = wp.tid() + for i in range(nq): + state_out[worldid, i] = qpos_in[worldid, i] + if i < nv: + state_out[worldid, nq + i] = qvel_in[worldid, i] + for i in range(na): + state_out[worldid, nq + nv + i] = act_in[worldid, i] + + +@wp.kernel +def _set_state( + # Model: + nq: int, + nv: int, + na: int, + # In: + state_in: wp.array2d(dtype=float), + # Data out: + qpos_out: wp.array2d(dtype=float), + qvel_out: wp.array2d(dtype=float), + act_out: wp.array2d(dtype=float), +): + # set state = [qpos, qvel, act] + worldid = wp.tid() + for i in range(nq): + qpos_out[worldid, i] = state_in[worldid, i] + if i < nv: + qvel_out[worldid, i] = state_in[worldid, nq + i] + for i in range(na): + act_out[worldid, i] = state_in[worldid, nq + nv + i] + + +@wp.kernel +def _state_diff( + # Model: + nq: int, + nv: int, + na: int, + njnt: int, + jnt_type: wp.array(dtype=int), + jnt_qposadr: wp.array(dtype=int), + jnt_dofadr: wp.array(dtype=int), + # In: + state1_in: wp.array2d(dtype=float), + state2_in: wp.array2d(dtype=float), + inv_h: float, + # Out: + ds_out: wp.array2d(dtype=float), +): + # finite difference two state vectors: ds = (s2 - s1) / h + worldid = wp.tid() + + # position difference via joint type + for jntid in range(njnt): + jnttype = jnt_type[jntid] + qpos_adr = jnt_qposadr[jntid] + dof_adr = jnt_dofadr[jntid] + + if jnttype == JointType.FREE: + # linear position difference + ds_out[worldid, dof_adr + 0] = (state2_in[worldid, qpos_adr + 0] - state1_in[worldid, qpos_adr + 0]) * inv_h + ds_out[worldid, dof_adr + 1] = (state2_in[worldid, qpos_adr + 1] - state1_in[worldid, qpos_adr + 1]) * inv_h + ds_out[worldid, dof_adr + 2] = (state2_in[worldid, qpos_adr + 2] - state1_in[worldid, qpos_adr + 2]) * inv_h + # quaternion difference + q1 = wp.quat( + state1_in[worldid, qpos_adr + 3], + state1_in[worldid, qpos_adr + 4], + state1_in[worldid, qpos_adr + 5], + state1_in[worldid, qpos_adr + 6], + ) + q2 = wp.quat( + state2_in[worldid, qpos_adr + 3], + state2_in[worldid, qpos_adr + 4], + state2_in[worldid, qpos_adr + 5], + state2_in[worldid, qpos_adr + 6], + ) + dq = math.quat_sub(q2, q1) + ds_out[worldid, dof_adr + 3] = dq[0] * inv_h + ds_out[worldid, dof_adr + 4] = dq[1] * inv_h + ds_out[worldid, dof_adr + 5] = dq[2] * inv_h + elif jnttype == JointType.BALL: + q1 = wp.quat( + state1_in[worldid, qpos_adr + 0], + state1_in[worldid, qpos_adr + 1], + state1_in[worldid, qpos_adr + 2], + state1_in[worldid, qpos_adr + 3], + ) + q2 = wp.quat( + state2_in[worldid, qpos_adr + 0], + state2_in[worldid, qpos_adr + 1], + state2_in[worldid, qpos_adr + 2], + state2_in[worldid, qpos_adr + 3], + ) + dq = math.quat_sub(q2, q1) + ds_out[worldid, dof_adr + 0] = dq[0] * inv_h + ds_out[worldid, dof_adr + 1] = dq[1] * inv_h + ds_out[worldid, dof_adr + 2] = dq[2] * inv_h + else: # SLIDE, HINGE + ds_out[worldid, dof_adr] = (state2_in[worldid, qpos_adr] - state1_in[worldid, qpos_adr]) * inv_h + + # velocity and activation difference + for i in range(nv): + ds_out[worldid, nv + i] = (state2_in[worldid, nq + i] - state1_in[worldid, nq + i]) * inv_h + for i in range(na): + ds_out[worldid, 2 * nv + i] = (state2_in[worldid, nq + nv + i] - state1_in[worldid, nq + nv + i]) * inv_h + + +@wp.kernel +def _perturb_position( + # Model: + nq: int, + njnt: int, + jnt_type: wp.array(dtype=int), + jnt_qposadr: wp.array(dtype=int), + jnt_dofadr: wp.array(dtype=int), + # Data in: + qpos_in: wp.array2d(dtype=float), + # In: + dof_idx: int, + eps: float, + # Data out: + qpos_out: wp.array2d(dtype=float), +): + worldid = wp.tid() + + # copy qpos_in to qpos_out + for i in range(nq): + qpos_out[worldid, i] = qpos_in[worldid, i] + + # find joint for this dof and perturb + for jntid in range(njnt): + jnttype = jnt_type[jntid] + qpos_adr = jnt_qposadr[jntid] + dof_adr = jnt_dofadr[jntid] + + if jnttype == JointType.FREE: + if dof_idx >= dof_adr and dof_idx < dof_adr + 3: + qpos_out[worldid, qpos_adr + (dof_idx - dof_adr)] += eps + elif dof_idx >= dof_adr + 3 and dof_idx < dof_adr + 6: + q = wp.quat( + qpos_in[worldid, qpos_adr + 3], + qpos_in[worldid, qpos_adr + 4], + qpos_in[worldid, qpos_adr + 5], + qpos_in[worldid, qpos_adr + 6], + ) + local_idx = dof_idx - dof_adr - 3 + if local_idx == 0: + v = wp.vec3(1.0, 0.0, 0.0) + elif local_idx == 1: + v = wp.vec3(0.0, 1.0, 0.0) + else: + v = wp.vec3(0.0, 0.0, 1.0) + q_new = math.quat_integrate(q, v, eps) + qpos_out[worldid, qpos_adr + 3] = q_new[0] + qpos_out[worldid, qpos_adr + 4] = q_new[1] + qpos_out[worldid, qpos_adr + 5] = q_new[2] + qpos_out[worldid, qpos_adr + 6] = q_new[3] + elif jnttype == JointType.BALL: + if dof_idx >= dof_adr and dof_idx < dof_adr + 3: + q = wp.quat( + qpos_in[worldid, qpos_adr + 0], + qpos_in[worldid, qpos_adr + 1], + qpos_in[worldid, qpos_adr + 2], + qpos_in[worldid, qpos_adr + 3], + ) + local_idx = dof_idx - dof_adr + if local_idx == 0: + v = wp.vec3(1.0, 0.0, 0.0) + elif local_idx == 1: + v = wp.vec3(0.0, 1.0, 0.0) + else: + v = wp.vec3(0.0, 0.0, 1.0) + q_new = math.quat_integrate(q, v, eps) + qpos_out[worldid, qpos_adr + 0] = q_new[0] + qpos_out[worldid, qpos_adr + 1] = q_new[1] + qpos_out[worldid, qpos_adr + 2] = q_new[2] + qpos_out[worldid, qpos_adr + 3] = q_new[3] + else: # SLIDE, HINGE + if dof_idx == dof_adr: + qpos_out[worldid, qpos_adr] += eps + + +@wp.kernel +def _perturb_array( + # In: + idx: int, + eps: float, + arr_in: wp.array2d(dtype=float), + # Out: + arr_out: wp.array2d(dtype=float), +): + worldid = wp.tid() + for i in range(arr_in.shape[1]): + if i == idx: + arr_out[worldid, i] = arr_in[worldid, i] + eps + else: + arr_out[worldid, i] = arr_in[worldid, i] + + +@wp.kernel +def _diff_vectors( + # In: + x1_in: wp.array2d(dtype=float), + x2_in: wp.array2d(dtype=float), + inv_h: float, + n: int, + # Out: + dx_out: wp.array2d(dtype=float), +): + # dx = (x2 - x1) / h + worldid = wp.tid() + for i in range(n): + dx_out[worldid, i] = (x2_in[worldid, i] - x1_in[worldid, i]) * inv_h + + +@wp.kernel +def _copy_to_jacobian_col( + # In: + col_in: wp.array2d(dtype=float), + col_idx: int, + nrow: int, + # Out: + jac_out: wp.array3d(dtype=float), +): + worldid = wp.tid() + for i in range(nrow): + jac_out[worldid, i, col_idx] = col_in[worldid, i] + + +@event_scope +def transition_fd( + m: Model, + d: Data, + eps: float, + centered: bool = False, + A: wp.array3d(dtype=float) = None, + B: wp.array3d(dtype=float) = None, + C: wp.array3d(dtype=float) = None, + D: wp.array3d(dtype=float) = None, +): + """Finite differenced transition matrices (control theory notation). + + Computes: d(x_next) = A*dx + B*du, d(sensor) = C*dx + D*du + where x = [qvel_diff, qvel, act] is the state in tangent space. + + Args: + m: model + d: data + eps: finite difference epsilon + centered: if True, use centered differences + A: output state transition matrix (nworld, ndx, ndx) where ndx = 2*nv+na + B: output control transition matrix (nworld, ndx, nu) + C: output state observation matrix (nworld, nsensordata, ndx) + D: output control observation matrix (nworld, nsensordata, nu) + """ + # TODO(team): add option for scratch memory + + nq, nv, na, nu = m.nq, m.nv, m.na, m.nu + ns = m.nsensordata + ndx = 2 * nv + na + nworld = d.nworld + + # skip sensor computations if not requested + skip_sensor = C is None and D is None + + # save current state + state_size = nq + nv + na + state0 = wp.zeros((nworld, state_size), dtype=float) + ctrl0 = wp.zeros((nworld, nu), dtype=float) if nu > 0 else None + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[state0]) + if nu > 0: + wp.copy(ctrl0, d.ctrl) + + # baseline step + forward.step(m, d) + + # save baseline next state and sensors + next_state = wp.zeros((nworld, state_size), dtype=float) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_state]) + sensor0 = None + if not skip_sensor: + sensor0 = wp.zeros((nworld, ns), dtype=float) + wp.copy(sensor0, d.sensordata) + + # restore state + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + if nu > 0: + wp.copy(d.ctrl, ctrl0) + + # allocate work arrays + next_plus = wp.zeros((nworld, state_size), dtype=float) + next_minus = wp.zeros((nworld, state_size), dtype=float) if centered else None + ds = wp.zeros((nworld, ndx), dtype=float) + sensor_plus = wp.zeros((nworld, ns), dtype=float) if not skip_sensor else None + sensor_minus = wp.zeros((nworld, ns), dtype=float) if not skip_sensor and centered else None + dsensor = wp.zeros((nworld, ns), dtype=float) if not skip_sensor else None + + inv_eps = 1.0 / eps + inv_2eps = 1.0 / (2.0 * eps) if centered else inv_eps + + # finite difference controls + if (B is not None or D is not None) and nu > 0: + ctrl_temp = wp.zeros((nworld, nu), dtype=float) + for i in range(nu): + # nudge forward + wp.launch(_perturb_array, dim=nworld, inputs=[i, eps, ctrl0], outputs=[ctrl_temp]) + wp.copy(d.ctrl, ctrl_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus]) + if not skip_sensor: + wp.copy(sensor_plus, d.sensordata) + + # restore and nudge backward if centered + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + wp.copy(d.ctrl, ctrl0) + + if centered: + wp.launch(_perturb_array, dim=nworld, inputs=[i, -eps, ctrl0], outputs=[ctrl_temp]) + wp.copy(d.ctrl, ctrl_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus]) + if not skip_sensor: + wp.copy(sensor_minus, d.sensordata) + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + wp.copy(d.ctrl, ctrl0) + + # compute derivatives + if B is not None: + if centered: + wp.launch( + _state_diff, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps], + outputs=[ds], + ) + else: + wp.launch( + _state_diff, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps], + outputs=[ds], + ) + wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[ds, i, ndx], outputs=[B]) + + if D is not None: + if centered: + wp.launch(_diff_vectors, dim=nworld, inputs=[sensor_plus, sensor_minus, inv_2eps, ns], outputs=[dsensor]) + else: + wp.launch(_diff_vectors, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns], outputs=[dsensor]) + wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[dsensor, i, ns], outputs=[D]) + + # finite difference activations + if (A is not None or C is not None) and na > 0: + act0 = wp.zeros((nworld, na), dtype=float) + wp.copy(act0, d.act) + act_temp = wp.zeros((nworld, na), dtype=float) + for i in range(na): + # nudge forward + wp.launch(_perturb_array, dim=nworld, inputs=[i, eps, act0], outputs=[act_temp]) + wp.copy(d.act, act_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus]) + if not skip_sensor: + wp.copy(sensor_plus, d.sensordata) + + # restore and nudge backward if centered + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + if nu > 0: + wp.copy(d.ctrl, ctrl0) + + if centered: + wp.launch(_perturb_array, dim=nworld, inputs=[i, -eps, act0], outputs=[act_temp]) + wp.copy(d.act, act_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus]) + if not skip_sensor: + wp.copy(sensor_minus, d.sensordata) + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + if nu > 0: + wp.copy(d.ctrl, ctrl0) + + # compute derivatives + col_idx = 2 * nv + i + if A is not None: + if centered: + wp.launch( + _state_diff, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps], + outputs=[ds], + ) + else: + wp.launch( + _state_diff, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps], + outputs=[ds], + ) + wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[ds, col_idx, ndx], outputs=[A]) + + if C is not None: + if centered: + wp.launch(_diff_vectors, dim=nworld, inputs=[sensor_minus, sensor_plus, inv_2eps, ns], outputs=[dsensor]) + else: + wp.launch(_diff_vectors, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns], outputs=[dsensor]) + wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[dsensor, col_idx, ns], outputs=[C]) + + # finite difference velocities + if A is not None or C is not None: + qvel0 = wp.zeros((nworld, nv), dtype=float) + wp.copy(qvel0, d.qvel) + qvel_temp = wp.zeros((nworld, nv), dtype=float) + for i in range(nv): + # nudge forward + wp.launch(_perturb_array, dim=nworld, inputs=[i, eps, qvel0], outputs=[qvel_temp]) + wp.copy(d.qvel, qvel_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus]) + if not skip_sensor: + wp.copy(sensor_plus, d.sensordata) + + # restore and nudge backward if centered + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + if nu > 0: + wp.copy(d.ctrl, ctrl0) + + if centered: + wp.launch(_perturb_array, dim=nworld, inputs=[i, -eps, qvel0], outputs=[qvel_temp]) + wp.copy(d.qvel, qvel_temp) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus]) + if not skip_sensor: + wp.copy(sensor_minus, d.sensordata) + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + if nu > 0: + wp.copy(d.ctrl, ctrl0) + + # compute derivatives + col_idx = nv + i + if A is not None: + if centered: + wp.launch( + _state_diff, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps], + outputs=[ds], + ) + else: + wp.launch( + _state_diff, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps], + outputs=[ds], + ) + wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[ds, col_idx, ndx], outputs=[A]) + + if C is not None: + if centered: + wp.launch(_diff_vectors, dim=nworld, inputs=[sensor_minus, sensor_plus, inv_2eps, ns], outputs=[dsensor]) + else: + wp.launch(_diff_vectors, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns], outputs=[dsensor]) + wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[dsensor, col_idx, ns], outputs=[C]) + + # finite difference positions + if A is not None or C is not None: + qpos_perturbed = wp.zeros((nworld, nq), dtype=float) + for i in range(nv): + # nudge position forward + wp.launch( + _perturb_position, + dim=nworld, + inputs=[nq, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, d.qpos, i, eps], + outputs=[qpos_perturbed], + ) + wp.copy(d.qpos, qpos_perturbed) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_plus]) + if not skip_sensor: + wp.copy(sensor_plus, d.sensordata) + + # restore and nudge backward if centered + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + if nu > 0: + wp.copy(d.ctrl, ctrl0) + + if centered: + wp.launch( + _perturb_position, + dim=nworld, + inputs=[nq, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, d.qpos, i, -eps], + outputs=[qpos_perturbed], + ) + wp.copy(d.qpos, qpos_perturbed) + forward.step(m, d) + wp.launch(_get_state, dim=nworld, inputs=[nq, nv, na, d.qpos, d.qvel, d.act], outputs=[next_minus]) + if not skip_sensor: + wp.copy(sensor_minus, d.sensordata) + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + if nu > 0: + wp.copy(d.ctrl, ctrl0) + + # compute derivatives + col_idx = i + if A is not None: + if centered: + wp.launch( + _state_diff, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_minus, next_plus, inv_2eps], + outputs=[ds], + ) + else: + wp.launch( + _state_diff, + dim=nworld, + inputs=[nq, nv, na, m.njnt, m.jnt_type, m.jnt_qposadr, m.jnt_dofadr, next_state, next_plus, inv_eps], + outputs=[ds], + ) + wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[ds, col_idx, ndx], outputs=[A]) + + if C is not None: + if centered: + wp.launch(_diff_vectors, dim=nworld, inputs=[sensor_minus, sensor_plus, inv_2eps, ns], outputs=[dsensor]) + else: + wp.launch(_diff_vectors, dim=nworld, inputs=[sensor0, sensor_plus, inv_eps, ns], outputs=[dsensor]) + wp.launch(_copy_to_jacobian_col, dim=nworld, inputs=[dsensor, col_idx, ns], outputs=[C]) + + # restore final state + wp.launch(_set_state, dim=nworld, inputs=[nq, nv, na, state0], outputs=[d.qpos, d.qvel, d.act]) + if nu > 0: + wp.copy(d.ctrl, ctrl0) diff --git a/mujoco_warp/_src/derivative_test.py b/mujoco_warp/_src/derivative_test.py index 6ce0e5c7b..5c390498d 100644 --- a/mujoco_warp/_src/derivative_test.py +++ b/mujoco_warp/_src/derivative_test.py @@ -24,13 +24,9 @@ import mujoco_warp as mjw from mujoco_warp import test_data -# tolerance for difference between MuJoCo and mjwarp smooth calculations - mostly -# due to float precision -_TOLERANCE = 5e-5 - def _assert_eq(a, b, name): - tol = _TOLERANCE * 10 # avoid test noise + tol = 1e-4 err_msg = f"mismatch: {name}" np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol) @@ -118,6 +114,179 @@ def test_smooth_vel(self, jacobian): _assert_eq(mjw_out, mj_out, "qM - dt * qDeriv") + @parameterized.parameters(False, True) + def test_transition_fd_linear_system(self, centered): + """Tests A and B matrices match MuJoCo mjd_transitionFD.""" + # simple linear system with 3 slide joints + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + + + + + + + + """, + keyframe=0, + ) + + # larger eps needed for float32 precision + eps = 1e-3 + ndx = 2 * mjm.nv + mjm.na + + # mujoco reference + A_mj = np.zeros((ndx, ndx)) + B_mj = np.zeros((ndx, mjm.nu)) + mujoco.mjd_transitionFD(mjm, mjd, eps, centered, A_mj, B_mj, None, None) + + # mujoco warp + A_mjw = wp.zeros((1, ndx, ndx), dtype=float) + B_mjw = wp.zeros((1, ndx, mjm.nu), dtype=float) + mjw.transition_fd(m, d, eps, centered, A_mjw, B_mjw, None, None) + + _assert_eq(A_mjw.numpy()[0], A_mj, "A") + _assert_eq(B_mjw.numpy()[0], B_mj, "B") + + @parameterized.parameters(False, True) + def test_transition_fd_sensor_derivatives(self, centered): + """Tests C and D matrices against MuJoCo mjd_transitionFD.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + + + """, + ) + + # larger eps needed for float32 precision + eps = 1e-3 + nv = mjm.nv + nu = mjm.nu + ns = mjm.nsensordata + ndx = 2 * nv + mjm.na + + # mujoco reference + C_mj = np.zeros((ns, ndx)) + D_mj = np.zeros((ns, nu)) + mujoco.mjd_transitionFD(mjm, mjd, eps, centered, None, None, C_mj, D_mj) + + # mujoco warp + C_mjw = wp.zeros((1, ns, ndx), dtype=float) + D_mjw = wp.zeros((1, ns, nu), dtype=float) + mjw.transition_fd(m, d, eps, centered, None, None, C_mjw, D_mjw) + + _assert_eq(C_mjw.numpy()[0], C_mj, "C") + _assert_eq(D_mjw.numpy()[0], D_mj, "D") + + @parameterized.parameters(False, True) + def test_transition_fd_clamped_ctrl(self, centered): + """Tests that B matrix is zero when ctrl is at or beyond limits.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + """, + ) + + eps = 1e-3 + nv = mjm.nv + nu = mjm.nu + ndx = 2 * nv + mjm.na + + # set ctrl beyond limits + mjd.ctrl[0] = 2.0 + d.ctrl.fill_(2.0) + + # mujoco reference - B should be zero + B_mj = np.zeros((ndx, nu)) + mujoco.mjd_transitionFD(mjm, mjd, eps, centered, None, B_mj, None, None) + + # mujoco warp + B_mjw = wp.zeros((1, ndx, nu), dtype=float) + mjw.transition_fd(m, d, eps, centered, None, B_mjw, None, None) + + # expect B to be zero since ctrl is beyond limits + _assert_eq(B_mjw.numpy()[0], B_mj, "B clamped") + np.testing.assert_allclose(B_mj, 0.0, atol=1e-10) + + def test_transition_fd_no_state_mutation(self): + """Tests that transition_fd does not mutate state.""" + mjm, mjd, m, d = test_data.fixture( + xml=""" + + + + + + + + + + + + + + + """, + keyframe=0, + ) + + # save state before + qpos_before = d.qpos.numpy().copy() + qvel_before = d.qvel.numpy().copy() + ctrl_before = d.ctrl.numpy().copy() + + # call transition_fd + eps = 1e-3 + ndx = 2 * m.nv + m.na + A = wp.zeros((1, ndx, ndx), dtype=float) + B = wp.zeros((1, ndx, m.nu), dtype=float) + mjw.transition_fd(m, d, eps, False, A, B, None, None) + + # check state unchanged + _assert_eq(d.qpos.numpy(), qpos_before, "qpos") + _assert_eq(d.qvel.numpy(), qvel_before, "qvel") + _assert_eq(d.ctrl.numpy(), ctrl_before, "ctrl") + if __name__ == "__main__": wp.init()