Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions examples/ppo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Training PPO with decentralized averaging

This tutorial will walk you through the steps to set up collaborative training of an on-policy reinforcement learning algorighm [PPO](https://arxiv.org/pdf/1707.06347.pdf) to play Atari Breakout. It uses [stable-baselines3 implementation of PPO](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html), hyperparameters for the algorithm are taken from [rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml), collaborative training is built on `hivemind.Optimizer` to exchange information between peers.

## Preparation

* Install hivemind: `pip install git+https://github.com/learning-at-home/hivemind.git`
* Dependencies: `pip install -r requirements.txt`

## Running an experiment

### First peer
Run the first DHT peer to welcome trainers and record training statistics (e.g., loss and performance):
- In this example, we use [tensorboard](https://www.tensorflow.org/tensorboard) to plot training metrics. If you're unfamiliar with Tensorboard, here's a [quickstart tutorial](https://www.tensorflow.org/tensorboard/get_started).
- Run `python3 ppo.py`

```
$ python3 ppo.py
To connect other peers to this one, use --initial_peers /ip4/127.0.0.1/tcp/41926/p2p/QmUmiebP4BxdEPEpQb28cqyhaheDugFRn7M
CoLJr556xYt
A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)
[Powered by Stella]
Using cuda device
Wrapping the env in a VecTransposeImage.
[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.
Jun 20 13:23:20.515 [INFO] Found no active peers: None
Jun 20 13:23:20.533 [INFO] Initializing optimizer manually since it has no tensors in state dict. To override this, prov
ide initialize_optimizer=False
Logging to logs/bs-256.target_bs-32768.n_envs-8.n_steps-128.n_epochs-1_1
---------------------------------
| rollout/ | |
| ep_len_mean | 521 |
| ep_rew_mean | 0 |
| time/ | |
| fps | 582 |
| iterations | 1 |
| time_elapsed | 1 |
| total_timesteps | 1024 |
| train/ | |
| timesteps | 1024 |
---------------------------------
Jun 20 13:23:23.525 [INFO] ppo_hivemind accumulated 1024 samples for epoch #0 from 1 peers. ETA 52.20 sec (refresh in 1$
.00 sec)

```
117 changes: 117 additions & 0 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import argparse
from importlib.resources import path
import pathlib
import torch

import hivemind
from hivemind import Float16Compression, SizeAdaptiveCompression, Uniform8BitQuantization

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--n-steps', type=int, default=128, help='Number of rollout steps per each agent')
parser.add_argument('--n-envs', type=int, default=8, help='Number of training envs')
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--target-batch-size', type=int, default=32768)
parser.add_argument('--n-epochs', type=int, default=1, help='Number of training epochs per each rollout')
parser.add_argument('--learning-rate', type=float, default=2.5e-4)
parser.add_argument('--tb-logs-path', type=pathlib.Path, default='./logs', help='Path to tensorboard logs folder')
parser.add_argument('--experiment-prefix', type=str, help='Experiment prefix for tensorboard logs')
parser.add_argument('--initial-peers', nargs='+', default=[])
parser.add_argument('--averaging-compression', action='store_true')
args = parser.parse_args()
return args

def generate_experiment_name(args):
exp_name_dict = {
'bs': args.batch_size,
'target_bs': args.target_batch_size,
'n_envs': args.n_envs,
'n_steps': args.n_steps,
'n_epochs': args.n_epochs,
}

exp_name = [f'{key}-{value}' for key, value in exp_name_dict.items()]
exp_name = '.'.join(exp_name)

if args.experiment_prefix:
exp_name = f'{args.experiment_prefix}.{exp_name}'
exp_name = exp_name.replace('000.', 'k.')
return exp_name


class AdamWithClipping(torch.optim.Adam):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to @mryab : we've recently merged the same clipping functionality here:
https://github.com/learning-at-home/hivemind/blob/master/hivemind/moe/server/layers/optim.py#L48

Would you prefer if we...

  • keep everything as is, accept some code duplication?
  • extract moe.server.layers.optim to utils.optim and use it here?
  • keep wrapper in hivemind.optim and import from there?
  • insert your option here :)

Copy link
Member

@mryab mryab Jun 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm 50:50 between the "keep here, accept duplication" and "move OptimizerWrapper and ClippingWrapper to hivemind.optim.wrapper" solutions, so ultimately, it's @foksly's call

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The utils option is also acceptable, but I'm slightly against this folder becoming too bloated. That said, it looks like a reasonable place to put such code, so any solution of these three is fine by me (as long as you don't import the wrapper from hivemind.moe)

def __init__(self, *args, max_grad_norm: float, **kwargs):
self.max_grad_norm = max_grad_norm
super().__init__(*args, **kwargs)

def step(self, *args, **kwargs):
iter_params = (param for group in self.param_groups for param in group["params"])
torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
return super().step(*args, **kwargs)


def configure_dht_opts(args):
opts = {
'start': True,
}
if args.initial_peers:
opts['initial_peers'] = args.initial_peers

return opts


if __name__ == "__main__":
args = parse_args()

dht_opts = configure_dht_opts(args)
dht = hivemind.DHT(**dht_opts)
print("To connect other peers to this one, use --initial_peers", *[str(addr) for addr in dht.get_visible_maddrs()])

env = make_atari_env('BreakoutNoFrameskip-v4', n_envs=args.n_envs)
env = VecFrameStack(env, n_stack=4)

model = PPO(
'CnnPolicy', env,
verbose=1,
batch_size=args.batch_size,
n_steps=args.n_steps,
n_epochs=args.n_epochs,
learning_rate=args.learning_rate,
clip_range=0.1,
vf_coef=0.5,
ent_coef=0.01,
tensorboard_log=args.tb_logs_path,
max_grad_norm=10000.0,
policy_kwargs={'optimizer_class': AdamWithClipping, 'optimizer_kwargs': {'max_grad_norm': 0.5}}
)

compression_opts = {}
if args.averaging_compression:
averaging_compression = SizeAdaptiveCompression(
threshold=2 ** 10 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization()
)
compression_opts.update({
'grad_compression': averaging_compression,
'state_averaging_compression': averaging_compression
})

model.policy.optimizer_class = hivemind.Optimizer
model.policy.optimizer = hivemind.Optimizer(
dht=dht,
optimizer=model.policy.optimizer,
run_id='ppo_hivemind',
batch_size_per_step=args.batch_size,
target_batch_size=args.target_batch_size,
offload_optimizer=False,
verbose=True,
use_local_updates=False,
matchmaking_time=4,
averaging_timeout=15,
**compression_opts,
)
model.policy.optimizer.load_state_from_peers()
model.learn(total_timesteps=int(5e11), tb_log_name=generate_experiment_name(args))
1 change: 1 addition & 0 deletions examples/ppo/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
stable-baselines3[extra]==1.5.0