Skip to content

Commit 7f63d17

Browse files
committed
[Doc] Better doc on multi-head entropy
1 parent d34dbb2 commit 7f63d17

File tree

1 file changed

+80
-5
lines changed

1 file changed

+80
-5
lines changed

torchrl/objectives/ppo.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ class PPOLoss(LossModule):
104104
* **Scalar**: one value applied to the summed entropy of every action head.
105105
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
106106
Defaults to ``0.01``.
107+
108+
See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
107109
log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
108110
predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
109111
This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
@@ -217,7 +219,7 @@ class PPOLoss(LossModule):
217219
>>> action = spec.rand(batch)
218220
>>> data = TensorDict({"observation": torch.randn(*batch, n_obs),
219221
... "action": action,
220-
... "sample_log_prob": torch.randn_like(action[..., 1]),
222+
... "action_log_prob": torch.randn_like(action[..., 1]),
221223
... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
222224
... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
223225
... ("next", "reward"): torch.randn(*batch, 1),
@@ -227,6 +229,8 @@ class PPOLoss(LossModule):
227229
TensorDict(
228230
fields={
229231
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
232+
explained_variance: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
233+
kl_approx: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
230234
loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
231235
loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
232236
loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
@@ -279,12 +283,69 @@ class PPOLoss(LossModule):
279283
... next_observation=torch.randn(*batch, n_obs))
280284
>>> loss_objective.backward()
281285
286+
**Simple Entropy Coefficient Examples**:
287+
>>> # Scalar entropy coefficient (default behavior)
288+
>>> loss = PPOLoss(actor, critic, entropy_coeff=0.01)
289+
>>>
290+
>>> # Per-head entropy coefficients (for composite action spaces)
291+
>>> entropy_coeff = {
292+
... ("agent0", "action_log_prob"): 0.01, # Low exploration
293+
... ("agent1", "action_log_prob"): 0.05, # High exploration
294+
... }
295+
>>> loss = PPOLoss(actor, critic, entropy_coeff=entropy_coeff)
296+
282297
.. note::
283298
There is an exception regarding compatibility with non-tensordict-based modules.
284299
If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`,
285300
this class must be used with tensordicts and cannot function as a tensordict-independent module.
286301
This is because composite action spaces inherently rely on the structured representation of data provided by
287302
tensordicts to handle their actions.
303+
304+
.. _ppo_entropy_coefficients:
305+
306+
.. note::
307+
**Entropy Bonus and Coefficient Management**
308+
309+
The entropy bonus encourages exploration by adding the negative entropy of the policy to the loss.
310+
This can be configured in two ways:
311+
312+
**Scalar Coefficient (Default)**: Use a single coefficient for all action heads:
313+
>>> loss = PPOLoss(actor, critic, entropy_coeff=0.01)
314+
315+
**Per-Head Coefficients**: Use different coefficients for different action components:
316+
>>> # For a robot with movement and gripper actions
317+
>>> entropy_coeff = {
318+
... ("agent0", "action_log_prob"): 0.01, # Movement: low exploration
319+
... ("agent1", "action_log_prob"): 0.05, # Gripper: high exploration
320+
... }
321+
>>> loss = PPOLoss(actor, critic, entropy_coeff=entropy_coeff)
322+
323+
**Key Requirements**: When using per-head coefficients, you must provide the full nested key
324+
path to each action head's log probability (e.g., `("agent0", "action_log_prob")`).
325+
326+
**Monitoring Entropy Loss**:
327+
328+
When using composite action spaces, the loss output includes:
329+
- `"entropy"`: Summed entropy across all action heads (for logging)
330+
- `"composite_entropy"`: Individual entropy values for each action head
331+
- `"loss_entropy"`: The weighted entropy loss term
332+
333+
Example output:
334+
>>> result = loss(data)
335+
>>> print(result["entropy"]) # Total entropy: 2.34
336+
>>> print(result["composite_entropy"]) # Per-head: {"movement": 1.2, "gripper": 1.14}
337+
>>> print(result["loss_entropy"]) # Weighted loss: -0.0234
338+
339+
**Common Issues**:
340+
341+
**KeyError: "Missing entropy coeff for head 'head_name'"**:
342+
- Ensure you provide coefficients for ALL action heads
343+
- Use full nested keys: `("head_name", "action_log_prob")`
344+
- Check that your action space structure matches the coefficient mapping
345+
346+
**Incorrect Entropy Calculation**:
347+
- Call `set_composite_lp_aggregate(False).set()` before creating your policy
348+
- Verify that your action space uses :class:`~tensordict.nn.distributions.CompositeDistribution`
288349
"""
289350

290351
@dataclass
@@ -911,27 +972,37 @@ def _weighted_loss_entropy(
911972
Otherwise, use the scalar `self.entropy_coeff`.
912973
The entries in self._entropy_coeff_map require the full nested key to the entropy head.
913974
"""
975+
# Mode 1: Use scalar entropy coefficient (default behavior)
914976
if self._entropy_coeff_map is None:
977+
# If entropy is a TensorDict (composite action space), sum all entropy values
915978
if is_tensor_collection(entropy):
916979
entropy = _sum_td_features(entropy)
980+
# Apply scalar coefficient: loss = -coeff * entropy (negative for maximization)
917981
return -self.entropy_coeff * entropy
918982

919-
loss_term = None # running sum over heads
920-
coeff = 0
983+
# Mode 2: Use per-head entropy coefficients (for composite action spaces)
984+
loss_term = None # Initialize running sum over action heads
985+
coeff = 0 # Placeholder for coefficient value
986+
# Iterate through all entropy heads in the composite action space
921987
for head_name, entropy_head in entropy.items(
922988
include_nested=True, leaves_only=True
923989
):
924990
try:
991+
# Look up the coefficient for this specific action head
925992
coeff = self._entropy_coeff_map[head_name]
926993
except KeyError as exc:
994+
# Provide clear error message if coefficient mapping is incomplete
927995
raise KeyError(f"Missing entropy coeff for head '{head_name}'") from exc
996+
# Convert coefficient to tensor with matching dtype and device
928997
coeff_t = torch.as_tensor(
929998
coeff, dtype=entropy_head.dtype, device=entropy_head.device
930999
)
1000+
# Compute weighted loss for this head: -coeff * entropy
9311001
head_loss_term = -coeff_t * entropy_head
1002+
# Accumulate loss terms across all heads
9321003
loss_term = (
9331004
head_loss_term if loss_term is None else loss_term + head_loss_term
934-
) # accumulate
1005+
)
9351006

9361007
return loss_term
9371008

@@ -972,10 +1043,12 @@ class ClipPPOLoss(PPOLoss):
9721043
``samples_mc_entropy`` will control how many
9731044
samples will be used to compute this estimate.
9741045
Defaults to ``1``.
975-
entropy_coeff: (scalar | Mapping[NesstedKey, scalar], optional): entropy multiplier when computing the total loss.
1046+
entropy_coeff: (scalar | Mapping[NestedKey, scalar], optional): entropy multiplier when computing the total loss.
9761047
* **Scalar**: one value applied to the summed entropy of every action head.
9771048
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
9781049
Defaults to ``0.01``.
1050+
1051+
See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
9791052
critic_coeff (scalar, optional): critic loss multiplier when computing the total
9801053
loss. Defaults to ``1.0``. Set ``critic_coeff`` to ``None`` to exclude the value
9811054
loss from the forward outputs.
@@ -1269,6 +1342,8 @@ class KLPENPPOLoss(PPOLoss):
12691342
* **Scalar**: one value applied to the summed entropy of every action head.
12701343
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
12711344
Defaults to ``0.01``.
1345+
1346+
See :ref:`ppo_entropy_coefficients` for detailed usage examples and troubleshooting.
12721347
critic_coeff (scalar, optional): critic loss multiplier when computing the total
12731348
loss. Defaults to ``1.0``.
12741349
loss_critic_type (str, optional): loss function for the value discrepancy.

0 commit comments

Comments
 (0)