@@ -45,8 +45,8 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
4545 self .obs_normalizer = EmpiricalNormalization (shape = [num_obs ], until = 1.0e8 ).to (self .device )
4646 self .critic_obs_normalizer = EmpiricalNormalization (shape = [num_critic_obs ], until = 1.0e8 ).to (self .device )
4747 else :
48- self .obs_normalizer = torch .nn .Identity () # no normalization
49- self .critic_obs_normalizer = torch .nn .Identity () # no normalization
48+ self .obs_normalizer = torch .nn .Identity (). to ( self . device ) # no normalization
49+ self .critic_obs_normalizer = torch .nn .Identity (). to ( self . device ) # no normalization
5050 # init storage and model
5151 self .alg .init_storage (
5252 self .env .num_envs ,
@@ -109,18 +109,21 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
109109 with torch .inference_mode ():
110110 for i in range (self .num_steps_per_env ):
111111 actions = self .alg .act (obs , critic_obs )
112- obs , rewards , dones , infos = self .env .step (actions )
113- obs = self .obs_normalizer (obs )
114- if "critic" in infos ["observations" ]:
115- critic_obs = self .critic_obs_normalizer (infos ["observations" ]["critic" ])
116- else :
117- critic_obs = obs
112+ obs , rewards , dones , infos = self .env .step (actions .to (self .env .device ))
113+ # move to the right device
118114 obs , critic_obs , rewards , dones = (
119115 obs .to (self .device ),
120116 critic_obs .to (self .device ),
121117 rewards .to (self .device ),
122118 dones .to (self .device ),
123119 )
120+ # perform normalization
121+ obs = self .obs_normalizer (obs )
122+ if "critic" in infos ["observations" ]:
123+ critic_obs = self .critic_obs_normalizer (infos ["observations" ]["critic" ])
124+ else :
125+ critic_obs = obs
126+ # process the step
124127 self .alg .process_env_step (rewards , dones , infos )
125128
126129 if self .log_dir is not None :
0 commit comments