Skip to content

Commit ca805ed

Browse files
committed
upgrade to mujoco 3.5.0
1 parent fc91589 commit ca805ed

File tree

10 files changed

+43
-197
lines changed

10 files changed

+43
-197
lines changed

mujoco_warp/_src/broadphase_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from mujoco_warp import DisableBit
2727
from mujoco_warp import test_data
2828
from mujoco_warp._src import collision_driver
29-
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
3029

3130

3231
def broadphase_caller(m, d):
@@ -186,7 +185,7 @@ def test_broadphase(self, broadphase, filter):
186185
(0.011, 0, 1),
187186
(0.00999, 0, 0),
188187
(0, 0.00999, 0),
189-
(0.00999, 0.00999, 1 if BLEEDING_EDGE_MUJOCO else 0),
188+
(0.00999, 0.00999, 1),
190189
)
191190
def test_broadphase_margin(self, margin1, margin2, ncollision):
192191
_MJCF = f"""

mujoco_warp/_src/collision_convex.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from mujoco_warp._src.collision_primitive import contact_params
2525
from mujoco_warp._src.collision_primitive import geom_collision_pair
2626
from mujoco_warp._src.collision_primitive import write_contact
27-
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
2827
from mujoco_warp._src.math import make_frame
2928
from mujoco_warp._src.math import upper_trid_index
3029
from mujoco_warp._src.types import MJ_MAX_EPAFACES
@@ -92,10 +91,7 @@ def _hfield_filter(
9291
r2 = geom_rbound[rbound_id, g2]
9392

9493
# TODO(team): margin?
95-
if BLEEDING_EDGE_MUJOCO:
96-
margin = geom_margin[margin_id, g1] + geom_margin[margin_id, g2]
97-
else:
98-
margin = wp.max(geom_margin[margin_id, g1], geom_margin[margin_id, g2])
94+
margin = geom_margin[margin_id, g1] + geom_margin[margin_id, g2]
9995

10096
# box-sphere test: horizontal plane
10197
for i in range(2):

mujoco_warp/_src/collision_driver.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from mujoco_warp._src.collision_convex import convex_narrowphase
2121
from mujoco_warp._src.collision_primitive import primitive_narrowphase
2222
from mujoco_warp._src.collision_sdf import sdf_narrowphase
23-
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
2423
from mujoco_warp._src.math import upper_tri_index
2524
from mujoco_warp._src.types import MJ_MAXVAL
2625
from mujoco_warp._src.types import BroadphaseFilter
@@ -102,27 +101,18 @@ def _plane_filter(
102101
if size1 == 0.0:
103102
# geom1 is a plane
104103
dist = wp.dot(xpos2 - xpos1, wp.vec3(xmat1[0, 2], xmat1[1, 2], xmat1[2, 2]))
105-
if BLEEDING_EDGE_MUJOCO:
106-
return dist <= size2 + margin1 + margin2
107-
else:
108-
return dist <= size2 + wp.max(margin1, margin2)
104+
return dist <= size2 + margin1 + margin2
109105
elif size2 == 0.0:
110106
# geom2 is a plane
111107
dist = wp.dot(xpos1 - xpos2, wp.vec3(xmat2[0, 2], xmat2[1, 2], xmat2[2, 2]))
112-
if BLEEDING_EDGE_MUJOCO:
113-
return dist <= size1 + margin1 + margin2
114-
else:
115-
return dist <= size1 + wp.max(margin1, margin2)
108+
return dist <= size1 + margin1 + margin2
116109

117110
return True
118111

119112

120113
@wp.func
121114
def _sphere_filter(size1: float, size2: float, margin1: float, margin2: float, xpos1: wp.vec3, xpos2: wp.vec3) -> bool:
122-
if BLEEDING_EDGE_MUJOCO:
123-
bound = size1 + size2 + margin1 + margin2
124-
else:
125-
bound = size1 + size2 + wp.max(margin1, margin2)
115+
bound = size1 + size2 + margin1 + margin2
126116
dif = xpos2 - xpos1
127117
dist_sq = wp.dot(dif, dif)
128118
return dist_sq <= bound * bound
@@ -151,10 +141,7 @@ def _aabb_filter(
151141
center1 = xmat1 @ center1 + xpos1
152142
center2 = xmat2 @ center2 + xpos2
153143

154-
if BLEEDING_EDGE_MUJOCO:
155-
margin = margin1 + margin2
156-
else:
157-
margin = wp.max(margin1, margin2)
144+
margin = margin1 + margin2
158145

159146
max_x1 = -MJ_MAXVAL
160147
max_y1 = -MJ_MAXVAL
@@ -249,10 +236,7 @@ def _obb_filter(
249236
xmat2: wp.mat33,
250237
) -> bool:
251238
"""Oriented bounding boxes collision (see Gottschalk et al.), see mj_collideOBB."""
252-
if BLEEDING_EDGE_MUJOCO:
253-
margin = margin1 + margin2
254-
else:
255-
margin = wp.max(margin1, margin2)
239+
margin = margin1 + margin2
256240

257241
xcenter = mat23()
258242
normal = mat63()

mujoco_warp/_src/collision_primitive.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from mujoco_warp._src.collision_primitive_core import sphere_capsule
3030
from mujoco_warp._src.collision_primitive_core import sphere_cylinder
3131
from mujoco_warp._src.collision_primitive_core import sphere_sphere
32-
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
3332
from mujoco_warp._src.math import make_frame
3433
from mujoco_warp._src.math import safe_div
3534
from mujoco_warp._src.math import upper_trid_index
@@ -553,12 +552,8 @@ def contact_params(
553552
solreffriction = wp.vec2(0.0, 0.0)
554553
solimp = mix * geom_solimp[solimp_id, g1] + (1.0 - mix) * geom_solimp[solimp_id, g2]
555554
# geom priority is ignored
556-
if BLEEDING_EDGE_MUJOCO:
557-
margin = geom_margin[margin_id, g1] + geom_margin[margin_id, g2]
558-
gap = geom_gap[gap_id, g1] + geom_gap[gap_id, g2]
559-
else:
560-
margin = wp.max(geom_margin[margin_id, g1], geom_margin[margin_id, g2])
561-
gap = wp.max(geom_gap[gap_id, g1], geom_gap[gap_id, g2])
555+
margin = geom_margin[margin_id, g1] + geom_margin[margin_id, g2]
556+
gap = geom_gap[gap_id, g1] + geom_gap[gap_id, g2]
562557

563558
friction = vec5(
564559
wp.max(MJ_MINMU, friction[0]),

mujoco_warp/_src/io.py

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030

3131

3232
def _is_mujoco_fresh() -> bool:
33-
"""Checks if mujoco version is > 3.4.0."""
33+
"""Checks if mujoco version is > 3.5.0."""
3434
version = importlib.metadata.version("mujoco")
3535
version = version.split(".")
3636
version = tuple(map(int, version[:3])) + tuple(version[3:])
37-
return version > (3, 4, 0)
37+
return version > (3, 5, 0)
3838

3939

4040
BLEEDING_EDGE_MUJOCO = _is_mujoco_fresh()
@@ -583,17 +583,9 @@ def geom_trid_index(i, j):
583583
m.qM_mulm_madr.append(madr)
584584
m.qM_mulm_rowadr.append(len(m.qM_mulm_col))
585585

586-
# TODO(team): remove after mjwarp depends on mujoco > 3.4.0 in pyproject.toml
587-
if BLEEDING_EDGE_MUJOCO:
588-
m.flexedge_J_rownnz = mjm.flexedge_J_rownnz
589-
m.flexedge_J_rowadr = mjm.flexedge_J_rowadr
590-
m.flexedge_J_colind = mjm.flexedge_J_colind.reshape(-1)
591-
else:
592-
mjd = mujoco.MjData(mjm)
593-
mujoco.mj_forward(mjm, mjd)
594-
m.flexedge_J_rownnz = mjd.flexedge_J_rownnz
595-
m.flexedge_J_rowadr = mjd.flexedge_J_rowadr
596-
m.flexedge_J_colind = mjd.flexedge_J_colind.reshape(-1)
586+
m.flexedge_J_rownnz = mjm.flexedge_J_rownnz
587+
m.flexedge_J_rowadr = mjm.flexedge_J_rowadr
588+
m.flexedge_J_colind = mjm.flexedge_J_colind.reshape(-1)
597589

598590
# place m on device
599591
sizes = dict({"*": 1}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int})
@@ -1065,21 +1057,7 @@ def get_data_into(
10651057
result.cinert[:] = d.cinert.numpy()[world_id]
10661058
result.flexvert_xpos[:] = d.flexvert_xpos.numpy()[world_id]
10671059
if mjm.nflexedge > 0:
1068-
# TODO(team): remove after mjwarp depends on mujoco > 3.4.0 in pyproject.toml
1069-
if not BLEEDING_EDGE_MUJOCO:
1070-
m = put_model(mjm)
1071-
result.flexedge_J_rownnz[:] = m.flexedge_J_rownnz.numpy()
1072-
result.flexedge_J_rowadr[:] = m.flexedge_J_rowadr.numpy()
1073-
result.flexedge_J_colind[:, :] = m.flexedge_J_colind.numpy().reshape((mjm.nflexedge, mjm.nv))
1074-
mujoco.mju_sparse2dense(
1075-
result.flexedge_J,
1076-
d.flexedge_J.numpy()[world_id].reshape(-1),
1077-
m.flexedge_J_rownnz.numpy(),
1078-
m.flexedge_J_rowadr.numpy(),
1079-
m.flexedge_J_colind.numpy(),
1080-
)
1081-
else:
1082-
result.flexedge_J[:] = d.flexedge_J.numpy()[world_id].reshape(-1)
1060+
result.flexedge_J[:] = d.flexedge_J.numpy()[world_id].reshape(-1)
10831061
result.flexedge_length[:] = d.flexedge_length.numpy()[world_id]
10841062
result.flexedge_velocity[:] = d.flexedge_velocity.numpy()[world_id]
10851063
result.actuator_length[:] = d.actuator_length.numpy()[world_id]
@@ -2312,20 +2290,12 @@ def create_render_context(
23122290
if render_rgb and isinstance(render_rgb, bool):
23132291
render_rgb = [render_rgb] * ncam
23142292
elif render_rgb is None:
2315-
# TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
2316-
if BLEEDING_EDGE_MUJOCO:
2317-
render_rgb = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_RGB for i in active_cam_indices]
2318-
else:
2319-
render_rgb = [True] * ncam
2293+
render_rgb = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_RGB for i in active_cam_indices]
23202294

23212295
if render_depth and isinstance(render_depth, bool):
23222296
render_depth = [render_depth] * ncam
23232297
elif render_depth is None:
2324-
# TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
2325-
if BLEEDING_EDGE_MUJOCO:
2326-
render_depth = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_DEPTH for i in active_cam_indices]
2327-
else:
2328-
render_depth = [True] * ncam
2298+
render_depth = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_DEPTH for i in active_cam_indices]
23292299

23302300
assert len(render_rgb) == ncam and len(render_depth) == ncam, (
23312301
f"Render RGB and depth must be provided for all active cameras (got {len(render_rgb)}, {len(render_depth)}, expected {ncam})"
@@ -2352,10 +2322,7 @@ def create_render_context(
23522322

23532323
ray = wp.zeros(int(total), dtype=wp.vec3)
23542324

2355-
# TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
2356-
cam_projection = np.zeros(mjm.ncam, dtype=int)
2357-
if BLEEDING_EDGE_MUJOCO:
2358-
cam_projection = mjm.cam_projection
2325+
cam_projection = mjm.cam_projection
23592326

23602327
offset = 0
23612328
for idx, cam_id in enumerate(active_cam_indices):

mujoco_warp/_src/io_test.py

Lines changed: 5 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -306,44 +306,11 @@ def test_get_data_into_io_test_models(self, xml):
306306

307307
# flexedge_J
308308
if xml == "flex/floppy.xml":
309-
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
310-
311-
if BLEEDING_EDGE_MUJOCO:
312-
_assert_eq(
313-
mjd_result.flexedge_J.reshape(-1),
314-
d.flexedge_J.numpy()[0].reshape(-1),
315-
"flexedge_J",
316-
)
317-
else:
318-
m = mjwarp.put_model(mjm)
319-
_assert_eq(
320-
mjd_result.flexedge_J_rownnz,
321-
m.flexedge_J_rownnz.numpy(),
322-
"flexedge_J_rownnz",
323-
)
324-
_assert_eq(
325-
mjd_result.flexedge_J_rowadr,
326-
m.flexedge_J_rowadr.numpy(),
327-
"flexedge_J_rowadr",
328-
)
329-
_assert_eq(
330-
mjd_result.flexedge_J_colind,
331-
m.flexedge_J_colind.numpy().reshape((mjm.nflexedge, mjm.nv)),
332-
"flexedge_J_colind",
333-
)
334-
flexedge_J = np.zeros((mjm.nflexedge, mjm.nv))
335-
mujoco.mju_sparse2dense(
336-
flexedge_J,
337-
d.flexedge_J.numpy().reshape(-1),
338-
m.flexedge_J_rownnz.numpy(),
339-
m.flexedge_J_rowadr.numpy(),
340-
m.flexedge_J_colind.numpy().reshape(-1),
341-
)
342-
_assert_eq(
343-
mjd_result.flexedge_J,
344-
flexedge_J,
345-
"flexedge_J",
346-
)
309+
_assert_eq(
310+
mjd_result.flexedge_J.reshape(-1),
311+
d.flexedge_J.numpy()[0].reshape(-1),
312+
"flexedge_J",
313+
)
347314

348315
def test_ellipsoid_fluid_model(self):
349316
mjm = mujoco.MjModel.from_xml_string(
@@ -1265,13 +1232,6 @@ def test_bvh_creation(self, nworld):
12651232

12661233
def test_output_buffers(self):
12671234
"""Test that the output rgb and depth buffers have correct shapes and addresses."""
1268-
# TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
1269-
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
1270-
1271-
if not BLEEDING_EDGE_MUJOCO:
1272-
self.skipTest("Skipping test that requires mujoco >= 3.4.1")
1273-
return
1274-
12751235
mjm, mjd, m, d = test_data.fixture(xml=_CAMERA_TEST_XML)
12761236
width, height = 32, 24
12771237
rc = mjwarp.create_render_context(mjm, cam_res=(width, height), render_rgb=True, render_depth=True)
@@ -1288,13 +1248,6 @@ def test_output_buffers(self):
12881248
_assert_eq(depth_adr, [0, width * height, 2 * width * height], "depth_adr")
12891249

12901250
def test_heterogeneous_camera(self):
1291-
# TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
1292-
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
1293-
1294-
if not BLEEDING_EDGE_MUJOCO:
1295-
self.skipTest("Skipping test that requires mujoco >= 3.4.1")
1296-
return
1297-
12981251
"""Tests render context with different resolutions and output."""
12991252
mjm, mjd, m, d = test_data.fixture(xml=_CAMERA_TEST_XML)
13001253
cam_res = [(64, 64), (32, 32), (16, 16)]
@@ -1320,13 +1273,6 @@ def test_heterogeneous_camera(self):
13201273
_assert_eq(rc.depth_adr.numpy(), rc_xml.depth_adr.numpy(), "depth_adr")
13211274

13221275
def test_cam_active_filtering(self):
1323-
# TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
1324-
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
1325-
1326-
if not BLEEDING_EDGE_MUJOCO:
1327-
self.skipTest("Skipping test that requires mujoco >= 3.4.1")
1328-
return
1329-
13301276
mjm, mjd, m, d = test_data.fixture(xml=_CAMERA_TEST_XML)
13311277
width, height = 32, 32
13321278

@@ -1339,13 +1285,6 @@ def test_cam_active_filtering(self):
13391285

13401286
def test_rgb_only_and_depth_only(self):
13411287
"""Test that disabling rgb or depth correctly reduces the shape and invalidates the address."""
1342-
# TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
1343-
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
1344-
1345-
if not BLEEDING_EDGE_MUJOCO:
1346-
self.skipTest("Skipping test that requires mujoco >= 3.4.1")
1347-
return
1348-
13491288
mjm, mjd, m, d = test_data.fixture(xml=_CAMERA_TEST_XML)
13501289
width, height = 32, 32
13511290
pixels = width * height

mujoco_warp/_src/smooth_test.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -491,17 +491,9 @@ def test_flex(self):
491491
mujoco.mj_comPos(mjm, mjd)
492492
mujoco.mj_flex(mjm, mjd)
493493

494-
# TODO(team): remove after mjwarp depends on mujoco > 3.4.0 in pyproject.toml
495-
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
496-
497-
if BLEEDING_EDGE_MUJOCO:
498-
rownnz = mjm.flexedge_J_rownnz
499-
rowadr = mjm.flexedge_J_rowadr
500-
colind = mjm.flexedge_J_colind.reshape(-1)
501-
else:
502-
rownnz = mjd.flexedge_J_rownnz
503-
rowadr = mjd.flexedge_J_rowadr
504-
colind = mjd.flexedge_J_colind.reshape(-1)
494+
rownnz = mjm.flexedge_J_rownnz
495+
rowadr = mjm.flexedge_J_rowadr
496+
colind = mjm.flexedge_J_colind.reshape(-1)
505497

506498
mj_flexedge_J = np.zeros((mjm.nflexedge, mjm.nv), dtype=float)
507499
mujoco.mju_sparse2dense(mj_flexedge_J, mjd.flexedge_J.ravel(), rownnz, rowadr, colind)

mujoco_warp/_src/types_test.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,32 +55,6 @@ def test_field_order(self, mj_class, mjw_class):
5555

5656
actual_fields = [f.name for f in dataclasses.fields(mjw_class)]
5757

58-
# TODO(team): remove after mjwarp depends on mujoco > 3.4.0 in pyproject.toml
59-
from mujoco_warp._src.io import BLEEDING_EDGE_MUJOCO
60-
61-
_FLEXEDGE_J_FIELDS = ("flexedge_J_rownnz", "flexedge_J_rowadr", "flexedge_J_colind")
62-
63-
def _remove_fields(fields, to_remove):
64-
for field in to_remove:
65-
if field in fields:
66-
fields.remove(field)
67-
68-
if not BLEEDING_EDGE_MUJOCO:
69-
_remove_fields(actual_fields, _FLEXEDGE_J_FIELDS)
70-
_remove_fields(desired_fields, _FLEXEDGE_J_FIELDS)
71-
_remove_fields(
72-
actual_fields,
73-
[
74-
"cam_projection",
75-
],
76-
)
77-
_remove_fields(
78-
desired_fields,
79-
[
80-
"cam_projection",
81-
],
82-
)
83-
8458
self.assertListEqual(actual_fields, desired_fields)
8559

8660
@parameterized.parameters(Option, Model, Data)

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ requires-python = ">=3.10"
2727
dependencies = [
2828
"absl-py",
2929
"etils[epath]",
30-
"mujoco>=3.4.0",
30+
"mujoco>=3.5.0",
3131
"numpy",
3232
"warp-lang>=1.11.0",
3333
]
@@ -54,7 +54,7 @@ dev = [
5454
"ruff",
5555
"pygls>=1.0.0,<2.0.0",
5656
"lsprotocol>=2023.0.1,<2024.0.0",
57-
"mujoco>=3.4.1.dev0",
57+
"mujoco>=3.5.0.dev0",
5858
"warp-lang>=1.11.0.dev0",
5959
]
6060
# TODO(team): cpu and cuda JAX optional dependencies are temporary, remove after we land MJX:Warp

0 commit comments

Comments
 (0)