8
8
from torch .nn import MSELoss
9
9
from torch .optim import Adam
10
10
11
- from fast_rl .agents .BaseAgent import BaseAgent , create_nn_model
11
+ from fast_rl .agents .BaseAgent import BaseAgent , create_nn_model , create_cnn_model , get_next_conv_shape , get_conv , \
12
+ Flatten
12
13
from fast_rl .core .Learner import AgentLearner
13
14
from fast_rl .core .MarkovDecisionProcess import MDPDataBunch
14
15
from fast_rl .core .agent_core import GreedyEpsilon , ExperienceReplay
@@ -27,6 +28,8 @@ def on_train_begin(self, n_epochs, **kwargs: Any):
27
28
28
29
def on_epoch_begin (self , epoch , ** kwargs : Any ):
29
30
self .episode = epoch
31
+ # if self.learn.model.training and self.iteration != 0:
32
+ # self.learn.model.memory.update(item=self.learn.data.x.items[-1])
30
33
self .iteration = 0
31
34
32
35
def on_loss_begin (self , ** kwargs : Any ):
@@ -47,7 +50,7 @@ def on_loss_begin(self, **kwargs: Any):
47
50
# self.learn.model.target_copy_over()
48
51
49
52
50
- class Critic (nn .Module ):
53
+ class NNCritic (nn .Module ):
51
54
def __init__ (self , layer_list : list , action_size , state_size , use_bn = False , use_embed = True ,
52
55
activation_function = None ):
53
56
super ().__init__ ()
@@ -59,7 +62,7 @@ def __init__(self, layer_list: list, action_size, state_size, use_bn=False, use_
59
62
self .fc3 = nn .Linear (layer_list [1 ], 1 )
60
63
61
64
def forward (self , x ):
62
- action , x = x [:, self . state_size :], x [:, : self . state_size ]
65
+ x , action = x
63
66
64
67
x = nn .LeakyReLU ()(self .fc1 (x ))
65
68
x = nn .LeakyReLU ()(self .fc2 (torch .cat ((x , action ), 1 )))
@@ -68,17 +71,41 @@ def forward(self, x):
68
71
return x
69
72
70
73
74
+ class CNNCritic (nn .Module ):
75
+ def __init__ (self , layer_list : list , action_size , state_size , activation_function = None ):
76
+ super ().__init__ ()
77
+ self .action_size = action_size [0 ]
78
+ self .state_size = state_size [0 ]
79
+
80
+ layers = []
81
+ layers , input_size = get_conv (self .state_size , nn .LeakyReLU (), 8 , 2 , 3 , layers )
82
+ layers += [Flatten ()]
83
+ self .conv_layers = nn .Sequential (* layers )
84
+
85
+ self .fc1 = nn .Linear (input_size + self .action_size , 200 )
86
+ self .fc2 = nn .Linear (200 , 1 )
87
+
88
+ def forward (self , x ):
89
+ x , action = x
90
+
91
+ x = nn .LeakyReLU ()(self .conv_layers (x ))
92
+ x = nn .LeakyReLU ()(self .fc1 (torch .cat ((x , action ), 1 )))
93
+ x = nn .LeakyReLU ()(self .fc2 (x ))
94
+
95
+ return x
96
+
97
+
71
98
class DDPG (BaseAgent ):
72
99
73
100
def __init__ (self , data : MDPDataBunch , memory = None , tau = 1e-3 , batch = 64 , discount = 0.99 ,
74
- lr = 1e-3 , actor_lr = 1e-4 , exploration_strategy = None , env_was_discrete = False ):
101
+ lr = 1e-3 , actor_lr = 1e-4 , exploration_strategy = None ):
75
102
"""
76
103
Implementation of a continuous control algorithm using an actor/critic architecture.
77
104
78
105
Notes:
79
106
Uses 4 networks, 2 actors, 2 critics.
80
107
All models use batch norm for feature invariance.
81
- Critic simply predicts Q while the Actor proposes the actions to take given a state s.
108
+ NNCritic simply predicts Q while the Actor proposes the actions to take given a state s.
82
109
83
110
References:
84
111
[1] Lillicrap, Timothy P., et al. "Continuous control with deep reinforcement learning."
@@ -93,7 +120,6 @@ def __init__(self, data: MDPDataBunch, memory=None, tau=1e-3, batch=64, discount
93
120
lr: Rate that the opt will learn parameter gradients.
94
121
"""
95
122
super ().__init__ (data )
96
- self .env_was_discrete = env_was_discrete
97
123
self .name = 'DDPG'
98
124
self .lr = lr
99
125
self .discount = discount
@@ -122,21 +148,30 @@ def __init__(self, data: MDPDataBunch, memory=None, tau=1e-3, batch=64, discount
122
148
do_exploration = self .training ))
123
149
124
150
def initialize_action_model (self , layers , data ):
125
- return create_nn_model (layers , * data .get_action_state_size (), False , use_embed = data .train_ds .embeddable ,
126
- final_activation_function = nn .Tanh )
151
+ actions , state = data .get_action_state_size ()
152
+ if type (state [0 ]) is tuple and len (state [0 ]) == 3 :
153
+ # actions, state = actions[0], state[0]
154
+ # If the shape has 3 dimensions, we will try using cnn's instead.
155
+ return create_cnn_model ([200 , 200 ], actions , state , False , kernel_size = 8 ,
156
+ final_activation_function = nn .Tanh , action_val_to_dim = False )
157
+ else :
158
+ return create_nn_model (layers , * data .get_action_state_size (), False , use_embed = data .train_ds .embeddable ,
159
+ final_activation_function = nn .Tanh , action_val_to_dim = False )
127
160
128
161
def initialize_critic_model (self , layers , data ):
129
162
""" Instead of state -> action, we are going state + action -> single expected reward. """
130
- return Critic (layers , * data .get_action_state_size ())
163
+ actions , state = data .get_action_state_size ()
164
+ if type (state [0 ]) is tuple and len (state [0 ]) == 3 :
165
+ return CNNCritic (layers , * data .get_action_state_size ())
166
+ else :
167
+ return NNCritic (layers , * data .get_action_state_size ())
131
168
132
169
def pick_action (self , x ):
133
170
if self .training : self .action_model .eval ()
134
171
with torch .no_grad ():
135
- action , x = super (DDPG , self ).pick_action (x )
172
+ action = super (DDPG , self ).pick_action (x )
136
173
if self .training : self .action_model .train ()
137
-
138
- if not self .env_was_discrete : action = np .clip (action , - 1 , 1 )
139
- return action , np .clip (x , - 1 , 1 )
174
+ return np .clip (action , - 1 , 1 )
140
175
141
176
def optimize (self ):
142
177
"""
@@ -160,12 +195,11 @@ def optimize(self):
160
195
s_prime = torch .from_numpy (np .array ([item .result_state for item in sampled ])).float ()
161
196
s = torch .from_numpy (np .array ([item .current_state for item in sampled ])).float ()
162
197
a = torch .from_numpy (np .array ([item .actions for item in sampled ]).astype (float )).float ()
163
- if self .env_was_discrete : a = torch .from_numpy (np .array ([item .raw_action for item in sampled ]).astype (float )).float ()
164
198
165
199
with torch .no_grad ():
166
- y = r + self .discount * self .t_critic_model (torch . cat (( s_prime , self .t_action_model (s_prime )), 1 ))
200
+ y = r + self .discount * self .t_critic_model (( s_prime , self .t_action_model (s_prime )))
167
201
168
- y_hat = self .critic_model (torch . cat (( s , a ), 1 ))
202
+ y_hat = self .critic_model (( s , a ))
169
203
170
204
critic_loss = self .loss_func (y_hat , y )
171
205
@@ -175,7 +209,7 @@ def optimize(self):
175
209
critic_loss .backward ()
176
210
self .critic_optimizer .step ()
177
211
178
- actor_loss = - self .critic_model (torch . cat (( s , self .action_model (s )), 1 )).mean ()
212
+ actor_loss = - self .critic_model (( s , self .action_model (s ))).mean ()
179
213
180
214
self .loss = critic_loss .cpu ().detach ()
181
215
0 commit comments