Skip to content

Commit a047305

Browse files
DeepMindcopybara-github
authored andcommitted
Update environments to utilize contact sensor.
PiperOrigin-RevId: 795391597 Change-Id: I877b54f441b00712f5280dad5a31b38a87fdfff8
1 parent d886c80 commit a047305

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+549
-242
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ All notable changes to this project will be documented in this file.
1010
support MjWarp. You can pass through the implementation via the config
1111
override
1212
`registry.load('CartpoleBalance', config_overrides={'impl': 'warp'})`.
13+
- Update environments to utilize contact sensors and remove `collision.py`.
1314

1415
## [0.0.5] - 2025-06-23
1516

mujoco_playground/_src/collision.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

mujoco_playground/_src/dm_control_suite/finger.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
112112
qpos = qpos.at[2].set(jax.random.uniform(rng1, minval=-jp.pi, maxval=jp.pi))
113113

114114
data = mjx_env.make_data(
115-
self.mj_model,
115+
self._mj_model,
116116
qpos=qpos,
117-
impl=self.mjx_model.impl.value,
117+
impl=self._mjx_model.impl.value,
118118
nconmax=self._config.nconmax,
119119
njmax=self._config.njmax,
120120
)
@@ -229,7 +229,7 @@ def __init__(
229229
_XML_PATH, target_radius, self._model_assets
230230
)
231231
self._mj_model.opt.timestep = self.sim_dt
232-
self._mjx_model = mjx.put_model(self._mj_model)
232+
self._mjx_model = mjx.put_model(self._mj_model, impl=self._config.impl)
233233
self._post_init()
234234

235235
def _post_init(self) -> None:
@@ -249,7 +249,7 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
249249
)
250250
qpos = qpos.at[2].set(jax.random.uniform(rng1, minval=-jp.pi, maxval=jp.pi))
251251

252-
data = mjx_env.init(self.mjx_model, qpos)
252+
data = mjx_env.make_data(self._mj_model, qpos=qpos, impl=self._config.impl)
253253

254254
target_angle = jax.random.uniform(rng2, minval=-jp.pi, maxval=jp.pi)
255255
hinge_x = data.xanchor[self._hinge_joint_id, 0]

mujoco_playground/_src/locomotion/apollo/base.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import mujoco
2424
from mujoco import mjx
2525
from mujoco_playground._src import mjx_env
26-
from mujoco_playground._src.collision import geoms_colliding
2726
from mujoco_playground._src.locomotion.apollo import constants as consts
2827
import numpy as np
2928

@@ -105,6 +104,31 @@ def __init__(
105104
[self._mj_model.site(name).id for name in consts.FEET_SITES]
106105
)
107106

107+
# Contact sensor IDs.
108+
self._left_feet_floor_found_sensor = [
109+
self._mj_model.sensor(foot_geom + "_floor_found").id
110+
for foot_geom in consts.LEFT_FEET_GEOMS
111+
]
112+
self._right_feet_floor_found_sensor = [
113+
self._mj_model.sensor(foot_geom + "_floor_found").id
114+
for foot_geom in consts.RIGHT_FEET_GEOMS
115+
]
116+
self._left_hand_left_thigh_found_sensor = self._mj_model.sensor(
117+
"collision_l_hand_plate_collision_capsule_body_l_thigh_found"
118+
).id
119+
self._right_hand_right_thigh_found_sensor = self._mj_model.sensor(
120+
"collision_r_hand_plate_collision_capsule_body_r_thigh_found"
121+
).id
122+
self._left_foot_right_foot_found_sensor = self._mj_model.sensor(
123+
"collision_l_sole_collision_r_sole_found"
124+
).id
125+
self._left_shin_right_shin_found_sensor = self._mj_model.sensor(
126+
"collision_capsule_body_l_shin_collision_capsule_body_r_shin_found"
127+
).id
128+
self._left_thigh_right_thigh_found_sensor = self._mj_model.sensor(
129+
"collision_capsule_body_l_thigh_collision_capsule_body_r_thigh_found"
130+
).id
131+
108132
# Sensor readings.
109133

110134
def get_gravity(self, data: mjx.Data) -> jax.Array:
@@ -144,12 +168,18 @@ def get_gyro(self, data: mjx.Data) -> jax.Array:
144168
def get_feet_ground_contacts(self, data: mjx.Data) -> jax.Array:
145169
"""Return an array indicating whether each foot is in contact with the ground."""
146170
left_feet_contact = jp.array([
147-
geoms_colliding(data, geom_id, self._floor_geom_id)
148-
for geom_id in self._left_feet_geom_id
171+
data.sensordata[
172+
self._mj_model.sensor_adr[self._mj_model.sensor_adr[sensorid]]
173+
]
174+
> 0
175+
for sensorid in self._left_feet_floor_found_sensor
149176
])
150177
right_feet_contact = jp.array([
151-
geoms_colliding(data, geom_id, self._floor_geom_id)
152-
for geom_id in self._right_feet_geom_id
178+
data.sensordata[
179+
self._mj_model.sensor_adr[self._mj_model.sensor_adr[sensorid]]
180+
]
181+
> 0
182+
for sensorid in self._right_feet_floor_found_sensor
153183
])
154184
return jp.hstack([jp.any(left_feet_contact), jp.any(right_feet_contact)])
155185

