|
| 1 | +.. _ppo_lstm: |
| 2 | + |
| 3 | +.. automodule:: sb3_contrib.ppo_recurrent |
| 4 | + |
| 5 | +Recurrent PPO |
| 6 | +============= |
| 7 | + |
| 8 | +Implementation of recurrent policies for the Proximal Policy Optimization (PPO) |
| 9 | +algorithm. Other than adding support for recurrent policies (LSTM here), the behavior is the same as in SB3's core PPO algorithm. |
| 10 | + |
| 11 | + |
| 12 | +.. rubric:: Available Policies |
| 13 | + |
| 14 | +.. autosummary:: |
| 15 | + :nosignatures: |
| 16 | + |
| 17 | + MlpLstmPolicy |
| 18 | + CnnLstmPolicy |
| 19 | + MultiInputLstmPolicy |
| 20 | + |
| 21 | + |
| 22 | +Notes |
| 23 | +----- |
| 24 | + |
| 25 | +- Blog post: https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/ |
| 26 | + |
| 27 | + |
| 28 | +Can I use? |
| 29 | +---------- |
| 30 | + |
| 31 | +- Recurrent policies: ✔️ |
| 32 | +- Multi processing: ✔️ |
| 33 | +- Gym spaces: |
| 34 | + |
| 35 | + |
| 36 | +============= ====== =========== |
| 37 | +Space Action Observation |
| 38 | +============= ====== =========== |
| 39 | +Discrete ✔️ ✔️ |
| 40 | +Box ✔️ ✔️ |
| 41 | +MultiDiscrete ✔️ ✔️ |
| 42 | +MultiBinary ✔️ ✔️ |
| 43 | +Dict ❌ ✔️ |
| 44 | +============= ====== =========== |
| 45 | + |
| 46 | + |
| 47 | +Example |
| 48 | +------- |
| 49 | + |
| 50 | +.. note:: |
| 51 | + |
| 52 | + It is particularly important to pass the ``lstm_states`` |
| 53 | + and ``episode_start`` argument to the ``predict()`` method, |
| 54 | + so the cell and hidden states of the LSTM are correctly updated. |
| 55 | + |
| 56 | + |
| 57 | +.. code-block:: python |
| 58 | +
|
| 59 | + import numpy as np |
| 60 | +
|
| 61 | + from sb3_contrib import RecurrentPPO |
| 62 | + from stable_baselines3.common.evaluation import evaluate_policy |
| 63 | +
|
| 64 | + model = RecurrentPPO("MlpLstmPolicy", "CartPole-v1", verbose=1) |
| 65 | + model.learn(5000) |
| 66 | +
|
| 67 | + env = model.get_env() |
| 68 | + mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=20, warn=False) |
| 69 | + print(mean_reward) |
| 70 | +
|
| 71 | + model.save("ppo_recurrent") |
| 72 | + del model # remove to demonstrate saving and loading |
| 73 | +
|
| 74 | + model = RecurrentPPO.load("ppo_recurrent") |
| 75 | +
|
| 76 | + obs = env.reset() |
| 77 | + # cell and hidden state of the LSTM |
| 78 | + lstm_states = None |
| 79 | + num_envs = 1 |
| 80 | + # Episode start signals are used to reset the lstm states |
| 81 | + episode_starts = np.ones((num_envs,), dtype=bool) |
| 82 | + while True: |
| 83 | + action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True) |
| 84 | + obs, rewards, dones, info = env.step(action) |
| 85 | + episode_starts = dones |
| 86 | + env.render() |
| 87 | +
|
| 88 | +
|
| 89 | +
|
| 90 | +Results |
| 91 | +------- |
| 92 | + |
| 93 | +Report on environments with masked velocity (with and without framestack) can be found here: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4 |
| 94 | + |
| 95 | +``RecurrentPPO`` was evaluated against PPO on: |
| 96 | + |
| 97 | +- PendulumNoVel-v1 |
| 98 | +- LunarLanderNoVel-v2 |
| 99 | +- CartPoleNoVel-v1 |
| 100 | +- MountainCarContinuousNoVel-v0 |
| 101 | +- CarRacing-v0 |
| 102 | + |
| 103 | +How to replicate the results? |
| 104 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 105 | + |
| 106 | +Clone the repo for the experiment: |
| 107 | + |
| 108 | +.. code-block:: bash |
| 109 | +
|
| 110 | + git clone https://github.com/DLR-RM/rl-baselines3-zoo |
| 111 | + cd rl-baselines3-zoo |
| 112 | + git checkout feat/recurrent-ppo |
| 113 | +
|
| 114 | +
|
| 115 | +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): |
| 116 | + |
| 117 | +.. code-block:: bash |
| 118 | +
|
| 119 | + python train.py --algo ppo_lstm --env $ENV_ID --eval-episodes 10 --eval-freq 10000 |
| 120 | +
|
| 121 | +
|
| 122 | +Parameters |
| 123 | +---------- |
| 124 | + |
| 125 | +.. autoclass:: RecurrentPPO |
| 126 | + :members: |
| 127 | + :inherited-members: |
| 128 | + |
| 129 | + |
| 130 | +RecurrentPPO Policies |
| 131 | +--------------------- |
| 132 | + |
| 133 | +.. autoclass:: MlpLstmPolicy |
| 134 | + :members: |
| 135 | + :inherited-members: |
| 136 | + |
| 137 | +.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticPolicy |
| 138 | + :members: |
| 139 | + :noindex: |
| 140 | + |
| 141 | +.. autoclass:: CnnLstmPolicy |
| 142 | + :members: |
| 143 | + |
| 144 | +.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentActorCriticCnnPolicy |
| 145 | + :members: |
| 146 | + :noindex: |
| 147 | + |
| 148 | +.. autoclass:: MultiInputLstmPolicy |
| 149 | + :members: |
| 150 | + |
| 151 | +.. autoclass:: sb3_contrib.common.recurrent.policies.RecurrentMultiInputActorCriticPolicy |
| 152 | + :members: |
| 153 | + :noindex: |
0 commit comments