From 92e1abc6d81264e4c983db2652eca7a32bb32896 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 14 Jul 2025 13:48:24 -0700 Subject: [PATCH 1/4] Fix DQN w RNN tutorial --- intermediate_source/dqn_with_rnn_tutorial.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/intermediate_source/dqn_with_rnn_tutorial.py b/intermediate_source/dqn_with_rnn_tutorial.py index bcc484f0a00..f28ad9f6903 100644 --- a/intermediate_source/dqn_with_rnn_tutorial.py +++ b/intermediate_source/dqn_with_rnn_tutorial.py @@ -342,7 +342,10 @@ # will return a new instance of the LSTM (with shared weights) that will # assume that the input data is sequential in nature. # -policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval) +from torchrl.modules import set_recurrent_mode + +with set_recurrent_mode(True): + policy = Seq(feature, lstm, mlp, qval) ###################################################################### # Because we still have a couple of uninitialized parameters we should @@ -389,7 +392,9 @@ # For the sake of efficiency, we're only running a few thousands iterations # here. In a real setting, the total number of frames should be set to 1M. # -collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device) +collector = SyncDataCollector( + env, stoch_policy, frames_per_batch=50, total_frames=200, device=device +) rb = TensorDictReplayBuffer( storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10 ) @@ -464,5 +469,5 @@ # # Further Reading # --------------- -# +# # - The TorchRL documentation can be found `here `_. From 5e85c0a72b0c6b09fc876fd1c73ea55f09178391 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 14 Jul 2025 16:02:40 -0700 Subject: [PATCH 2/4] Update intermediate_source/dqn_with_rnn_tutorial.py --- intermediate_source/dqn_with_rnn_tutorial.py | 1 + 1 file changed, 1 insertion(+) diff --git a/intermediate_source/dqn_with_rnn_tutorial.py b/intermediate_source/dqn_with_rnn_tutorial.py index f28ad9f6903..9b41dbfecaf 100644 --- a/intermediate_source/dqn_with_rnn_tutorial.py +++ b/intermediate_source/dqn_with_rnn_tutorial.py @@ -392,6 +392,7 @@ # For the sake of efficiency, we're only running a few thousands iterations # here. In a real setting, the total number of frames should be set to 1M. # + collector = SyncDataCollector( env, stoch_policy, frames_per_batch=50, total_frames=200, device=device ) From f4d1b1cc7b406a4bc43eba443da779ab2342415f Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 17 Jul 2025 14:14:21 -0700 Subject: [PATCH 3/4] Update --- intermediate_source/dqn_with_rnn_tutorial.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/intermediate_source/dqn_with_rnn_tutorial.py b/intermediate_source/dqn_with_rnn_tutorial.py index 9b41dbfecaf..462415dcc74 100644 --- a/intermediate_source/dqn_with_rnn_tutorial.py +++ b/intermediate_source/dqn_with_rnn_tutorial.py @@ -344,8 +344,7 @@ # from torchrl.modules import set_recurrent_mode -with set_recurrent_mode(True): - policy = Seq(feature, lstm, mlp, qval) +policy = Seq(feature, lstm, mlp, qval) ###################################################################### # Because we still have a couple of uninitialized parameters we should @@ -428,7 +427,8 @@ rb.extend(data.unsqueeze(0).to_tensordict().cpu()) for _ in range(utd): s = rb.sample().to(device, non_blocking=True) - loss_vals = loss_fn(s) + with set_recurrent_mode(True): + loss_vals = loss_fn(s) loss_vals["loss"].backward() optim.step() optim.zero_grad() From 0ce4e162c4f39f53777cf4e29c4199f1c66968e1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 22 Jul 2025 18:28:00 +0100 Subject: [PATCH 4/4] bump torchrl and tensordict req (#3474) --- .ci/docker/requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index a25c4494b64..bd95726bee4 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -29,8 +29,8 @@ tensorboard jinja2==3.1.3 pytorch-lightning torchx -torchrl==0.7.2 -tensordict==0.7.2 +torchrl==0.9.2 +tensordict==0.9.1 # For ax_multiobjective_nas_tutorial.py ax-platform>=0.4.0,<0.5.0 nbformat>=5.9.2