3030
3131
3232def _is_mujoco_fresh () -> bool :
33- """Checks if mujoco version is > 3.4 .0."""
33+ """Checks if mujoco version is > 3.5 .0."""
3434 version = importlib .metadata .version ("mujoco" )
3535 version = version .split ("." )
3636 version = tuple (map (int , version [:3 ])) + tuple (version [3 :])
37- return version > (3 , 4 , 0 )
37+ return version > (3 , 5 , 0 )
3838
3939
4040BLEEDING_EDGE_MUJOCO = _is_mujoco_fresh ()
@@ -583,17 +583,9 @@ def geom_trid_index(i, j):
583583 m .qM_mulm_madr .append (madr )
584584 m .qM_mulm_rowadr .append (len (m .qM_mulm_col ))
585585
586- # TODO(team): remove after mjwarp depends on mujoco > 3.4.0 in pyproject.toml
587- if BLEEDING_EDGE_MUJOCO :
588- m .flexedge_J_rownnz = mjm .flexedge_J_rownnz
589- m .flexedge_J_rowadr = mjm .flexedge_J_rowadr
590- m .flexedge_J_colind = mjm .flexedge_J_colind .reshape (- 1 )
591- else :
592- mjd = mujoco .MjData (mjm )
593- mujoco .mj_forward (mjm , mjd )
594- m .flexedge_J_rownnz = mjd .flexedge_J_rownnz
595- m .flexedge_J_rowadr = mjd .flexedge_J_rowadr
596- m .flexedge_J_colind = mjd .flexedge_J_colind .reshape (- 1 )
586+ m .flexedge_J_rownnz = mjm .flexedge_J_rownnz
587+ m .flexedge_J_rowadr = mjm .flexedge_J_rowadr
588+ m .flexedge_J_colind = mjm .flexedge_J_colind .reshape (- 1 )
597589
598590 # place m on device
599591 sizes = dict ({"*" : 1 }, ** {f .name : getattr (m , f .name ) for f in dataclasses .fields (types .Model ) if f .type is int })
@@ -1065,21 +1057,7 @@ def get_data_into(
10651057 result .cinert [:] = d .cinert .numpy ()[world_id ]
10661058 result .flexvert_xpos [:] = d .flexvert_xpos .numpy ()[world_id ]
10671059 if mjm .nflexedge > 0 :
1068- # TODO(team): remove after mjwarp depends on mujoco > 3.4.0 in pyproject.toml
1069- if not BLEEDING_EDGE_MUJOCO :
1070- m = put_model (mjm )
1071- result .flexedge_J_rownnz [:] = m .flexedge_J_rownnz .numpy ()
1072- result .flexedge_J_rowadr [:] = m .flexedge_J_rowadr .numpy ()
1073- result .flexedge_J_colind [:, :] = m .flexedge_J_colind .numpy ().reshape ((mjm .nflexedge , mjm .nv ))
1074- mujoco .mju_sparse2dense (
1075- result .flexedge_J ,
1076- d .flexedge_J .numpy ()[world_id ].reshape (- 1 ),
1077- m .flexedge_J_rownnz .numpy (),
1078- m .flexedge_J_rowadr .numpy (),
1079- m .flexedge_J_colind .numpy (),
1080- )
1081- else :
1082- result .flexedge_J [:] = d .flexedge_J .numpy ()[world_id ].reshape (- 1 )
1060+ result .flexedge_J [:] = d .flexedge_J .numpy ()[world_id ].reshape (- 1 )
10831061 result .flexedge_length [:] = d .flexedge_length .numpy ()[world_id ]
10841062 result .flexedge_velocity [:] = d .flexedge_velocity .numpy ()[world_id ]
10851063 result .actuator_length [:] = d .actuator_length .numpy ()[world_id ]
@@ -2312,20 +2290,12 @@ def create_render_context(
23122290 if render_rgb and isinstance (render_rgb , bool ):
23132291 render_rgb = [render_rgb ] * ncam
23142292 elif render_rgb is None :
2315- # TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
2316- if BLEEDING_EDGE_MUJOCO :
2317- render_rgb = [mjm .cam_output [i ] & mujoco .mjtCamOutBit .mjCAMOUT_RGB for i in active_cam_indices ]
2318- else :
2319- render_rgb = [True ] * ncam
2293+ render_rgb = [mjm .cam_output [i ] & mujoco .mjtCamOutBit .mjCAMOUT_RGB for i in active_cam_indices ]
23202294
23212295 if render_depth and isinstance (render_depth , bool ):
23222296 render_depth = [render_depth ] * ncam
23232297 elif render_depth is None :
2324- # TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
2325- if BLEEDING_EDGE_MUJOCO :
2326- render_depth = [mjm .cam_output [i ] & mujoco .mjtCamOutBit .mjCAMOUT_DEPTH for i in active_cam_indices ]
2327- else :
2328- render_depth = [True ] * ncam
2298+ render_depth = [mjm .cam_output [i ] & mujoco .mjtCamOutBit .mjCAMOUT_DEPTH for i in active_cam_indices ]
23292299
23302300 assert len (render_rgb ) == ncam and len (render_depth ) == ncam , (
23312301 f"Render RGB and depth must be provided for all active cameras (got { len (render_rgb )} , { len (render_depth )} , expected { ncam } )"
@@ -2352,10 +2322,7 @@ def create_render_context(
23522322
23532323 ray = wp .zeros (int (total ), dtype = wp .vec3 )
23542324
2355- # TODO: remove after mjwarp depends on mujoco >= 3.4.1 in pyproject.toml
2356- cam_projection = np .zeros (mjm .ncam , dtype = int )
2357- if BLEEDING_EDGE_MUJOCO :
2358- cam_projection = mjm .cam_projection
2325+ cam_projection = mjm .cam_projection
23592326
23602327 offset = 0
23612328 for idx , cam_id in enumerate (active_cam_indices ):
0 commit comments