From d98dcafc66dde2b94e7839ee0dbc8100746b1350 Mon Sep 17 00:00:00 2001 From: Maryam Date: Wed, 25 Jan 2023 13:50:05 -0500 Subject: [PATCH 1/4] Add unit tests for recurrent buffer --- tests/hive/replays/test_recurrent_buffer.py | 148 ++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 tests/hive/replays/test_recurrent_buffer.py diff --git a/tests/hive/replays/test_recurrent_buffer.py b/tests/hive/replays/test_recurrent_buffer.py new file mode 100644 index 00000000..b938f23a --- /dev/null +++ b/tests/hive/replays/test_recurrent_buffer.py @@ -0,0 +1,148 @@ + +import numpy as np +import pytest +from pytest_lazyfixture import lazy_fixture + +from hive.replays.recurrent_replay import RecurrentReplayBuffer + +### size, seq_length, +### add --> check shape, + +OBS_SHAPE = (4, 4) +CAPACITY = 60 +MAX_SEQ_LEN = 10 +N_STEP_HORIZON = 1 +GAMMA = 0.99 + +@pytest.fixture() +def rec_buffer(): + return RecurrentReplayBuffer( + capacity=CAPACITY, + max_seq_len = MAX_SEQ_LEN, + observation_shape=OBS_SHAPE, + observation_dtype=np.float32, + extra_storage_types={"priority": (np.int8, ())}, + ) + +@pytest.fixture( + params=[ + pytest.lazy_fixture("rec_buffer"), + ] +) + +def buffer(request): + return request.param + +### truncated and terminated instead of done??? +@pytest.fixture() +def full_buffer(buffer): + for i in range(CAPACITY + 20): + buffer.add( + observation=np.ones(OBS_SHAPE) * i, + action=i, + reward=i % 10, + done=((i + 1) % 15) == 0, + priority=(i % 10) + 1, + ) + return buffer + + + +@pytest.fixture() +def full_n_step_buffer(): + n_step_buffer = RecurrentReplayBuffer( + capacity=CAPACITY, + observation_shape=OBS_SHAPE, + observation_dtype=np.float32, + n_step=N_STEP_HORIZON, + gamma=GAMMA, + ) + for i in range(CAPACITY + 20): + n_step_buffer.add( + observation=np.ones(OBS_SHAPE) * i, + action=i, + reward=i % 10, + done=((i + 1) % 15) == 0, + ) + return n_step_buffer + + +@pytest.mark.parametrize("constructor", [RecurrentReplayBuffer]) +@pytest.mark.parametrize("observation_shape", [(), (2,), (3, 4)]) +@pytest.mark.parametrize("observation_dtype", [np.uint8, np.float32]) +@pytest.mark.parametrize("action_shape", [(), (5,)]) +@pytest.mark.parametrize("action_dtype", [np.int8, np.float32]) +@pytest.mark.parametrize("reward_shape", [(), (6,)]) +@pytest.mark.parametrize("reward_dtype", [np.int8, np.float32]) +@pytest.mark.parametrize("extra_storage_types", [None, {"foo": (np.float32, (7,))}]) + + +def test_constructor( + constructor, + observation_shape, + observation_dtype, + action_shape, + action_dtype, + reward_shape, + reward_dtype, + extra_storage_types, +): + buffer = constructor( + capacity=10, + max_seq_len=MAX_SEQ_LEN, + observation_shape=observation_shape, + observation_dtype=observation_dtype, + action_shape=action_shape, + action_dtype=action_dtype, + reward_shape=reward_shape, + reward_dtype=reward_dtype, + extra_storage_types=extra_storage_types, + ) + assert buffer.size() == 0 + assert buffer._max_seq_len == MAX_SEQ_LEN + assert buffer._storage["observation"].shape == (10,) + observation_shape + assert buffer._storage["observation"].dtype == observation_dtype + assert buffer._storage["action"].shape == (10,) + action_shape + assert buffer._storage["action"].dtype == action_dtype + assert buffer._storage["reward"].shape == (10,) + reward_shape + assert buffer._storage["reward"].dtype == reward_dtype + if extra_storage_types is not None: + for key in extra_storage_types: + assert buffer._storage[key].shape == (10,) + extra_storage_types[key][1] + assert buffer._storage[key].dtype == extra_storage_types[key][0] + + +def test_add(buffer): + assert buffer.size() == 0 + done_time = 15 + for i in range(33): #until the buffer is full instead of CAPACITY + buffer.add( + observation=np.ones(OBS_SHAPE) * i, + action=i, + reward=i % 10, + done=((i + 1) % done_time) == 0, + priority=(i % 10) + 1, + ) + + assert buffer.size() == i + ((i) // done_time) * (MAX_SEQ_LEN - 1) + assert buffer._cursor == ((i + MAX_SEQ_LEN) + (((i) // done_time) * (MAX_SEQ_LEN - 1))) % CAPACITY + + + more_steps = 40 + for i in range(more_steps): + buffer.add( + observation=np.ones(OBS_SHAPE) * i, + action=i, + reward=i % 10, + done=((i + 1) % done_time) == 0, + priority=(i % 10) + 1, + ) + assert buffer.size() == CAPACITY - MAX_SEQ_LEN + assert buffer._cursor == (((i + 1) + (((i) // done_time) * (MAX_SEQ_LEN - 1)) ) % CAPACITY) + + assert buffer._num_added == (more_steps + (((i) // done_time) * (MAX_SEQ_LEN - 1))) + CAPACITY + + + + + From 04a2336c945cad6f09c9298e89d1bc37767c6175 Mon Sep 17 00:00:00 2001 From: Maryam Date: Wed, 25 Jan 2023 17:19:45 -0500 Subject: [PATCH 2/4] Reformat test_recurrent_buffer --- tests/hive/replays/test_recurrent_buffer.py | 32 ++++++++++----------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/hive/replays/test_recurrent_buffer.py b/tests/hive/replays/test_recurrent_buffer.py index b938f23a..40f4ea39 100644 --- a/tests/hive/replays/test_recurrent_buffer.py +++ b/tests/hive/replays/test_recurrent_buffer.py @@ -1,4 +1,3 @@ - import numpy as np import pytest from pytest_lazyfixture import lazy_fixture @@ -14,25 +13,27 @@ N_STEP_HORIZON = 1 GAMMA = 0.99 + @pytest.fixture() def rec_buffer(): return RecurrentReplayBuffer( capacity=CAPACITY, - max_seq_len = MAX_SEQ_LEN, + max_seq_len=MAX_SEQ_LEN, observation_shape=OBS_SHAPE, observation_dtype=np.float32, extra_storage_types={"priority": (np.int8, ())}, ) + @pytest.fixture( params=[ pytest.lazy_fixture("rec_buffer"), ] ) - def buffer(request): return request.param + ### truncated and terminated instead of done??? @pytest.fixture() def full_buffer(buffer): @@ -47,7 +48,6 @@ def full_buffer(buffer): return buffer - @pytest.fixture() def full_n_step_buffer(): n_step_buffer = RecurrentReplayBuffer( @@ -75,8 +75,6 @@ def full_n_step_buffer(): @pytest.mark.parametrize("reward_shape", [(), (6,)]) @pytest.mark.parametrize("reward_dtype", [np.int8, np.float32]) @pytest.mark.parametrize("extra_storage_types", [None, {"foo": (np.float32, (7,))}]) - - def test_constructor( constructor, observation_shape, @@ -115,7 +113,7 @@ def test_constructor( def test_add(buffer): assert buffer.size() == 0 done_time = 15 - for i in range(33): #until the buffer is full instead of CAPACITY + for i in range(33): # until the buffer is full instead of CAPACITY buffer.add( observation=np.ones(OBS_SHAPE) * i, action=i, @@ -125,8 +123,10 @@ def test_add(buffer): ) assert buffer.size() == i + ((i) // done_time) * (MAX_SEQ_LEN - 1) - assert buffer._cursor == ((i + MAX_SEQ_LEN) + (((i) // done_time) * (MAX_SEQ_LEN - 1))) % CAPACITY - + assert ( + buffer._cursor + == ((i + MAX_SEQ_LEN) + (((i) // done_time) * (MAX_SEQ_LEN - 1))) % CAPACITY + ) more_steps = 40 for i in range(more_steps): @@ -138,11 +138,11 @@ def test_add(buffer): priority=(i % 10) + 1, ) assert buffer.size() == CAPACITY - MAX_SEQ_LEN - assert buffer._cursor == (((i + 1) + (((i) // done_time) * (MAX_SEQ_LEN - 1)) ) % CAPACITY) - - assert buffer._num_added == (more_steps + (((i) // done_time) * (MAX_SEQ_LEN - 1))) + CAPACITY - - - - + assert buffer._cursor == ( + ((i + 1) + (((i) // done_time) * (MAX_SEQ_LEN - 1))) % CAPACITY + ) + assert ( + buffer._num_added + == (more_steps + (((i) // done_time) * (MAX_SEQ_LEN - 1))) + CAPACITY + ) From 522ff4d8bed9c911d2259c0fcc034dffd999b323 Mon Sep 17 00:00:00 2001 From: mrsamsami Date: Tue, 7 Feb 2023 18:08:02 -0500 Subject: [PATCH 3/4] Add sample tests --- tests/hive/replays/test_recurrent_buffer.py | 86 +++++++++++++++++---- 1 file changed, 73 insertions(+), 13 deletions(-) diff --git a/tests/hive/replays/test_recurrent_buffer.py b/tests/hive/replays/test_recurrent_buffer.py index 40f4ea39..739c667b 100644 --- a/tests/hive/replays/test_recurrent_buffer.py +++ b/tests/hive/replays/test_recurrent_buffer.py @@ -4,14 +4,12 @@ from hive.replays.recurrent_replay import RecurrentReplayBuffer -### size, seq_length, -### add --> check shape, -OBS_SHAPE = (4, 4) +OBS_SHAPE = (2, 2) CAPACITY = 60 MAX_SEQ_LEN = 10 -N_STEP_HORIZON = 1 -GAMMA = 0.99 +N_STEP_HORIZON = 2 +GAMMA = 1 @pytest.fixture() @@ -37,12 +35,13 @@ def buffer(request): ### truncated and terminated instead of done??? @pytest.fixture() def full_buffer(buffer): - for i in range(CAPACITY + 20): + done_time = 15 + for i in range(33): # until the buffer is full instead of CAPACITY buffer.add( observation=np.ones(OBS_SHAPE) * i, - action=i, - reward=i % 10, - done=((i + 1) % 15) == 0, + action=(i % done_time) + 1, + reward=(i % done_time) + 1, + done=((i + 1) % done_time) == 0, priority=(i % 10) + 1, ) return buffer @@ -52,17 +51,20 @@ def full_buffer(buffer): def full_n_step_buffer(): n_step_buffer = RecurrentReplayBuffer( capacity=CAPACITY, + max_seq_len=MAX_SEQ_LEN, observation_shape=OBS_SHAPE, observation_dtype=np.float32, n_step=N_STEP_HORIZON, gamma=GAMMA, ) - for i in range(CAPACITY + 20): + done_time = 15 + for i in range(33): # until the buffer is full instead of CAPACITY n_step_buffer.add( observation=np.ones(OBS_SHAPE) * i, - action=i, - reward=i % 10, - done=((i + 1) % 15) == 0, + action=(i % done_time) + 1, + reward=(i % done_time) + 1, + done=((i + 1) % done_time) == 0, + priority=(i % 10) + 1, ) return n_step_buffer @@ -112,6 +114,7 @@ def test_constructor( def test_add(buffer): assert buffer.size() == 0 + ## add to the buffer until the buffer is full done_time = 15 for i in range(33): # until the buffer is full instead of CAPACITY buffer.add( @@ -128,6 +131,7 @@ def test_add(buffer): == ((i + MAX_SEQ_LEN) + (((i) // done_time) * (MAX_SEQ_LEN - 1))) % CAPACITY ) + ## when the buffer is full more_steps = 40 for i in range(more_steps): buffer.add( @@ -146,3 +150,59 @@ def test_add(buffer): buffer._num_added == (more_steps + (((i) // done_time) * (MAX_SEQ_LEN - 1))) + CAPACITY ) + + +def test_sample_shape(full_buffer): + # sample transitions from buffer + batch_size = CAPACITY - 1 + batch = full_buffer.sample(batch_size) + # check if the shape of batch is correct + assert batch["indices"].shape == (batch_size,) + assert batch["observation"].shape == (batch_size, 10) + OBS_SHAPE + assert batch["action"].shape == (batch_size, 10) + assert batch["done"].shape == (batch_size, 10) + assert batch["reward"].shape == (batch_size, 10) + assert batch["trajectory_lengths"].shape == (batch_size,) + assert batch["next_observation"].shape == (batch_size, 10) + OBS_SHAPE + + +def test_sample(full_buffer): + # sample transitions from buffer + batch_size = 50 + batch = full_buffer.sample(batch_size) + for b in range(batch_size): + t = 0 + + while batch["action"][b, t] == 0: + t += 1 + + while t < MAX_SEQ_LEN and batch["action"][b, t] > 0: + if t > 0: + assert batch["action"][b, t] - batch["action"][b, t - 1] == 1 + assert batch["reward"][b, t] - batch["reward"][b, t - 1] == 1 + t += 1 + + +def test_sample_n_step(full_n_step_buffer): + # sample transitions from buffer + batch_size = 50 + batch = full_n_step_buffer.sample(batch_size) + for b in range(batch_size): + t = 0 + + while batch["action"][b, t] == 0: + t += 1 + + while t < MAX_SEQ_LEN - 1 and batch["action"][b, t] > 0: + if t > 0: + assert batch["action"][b, t] - batch["action"][b, t - 1] == 1 + if t == full_n_step_buffer.size() - 1 or batch["reward"][b, t + 1] == 0: + assert ( + batch["reward"][b, t] - batch["reward"][b, t - 1] + == 1 - batch["reward"][b, t] * GAMMA + ) + else: + assert ( + batch["reward"][b, t] - batch["reward"][b, t - 1] == 1 + GAMMA + ) + t += 1 From 18e378024400bfe110869f3b1765bec64b52bb86 Mon Sep 17 00:00:00 2001 From: Maryam Date: Wed, 22 Feb 2023 13:03:27 -0500 Subject: [PATCH 4/4] Reformat with black ==23.1.0 --- hive/agents/rainbow.py | 1 - hive/agents/td3.py | 2 +- hive/envs/marlgrid/ma_envs/base.py | 1 - hive/envs/marlgrid/ma_envs/checkers.py | 1 - hive/envs/marlgrid/ma_envs/pursuit.py | 1 - hive/replays/circular_replay.py | 1 - 6 files changed, 1 insertion(+), 6 deletions(-) diff --git a/hive/agents/rainbow.py b/hive/agents/rainbow.py index 2e795090..a6339dc0 100644 --- a/hive/agents/rainbow.py +++ b/hive/agents/rainbow.py @@ -221,7 +221,6 @@ def create_q_networks(self, representation_net): @torch.no_grad() def act(self, observation, agent_traj_state=None): - if self._training: if not self._learn_schedule.get_value(): epsilon = 1.0 diff --git a/hive/agents/td3.py b/hive/agents/td3.py index dad56ae4..4fecd247 100644 --- a/hive/agents/td3.py +++ b/hive/agents/td3.py @@ -399,7 +399,7 @@ def update(self, update_info, agent_traj_state=None): def _update_target(self): """Update the target network.""" - for (network, target_network) in [ + for network, target_network in [ (self._actor, self._target_actor), (self._critic, self._target_critic), ]: diff --git a/hive/envs/marlgrid/ma_envs/base.py b/hive/envs/marlgrid/ma_envs/base.py index 554297c3..2f1833bd 100644 --- a/hive/envs/marlgrid/ma_envs/base.py +++ b/hive/envs/marlgrid/ma_envs/base.py @@ -135,7 +135,6 @@ def render( ) if show_agent_views: - target_partial_width = int( img.shape[0] * agent_col_width_frac - 2 * agent_col_padding_px ) diff --git a/hive/envs/marlgrid/ma_envs/checkers.py b/hive/envs/marlgrid/ma_envs/checkers.py index f7eb44e3..62065565 100644 --- a/hive/envs/marlgrid/ma_envs/checkers.py +++ b/hive/envs/marlgrid/ma_envs/checkers.py @@ -85,7 +85,6 @@ def step(self, actions): agent.step_reward = 0 if agent.active: - cur_pos = agent.pos[:] cur_cell = self.grid.get(*cur_pos) fwd_pos = agent.front_pos[:] diff --git a/hive/envs/marlgrid/ma_envs/pursuit.py b/hive/envs/marlgrid/ma_envs/pursuit.py index 6d964b26..8137b994 100644 --- a/hive/envs/marlgrid/ma_envs/pursuit.py +++ b/hive/envs/marlgrid/ma_envs/pursuit.py @@ -54,7 +54,6 @@ def step(self, actions): agent.step_reward = 0 if agent.active: - cur_pos = agent.pos[:] cur_cell = self.grid.get(*cur_pos) fwd_pos = agent.front_pos[:] diff --git a/hive/replays/circular_replay.py b/hive/replays/circular_replay.py index 4c2a7e01..26bd38fb 100644 --- a/hive/replays/circular_replay.py +++ b/hive/replays/circular_replay.py @@ -334,7 +334,6 @@ class SimpleReplayBuffer(BaseReplayBuffer): """ def __init__(self, capacity=1e5, compress=False, seed=42, **kwargs): - self._numpy_rng = np.random.default_rng(seed) self._capacity = int(capacity) self._compress = compress