Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 0cc7c7e

Browse files
authored
add RandomStartPolicy & restrict QBasedPolicy to return index of action (#189)
1 parent 135e691 commit 0cc7c7e

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/policies/policies.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ include("base.jl")
22
include("agents/agents.jl")
33
include("q_based_policies/q_based_policies.jl")
44
include("random_policy.jl")
5+
include("random_start_policy.jl")

src/policies/q_based_policies/q_based_policy.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@ end
1919
Flux.functor(x::QBasedPolicy) = (learner = x.learner,), y -> @set x.learner = y.learner
2020

2121
::QBasedPolicy)(env) = π(env, ActionStyle(env))
22-
::QBasedPolicy)(env, ::MinimalActionSet) = action_space(env)[π.explorer.learner(env))]
23-
::QBasedPolicy)(env, ::FullActionSet) =
24-
action_space(env)[π.explorer.learner(env), legal_action_space_mask(env))]
22+
::QBasedPolicy)(env, ::MinimalActionSet) = π.explorer.learner(env))
23+
::QBasedPolicy)(env, ::FullActionSet) = π.explorer.learner(env), legal_action_space_mask(env))
2524

2625
RLBase.prob(p::QBasedPolicy, env) = prob(p, env, ActionStyle(env))
2726
RLBase.prob(p::QBasedPolicy, env, ::MinimalActionSet) =

0 commit comments

Comments
 (0)