Skip to content

Commit 3c181eb

Browse files
committed
Fix #3186 - create leaf/non-leaf/requires_grad/retain_grad tutorial
This was originally bundled with PR #3389, but now broken into two separate tutorials after discussing with PyTorch team.
1 parent b5637fa commit 3c181eb

File tree

5 files changed

+346
-0
lines changed

5 files changed

+346
-0
lines changed
Loading
Loading
Loading
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
"""
2+
Understanding requires_grad, retain_grad, Leaf, and Non-leaf Tensors
3+
====================================================================
4+
5+
**Author:** `Justin Silver <https://github.com/j-silv>`__
6+
7+
This tutorial explains the subtleties of ``requires_grad``,
8+
``retain_grad``, leaf, and non-leaf tensors using a simple example.
9+
10+
Before starting, make sure you understand `tensors and how to manipulate
11+
them <https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html>`__.
12+
A basic knowledge of `how autograd
13+
works <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html>`__
14+
would also be useful.
15+
16+
"""
17+
18+
19+
######################################################################
20+
# Setup
21+
# -----
22+
#
23+
# First, make sure `PyTorch is
24+
# installed <https://pytorch.org/get-started/locally/>`__ and then import
25+
# the necessary libraries.
26+
#
27+
28+
import torch
29+
import torch.nn.functional as F
30+
31+
32+
######################################################################
33+
# Next, we instantiate a simple network to focus on the gradients. This
34+
# will be an affine layer, followed by a ReLU activation, and ending with
35+
# a MSE loss between prediction and label tensors.
36+
#
37+
# .. math::
38+
#
39+
# \mathbf{y}_{\text{pred}} = \text{ReLU}(\mathbf{x} \mathbf{W} + \mathbf{b})
40+
#
41+
# .. math::
42+
#
43+
# L = \text{MSE}(\mathbf{y}_{\text{pred}}, \mathbf{y})
44+
#
45+
# Note that the ``requires_grad=True`` is necessary for the parameters
46+
# (``W`` and ``b``) so that PyTorch tracks operations involving those
47+
# tensors. We’ll discuss more about this in a future
48+
# `section <#requires-grad>`__.
49+
#
50+
51+
# tensor setup
52+
x = torch.ones(1, 3) # input with shape: (1, 3)
53+
W = torch.ones(3, 2, requires_grad=True) # weights with shape: (3, 2)
54+
b = torch.ones(1, 2, requires_grad=True) # bias with shape: (1, 2)
55+
y = torch.ones(1, 2) # output with shape: (1, 2)
56+
57+
# forward pass
58+
z = (x @ W) + b # pre-activation with shape: (1, 2)
59+
y_pred = F.relu(z) # activation with shape: (1, 2)
60+
loss = F.mse_loss(y_pred, y) # scalar loss
61+
62+
63+
######################################################################
64+
# Leaf vs. non-leaf tensors
65+
# -------------------------
66+
#
67+
# After running the forward pass, PyTorch autograd has built up a `dynamic
68+
# computational
69+
# graph <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#computational-graph>`__
70+
# which is shown below. This is a `Directed Acyclic Graph
71+
# (DAG) <https://en.wikipedia.org/wiki/Directed_acyclic_graph>`__ which
72+
# keeps a record of input tensors (leaf nodes), all subsequent operations
73+
# on those tensors, and the intermediate/output tensors (non-leaf nodes).
74+
# The graph is used to compute gradients for each tensor starting from the
75+
# graph roots (outputs) to the leaves (inputs) using the `chain
76+
# rule <https://en.wikipedia.org/wiki/Chain_rule>`__ from calculus:
77+
#
78+
# .. math::
79+
#
80+
# \mathbf{y} = \mathbf{f}_k\bigl(\mathbf{f}_{k-1}(\dots \mathbf{f}_1(\mathbf{x}) \dots)\bigr)
81+
#
82+
# .. math::
83+
#
84+
# \frac{\partial \mathbf{y}}{\partial \mathbf{x}} =
85+
# \frac{\partial \mathbf{f}_k}{\partial \mathbf{f}_{k-1}} \cdot
86+
# \frac{\partial \mathbf{f}_{k-1}}{\partial \mathbf{f}_{k-2}} \cdot
87+
# \cdots \cdot
88+
# \frac{\partial \mathbf{f}_1}{\partial \mathbf{x}}
89+
#
90+
# .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-1.png
91+
# :alt: Computational graph after forward pass
92+
#
93+
# Computational graph after forward pass
94+
#
95+
# PyTorch considers a node to be a *leaf* if it is not the result of a
96+
# tensor operation with at least one input having ``requires_grad=True``
97+
# (e.g. ``x``, ``W``, ``b``, and ``y``), and everything else to be
98+
# *non-leaf* (e.g. ``z``, ``y_pred``, and ``loss``). You can verify this
99+
# programmatically by probing the ``is_leaf`` attribute of the tensors:
100+
#
101+
102+
# prints True because new tensors are leafs by convention
103+
print(f"{x.is_leaf=}")
104+
105+
# prints False because tensor is the result of an operation with at
106+
# least one input having requires_grad=True
107+
print(f"{z.is_leaf=}")
108+
109+
110+
######################################################################
111+
# The distinction between leaf and non-leaf determines whether the
112+
# tensor’s gradient will be stored in the ``grad`` property after the
113+
# backward pass, and thus be usable for `gradient
114+
# descent <https://en.wikipedia.org/wiki/Gradient_descent>`__. We’ll cover
115+
# this some more in the `following section <#retain-grad>`__.
116+
#
117+
# Let’s now investigate how PyTorch calculates and stores gradients for
118+
# the tensors in its computational graph.
119+
#
120+
121+
122+
######################################################################
123+
# ``requires_grad``
124+
# -----------------
125+
#
126+
# To build the computational graph which can be used for gradient
127+
# calculation, we need to pass in the ``requires_grad=True`` parameter to
128+
# a tensor constructor. By default, the value is ``False``, and thus
129+
# PyTorch does not track gradients on any created tensors. To verify this,
130+
# try not setting ``requires_grad``, re-run the forward pass, and then run
131+
# backpropagation. You will see:
132+
#
133+
# ::
134+
#
135+
# >>> loss.backward()
136+
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
137+
#
138+
# This error means that autograd can’t backpropagate to any leaf tensors
139+
# because ``loss`` is not tracking gradients. If you need to change the
140+
# property, you can call ``requires_grad_()`` on the tensor (notice the \_
141+
# suffix).
142+
#
143+
# We can sanity check which nodes require gradient calculation, just like
144+
# we did above with the ``is_leaf`` attribute:
145+
#
146+
147+
print(f"{x.requires_grad=}") # prints False because requires_grad=False by default
148+
print(f"{W.requires_grad=}") # prints True because we set requires_grad=True in constructor
149+
print(f"{z.requires_grad=}") # prints True because tensor is a non-leaf node
150+
151+
152+
######################################################################
153+
# It’s useful to remember that a non-leaf tensor has
154+
# ``requires_grad=True`` by definition, since backpropagation would fail
155+
# otherwise. If the tensor is a leaf, then it will only have
156+
# ``requires_grad=True`` if it was specifically set by the user. Another
157+
# way to phrase this is that if at least one of the inputs to a tensor
158+
# requires the gradient, then it will require the gradient as well.
159+
#
160+
# There are two exceptions to this rule:
161+
#
162+
# 1. Any ``nn.Module`` that has ``nn.Parameter`` will have
163+
# ``requires_grad=True`` for its parameters (see
164+
# `here <https://docs.pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#creating-models>`__)
165+
# 2. Locally disabling gradient computation with context managers (see
166+
# `here <https://docs.pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation>`__)
167+
#
168+
# In summary, ``requires_grad`` tells autograd which tensors need to have
169+
# their gradients calculated for backpropagation to work. This is
170+
# different from which tensors have their ``grad`` field populated, which
171+
# is the topic of the next section.
172+
#
173+
174+
175+
######################################################################
176+
# ``retain_grad``
177+
# ---------------
178+
#
179+
# To actually perform optimization (e.g. SGD, Adam, etc.), we need to run
180+
# the backward pass so that we can extract the gradients.
181+
#
182+
183+
loss.backward()
184+
185+
186+
######################################################################
187+
# Calling ``backward()`` populates the ``grad`` field of all leaf tensors
188+
# which had ``requires_grad=True``. The ``grad`` is the gradient of the
189+
# loss with respect to the tensor we are probing. Before running
190+
# ``backward()``, this attribute is set to ``None``.
191+
#
192+
193+
print(f"{W.grad=}")
194+
print(f"{b.grad=}")
195+
196+
197+
######################################################################
198+
# You might be wondering about the other tensors in our network. Let’s
199+
# check the remaining leaf nodes:
200+
#
201+
202+
# prints all None because requires_grad=False
203+
print(f"{x.grad=}")
204+
print(f"{y.grad=}")
205+
206+
207+
######################################################################
208+
# The gradients for these tensors haven’t been populated because we did
209+
# not explicitly tell PyTorch to calculate their gradient
210+
# (``requires_grad=False``).
211+
#
212+
# Let’s now look at an intermediate non-leaf node:
213+
#
214+
215+
print(f"{z.grad=}")
216+
217+
218+
######################################################################
219+
# PyTorch returns ``None`` for the gradient and also warns us that a
220+
# non-leaf node’s ``grad`` attribute is being accessed. Although autograd
221+
# has to calculate intermediate gradients for backpropagation to work, it
222+
# assumes you don’t need to access the values afterwards. To change this
223+
# behavior, we can use the ``retain_grad()`` function on a tensor. This
224+
# tells the autograd engine to populate that tensor’s ``grad`` after
225+
# calling ``backward()``.
226+
#
227+
228+
# we have to re-run the forward pass
229+
z = (x @ W) + b
230+
y_pred = F.relu(z)
231+
loss = F.mse_loss(y_pred, y)
232+
233+
# tell PyTorch to store the gradients after backward()
234+
z.retain_grad()
235+
y_pred.retain_grad()
236+
loss.retain_grad()
237+
238+
# have to zero out gradients otherwise they would accumulate
239+
W.grad = None
240+
b.grad = None
241+
242+
# backpropagation
243+
loss.backward()
244+
245+
# print gradients for all tensors that have requires_grad=True
246+
print(f"{W.grad=}")
247+
print(f"{b.grad=}")
248+
print(f"{z.grad=}")
249+
print(f"{y_pred.grad=}")
250+
print(f"{loss.grad=}")
251+
252+
253+
######################################################################
254+
# We get the same result for ``W.grad`` as before. Also note that because
255+
# the loss is scalar, the gradient of the loss with respect to itself is
256+
# simply ``1.0``.
257+
#
258+
# If we look at the state of the computational graph now, we see that the
259+
# ``retains_grad`` attribute has changed for the intermediate tensors. By
260+
# convention, this attribute will print ``False`` for any leaf node, even
261+
# if it requires its gradient.
262+
#
263+
# .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-2.png
264+
# :alt: Computational graph after backward pass
265+
#
266+
# Computational graph after backward pass
267+
#
268+
# If you call ``retain_grad()`` on a non-leaf node, it results in a no-op.
269+
# If we call ``retain_grad()`` on a node that has ``requires_grad=False``,
270+
# PyTorch actually throws an error, since it can’t store the gradient if
271+
# it is never calculated.
272+
#
273+
# ::
274+
#
275+
# >>> x.retain_grad()
276+
# RuntimeError: can't retain_grad on Tensor that has requires_grad=False
277+
#
278+
279+
280+
######################################################################
281+
# Summary table
282+
# -------------
283+
#
284+
# Using ``retain_grad()`` and ``retains_grad`` only make sense for
285+
# non-leaf nodes, since the ``grad`` attribute will already be populated
286+
# for leaf tensors that have ``requires_grad=True``. By default, these
287+
# non-leaf nodes do not retain (store) their gradient after
288+
# backpropagation. We can change that by rerunning the forward pass,
289+
# telling PyTorch to store the gradients, and then performing
290+
# backpropagation.
291+
#
292+
# The following table can be used as a reference which summarizes the
293+
# above discussions. The following scenarios are the only ones that are
294+
# valid for PyTorch tensors.
295+
#
296+
#
297+
#
298+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
299+
# | ``is_leaf`` | ``requires_grad`` | ``retains_grad`` | ``require_grad()`` | ``retain_grad()`` |
300+
# +================+========================+========================+===================================================+=====================================+
301+
# | ``True`` | ``False`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
302+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
303+
# | ``True`` | ``True`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
304+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
305+
# | ``False`` | ``True`` | ``False`` | no-op | sets ``retains_grad`` to ``True`` |
306+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
307+
# | ``False`` | ``True`` | ``True`` | no-op | no-op |
308+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
309+
#
310+
311+
312+
######################################################################
313+
# Conclusion
314+
# ----------
315+
#
316+
# In this tutorial, we covered when and how PyTorch computes gradients for
317+
# leaf and non-leaf tensors. By using ``retain_grad``, we can access the
318+
# gradients of intermediate tensors within autograd’s computational graph.
319+
#
320+
# If you would like to learn more about how PyTorch’s autograd system
321+
# works, please visit the `references <#references>`__ below. If you have
322+
# any feedback for this tutorial (improvements, typo fixes, etc.) then
323+
# please use the `PyTorch Forums <https://discuss.pytorch.org/>`__ and/or
324+
# the `issue tracker <https://github.com/pytorch/tutorials/issues>`__ to
325+
# reach out.
326+
#
327+
328+
329+
######################################################################
330+
# References
331+
# ----------
332+
#
333+
# - `A Gentle Introduction to
334+
# torch.autograd <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__
335+
# - `Automatic Differentiation with
336+
# torch.autograd <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial>`__
337+
# - `Autograd
338+
# mechanics <https://docs.pytorch.org/docs/stable/notes/autograd.html>`__
339+
#

index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ Welcome to PyTorch Tutorials
9999
:link: intermediate/pinmem_nonblock.html
100100
:tags: Getting-Started
101101

102+
.. customcarditem::
103+
:header: Understanding requires_grad, retain_grad, Leaf, and Non-leaf Tensors
104+
:card_description: Learn the subtleties of requires_grad, retain_grad, leaf, and non-leaf tensors
105+
:image: _static/img/thumbnails/cropped/understanding_leaf_vs_nonleaf.png
106+
:link: beginner/understanding_leaf_vs_nonleaf_tutorial.html
107+
:tags: Getting-Started
108+
102109
.. Image/Video
103110
104111
.. customcarditem::

0 commit comments

Comments
 (0)