Skip to content

Commit 6983041

Browse files
Observations as TensorDicts
Approved-by: Mayank Mittal
1 parent 491ca91 commit 6983041

25 files changed

+1394
-735
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ The package supports the following logging frameworks which can be configured th
4646
* Weights & Biases: https://wandb.ai/site
4747
* Neptune: https://docs.neptune.ai/
4848

49-
For a demo configuration of PPO, please check the [dummy_config.yaml](config/dummy_config.yaml) file.
49+
For a demo configuration of PPO, please check the [example_config.yaml](config/example_config.yaml) file.
5050

5151

5252
## Contribution Guidelines

config/dummy_config.yaml

Lines changed: 0 additions & 95 deletions
This file was deleted.

config/example_config.yaml

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
runner:
2+
class_name: OnPolicyRunner
3+
# -- general
4+
num_steps_per_env: 24 # number of steps per environment per iteration
5+
max_iterations: 1500 # number of policy updates
6+
seed: 1
7+
# -- observations
8+
obs_groups: {"policy": ["policy"], "critic": ["policy", "privileged"]} # maps observation groups to types. See `vec_env.py` for more information
9+
# -- logging parameters
10+
save_interval: 50 # check for potential saves every `save_interval` iterations
11+
experiment_name: walking_experiment
12+
run_name: ""
13+
# -- logging writer
14+
logger: tensorboard # tensorboard, neptune, wandb
15+
neptune_project: legged_gym
16+
wandb_project: legged_gym
17+
18+
# -- policy
19+
policy:
20+
class_name: ActorCritic
21+
activation: elu
22+
actor_obs_normalization: false
23+
critic_obs_normalization: false
24+
actor_hidden_dims: [256, 256, 256]
25+
critic_hidden_dims: [256, 256, 256]
26+
init_noise_std: 1.0
27+
noise_std_type: "scalar" # 'scalar' or 'log'
28+
29+
# -- algorithm
30+
algorithm:
31+
class_name: PPO
32+
# -- training
33+
learning_rate: 0.001
34+
num_learning_epochs: 5
35+
num_mini_batches: 4 # mini batch size = num_envs * num_steps / num_mini_batches
36+
schedule: adaptive # adaptive, fixed
37+
# -- value function
38+
value_loss_coef: 1.0
39+
clip_param: 0.2
40+
use_clipped_value_loss: true
41+
# -- surrogate loss
42+
desired_kl: 0.01
43+
entropy_coef: 0.01
44+
gamma: 0.99
45+
lam: 0.95
46+
max_grad_norm: 1.0
47+
# -- miscellaneous
48+
normalize_advantage_per_mini_batch: false
49+
50+
# -- random network distillation
51+
rnd_cfg:
52+
weight: 0.0 # initial weight of the RND reward
53+
weight_schedule: null # note: this is a dictionary with a required key called "mode". Please check the RND module for more information
54+
reward_normalization: false # whether to normalize RND reward
55+
# -- learning parameters
56+
learning_rate: 0.001 # learning rate for RND
57+
# -- network parameters
58+
num_outputs: 1 # number of outputs of RND network. Note: if -1, then the network will use dimensions of the observation
59+
predictor_hidden_dims: [-1] # hidden dimensions of predictor network
60+
target_hidden_dims: [-1] # hidden dimensions of target network
61+
62+
# -- symmetry augmentation
63+
symmetry_cfg:
64+
use_data_augmentation: true # this adds symmetric trajectories to the batch
65+
use_mirror_loss: false # this adds symmetry loss term to the loss function
66+
data_augmentation_func: null # string containing the module and function name to import
67+
# Example: "legged_gym.envs.locomotion.anymal_c.symmetry:get_symmetric_states"
68+
#
69+
# .. code-block:: python
70+
#
71+
# @torch.no_grad()
72+
# def get_symmetric_states(
73+
# obs: Optional[torch.Tensor] = None, actions: Optional[torch.Tensor] = None, cfg: "BaseEnvCfg" = None, obs_type: str = "policy"
74+
# ) -> Tuple[torch.Tensor, torch.Tensor]:
75+
#
76+
mirror_loss_coeff: 0.0 #coefficient for symmetry loss term. If 0, no symmetry loss is used