mujoco_playground/_src/locomotion/apollo/joystick.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from mujoco import mjx
2323
from mujoco.mjx._src import math
2424
from mujoco_playground._src import gait, mjx_env
25-
from mujoco_playground._src.collision import geoms_colliding
2625
from mujoco_playground._src.locomotion.apollo import base
2726
from mujoco_playground._src.locomotion.apollo import constants as consts
2827

@@ -377,27 +376,17 @@ def _cost_action_rate(self, act: jax.Array, last_act: jax.Array) -> jax.Array:
377376
return jp.sum(jp.square(act - last_act))
378377

379378
def _cost_collision(self, data: mjx.Data) -> jax.Array:
379+
adr = self._mj_model.sensor_adr
380380
# Hand - thigh.
381-
c = geoms_colliding(data, self._left_hand_geom_id, self._left_thigh_geom_id)
382-
c |= geoms_colliding(
383-
data, self._right_hand_geom_id, self._right_thigh_geom_id
384-
)
381+
c = data.sensordata[adr[self._left_hand_left_thigh_found_sensor]] > 0
382+
c |= data.sensordata[adr[self._right_hand_right_thigh_found_sensor]] > 0
385383
# Foot - foot.
386-
c |= geoms_colliding(
387-
data, self._left_foot_geom_id, self._right_foot_geom_id
388-
)
384+
c |= data.sensordata[adr[self._left_foot_right_foot_found_sensor]] > 0
389385
# Shin - shin.
390-
c |= geoms_colliding(
391-
data,
392-
self._left_shin_geom_id,
393-
self._right_shin_geom_id,
394-
)
386+
c |= data.sensordata[adr[self._left_shin_right_shin_found_sensor]] > 0
395387
# Thigh - thigh.
396-
c |= geoms_colliding(
397-
data,
398-
self._left_thigh_geom_id,
399-
self._right_thigh_geom_id,
400-
)
388+
c |= data.sensordata[adr[self._left_thigh_right_thigh_found_sensor]] > 0
389+
401390
return jp.any(c)
402391

403392
def _cost_pose(self, qpos: jax.Array, commands: jax.Array) -> jax.Array:

mujoco_playground/_src/locomotion/apollo/xmls/scene_mjx_feetonly_flat_terrain.xml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@
2323
<geom name="floor" size="0 0 0.01" type="plane" material="groundplane" priority="1" friction="0.6" condim="3"/>
2424
</worldbody>
2525

26+
<sensor>
27+
<contact name="collision_l_sole_floor_found" geom1="collision_l_sole" geom2="floor" reduce="mindist" num="1" data="found"/>
28+
<contact name="collision_r_sole_floor_found" geom1="collision_r_sole" geom2="floor" reduce="mindist" num="1" data="found"/>
29+
<contact name="collision_l_hand_plate_collision_capsule_body_l_thigh_found" geom1="collision_l_hand_plate" geom2="collision_capsule_body_l_thigh" reduce="mindist" num="1" data="found"/>
30+
<contact name="collision_r_hand_plate_collision_capsule_body_r_thigh_found" geom1="collision_r_hand_plate" geom2="collision_capsule_body_r_thigh" reduce="mindist" num="1" data="found"/>
31+
<contact name="collision_l_sole_collision_r_sole_found" geom1="collision_l_sole" geom2="collision_r_sole" reduce="mindist" num="1" data="found"/>
32+
<contact name="collision_capsule_body_l_shin_collision_capsule_body_r_shin_found" geom1="collision_capsule_body_l_shin" geom2="collision_capsule_body_r_shin" reduce="mindist" num="1" data="found"/>
33+
<contact name="collision_capsule_body_l_thigh_collision_capsule_body_r_thigh_found" geom1="collision_capsule_body_l_thigh" geom2="collision_capsule_body_r_thigh" reduce="mindist" num="1" data="found"/>
34+
</sensor>
35+
2636
<keyframe>
2737
<key name="stand"
2838
qpos="

mujoco_playground/_src/locomotion/barkour/joystick.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from mujoco import mjx
2525
import numpy as np
2626

27-
from mujoco_playground._src import collision
2827
from mujoco_playground._src import mjx_env
2928

