Skip to content

Commit 73fd7c6

Browse files
committed
Merge branch 'release'
2 parents a1d25d1 + 2fab9bb commit 73fd7c6

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

rsl_rl/runners/on_policy_runner.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
4545
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
4646
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
4747
else:
48-
self.obs_normalizer = torch.nn.Identity() # no normalization
49-
self.critic_obs_normalizer = torch.nn.Identity() # no normalization
48+
self.obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
49+
self.critic_obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
5050
# init storage and model
5151
self.alg.init_storage(
5252
self.env.num_envs,
@@ -109,18 +109,21 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
109109
with torch.inference_mode():
110110
for i in range(self.num_steps_per_env):
111111
actions = self.alg.act(obs, critic_obs)
112-
obs, rewards, dones, infos = self.env.step(actions)
113-
obs = self.obs_normalizer(obs)
114-
if "critic" in infos["observations"]:
115-
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
116-
else:
117-
critic_obs = obs
112+
obs, rewards, dones, infos = self.env.step(actions.to(self.env.device))
113+
# move to the right device
118114
obs, critic_obs, rewards, dones = (
119115
obs.to(self.device),
120116
critic_obs.to(self.device),
121117
rewards.to(self.device),
122118
dones.to(self.device),
123119
)
120+
# perform normalization
121+
obs = self.obs_normalizer(obs)
122+
if "critic" in infos["observations"]:
123+
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
124+
else:
125+
critic_obs = obs
126+
# process the step
124127
self.alg.process_env_step(rewards, dones, infos)
125128

126129
if self.log_dir is not None:

0 commit comments

Comments
 (0)