licenses/dependencies/tensordict.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) Meta Platforms, Inc. and affiliates.
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "rsl-rl-lib"
7-
version = "2.3.3"
7+
version = "3.0.0"
88
keywords = ["reinforcement-learning", "isaac", "leggedrobotics", "rl-pytorch"]
99
maintainers = [
1010
{ name="Clemens Schwarke", email="[email protected]" },
@@ -26,8 +26,9 @@ classifiers = [
2626
"Operating System :: OS Independent",
2727
]
2828
dependencies = [
29-
"torch>=1.10.0",
29+
"torch>=2.6.0",
3030
"torchvision>=0.5.0",
31+
"tensordict>=0.7.0",
3132
"numpy>=1.16.4",
3233
"GitPython",
3334
"onnx",
@@ -46,7 +47,7 @@ include = ["rsl_rl*"]
4647

4748
[tool.isort]
4849

49-
py_version = 37
50+
py_version = 38
5051
line_length = 120
5152
group_by_package = true
5253

@@ -79,7 +80,7 @@ known_first_party = "rsl_rl"
7980
include = ["rsl_rl"]
8081

8182
typeCheckingMode = "basic"
82-
pythonVersion = "3.7"
83+
pythonVersion = "3.8"
8384
pythonPlatform = "Linux"
8485
enableTypeIgnoreComments = true
8586

rsl_rl/algorithms/distillation.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,12 @@
33
#
44
# SPDX-License-Identifier: BSD-3-Clause
55

6-
# torch
76
import torch
87
import torch.nn as nn
9-
import torch.optim as optim
108

11-
# rsl-rl
129
from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent
1310
from rsl_rl.storage import RolloutStorage
11+
from rsl_rl.utils import resolve_optimizer
1412

1513

1614
class Distillation:
@@ -27,6 +25,7 @@ def __init__(
2725
learning_rate=1e-3,
2826
max_grad_norm=None,
2927
loss_type="mse",
28+
optimizer="adam",
3029
device="cpu",
3130
# Distributed training parameters
3231
multi_gpu_cfg: dict | None = None,
@@ -42,13 +41,15 @@ def __init__(
4241
self.gpu_global_rank = 0
4342
self.gpu_world_size = 1
4443

45-
self.rnd = None # TODO: remove when runner has a proper base class
46-
4744
# distillation components
4845
self.policy = policy
4946
self.policy.to(self.device)
5047
self.storage = None # initialized later
51-
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
48+
49+
# initialize the optimizer
50+
self.optimizer = resolve_optimizer(optimizer)(self.policy.parameters(), lr=learning_rate)
51+
52+
# initialize the transition
5253
self.transition = RolloutStorage.Transition()
5354
self.last_hidden_states = None
5455

@@ -59,40 +60,40 @@ def __init__(
5960
self.max_grad_norm = max_grad_norm
6061

6162
# initialize the loss function
62-
if loss_type == "mse":
63-
self.loss_fn = nn.functional.mse_loss
64-
elif loss_type == "huber":
65-
self.loss_fn = nn.functional.huber_loss
63+
loss_fn_dict = {
64+
"mse": nn.functional.mse_loss,
65+
"huber": nn.functional.huber_loss,
66+
}
67+
if loss_type in loss_fn_dict:
68+
self.loss_fn = loss_fn_dict[loss_type]
6669
else:
67-
raise ValueError(f"Unknown loss type: {loss_type}. Supported types are: mse, huber")
70+
raise ValueError(f"Unknown loss type: {loss_type}. Supported types are: {list(loss_fn_dict.keys())}")
6871

6972
self.num_updates = 0
7073

71-
def init_storage(
72-
self, training_type, num_envs, num_transitions_per_env, student_obs_shape, teacher_obs_shape, actions_shape
73-
):
74+
def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, actions_shape):
7475
# create rollout storage
7576
self.storage = RolloutStorage(
7677
training_type,
7778
num_envs,
7879
num_transitions_per_env,
79-
student_obs_shape,
80-
teacher_obs_shape,
80+
obs,
8181
actions_shape,
82-
None,
8382
self.device,
8483
)
8584

86-
def act(self, obs, teacher_obs):
85+
def act(self, obs):
8786
# compute the actions
8887
self.transition.actions = self.policy.act(obs).detach()
89-
self.transition.privileged_actions = self.policy.evaluate(teacher_obs).detach()
88+
self.transition.privileged_actions = self.policy.evaluate(obs).detach()
9089
# record the observations
9190
self.transition.observations = obs
92-
self.transition.privileged_observations = teacher_obs
9391
return self.transition.actions
9492

95-
def process_env_step(self, rewards, dones, infos):
93+
def process_env_step(self, obs, rewards, dones, extras):
94+
# update the normalizers
95+
self.policy.update_normalization(obs)
96+
9697
# record the rewards and dones
9798
self.transition.rewards = rewards
9899
self.transition.dones = dones
@@ -110,7 +111,7 @@ def update(self):
110111
for epoch in range(self.num_learning_epochs):
111112
self.policy.reset(hidden_states=self.last_hidden_states)
112113
self.policy.detach_hidden_states()
113-
for obs, _, _, privileged_actions, dones in self.storage.generator():
114+
for obs, _, privileged_actions, dones in self.storage.generator():
114115

115116
# inference the student for gradient computation
116117
actions = self.policy.act_inference(obs)

0 commit comments

Comments
 (0)