3029
_FEET_SITES = [
@@ -119,9 +118,30 @@ def __init__(
119118
super().__init__(config, config_overrides)
120119
xml_path = mjx_env.MENAGERIE_PATH / "google_barkour_vb" / "scene_mjx.xml"
121120
self._xml_path = xml_path.as_posix()
122-
xml = epath.Path(xml_path).read_text()
123121
self._model_assets = get_assets()
124-
mj_model = mujoco.MjModel.from_xml_string(xml, assets=self._model_assets)
122+
mj_spec = mujoco.MjSpec.from_file(
123+
xml_path.as_posix(), assets=self._model_assets
124+
)
125+
# add contact sensors
126+
feet_floor_found_sensor = []
127+
for geom in _FEET_GEOMS:
128+
name = f"{geom}_floor_found"
129+
mj_spec.add_sensor(
130+
name=name,
131+
type=mujoco.mjtSensor.mjSENS_CONTACT,
132+
objtype=mujoco.mjtObj.mjOBJ_GEOM,
133+
objname=geom,
134+
reftype=mujoco.mjtObj.mjOBJ_GEOM,
135+
refname="floor",
136+
intprm=[1, 1, 1], # data=found, reduce=mindist
137+
datatype=mujoco.mjtDataType.mjDATATYPE_REAL,
138+
needstage=mujoco.mjtStage.mjSTAGE_ACC,
139+
dim=1,
140+
)
141+
feet_floor_found_sensor.append(name)
142+
self._feet_floor_found_sensor = feet_floor_found_sensor
143+
# compile spec
144+
mj_model = mj_spec.compile()
125145
mj_model.vis.global_.offwidth = 3840
126146
mj_model.vis.global_.offheight = 2160
127147
mj_model.dof_damping[6:] = 0.5239
@@ -239,8 +259,11 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
239259
torso_z = data.xpos[self._torso_body_id, -1]
240260

241261
contact = jp.array([
242-
collision.geoms_colliding(data, geom_id, self._floor_geom_id)
243-
for geom_id in self._feet_geom_id
262+
data.sensordata[
263+
self._mj_model.sensor_adr[self._mj_model.sensor(sensor).id]
264+
]
265+
> 0
266+
for sensor in self._feet_floor_found_sensor
244267
])
245268
contact_filt = contact | state.info["last_contact"]
246269
first_contact = (state.info["feet_air_time"] > 0.0) * contact_filt

mujoco_playground/_src/locomotion/berkeley_humanoid/joystick.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from mujoco_playground._src import gait
2727
from mujoco_playground._src import mjx_env
28-
from mujoco_playground._src.collision import geoms_colliding
2928
from mujoco_playground._src.locomotion.berkeley_humanoid import base as berkeley_humanoid_base
3029
from mujoco_playground._src.locomotion.berkeley_humanoid import berkeley_humanoid_constants as consts
3130

@@ -184,6 +183,12 @@ def _post_init(self) -> None:
184183
qpos_noise_scale[faa_ids] = self._config.noise_config.scales.faa_pos
185184
self._qpos_noise_scale = jp.array(qpos_noise_scale)
186185

186+
# Contact sensor IDs.
187+
self._feet_floor_found_sensor = [
188+
self._mj_model.sensor(f"{geom}_floor_found").id
189+
for geom in consts.FEET_GEOMS
190+
]
191+
187192
def reset(self, rng: jax.Array) -> mjx_env.State:
188193
qpos = self._init_q
189194
qvel = jp.zeros(self.mjx_model.nv)
@@ -264,9 +269,10 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
264269
metrics["swing_peak"] = jp.zeros(())
265270

266271
contact = jp.array([
267-
geoms_colliding(data, geom_id, self._floor_geom_id)
268-
for geom_id in self._feet_geom_id
272+
data.sensordata[self._mj_model.sensor_adr[sensor_id]] > 0
273+
for sensor_id in self._feet_floor_found_sensor
269274
])
275+
270276
obs = self._get_obs(data, info, contact)
271277
reward, done = jp.zeros(2)
272278
return mjx_env.State(data, obs, reward, done, metrics, info)
@@ -299,8 +305,8 @@ def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State:
299305
state.info["motor_targets"] = motor_targets
300306

301307
contact = jp.array([
302-
geoms_colliding(data, geom_id, self._floor_geom_id)
303-
for geom_id in self._feet_geom_id
308+
data.sensordata[self._mj_model.sensor_adr[sensor_id]] > 0
309+
for sensor_id in self._feet_floor_found_sensor
304310
])
305311
contact_filt = contact | state.info["last_contact"]
306312
first_contact = (state.info["feet_air_time"] > 0.0) * contact_filt

mujoco_playground/_src/locomotion/berkeley_humanoid/xmls/scene_mjx_feetonly_flat_terrain.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
friction="0.6" condim="3"/>
2525
</worldbody>
2626

27+
<include file="sensor.xml"/>
28+
2729
<keyframe>
2830
<key name="home"
2931
qpos="

mujoco_playground/_src/locomotion/berkeley_humanoid/xmls/scene_mjx_feetonly_rough_terrain.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
friction="1.0"/>
2424
</worldbody>
2525

26+
<include file="sensor.xml"/>
27+
2628
<keyframe>
2729
<key name="home"
2830
qpos="

0 commit comments

Comments
 (0)