1
+ """
2
+ Visualizing Gradients
3
+ =====================
4
+
5
+ **Author:** `Justin Silver <https://github.com/j-silv>`__
6
+
7
+ This tutorial explains how to extract and visualize gradients at any
8
+ layer in a neural network. By inspecting how information flows from the
9
+ end of the network to the parameters we want to optimize, we can debug
10
+ issues such as `vanishing or exploding
11
+ gradients <https://arxiv.org/abs/1211.5063>`__ that occur during
12
+ training.
13
+
14
+ Before starting, make sure you understand `tensors and how to manipulate
15
+ them <https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html>`__.
16
+ A basic knowledge of `how autograd
17
+ works <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html>`__
18
+ would also be useful.
19
+
20
+ """
21
+
22
+
23
+ ######################################################################
24
+ # Setup
25
+ # -----
26
+ #
27
+ # First, make sure `PyTorch is
28
+ # installed <https://pytorch.org/get-started/locally/>`__ and then import
29
+ # the necessary libraries.
30
+ #
31
+
32
+ import torch
33
+ import torch .nn as nn
34
+ import torch .optim as optim
35
+ import torch .nn .functional as F
36
+ import matplotlib .pyplot as plt
37
+
38
+
39
+ ######################################################################
40
+ # Next, we’ll be creating a network intended for the MNIST dataset,
41
+ # similar to the architecture described by the `batch normalization
42
+ # paper <https://arxiv.org/abs/1502.03167>`__.
43
+ #
44
+ # To illustrate the importance of gradient visualization, we will
45
+ # instantiate one version of the network with batch normalization
46
+ # (BatchNorm), and one without it. Batch normalization is an extremely
47
+ # effective technique to resolve `vanishing/exploding
48
+ # gradients <https://arxiv.org/abs/1211.5063>`__, and we will be verifying
49
+ # that experimentally.
50
+ #
51
+ # The model we use has a configurable number of repeating fully-connected
52
+ # layers which alternate between ``nn.Linear``, ``norm_layer``, and
53
+ # ``nn.Sigmoid``. If batch normalization is enabled, then ``norm_layer``
54
+ # will use
55
+ # `BatchNorm1d <https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html>`__,
56
+ # otherwise it will use the
57
+ # `Identity <https://docs.pytorch.org/docs/stable/generated/torch.nn.Identity.html>`__
58
+ # transformation.
59
+ #
60
+
61
+ def fc_layer (in_size , out_size , norm_layer ):
62
+ """Return a stack of linear->norm->sigmoid layers"""
63
+ return nn .Sequential (nn .Linear (in_size , out_size ), norm_layer (out_size ), nn .Sigmoid ())
64
+
65
+ class Net (nn .Module ):
66
+ """Define a network that has num_layers of linear->norm->sigmoid transformations"""
67
+ def __init__ (self , in_size = 28 * 28 , hidden_size = 128 ,
68
+ out_size = 10 , num_layers = 3 , batchnorm = False ):
69
+ super ().__init__ ()
70
+ if batchnorm is False :
71
+ norm_layer = nn .Identity
72
+ else :
73
+ norm_layer = nn .BatchNorm1d
74
+
75
+ layers = []
76
+ layers .append (fc_layer (in_size , hidden_size , norm_layer ))
77
+
78
+ for i in range (num_layers - 1 ):
79
+ layers .append (fc_layer (hidden_size , hidden_size , norm_layer ))
80
+
81
+ layers .append (nn .Linear (hidden_size , out_size ))
82
+
83
+ self .layers = nn .Sequential (* layers )
84
+
85
+ def forward (self , x ):
86
+ x = torch .flatten (x , 1 )
87
+ return self .layers (x )
88
+
89
+
90
+ ######################################################################
91
+ # Next we set up some dummy data, instantiate two versions of the model,
92
+ # and initialize the optimizers.
93
+ #
94
+
95
+ # set up dummy data
96
+ x = torch .randn (10 , 28 , 28 )
97
+ y = torch .randint (10 , (10 , ))
98
+
99
+ # init model
100
+ model_bn = Net (batchnorm = True , num_layers = 3 )
101
+ model_nobn = Net (batchnorm = False , num_layers = 3 )
102
+
103
+ model_bn .train ()
104
+ model_nobn .train ()
105
+
106
+ optimizer_bn = optim .SGD (model_bn .parameters (), lr = 0.01 , momentum = 0.9 )
107
+ optimizer_nobn = optim .SGD (model_nobn .parameters (), lr = 0.01 , momentum = 0.9 )
108
+
109
+
110
+
111
+ ######################################################################
112
+ # We can verify that batch normalization is only being applied to one of
113
+ # the models by probing one of the internal layers:
114
+ #
115
+
116
+ print (model_bn .layers [0 ])
117
+ print (model_nobn .layers [0 ])
118
+
119
+
120
+ ######################################################################
121
+ # Registering hooks
122
+ # -----------------
123
+ #
124
+
125
+
126
+ ######################################################################
127
+ # Because we wrapped up the logic and state of our model in a
128
+ # ``nn.Module``, we need another method to access the intermediate
129
+ # gradients if we want to avoid modifying the module code directly. This
130
+ # is done by `registering a
131
+ # hook <https://docs.pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`__.
132
+ #
133
+ # .. warning::
134
+ #
135
+ # Using backward pass hooks attached to output tensors is preferred over using ``retain_grad()`` on the tensors themselves. An alternative method is to directly attach module hooks (e.g. ``register_full_backward_hook()``) so long as the ``nn.Module`` instance does not do perform any in-place operations. For more information, please refer to `this issue <https://github.com/pytorch/pytorch/issues/61519>`__.
136
+ #
137
+ # The following code defines our hooks and gathers descriptive names for
138
+ # the network’s layers.
139
+ #
140
+
141
+ # note that wrapper functions are used for Python closure
142
+ # so that we can pass arguments.
143
+
144
+ def hook_forward (module_name , grads , hook_backward ):
145
+ def hook (module , args , output ):
146
+ """Forward pass hook which attaches backward pass hooks to intermediate tensors"""
147
+ output .register_hook (hook_backward (module_name , grads ))
148
+ return hook
149
+
150
+ def hook_backward (module_name , grads ):
151
+ def hook (grad ):
152
+ """Backward pass hook which appends gradients"""
153
+ grads .append ((module_name , grad ))
154
+ return hook
155
+
156
+ def get_all_layers (model , hook_forward , hook_backward ):
157
+ """Register forward pass hook (which registers a backward hook) to model outputs
158
+
159
+ Returns:
160
+ - layers: a dict with keys as layer/module and values as layer/module names
161
+ e.g. layers[nn.Conv2d] = layer1.0.conv1
162
+ - grads: a list of tuples with module name and tensor output gradient
163
+ e.g. grads[0] == (layer1.0.conv1, tensor.Torch(...))
164
+ """
165
+ layers = dict ()
166
+ grads = []
167
+ for name , layer in model .named_modules ():
168
+ # skip Sequential and/or wrapper modules
169
+ if any (layer .children ()) is False :
170
+ layers [layer ] = name
171
+ layer .register_forward_hook (hook_forward (name , grads , hook_backward ))
172
+ return layers , grads
173
+
174
+ # register hooks
175
+ layers_bn , grads_bn = get_all_layers (model_bn , hook_forward , hook_backward )
176
+ layers_nobn , grads_nobn = get_all_layers (model_nobn , hook_forward , hook_backward )
177
+
178
+
179
+ ######################################################################
180
+ # Training and visualization
181
+ # --------------------------
182
+ #
183
+ # Let’s now train the models for a few epochs:
184
+ #
185
+
186
+ epochs = 10
187
+
188
+ for epoch in range (epochs ):
189
+
190
+ # important to clear, because we append to
191
+ # outputs everytime we do a forward pass
192
+ grads_bn .clear ()
193
+ grads_nobn .clear ()
194
+
195
+ optimizer_bn .zero_grad ()
196
+ optimizer_nobn .zero_grad ()
197
+
198
+ y_pred_bn = model_bn (x )
199
+ y_pred_nobn = model_nobn (x )
200
+
201
+ loss_bn = F .cross_entropy (y_pred_bn , y )
202
+ loss_nobn = F .cross_entropy (y_pred_nobn , y )
203
+
204
+ loss_bn .backward ()
205
+ loss_nobn .backward ()
206
+
207
+ optimizer_bn .step ()
208
+ optimizer_nobn .step ()
209
+
210
+
211
+ ######################################################################
212
+ # After running the forward and backward pass, the gradients for all the
213
+ # intermediate tensors should be present in ``grads_bn`` and
214
+ # ``grads_nobn``. We compute the mean absolute value of each gradient
215
+ # matrix so that we can compare the two models.
216
+ #
217
+
218
+ def get_grads (grads ):
219
+ layer_idx = []
220
+ avg_grads = []
221
+ for idx , (name , grad ) in enumerate (grads ):
222
+ if grad is not None :
223
+ avg_grad = grad .abs ().mean ()
224
+ avg_grads .append (avg_grad )
225
+ # idx is backwards since we appended in backward pass
226
+ layer_idx .append (len (grads ) - 1 - idx )
227
+ return layer_idx , avg_grads
228
+
229
+ layer_idx_bn , avg_grads_bn = get_grads (grads_bn )
230
+ layer_idx_nobn , avg_grads_nobn = get_grads (grads_nobn )
231
+
232
+
233
+ ######################################################################
234
+ # With the average gradients computed, we can now plot them and see how
235
+ # the values change as a function of the network depth. Notice that when
236
+ # we don’t apply batch normalization, the gradient values in the
237
+ # intermediate layers fall to zero very quickly. The batch normalization
238
+ # model, however, maintains non-zero gradients in its intermediate layers.
239
+ #
240
+
241
+ fig , ax = plt .subplots ()
242
+ ax .plot (layer_idx_bn , avg_grads_bn , label = "With BatchNorm" , marker = "o" )
243
+ ax .plot (layer_idx_nobn , avg_grads_nobn , label = "Without BatchNorm" , marker = "x" )
244
+ ax .set_xlabel ("Layer depth" )
245
+ ax .set_ylabel ("Average gradient" )
246
+ ax .set_title ("Gradient flow" )
247
+ ax .grid (True )
248
+ ax .legend ()
249
+ plt .show ()
250
+
251
+
252
+ ######################################################################
253
+ # Conclusion
254
+ # ----------
255
+ #
256
+ # In this tutorial, we demonstrated how to visualize the gradient flow
257
+ # through a neural network wrapped in a ``nn.Module`` class. We
258
+ # qualitatively showed how batch normalization helps to alleviate the
259
+ # vanishing gradient issue which occurs with deep neural networks.
260
+ #
261
+ # If you would like to learn more about how PyTorch’s autograd system
262
+ # works, please visit the `references <#references>`__ below. If you have
263
+ # any feedback for this tutorial (improvements, typo fixes, etc.) then
264
+ # please use the `PyTorch Forums <https://discuss.pytorch.org/>`__ and/or
265
+ # the `issue tracker <https://github.com/pytorch/tutorials/issues>`__ to
266
+ # reach out.
267
+ #
268
+
269
+
270
+ ######################################################################
271
+ # (Optional) Additional exercises
272
+ # -------------------------------
273
+ #
274
+ # - Try increasing the number of layers (``num_layers``) in our model and
275
+ # see what effect this has on the gradient flow graph
276
+ # - How would you adapt the code to visualize average activations instead
277
+ # of average gradients? (*Hint: in the hook_forward() function we have
278
+ # access to the raw tensor output*)
279
+ # - What are some other methods to deal with vanishing and exploding
280
+ # gradients?
281
+ #
282
+
283
+
284
+ ######################################################################
285
+ # References
286
+ # ----------
287
+ #
288
+ # - `A Gentle Introduction to
289
+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__
290
+ # - `Automatic Differentiation with
291
+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial>`__
292
+ # - `Autograd
293
+ # mechanics <https://docs.pytorch.org/docs/stable/notes/autograd.html>`__
294
+ # - `Batch Normalization: Accelerating Deep Network Training by Reducing
295
+ # Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__
296
+ # - `On the difficulty of training Recurrent Neural
297
+ # Networks <https://arxiv.org/abs/1211.5063>`__
298
+ #
0 commit comments