Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 135e691

Browse files
authored
Update RLBase to the latest version (#188)
* fix RLBase imported functions * rename functions in RLBase * update dependency
1 parent 899b1bc commit 135e691

19 files changed

+188
-219
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReinforcementLearningCore"
22
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
33
authors = ["Jun Tian <tianjun.cpp@gmail.com>"]
4-
version = "0.6.0"
4+
version = "0.6.1"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -16,7 +16,6 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1616
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1717
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1818
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
19-
JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8"
2019
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2120
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2221
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
@@ -39,10 +38,9 @@ Flux = "0.11.1"
3938
Functors = "0.1"
4039
GPUArrays = "5, 6.0"
4140
ImageTransformations = "0.8"
42-
JLD = "0.10, 0.11"
4341
MacroTools = "0.5"
4442
ProgressMeter = "1.2"
45-
ReinforcementLearningBase = "0.8.2"
43+
ReinforcementLearningBase = "0.9"
4644
Setfield = "0.6, 0.7"
4745
StatsBase = "0.32, 0.33"
4846
Zygote = "0.5"

src/core/hooks.jl

Lines changed: 30 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,7 @@ function (hook::RewardsPerEpisode)(::PreEpisodeStage, agent, env)
110110
end
111111

112112
function (hook::RewardsPerEpisode)(::PostActStage, agent, env)
113-
push!(hook.rewards[end], get_reward(env))
114-
end
115-
116-
function (hook::RewardsPerEpisode)(::PostActStage, agent, env::RewardOverriddenEnv)
117-
push!(hook.rewards[end], get_reward(env.env))
113+
push!(hook.rewards[end], reward(env))
118114
end
119115

120116
#####
@@ -125,9 +121,6 @@ end
125121
TotalRewardPerEpisode(; rewards = Float64[], reward = 0.0)
126122
127123
Store the total rewards of each episode in the field of `rewards`.
128-
129-
!!! note
130-
If the environment is a [`RewardOverriddenenv`](@ref), then the original reward is recorded.
131124
"""
132125
Base.@kwdef mutable struct TotalRewardPerEpisode <: AbstractHook
133126
rewards::Vector{Float64} = Float64[]
@@ -136,58 +129,11 @@ end
136129

137130
Base.getindex(h::TotalRewardPerEpisode) = h.rewards
138131

139-
(hook::TotalRewardPerEpisode)(s::AbstractStage, agent, env) =
140-
hook(s, agent, env, RewardStyle(env), NumAgentStyle(env))
141-
(hook::TotalRewardPerEpisode)(::AbstractStage, agent, env, ::Any, ::Any) = nothing
142-
143-
(hook::TotalRewardPerEpisode)(
144-
::PostEpisodeStage,
145-
agent,
146-
env,
147-
::TerminalReward,
148-
::SingleAgent,
149-
) = push!(hook.rewards, get_reward(env))
150-
(hook::TotalRewardPerEpisode)(
151-
::PostEpisodeStage,
152-
agent,
153-
env,
154-
::TerminalReward,
155-
::MultiAgent,
156-
) = push!(hook.rewards, get_reward(env, get_role(agent)))
157-
(hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::StepReward, ::SingleAgent) =
158-
hook.reward += get_reward(env)
159-
(hook::TotalRewardPerEpisode)(::PostActStage, agent, env, ::StepReward, ::MultiAgent) =
160-
hook.reward += get_reward(env, get_role(agent))
161-
(hook::TotalRewardPerEpisode)(
162-
::PostEpisodeStage,
163-
agent,
164-
env::RewardOverriddenEnv,
165-
::TerminalReward,
166-
::SingleAgent,
167-
) = push!(hook.rewards, get_reward(env.env))
168-
(hook::TotalRewardPerEpisode)(
169-
::PostEpisodeStage,
170-
agent,
171-
env::RewardOverriddenEnv,
172-
::TerminalReward,
173-
::MultiAgent,
174-
) = push!(hook.rewards, get_reward(env.env, get_role(agent)))
175-
(hook::TotalRewardPerEpisode)(
176-
::PostActStage,
177-
agent,
178-
env::RewardOverriddenEnv,
179-
::StepReward,
180-
::SingleAgent,
181-
) = hook.reward += get_reward(env.env)
182-
(hook::TotalRewardPerEpisode)(
183-
::PostActStage,
184-
agent,
185-
env::RewardOverriddenEnv,
186-
::StepReward,
187-
::MultiAgent,
188-
) = hook.reward += get_reward(env.env, get_role(agent))
189-
190-
function (hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env, ::StepReward, ::Any)
132+
function (hook::TotalRewardPerEpisode)(::PostActStage, agent, env)
133+
hook.reward += reward(env)
134+
end
135+
136+
function (hook::TotalRewardPerEpisode)(::PostEpisodeStage, agent, env)
191137
push!(hook.rewards, hook.reward)
192138
hook.reward = 0
193139
end
@@ -205,33 +151,27 @@ Base.getindex(h::TotalBatchRewardPerEpisode) = h.rewards
205151
"""
206152
TotalBatchRewardPerEpisode(batch_size::Int)
207153
208-
Similar to [`TotalRewardPerEpisode`](@ref), but will record total rewards per episode in [`MultiThreadEnv`](@ref).
209-
210-
!!! note
211-
If the environment is a [`RewardOverriddenEnv`](@ref), then the original reward is recorded.
154+
Similar to [`TotalRewardPerEpisode`](@ref), but is specific to environments
155+
which return a `Vector` of rewards (a typical case with `MultiThreadEnv`).
212156
"""
213157
function TotalBatchRewardPerEpisode(batch_size::Int)
214158
TotalBatchRewardPerEpisode([Float64[] for _ in 1:batch_size], zeros(batch_size))
215159
end
216160

217-
function (hook::TotalBatchRewardPerEpisode)(
218-
::PostActStage,
219-
agent,
220-
env::MultiThreadEnv{T},
221-
) where {T}
222-
for i in 1:length(env)
223-
if T <: RewardOverriddenEnv
224-
hook.reward[i] += get_reward(env[i].env)
225-
else
226-
hook.reward[i] += get_reward(env[i])
227-
end
228-
if get_terminal(env[i])
161+
function (hook::TotalBatchRewardPerEpisode)(::PostActStage, agent, env)
162+
for (i, (t, r)) in enumerate(zip(is_terminated(env), reward(env)))
163+
hook.reward[i] += r
164+
if t
229165
push!(hook.rewards[i], hook.reward[i])
230166
hook.reward[i] = 0.0
231167
end
232168
end
233169
end
234170

171+
#####
172+
# BatchStepsPerEpisode
173+
#####
174+
235175
struct BatchStepsPerEpisode <: AbstractHook
236176
steps::Vector{Vector{Int}}
237177
step::Vector{Int}
@@ -242,16 +182,17 @@ Base.getindex(h::BatchStepsPerEpisode) = h.steps
242182
"""
243183
BatchStepsPerEpisode(batch_size::Int; tag = "TRAINING")
244184
245-
Similar to [`StepsPerEpisode`](@ref), but only work for [`MultiThreadEnv`](@ref)
185+
Similar to [`StepsPerEpisode`](@ref), but is specific to environments
186+
which return a `Vector` of rewards (a typical case with `MultiThreadEnv`).
246187
"""
247188
function BatchStepsPerEpisode(batch_size::Int)
248189
BatchStepsPerEpisode([Int[] for _ in 1:batch_size], zeros(Int, batch_size))
249190
end
250191

251-
function (hook::BatchStepsPerEpisode)(::PostActStage, agent, env::MultiThreadEnv)
252-
for i in 1:length(env)
192+
function (hook::BatchStepsPerEpisode)(::PostActStage, agent, env)
193+
for (i, t) in enumerate(is_terminated(env))
253194
hook.step[i] += 1
254-
if get_terminal(env[i])
195+
if t
255196
push!(hook.steps[i], hook.step[i])
256197
hook.step[i] = 0
257198
end
@@ -266,24 +207,20 @@ end
266207
CumulativeReward(rewards::Vector{Float64} = [0.0])
267208
268209
Store cumulative rewards since the beginning to the field of `rewards`.
269-
270-
!!! note
271-
If the environment is a [`RewardOverriddenEnv`](@ref), then the original reward is recorded instead.
272210
"""
273211
Base.@kwdef struct CumulativeReward <: AbstractHook
274-
rewards::Vector{Float64} = [0.0]
212+
rewards::Vector{Vector{Float64}} = [[0.0]]
275213
end
276214

277215
Base.getindex(h::CumulativeReward) = h.rewards
278216

279-
function (hook::CumulativeReward)(::PostActStage, agent, env::T) where {T}
280-
if T <: RewardOverriddenEnv
281-
r = get_reward(env.env)
282-
else
283-
r = get_reward(env)
284-
end
285-
push!(hook.rewards, r + hook.rewards[end])
286-
@debug hook.tag CUMULATIVE_REWARD = hook.rewards[end]
217+
function (hook::CumulativeReward)(::PostEpisodeStage, agent, env)
218+
push!(hook.rewards, [0.0])
219+
end
220+
221+
function (hook::CumulativeReward)(::PostActStage, agent, env)
222+
r = reward(env)
223+
push!(hook.rewards[end], r + hook.rewards[end][end])
287224
end
288225

289226
#####
@@ -363,7 +300,7 @@ Base.@kwdef mutable struct UploadTrajectoryEveryNStep{M,S} <: AbstractHook
363300
sealer::S = deepcopy
364301
end
365302

366-
function (hook::UploadTrajectoryEveryNStep)(::PostActStage, agent, env)
303+
function (hook::UploadTrajectoryEveryNStep)(::PostActStage, agent::Agent, env)
367304
hook.t += 1
368305
if hook.t > 0 && hook.t % hook.n == 0
369306
put!(hook.mailbox, hook.sealer(agent.trajectory))

src/core/run.jl

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function _run(
3333
policy(PRE_EPISODE_STAGE, env)
3434
hook(PRE_EPISODE_STAGE, policy, env)
3535

36-
while !get_terminal(env) # one episode
36+
while !is_terminated(env) # one episode
3737
action = policy(PRE_ACT_STAGE, env)
3838
hook(PRE_ACT_STAGE, policy, env, action)
3939

@@ -53,29 +53,3 @@ function _run(
5353
end
5454
hook
5555
end
56-
57-
function _run(
58-
::Sequential,
59-
::SingleAgent,
60-
policy::AbstractPolicy,
61-
env::MultiThreadEnv,
62-
stop_condition,
63-
hook::AbstractHook = EmptyHook(),
64-
)
65-
66-
while true
67-
reset!(env)
68-
action = policy(PRE_ACT_STAGE, env)
69-
hook(PRE_ACT_STAGE, policy, env, action)
70-
71-
env(action)
72-
policy(POST_ACT_STAGE, env)
73-
hook(POST_ACT_STAGE, policy, env)
74-
75-
if stop_condition(policy, env)
76-
policy(PRE_ACT_STAGE, env) # let the policy see the last observation
77-
break
78-
end
79-
end
80-
hook
81-
end

src/core/stop_conditions.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ function StopAfterEpisode(episode; cur = 0, is_show_progress = true)
9595
end
9696

9797
function (s::StopAfterEpisode)(agent, env)
98-
if get_terminal(env)
98+
if is_terminated(env)
9999
s.cur += 1
100100
if !isnothing(s.progress)
101101
next!(s.progress;)
@@ -105,10 +105,6 @@ function (s::StopAfterEpisode)(agent, env)
105105
s.cur >= s.episode
106106
end
107107

108-
(s::StopAfterEpisode)(agent, env::MultiThreadEnv) =
109-
@error "MultiThreadEnv is not supported!"
110-
111-
112108
#####
113109
# StopAfterNoImprovement
114110
#####
@@ -128,7 +124,7 @@ Parameters:
128124
129125
fn: a closure, return a scalar value, which indicates the performance of the policy (the higher the better)
130126
e.g.
131-
1. () -> get_reward(env)
127+
1. () -> reward(env)
132128
1. () -> total_reward_per_episode.reward
133129
134130
patience: Number of epochs with no improvement after which training will be stopped.
@@ -142,7 +138,7 @@ function StopAfterNoImprovement(fn, patience::Int, δ::T = 0.0f0) where {T<:Numb
142138
end
143139

144140
function (s::StopAfterNoImprovement)(agent, env)::Bool
145-
get_terminal(env) || return false # post episode stage
141+
is_terminated(env) || return false # post episode stage
146142
val = s.fn()
147143
improved = isfull(s.buffer) ? all(s.buffer .< (val - s.δ)) : true
148144
push!(s.buffer, val)
@@ -160,7 +156,7 @@ Return `true` if the environment is terminated.
160156
"""
161157
struct StopWhenDone end
162158

163-
(s::StopWhenDone)(agent, env) = get_terminal(env)
159+
(s::StopWhenDone)(agent, env) = is_terminated(env)
164160

165161
#####
166162
# StopSignal

src/extensions/ReinforcementLearningBase.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
1-
using CUDA
2-
using Distributions: pdf
3-
using Random
4-
using Flux
51
using AbstractTrees
62

7-
RLBase.update!(p::RandomPolicy, x) = nothing
8-
9-
Random.rand(s::MultiContinuousSpace{<:CuArray}) = rand(CUDA.CURAND.generator(), s)
10-
113
Base.show(io::IO, p::AbstractPolicy) =
124
AbstractTrees.print_tree(io, StructTree(p), get(io, :max_depth, 10))
135

src/policies/agents/agent.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ function RLBase.update!(
8585
::Union{PreActStage,PostEpisodeStage},
8686
)
8787
action = policy(env)
88-
push!(trajectory[:state], get_state(env))
88+
push!(trajectory[:state], state(env))
8989
push!(trajectory[:action], action)
9090
action
9191
end
@@ -100,9 +100,9 @@ function RLBase.update!(
100100
::Union{PreActStage,PostEpisodeStage},
101101
)
102102
action = policy(env)
103-
push!(trajectory[:state], get_state(env))
103+
push!(trajectory[:state], state(env))
104104
push!(trajectory[:action], action)
105-
push!(trajectory[:legal_actions_mask], get_legal_actions_mask(env))
105+
push!(trajectory[:legal_actions_mask], legal_action_space_mask(env))
106106
action
107107
end
108108

@@ -112,6 +112,6 @@ function RLBase.update!(
112112
env::AbstractEnv,
113113
::PostActStage,
114114
)
115-
push!(trajectory[:reward], get_reward(env))
116-
push!(trajectory[:terminal], get_terminal(env))
115+
push!(trajectory[:reward], reward(env))
116+
push!(trajectory[:terminal], is_terminated(env))
117117
end

src/policies/policies.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
include("base.jl")
22
include("agents/agents.jl")
33
include("q_based_policies/q_based_policies.jl")
4+
include("random_policy.jl")

src/policies/q_based_policies/explorers/abstract_explorer.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ function (p::AbstractExplorer)(x) end
1414
function (p::AbstractExplorer)(x, mask) end
1515

1616
"""
17-
get_prob(p::AbstractExplorer, x) -> AbstractDistribution
17+
prob(p::AbstractExplorer, x) -> AbstractDistribution
1818
1919
Get the action distribution given action values.
2020
"""
21-
function RLBase.get_prob(p::AbstractExplorer, x) end
21+
function RLBase.prob(p::AbstractExplorer, x) end
2222

2323
"""
24-
get_prob(p::AbstractExplorer, x, mask)
24+
prob(p::AbstractExplorer, x, mask)
2525
26-
Similart to `get_prob(p::AbstractExplorer, x)`, but here only the `mask`ed elements are considered.
26+
Similart to `prob(p::AbstractExplorer, x)`, but here only the `mask`ed elements are considered.
2727
"""
28-
function RLBase.get_prob(p::AbstractExplorer, x, mask) end
28+
function RLBase.prob(p::AbstractExplorer, x, mask) end

0 commit comments

Comments
 (0)