@@ -110,11 +110,7 @@ function (hook::RewardsPerEpisode)(::PreEpisodeStage, agent, env)
110110end
111111
112112function (hook:: RewardsPerEpisode )(:: PostActStage , agent, env)
113- push! (hook. rewards[end ], get_reward (env))
114- end
115-
116- function (hook:: RewardsPerEpisode )(:: PostActStage , agent, env:: RewardOverriddenEnv )
117- push! (hook. rewards[end ], get_reward (env. env))
113+ push! (hook. rewards[end ], reward (env))
118114end
119115
120116# ####
125121 TotalRewardPerEpisode(; rewards = Float64[], reward = 0.0)
126122
127123Store the total rewards of each episode in the field of `rewards`.
128-
129- !!! note
130- If the environment is a [`RewardOverriddenenv`](@ref), then the original reward is recorded.
131124"""
132125Base. @kwdef mutable struct TotalRewardPerEpisode <: AbstractHook
133126 rewards:: Vector{Float64} = Float64[]
@@ -136,58 +129,11 @@ end
136129
137130Base. getindex (h:: TotalRewardPerEpisode ) = h. rewards
138131
139- (hook:: TotalRewardPerEpisode )(s:: AbstractStage , agent, env) =
140- hook (s, agent, env, RewardStyle (env), NumAgentStyle (env))
141- (hook:: TotalRewardPerEpisode )(:: AbstractStage , agent, env, :: Any , :: Any ) = nothing
142-
143- (hook:: TotalRewardPerEpisode )(
144- :: PostEpisodeStage ,
145- agent,
146- env,
147- :: TerminalReward ,
148- :: SingleAgent ,
149- ) = push! (hook. rewards, get_reward (env))
150- (hook:: TotalRewardPerEpisode )(
151- :: PostEpisodeStage ,
152- agent,
153- env,
154- :: TerminalReward ,
155- :: MultiAgent ,
156- ) = push! (hook. rewards, get_reward (env, get_role (agent)))
157- (hook:: TotalRewardPerEpisode )(:: PostActStage , agent, env, :: StepReward , :: SingleAgent ) =
158- hook. reward += get_reward (env)
159- (hook:: TotalRewardPerEpisode )(:: PostActStage , agent, env, :: StepReward , :: MultiAgent ) =
160- hook. reward += get_reward (env, get_role (agent))
161- (hook:: TotalRewardPerEpisode )(
162- :: PostEpisodeStage ,
163- agent,
164- env:: RewardOverriddenEnv ,
165- :: TerminalReward ,
166- :: SingleAgent ,
167- ) = push! (hook. rewards, get_reward (env. env))
168- (hook:: TotalRewardPerEpisode )(
169- :: PostEpisodeStage ,
170- agent,
171- env:: RewardOverriddenEnv ,
172- :: TerminalReward ,
173- :: MultiAgent ,
174- ) = push! (hook. rewards, get_reward (env. env, get_role (agent)))
175- (hook:: TotalRewardPerEpisode )(
176- :: PostActStage ,
177- agent,
178- env:: RewardOverriddenEnv ,
179- :: StepReward ,
180- :: SingleAgent ,
181- ) = hook. reward += get_reward (env. env)
182- (hook:: TotalRewardPerEpisode )(
183- :: PostActStage ,
184- agent,
185- env:: RewardOverriddenEnv ,
186- :: StepReward ,
187- :: MultiAgent ,
188- ) = hook. reward += get_reward (env. env, get_role (agent))
189-
190- function (hook:: TotalRewardPerEpisode )(:: PostEpisodeStage , agent, env, :: StepReward , :: Any )
132+ function (hook:: TotalRewardPerEpisode )(:: PostActStage , agent, env)
133+ hook. reward += reward (env)
134+ end
135+
136+ function (hook:: TotalRewardPerEpisode )(:: PostEpisodeStage , agent, env)
191137 push! (hook. rewards, hook. reward)
192138 hook. reward = 0
193139end
@@ -205,33 +151,27 @@ Base.getindex(h::TotalBatchRewardPerEpisode) = h.rewards
205151"""
206152 TotalBatchRewardPerEpisode(batch_size::Int)
207153
208- Similar to [`TotalRewardPerEpisode`](@ref), but will record total rewards per episode in [`MultiThreadEnv`](@ref).
209-
210- !!! note
211- If the environment is a [`RewardOverriddenEnv`](@ref), then the original reward is recorded.
154+ Similar to [`TotalRewardPerEpisode`](@ref), but is specific to environments
155+ which return a `Vector` of rewards (a typical case with `MultiThreadEnv`).
212156"""
213157function TotalBatchRewardPerEpisode (batch_size:: Int )
214158 TotalBatchRewardPerEpisode ([Float64[] for _ in 1 : batch_size], zeros (batch_size))
215159end
216160
217- function (hook:: TotalBatchRewardPerEpisode )(
218- :: PostActStage ,
219- agent,
220- env:: MultiThreadEnv{T} ,
221- ) where {T}
222- for i in 1 : length (env)
223- if T <: RewardOverriddenEnv
224- hook. reward[i] += get_reward (env[i]. env)
225- else
226- hook. reward[i] += get_reward (env[i])
227- end
228- if get_terminal (env[i])
161+ function (hook:: TotalBatchRewardPerEpisode )(:: PostActStage , agent, env)
162+ for (i, (t, r)) in enumerate (zip (is_terminated (env), reward (env)))
163+ hook. reward[i] += r
164+ if t
229165 push! (hook. rewards[i], hook. reward[i])
230166 hook. reward[i] = 0.0
231167 end
232168 end
233169end
234170
171+ # ####
172+ # BatchStepsPerEpisode
173+ # ####
174+
235175struct BatchStepsPerEpisode <: AbstractHook
236176 steps:: Vector{Vector{Int}}
237177 step:: Vector{Int}
@@ -242,16 +182,17 @@ Base.getindex(h::BatchStepsPerEpisode) = h.steps
242182"""
243183 BatchStepsPerEpisode(batch_size::Int; tag = "TRAINING")
244184
245- Similar to [`StepsPerEpisode`](@ref), but only work for [`MultiThreadEnv`](@ref)
185+ Similar to [`StepsPerEpisode`](@ref), but is specific to environments
186+ which return a `Vector` of rewards (a typical case with `MultiThreadEnv`).
246187"""
247188function BatchStepsPerEpisode (batch_size:: Int )
248189 BatchStepsPerEpisode ([Int[] for _ in 1 : batch_size], zeros (Int, batch_size))
249190end
250191
251- function (hook:: BatchStepsPerEpisode )(:: PostActStage , agent, env:: MultiThreadEnv )
252- for i in 1 : length ( env)
192+ function (hook:: BatchStepsPerEpisode )(:: PostActStage , agent, env)
193+ for (i, t) in enumerate ( is_terminated ( env) )
253194 hook. step[i] += 1
254- if get_terminal (env[i])
195+ if t
255196 push! (hook. steps[i], hook. step[i])
256197 hook. step[i] = 0
257198 end
@@ -266,24 +207,20 @@ end
266207 CumulativeReward(rewards::Vector{Float64} = [0.0])
267208
268209Store cumulative rewards since the beginning to the field of `rewards`.
269-
270- !!! note
271- If the environment is a [`RewardOverriddenEnv`](@ref), then the original reward is recorded instead.
272210"""
273211Base. @kwdef struct CumulativeReward <: AbstractHook
274- rewards:: Vector{Float64} = [0.0 ]
212+ rewards:: Vector{Vector{ Float64}} = [[ 0.0 ] ]
275213end
276214
277215Base. getindex (h:: CumulativeReward ) = h. rewards
278216
279- function (hook:: CumulativeReward )(:: PostActStage , agent, env:: T ) where {T}
280- if T <: RewardOverriddenEnv
281- r = get_reward (env. env)
282- else
283- r = get_reward (env)
284- end
285- push! (hook. rewards, r + hook. rewards[end ])
286- @debug hook. tag CUMULATIVE_REWARD = hook. rewards[end ]
217+ function (hook:: CumulativeReward )(:: PostEpisodeStage , agent, env)
218+ push! (hook. rewards, [0.0 ])
219+ end
220+
221+ function (hook:: CumulativeReward )(:: PostActStage , agent, env)
222+ r = reward (env)
223+ push! (hook. rewards[end ], r + hook. rewards[end ][end ])
287224end
288225
289226# ####
@@ -363,7 +300,7 @@ Base.@kwdef mutable struct UploadTrajectoryEveryNStep{M,S} <: AbstractHook
363300 sealer:: S = deepcopy
364301end
365302
366- function (hook:: UploadTrajectoryEveryNStep )(:: PostActStage , agent, env)
303+ function (hook:: UploadTrajectoryEveryNStep )(:: PostActStage , agent:: Agent , env)
367304 hook. t += 1
368305 if hook. t > 0 && hook. t % hook. n == 0
369306 put! (hook. mailbox, hook. sealer (agent. trajectory))
0 commit comments