1
+ """
2
+ Understanding requires_grad, retain_grad, Leaf, and Non-leaf Tensors
3
+ ====================================================================
4
+
5
+ **Author:** `Justin Silver <https://github.com/j-silv>`__
6
+
7
+ This tutorial explains the subtleties of ``requires_grad``,
8
+ ``retain_grad``, leaf, and non-leaf tensors using a simple example.
9
+
10
+ Before starting, make sure you understand `tensors and how to manipulate
11
+ them <https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html>`__.
12
+ A basic knowledge of `how autograd
13
+ works <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html>`__
14
+ would also be useful.
15
+
16
+ """
17
+
18
+
19
+ ######################################################################
20
+ # Setup
21
+ # -----
22
+ #
23
+ # First, make sure `PyTorch is
24
+ # installed <https://pytorch.org/get-started/locally/>`__ and then import
25
+ # the necessary libraries.
26
+ #
27
+
28
+ import torch
29
+ import torch .nn .functional as F
30
+
31
+
32
+ ######################################################################
33
+ # Next, we instantiate a simple network to focus on the gradients. This
34
+ # will be an affine layer, followed by a ReLU activation, and ending with
35
+ # a MSE loss between prediction and label tensors.
36
+ #
37
+ # .. math::
38
+ #
39
+ # \mathbf{y}_{\text{pred}} = \text{ReLU}(\mathbf{x} \mathbf{W} + \mathbf{b})
40
+ #
41
+ # .. math::
42
+ #
43
+ # L = \text{MSE}(\mathbf{y}_{\text{pred}}, \mathbf{y})
44
+ #
45
+ # Note that the ``requires_grad=True`` is necessary for the parameters
46
+ # (``W`` and ``b``) so that PyTorch tracks operations involving those
47
+ # tensors. We’ll discuss more about this in a future
48
+ # `section <#requires-grad>`__.
49
+ #
50
+
51
+ # tensor setup
52
+ x = torch .ones (1 , 3 ) # input with shape: (1, 3)
53
+ W = torch .ones (3 , 2 , requires_grad = True ) # weights with shape: (3, 2)
54
+ b = torch .ones (1 , 2 , requires_grad = True ) # bias with shape: (1, 2)
55
+ y = torch .ones (1 , 2 ) # output with shape: (1, 2)
56
+
57
+ # forward pass
58
+ z = (x @ W ) + b # pre-activation with shape: (1, 2)
59
+ y_pred = F .relu (z ) # activation with shape: (1, 2)
60
+ loss = F .mse_loss (y_pred , y ) # scalar loss
61
+
62
+
63
+ ######################################################################
64
+ # Leaf vs. non-leaf tensors
65
+ # -------------------------
66
+ #
67
+ # After running the forward pass, PyTorch autograd has built up a `dynamic
68
+ # computational
69
+ # graph <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#computational-graph>`__
70
+ # which is shown below. This is a `Directed Acyclic Graph
71
+ # (DAG) <https://en.wikipedia.org/wiki/Directed_acyclic_graph>`__ which
72
+ # keeps a record of input tensors (leaf nodes), all subsequent operations
73
+ # on those tensors, and the intermediate/output tensors (non-leaf nodes).
74
+ # The graph is used to compute gradients for each tensor starting from the
75
+ # graph roots (outputs) to the leaves (inputs) using the `chain
76
+ # rule <https://en.wikipedia.org/wiki/Chain_rule>`__ from calculus:
77
+ #
78
+ # .. math::
79
+ #
80
+ # \mathbf{y} = \mathbf{f}_k\bigl(\mathbf{f}_{k-1}(\dots \mathbf{f}_1(\mathbf{x}) \dots)\bigr)
81
+ #
82
+ # .. math::
83
+ #
84
+ # \frac{\partial \mathbf{y}}{\partial \mathbf{x}} =
85
+ # \frac{\partial \mathbf{f}_k}{\partial \mathbf{f}_{k-1}} \cdot
86
+ # \frac{\partial \mathbf{f}_{k-1}}{\partial \mathbf{f}_{k-2}} \cdot
87
+ # \cdots \cdot
88
+ # \frac{\partial \mathbf{f}_1}{\partial \mathbf{x}}
89
+ #
90
+ # .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-1.png
91
+ # :alt: Computational graph after forward pass
92
+ #
93
+ # Computational graph after forward pass
94
+ #
95
+ # PyTorch considers a node to be a *leaf* if it is not the result of a
96
+ # tensor operation with at least one input having ``requires_grad=True``
97
+ # (e.g. ``x``, ``W``, ``b``, and ``y``), and everything else to be
98
+ # *non-leaf* (e.g. ``z``, ``y_pred``, and ``loss``). You can verify this
99
+ # programmatically by probing the ``is_leaf`` attribute of the tensors:
100
+ #
101
+
102
+ # prints True because new tensors are leafs by convention
103
+ print (f"{ x .is_leaf = } " )
104
+
105
+ # prints False because tensor is the result of an operation with at
106
+ # least one input having requires_grad=True
107
+ print (f"{ z .is_leaf = } " )
108
+
109
+
110
+ ######################################################################
111
+ # The distinction between leaf and non-leaf determines whether the
112
+ # tensor’s gradient will be stored in the ``grad`` property after the
113
+ # backward pass, and thus be usable for `gradient
114
+ # descent <https://en.wikipedia.org/wiki/Gradient_descent>`__. We’ll cover
115
+ # this some more in the `following section <#retain-grad>`__.
116
+ #
117
+ # Let’s now investigate how PyTorch calculates and stores gradients for
118
+ # the tensors in its computational graph.
119
+ #
120
+
121
+
122
+ ######################################################################
123
+ # ``requires_grad``
124
+ # -----------------
125
+ #
126
+ # To build the computational graph which can be used for gradient
127
+ # calculation, we need to pass in the ``requires_grad=True`` parameter to
128
+ # a tensor constructor. By default, the value is ``False``, and thus
129
+ # PyTorch does not track gradients on any created tensors. To verify this,
130
+ # try not setting ``requires_grad``, re-run the forward pass, and then run
131
+ # backpropagation. You will see:
132
+ #
133
+ # ::
134
+ #
135
+ # >>> loss.backward()
136
+ # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
137
+ #
138
+ # This error means that autograd can’t backpropagate to any leaf tensors
139
+ # because ``loss`` is not tracking gradients. If you need to change the
140
+ # property, you can call ``requires_grad_()`` on the tensor (notice the \_
141
+ # suffix).
142
+ #
143
+ # We can sanity check which nodes require gradient calculation, just like
144
+ # we did above with the ``is_leaf`` attribute:
145
+ #
146
+
147
+ print (f"{ x .requires_grad = } " ) # prints False because requires_grad=False by default
148
+ print (f"{ W .requires_grad = } " ) # prints True because we set requires_grad=True in constructor
149
+ print (f"{ z .requires_grad = } " ) # prints True because tensor is a non-leaf node
150
+
151
+
152
+ ######################################################################
153
+ # It’s useful to remember that a non-leaf tensor has
154
+ # ``requires_grad=True`` by definition, since backpropagation would fail
155
+ # otherwise. If the tensor is a leaf, then it will only have
156
+ # ``requires_grad=True`` if it was specifically set by the user. Another
157
+ # way to phrase this is that if at least one of the inputs to a tensor
158
+ # requires the gradient, then it will require the gradient as well.
159
+ #
160
+ # There are two exceptions to this rule:
161
+ #
162
+ # 1. Any ``nn.Module`` that has ``nn.Parameter`` will have
163
+ # ``requires_grad=True`` for its parameters (see
164
+ # `here <https://docs.pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#creating-models>`__)
165
+ # 2. Locally disabling gradient computation with context managers (see
166
+ # `here <https://docs.pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation>`__)
167
+ #
168
+ # In summary, ``requires_grad`` tells autograd which tensors need to have
169
+ # their gradients calculated for backpropagation to work. This is
170
+ # different from which tensors have their ``grad`` field populated, which
171
+ # is the topic of the next section.
172
+ #
173
+
174
+
175
+ ######################################################################
176
+ # ``retain_grad``
177
+ # ---------------
178
+ #
179
+ # To actually perform optimization (e.g. SGD, Adam, etc.), we need to run
180
+ # the backward pass so that we can extract the gradients.
181
+ #
182
+
183
+ loss .backward ()
184
+
185
+
186
+ ######################################################################
187
+ # Calling ``backward()`` populates the ``grad`` field of all leaf tensors
188
+ # which had ``requires_grad=True``. The ``grad`` is the gradient of the
189
+ # loss with respect to the tensor we are probing. Before running
190
+ # ``backward()``, this attribute is set to ``None``.
191
+ #
192
+
193
+ print (f"{ W .grad = } " )
194
+ print (f"{ b .grad = } " )
195
+
196
+
197
+ ######################################################################
198
+ # You might be wondering about the other tensors in our network. Let’s
199
+ # check the remaining leaf nodes:
200
+ #
201
+
202
+ # prints all None because requires_grad=False
203
+ print (f"{ x .grad = } " )
204
+ print (f"{ y .grad = } " )
205
+
206
+
207
+ ######################################################################
208
+ # The gradients for these tensors haven’t been populated because we did
209
+ # not explicitly tell PyTorch to calculate their gradient
210
+ # (``requires_grad=False``).
211
+ #
212
+ # Let’s now look at an intermediate non-leaf node:
213
+ #
214
+
215
+ print (f"{ z .grad = } " )
216
+
217
+
218
+ ######################################################################
219
+ # PyTorch returns ``None`` for the gradient and also warns us that a
220
+ # non-leaf node’s ``grad`` attribute is being accessed. Although autograd
221
+ # has to calculate intermediate gradients for backpropagation to work, it
222
+ # assumes you don’t need to access the values afterwards. To change this
223
+ # behavior, we can use the ``retain_grad()`` function on a tensor. This
224
+ # tells the autograd engine to populate that tensor’s ``grad`` after
225
+ # calling ``backward()``.
226
+ #
227
+
228
+ # we have to re-run the forward pass
229
+ z = (x @ W ) + b
230
+ y_pred = F .relu (z )
231
+ loss = F .mse_loss (y_pred , y )
232
+
233
+ # tell PyTorch to store the gradients after backward()
234
+ z .retain_grad ()
235
+ y_pred .retain_grad ()
236
+ loss .retain_grad ()
237
+
238
+ # have to zero out gradients otherwise they would accumulate
239
+ W .grad = None
240
+ b .grad = None
241
+
242
+ # backpropagation
243
+ loss .backward ()
244
+
245
+ # print gradients for all tensors that have requires_grad=True
246
+ print (f"{ W .grad = } " )
247
+ print (f"{ b .grad = } " )
248
+ print (f"{ z .grad = } " )
249
+ print (f"{ y_pred .grad = } " )
250
+ print (f"{ loss .grad = } " )
251
+
252
+
253
+ ######################################################################
254
+ # We get the same result for ``W.grad`` as before. Also note that because
255
+ # the loss is scalar, the gradient of the loss with respect to itself is
256
+ # simply ``1.0``.
257
+ #
258
+ # If we look at the state of the computational graph now, we see that the
259
+ # ``retains_grad`` attribute has changed for the intermediate tensors. By
260
+ # convention, this attribute will print ``False`` for any leaf node, even
261
+ # if it requires its gradient.
262
+ #
263
+ # .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-2.png
264
+ # :alt: Computational graph after backward pass
265
+ #
266
+ # Computational graph after backward pass
267
+ #
268
+ # If you call ``retain_grad()`` on a non-leaf node, it results in a no-op.
269
+ # If we call ``retain_grad()`` on a node that has ``requires_grad=False``,
270
+ # PyTorch actually throws an error, since it can’t store the gradient if
271
+ # it is never calculated.
272
+ #
273
+ # ::
274
+ #
275
+ # >>> x.retain_grad()
276
+ # RuntimeError: can't retain_grad on Tensor that has requires_grad=False
277
+ #
278
+
279
+
280
+ ######################################################################
281
+ # Summary table
282
+ # -------------
283
+ #
284
+ # Using ``retain_grad()`` and ``retains_grad`` only make sense for
285
+ # non-leaf nodes, since the ``grad`` attribute will already be populated
286
+ # for leaf tensors that have ``requires_grad=True``. By default, these
287
+ # non-leaf nodes do not retain (store) their gradient after
288
+ # backpropagation. We can change that by rerunning the forward pass,
289
+ # telling PyTorch to store the gradients, and then performing
290
+ # backpropagation.
291
+ #
292
+ # The following table can be used as a reference which summarizes the
293
+ # above discussions. The following scenarios are the only ones that are
294
+ # valid for PyTorch tensors.
295
+ #
296
+ #
297
+ #
298
+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
299
+ # | ``is_leaf`` | ``requires_grad`` | ``retains_grad`` | ``require_grad()`` | ``retain_grad()`` |
300
+ # +================+========================+========================+===================================================+=====================================+
301
+ # | ``True`` | ``False`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
302
+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
303
+ # | ``True`` | ``True`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
304
+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
305
+ # | ``False`` | ``True`` | ``False`` | no-op | sets ``retains_grad`` to ``True`` |
306
+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
307
+ # | ``False`` | ``True`` | ``True`` | no-op | no-op |
308
+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
309
+ #
310
+
311
+
312
+ ######################################################################
313
+ # Conclusion
314
+ # ----------
315
+ #
316
+ # In this tutorial, we covered when and how PyTorch computes gradients for
317
+ # leaf and non-leaf tensors. By using ``retain_grad``, we can access the
318
+ # gradients of intermediate tensors within autograd’s computational graph.
319
+ #
320
+ # If you would like to learn more about how PyTorch’s autograd system
321
+ # works, please visit the `references <#references>`__ below. If you have
322
+ # any feedback for this tutorial (improvements, typo fixes, etc.) then
323
+ # please use the `PyTorch Forums <https://discuss.pytorch.org/>`__ and/or
324
+ # the `issue tracker <https://github.com/pytorch/tutorials/issues>`__ to
325
+ # reach out.
326
+ #
327
+
328
+
329
+ ######################################################################
330
+ # References
331
+ # ----------
332
+ #
333
+ # - `A Gentle Introduction to
334
+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__
335
+ # - `Automatic Differentiation with
336
+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial>`__
337
+ # - `Autograd
338
+ # mechanics <https://docs.pytorch.org/docs/stable/notes/autograd.html>`__
339
+ #
0 commit comments