Skip to content

Commit 0716a13

Browse files
Merge pull request #211 from kywch:warp-rsl
PiperOrigin-RevId: 854981803 Change-Id: Ib5124bcb7271df69d106ebb4566d1c6e48affb15
2 parents 4df68e7 + 943a1ec commit 0716a13

File tree

3 files changed

+34
-6
lines changed

3 files changed

+34
-6
lines changed

learning/train_jax_ppo.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@
7272
f"Name of the environment. One of {', '.join(registry.ALL_ENVS)}",
7373
)
7474
_IMPL = flags.DEFINE_enum("impl", "jax", ["jax", "warp"], "MJX implementation")
75+
_PLAYGROUND_CONFIG_OVERRIDES = flags.DEFINE_string(
76+
"playground_config_overrides",
77+
None,
78+
"Overrides for the playground env config.",
79+
)
7580
_VISION = flags.DEFINE_boolean("vision", False, "Use vision input")
7681
_LOAD_CHECKPOINT_PATH = flags.DEFINE_string(
7782
"load_checkpoint_path", None, "Path to load checkpoint from"
@@ -264,7 +269,12 @@ def main(argv):
264269
if _VISION.value:
265270
env_cfg.vision = True
266271
env_cfg.vision_config.render_batch_size = ppo_params.num_envs
267-
env = registry.load(_ENV_NAME.value, config=env_cfg)
272+
env_cfg_overrides = {}
273+
if _PLAYGROUND_CONFIG_OVERRIDES.value is not None:
274+
env_cfg_overrides = json.loads(_PLAYGROUND_CONFIG_OVERRIDES.value)
275+
env = registry.load(
276+
_ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides
277+
)
268278
if _RUN_EVALS.present:
269279
ppo_params.run_evals = _RUN_EVALS.value
270280
if _LOG_TRAINING_METRICS.present:
@@ -273,6 +283,8 @@ def main(argv):
273283
ppo_params.training_metrics_steps = _TRAINING_METRICS_STEPS.value
274284

275285
print(f"Environment Config:\n{env_cfg}")
286+
if env_cfg_overrides:
287+
print(f"Environment Config Overrides:\n{env_cfg_overrides}\n")
276288
print(f"PPO Training Parameters:\n{ppo_params}")
277289

278290
# Generate unique experiment name
@@ -408,7 +420,9 @@ def progress(num_steps, metrics):
408420
# Load evaluation environment.
409421
eval_env = None
410422
if not _VISION.value:
411-
eval_env = registry.load(_ENV_NAME.value, config=env_cfg)
423+
eval_env = registry.load(
424+
_ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides
425+
)
412426
num_envs = 1
413427
if _VISION.value:
414428
num_envs = env_cfg.vision_config.render_batch_size
@@ -419,7 +433,9 @@ def progress(num_steps, metrics):
419433
from rscope import brax as rscope_utils
420434

421435
if not _VISION.value:
422-
rscope_env = registry.load(_ENV_NAME.value, config=env_cfg)
436+
rscope_env = registry.load(
437+
_ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides
438+
)
423439
rscope_env = wrapper.wrap_for_brax_training(
424440
rscope_env,
425441
episode_length=ppo_params.episode_length,

learning/train_rsl_rl.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@
5757
f"{', '.join(mujoco_playground.registry.ALL_ENVS)}"
5858
),
5959
)
60+
_IMPL = flags.DEFINE_enum("impl", "jax", ["jax", "warp"], "MJX implementation")
61+
_PLAYGROUND_CONFIG_OVERRIDES = flags.DEFINE_string(
62+
"playground_config_overrides",
63+
None,
64+
"Overrides for the playground env config.",
65+
)
6066
_LOAD_RUN_NAME = flags.DEFINE_string(
6167
"load_run_name", None, "Run name to load from (for checkpoint restoration)."
6268
)
@@ -118,8 +124,14 @@ def main(argv):
118124

119125
# Load default config from registry
120126
env_cfg = registry.get_default_config(_ENV_NAME.value)
127+
env_cfg.impl = _IMPL.value
121128
print(f"Environment config:\n{env_cfg}")
122129

130+
env_cfg_overrides = {}
131+
if _PLAYGROUND_CONFIG_OVERRIDES.value is not None:
132+
env_cfg_overrides = json.loads(_PLAYGROUND_CONFIG_OVERRIDES.value)
133+
print(f"Environment config overrides:\n{env_cfg_overrides}\n")
134+
123135
# Generate unique experiment name
124136
now = datetime.now()
125137
timestamp = now.strftime("%Y%m%d-%H%M%S")
@@ -163,7 +175,7 @@ def render_callback(_, state):
163175

164176
# Create the environment
165177
raw_env = registry.load(
166-
_ENV_NAME.value, config=env_cfg, config_overrides={"impl": "jax"}
178+
_ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides
167179
)
168180
brax_env = wrapper_torch.RSLRLBraxWrapper(
169181
raw_env,
@@ -219,7 +231,7 @@ def render_callback(_, state):
219231

220232
# Example: run a single rollout
221233
eval_env = registry.load(
222-
_ENV_NAME.value, config=env_cfg, config_overrides={"impl": "jax"}
234+
_ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides
223235
)
224236
jit_reset = jax.jit(eval_env.reset)
225237
jit_step = jax.jit(eval_env.step)

mujoco_playground/_src/manipulation/leap_hand/reorient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,4 +628,4 @@ def rand(rng):
628628
"actuator_biasprm": actuator_biasprm,
629629
})
630630

631-
return model, in_axes
631+
return model, in_axes

0 commit comments

Comments
 (0)