diff --git a/advanced_source/pendulum.py b/advanced_source/pendulum.py index 3084fe8312..03aae4c3ec 100644 --- a/advanced_source/pendulum.py +++ b/advanced_source/pendulum.py @@ -100,7 +100,7 @@ from tensordict.nn import TensorDictModule from torch import nn -from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec +from torchrl.data import Bounded, Composite, Unbounded from torchrl.envs import ( CatTensors, EnvBase, @@ -403,14 +403,14 @@ def _reset(self, tensordict): def _make_spec(self, td_params): # Under the hood, this will populate self.output_spec["observation"] - self.observation_spec = CompositeSpec( - th=BoundedTensorSpec( + self.observation_spec = Composite( + th=Bounded( low=-torch.pi, high=torch.pi, shape=(), dtype=torch.float32, ), - thdot=BoundedTensorSpec( + thdot=Bounded( low=-td_params["params", "max_speed"], high=td_params["params", "max_speed"], shape=(), @@ -426,24 +426,26 @@ def _make_spec(self, td_params): self.state_spec = self.observation_spec.clone() # action-spec will be automatically wrapped in input_spec when # `self.action_spec = spec` will be called supported - self.action_spec = BoundedTensorSpec( + self.action_spec = Bounded( low=-td_params["params", "max_torque"], high=td_params["params", "max_torque"], shape=(1,), dtype=torch.float32, ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1)) + self.reward_spec = Unbounded(shape=(*td_params.shape, 1)) def make_composite_from_td(td): # custom function to convert a ``tensordict`` in a similar spec structure # of unbounded values. - composite = CompositeSpec( + composite = Composite( { - key: make_composite_from_td(tensor) - if isinstance(tensor, TensorDictBase) - else UnboundedContinuousTensorSpec( - dtype=tensor.dtype, device=tensor.device, shape=tensor.shape + key: ( + make_composite_from_td(tensor) + if isinstance(tensor, TensorDictBase) + else Unbounded( + dtype=tensor.dtype, device=tensor.device, shape=tensor.shape + ) ) for key, tensor in td.items() }, @@ -687,7 +689,7 @@ def _reset( # is of type ``Composite`` @_apply_to_composite def transform_observation_spec(self, observation_spec): - return BoundedTensorSpec( + return Bounded( low=-1, high=1, shape=observation_spec.shape, @@ -711,7 +713,7 @@ def _reset( # is of type ``Composite`` @_apply_to_composite def transform_observation_spec(self, observation_spec): - return BoundedTensorSpec( + return Bounded( low=-1, high=1, shape=observation_spec.shape, diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index ece80d3f94..f2c9a7d712 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -168,9 +168,90 @@ def compute_loss(params, buffers, sample, target): # we can double check that the results using ``grad`` and ``vmap`` match the # results of hand processing each one individually: -for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()): +# Create a float64 baseline for more precise comparison +def compute_grad_fp64(sample, target): + # Convert to float64 for higher precision + sample_fp64 = sample.to(torch.float64) + target_fp64 = target + + # Create a float64 version of the model and explicitly convert it to float64 + model_fp64 = SimpleCNN().to(device=device).to(torch.float64) + + # No need to manually copy parameters as the model is already in float64 + + sample_fp64 = sample_fp64.unsqueeze(0) # prepend batch dimension + target_fp64 = target_fp64.unsqueeze(0) + + prediction = model_fp64(sample_fp64) + loss = loss_fn(prediction, target_fp64) + + return torch.autograd.grad(loss, list(model_fp64.parameters())) + + +def compute_fp64_baseline(data, targets, indices): + """Compute float64 gradient for a specific sample""" + # Only compute for the sample with the largest difference to save computation + i = indices[0] # Sample index + sample_grad = compute_grad_fp64(data[i], targets[i]) + return sample_grad + + +for i, (per_sample_grad, ft_per_sample_grad) in enumerate( + zip(per_sample_grads, ft_per_sample_grads.values()) +): + is_close = torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5) + if not is_close: + # Calculate and print the maximum absolute difference + abs_diff = (per_sample_grad - ft_per_sample_grad).abs() + max_diff = abs_diff.max().item() + mean_diff = abs_diff.mean().item() + print(f"Gradient {i} mismatch:") + print(f" Max absolute difference: {max_diff}") + print(f" Mean absolute difference: {mean_diff}") + print(f" Shape of tensors: {per_sample_grad.shape}") + + # Find the location of maximum difference + max_idx = abs_diff.argmax().item() + flat_idx = max_idx + if len(abs_diff.shape) > 1: + # Convert flat index to multi-dimensional index + indices = [] + temp_shape = abs_diff.shape + for dim in reversed(temp_shape): + indices.insert(0, flat_idx % dim) + flat_idx //= dim + print(f" Max difference at index: {indices}") + print(f" Manual gradient value: {per_sample_grad[tuple(indices)].item()}") + print( + f" Functional gradient value: {ft_per_sample_grad[tuple(indices)].item()}" + ) + + # Compute float64 baseline for the sample with the largest difference + print("\nComputing float64 baseline for comparison...") + try: + fp64_grads = compute_fp64_baseline(data, targets, indices) + fp64_value = fp64_grads[i][ + tuple(indices[1:]) + ].item() # Skip batch dimension + print(f" Float64 baseline value: {fp64_value}") + + # Compare both methods against float64 baseline + manual_diff = abs(per_sample_grad[tuple(indices)].item() - fp64_value) + functional_diff = abs( + ft_per_sample_grad[tuple(indices)].item() - fp64_value + ) + print(f" Manual method vs float64 difference: {manual_diff}") + print(f" Functional method vs float64 difference: {functional_diff}") + + if manual_diff < functional_diff: + print(" Manual method is closer to float64 baseline") + else: + print(" Functional method is closer to float64 baseline") + except Exception as e: + print(f" Error computing float64 baseline: {e}") + + # Keep the original assertion assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5) - ###################################################################### # A quick note: there are limitations around what types of functions can be # transformed by ``vmap``. The best functions to transform are ones that are pure