@@ -104,6 +104,8 @@ class PPOLoss(LossModule):
104
104
* **Scalar**: one value applied to the summed entropy of every action head.
105
105
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
106
106
Defaults to ``0.01``.
107
+
108
+ See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
107
109
log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
108
110
predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
109
111
This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
@@ -217,7 +219,7 @@ class PPOLoss(LossModule):
217
219
>>> action = spec.rand(batch)
218
220
>>> data = TensorDict({"observation": torch.randn(*batch, n_obs),
219
221
... "action": action,
220
- ... "sample_log_prob ": torch.randn_like(action[..., 1]),
222
+ ... "action_log_prob ": torch.randn_like(action[..., 1]),
221
223
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
222
224
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
223
225
... ("next", "reward"): torch.randn(*batch, 1),
@@ -227,6 +229,8 @@ class PPOLoss(LossModule):
227
229
TensorDict(
228
230
fields={
229
231
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
232
+ explained_variance: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
233
+ kl_approx: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
230
234
loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
231
235
loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
232
236
loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
@@ -279,12 +283,69 @@ class PPOLoss(LossModule):
279
283
... next_observation=torch.randn(*batch, n_obs))
280
284
>>> loss_objective.backward()
281
285
286
+ **Simple Entropy Coefficient Examples**:
287
+ >>> # Scalar entropy coefficient (default behavior)
288
+ >>> loss = PPOLoss(actor, critic, entropy_coeff=0.01)
289
+ >>>
290
+ >>> # Per-head entropy coefficients (for composite action spaces)
291
+ >>> entropy_coeff = {
292
+ ... ("agent0", "action_log_prob"): 0.01, # Low exploration
293
+ ... ("agent1", "action_log_prob"): 0.05, # High exploration
294
+ ... }
295
+ >>> loss = PPOLoss(actor, critic, entropy_coeff=entropy_coeff)
296
+
282
297
.. note::
283
298
There is an exception regarding compatibility with non-tensordict-based modules.
284
299
If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`,
285
300
this class must be used with tensordicts and cannot function as a tensordict-independent module.
286
301
This is because composite action spaces inherently rely on the structured representation of data provided by
287
302
tensordicts to handle their actions.
303
+
304
+ .. _ppo_entropy_coefficients:
305
+
306
+ .. note::
307
+ **Entropy Bonus and Coefficient Management**
308
+
309
+ The entropy bonus encourages exploration by adding the negative entropy of the policy to the loss.
310
+ This can be configured in two ways:
311
+
312
+ **Scalar Coefficient (Default)**: Use a single coefficient for all action heads:
313
+ >>> loss = PPOLoss(actor, critic, entropy_coeff=0.01)
314
+
315
+ **Per-Head Coefficients**: Use different coefficients for different action components:
316
+ >>> # For a robot with movement and gripper actions
317
+ >>> entropy_coeff = {
318
+ ... ("agent0", "action_log_prob"): 0.01, # Movement: low exploration
319
+ ... ("agent1", "action_log_prob"): 0.05, # Gripper: high exploration
320
+ ... }
321
+ >>> loss = PPOLoss(actor, critic, entropy_coeff=entropy_coeff)
322
+
323
+ **Key Requirements**: When using per-head coefficients, you must provide the full nested key
324
+ path to each action head's log probability (e.g., `("agent0", "action_log_prob")`).
325
+
326
+ **Monitoring Entropy Loss**:
327
+
328
+ When using composite action spaces, the loss output includes:
329
+ - `"entropy"`: Summed entropy across all action heads (for logging)
330
+ - `"composite_entropy"`: Individual entropy values for each action head
331
+ - `"loss_entropy"`: The weighted entropy loss term
332
+
333
+ Example output:
334
+ >>> result = loss(data)
335
+ >>> print(result["entropy"]) # Total entropy: 2.34
336
+ >>> print(result["composite_entropy"]) # Per-head: {"movement": 1.2, "gripper": 1.14}
337
+ >>> print(result["loss_entropy"]) # Weighted loss: -0.0234
338
+
339
+ **Common Issues**:
340
+
341
+ **KeyError: "Missing entropy coeff for head 'head_name'"**:
342
+ - Ensure you provide coefficients for ALL action heads
343
+ - Use full nested keys: `("head_name", "action_log_prob")`
344
+ - Check that your action space structure matches the coefficient mapping
345
+
346
+ **Incorrect Entropy Calculation**:
347
+ - Call `set_composite_lp_aggregate(False).set()` before creating your policy
348
+ - Verify that your action space uses :class:`~tensordict.nn.distributions.CompositeDistribution`
288
349
"""
289
350
290
351
@dataclass
@@ -911,27 +972,37 @@ def _weighted_loss_entropy(
911
972
Otherwise, use the scalar `self.entropy_coeff`.
912
973
The entries in self._entropy_coeff_map require the full nested key to the entropy head.
913
974
"""
975
+ # Mode 1: Use scalar entropy coefficient (default behavior)
914
976
if self ._entropy_coeff_map is None :
977
+ # If entropy is a TensorDict (composite action space), sum all entropy values
915
978
if is_tensor_collection (entropy ):
916
979
entropy = _sum_td_features (entropy )
980
+ # Apply scalar coefficient: loss = -coeff * entropy (negative for maximization)
917
981
return - self .entropy_coeff * entropy
918
982
919
- loss_term = None # running sum over heads
920
- coeff = 0
983
+ # Mode 2: Use per-head entropy coefficients (for composite action spaces)
984
+ loss_term = None # Initialize running sum over action heads
985
+ coeff = 0 # Placeholder for coefficient value
986
+ # Iterate through all entropy heads in the composite action space
921
987
for head_name , entropy_head in entropy .items (
922
988
include_nested = True , leaves_only = True
923
989
):
924
990
try :
991
+ # Look up the coefficient for this specific action head
925
992
coeff = self ._entropy_coeff_map [head_name ]
926
993
except KeyError as exc :
994
+ # Provide clear error message if coefficient mapping is incomplete
927
995
raise KeyError (f"Missing entropy coeff for head '{ head_name } '" ) from exc
996
+ # Convert coefficient to tensor with matching dtype and device
928
997
coeff_t = torch .as_tensor (
929
998
coeff , dtype = entropy_head .dtype , device = entropy_head .device
930
999
)
1000
+ # Compute weighted loss for this head: -coeff * entropy
931
1001
head_loss_term = - coeff_t * entropy_head
1002
+ # Accumulate loss terms across all heads
932
1003
loss_term = (
933
1004
head_loss_term if loss_term is None else loss_term + head_loss_term
934
- ) # accumulate
1005
+ )
935
1006
936
1007
return loss_term
937
1008
@@ -972,10 +1043,12 @@ class ClipPPOLoss(PPOLoss):
972
1043
``samples_mc_entropy`` will control how many
973
1044
samples will be used to compute this estimate.
974
1045
Defaults to ``1``.
975
- entropy_coeff: (scalar | Mapping[NesstedKey , scalar], optional): entropy multiplier when computing the total loss.
1046
+ entropy_coeff: (scalar | Mapping[NestedKey , scalar], optional): entropy multiplier when computing the total loss.
976
1047
* **Scalar**: one value applied to the summed entropy of every action head.
977
1048
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
978
1049
Defaults to ``0.01``.
1050
+
1051
+ See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
979
1052
critic_coeff (scalar, optional): critic loss multiplier when computing the total
980
1053
loss. Defaults to ``1.0``. Set ``critic_coeff`` to ``None`` to exclude the value
981
1054
loss from the forward outputs.
@@ -1269,6 +1342,8 @@ class KLPENPPOLoss(PPOLoss):
1269
1342
* **Scalar**: one value applied to the summed entropy of every action head.
1270
1343
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
1271
1344
Defaults to ``0.01``.
1345
+
1346
+ See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
1272
1347
critic_coeff (scalar, optional): critic loss multiplier when computing the total
1273
1348
loss. Defaults to ``1.0``.
1274
1349
loss_critic_type (str, optional): loss function for the value discrepancy.
0 commit comments