Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 74 additions & 40 deletions mujoco_warp/viewer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 The Newton Developers
# Copyright 2026 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,6 +62,7 @@ class EngineOptions(enum.IntEnum):
_KEYFRAME = flags.DEFINE_integer("keyframe", 0, "keyframe to initialize simulation.")
_DEVICE = flags.DEFINE_string("device", None, "override the default Warp device")
_REPLAY = flags.DEFINE_string("replay", None, "keyframe sequence to replay, keyframe name must prefix match")
_VIEWER = flags.DEFINE_enum("viewer", "mujoco", ["mujoco", "viser"], "Viewer backend (mujoco native or mjviser web)")

_VIEWER_GLOBAL_STATE = {"running": True, "step_once": False}

Expand Down Expand Up @@ -105,6 +106,70 @@ def _compile_step(m, d):
return capture.graph


def _make_warp_step_fn(mjm, m, d, graph, ctrls=None):
ctrlid = 0
opt = copy.copy(mjm.opt)

def step_fn(mjm, mjd):
nonlocal ctrlid, opt, m, graph
if ctrls is not None and ctrlid < len(ctrls):
mjd.ctrl[:] = ctrls[ctrlid]
ctrlid += 1
if mjm.opt != opt:
opt = copy.copy(mjm.opt)
m = mjw.put_model(mjm)
graph = _compile_step(m, d) if wp.get_device().is_cuda else None
wp.copy(d.ctrl, wp.array([mjd.ctrl.astype(np.float32)]))
wp.copy(d.act, wp.array([mjd.act.astype(np.float32)]))
wp.copy(d.xfrc_applied, wp.array([mjd.xfrc_applied.astype(np.float32)]))
wp.copy(d.qpos, wp.array([mjd.qpos.astype(np.float32)]))
wp.copy(d.qvel, wp.array([mjd.qvel.astype(np.float32)]))
wp.copy(d.time, wp.array([mjd.time], dtype=wp.float32))
if graph is None:
mjw.step(m, d)
else:
wp.capture_launch(graph)
wp.synchronize()
mjw.get_data_into(mjd, mjm, d)

return step_fn


def _make_c_step_fn(ctrls=None):
if ctrls is None:
return mujoco.mj_step

ctrlid = 0

def step_fn(mjm, mjd):
nonlocal ctrlid
if ctrlid < len(ctrls):
mjd.ctrl[:] = ctrls[ctrlid]
ctrlid += 1
mujoco.mj_step(mjm, mjd)

return step_fn


def _run_viser_viewer(mjm, mjd, step_fn):
from mjviser import Viewer as MjViserViewer

MjViserViewer(mjm, mjd, step_fn=step_fn).run()


def _run_passive_viewer(mjm, mjd, step_fn):
with mujoco.viewer.launch_passive(mjm, mjd, key_callback=key_callback) as viewer:
while True:
start = time.time()
if _VIEWER_GLOBAL_STATE["running"] or _VIEWER_GLOBAL_STATE["step_once"]:
_VIEWER_GLOBAL_STATE["step_once"] = False
step_fn(mjm, mjd)
viewer.sync()
elapsed = time.time() - start
if elapsed < mjm.opt.timestep:
time.sleep(mjm.opt.timestep - elapsed)


def _main(argv: Sequence[str]) -> None:
"""Runs viewer app."""
if len(argv) < 2:
Expand All @@ -115,7 +180,6 @@ def _main(argv: Sequence[str]) -> None:
mjm = _load_model(epath.Path(argv[1]))
mjd = mujoco.MjData(mjm)
ctrls = None
ctrlid = 0
if _REPLAY.value:
keys = find_keys(mjm, _REPLAY.value)
if not keys:
Expand Down Expand Up @@ -169,45 +233,15 @@ def _main(argv: Sequence[str]) -> None:
print(f"Data\n nworld: {d.nworld} nconmax: {int(d.naconmax / d.nworld)} njmax: {d.njmax}\n")
print(f"MuJoCo Warp simulating with dt = {m.opt.timestep.numpy()[0]:.3f}...")

with mujoco.viewer.launch_passive(mjm, mjd, key_callback=key_callback) as viewer:
opt = copy.copy(mjm.opt)

while True:
start = time.time()

if ctrls is not None and ctrlid < len(ctrls):
mjd.ctrl[:] = ctrls[ctrlid]
ctrlid += 1

if _ENGINE.value == EngineOptions.C:
mujoco.mj_step(mjm, mjd)
else: # mjwarp
wp.copy(d.ctrl, wp.array([mjd.ctrl.astype(np.float32)]))
wp.copy(d.act, wp.array([mjd.act.astype(np.float32)]))
wp.copy(d.xfrc_applied, wp.array([mjd.xfrc_applied.astype(np.float32)]))
wp.copy(d.qpos, wp.array([mjd.qpos.astype(np.float32)]))
wp.copy(d.qvel, wp.array([mjd.qvel.astype(np.float32)]))
wp.copy(d.time, wp.array([mjd.time], dtype=wp.float32))
# if the user changed an option in the MuJoCo Simulate UI, go ahead and recompile the step
# TODO: update memory tied to option max iterations
if mjm.opt != opt:
opt = copy.copy(mjm.opt)
m = mjw.put_model(mjm)
graph = _compile_step(m, d) if wp.get_device().is_cuda else None
if _VIEWER_GLOBAL_STATE["running"] or _VIEWER_GLOBAL_STATE["step_once"]:
_VIEWER_GLOBAL_STATE["step_once"] = False
if graph is None:
mjw.step(m, d)
else:
wp.capture_launch(graph)
wp.synchronize()
mjw.get_data_into(mjd, mjm, d)

viewer.sync()
if _ENGINE.value == EngineOptions.WARP:
step_fn = _make_warp_step_fn(mjm, m, d, graph, ctrls)
else:
step_fn = _make_c_step_fn(ctrls)

elapsed = time.time() - start
if elapsed < mjm.opt.timestep:
time.sleep(mjm.opt.timestep - elapsed)
if _VIEWER.value == "viser":
_run_viser_viewer(mjm, mjd, step_fn)
else:
_run_passive_viewer(mjm, mjd, step_fn)


def main():
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dev = [
"lsprotocol>=2023.0.1,<2024.0.0",
"mujoco>=3.6.0.dev0",
"warp-lang>=1.11.0.dev0",
"mjviser>=0.0.10",
]
# TODO(team): cpu and cuda JAX optional dependencies are temporary, remove after we land MJX:Warp
cpu = [
Expand Down
Loading
Loading