Skip to content
This repository was archived by the owner on Apr 10, 2024. It is now read-only.

Commit 869fe06

Browse files
Add redirected Relu grad and test
1 parent 6a7dc95 commit 869fe06

File tree

2 files changed

+185
-0
lines changed

2 files changed

+185
-0
lines changed

lucid/misc/redirected_relu_grad.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2018 The Lucid Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""Redirected ReLu Gradient Overrides
17+
18+
When visualizing models we often[0] have to optimize through ReLu activation
19+
functions. Where accessing pre-relu tensors is too hard, we use these
20+
overrides to allow gradient to flow back through the ReLu—even if it didn't
21+
activate ("dead neuron") and thus its derivative is 0.
22+
23+
Usage:
24+
```python
25+
from lucid.misc.gradient_override import gradient_override_map
26+
from lucid.misc.redirected_relu_grad import redirected_relu_grad
27+
28+
with gradient_override_map({'Relu': redirected_relu_grad}):
29+
model.import_graph(…)
30+
```
31+
32+
Discussion:
33+
ReLus block the flow of the gradient during backpropagation when their input is
34+
negative. ReLu6s also do so when the input is larger than 6. These overrides
35+
change this behavior to allow gradient pushing the input into a desired regime
36+
between these points.
37+
38+
In effect, this replaces the relu gradient with the following:
39+
40+
Regime | Effect
41+
============================================================
42+
0 <= x <= 6 | pass through gradient
43+
x < 0 | pass through gradient pushing the input up
44+
x > 6 | pass through gradient pushing the input down
45+
46+
Or visually:
47+
48+
ReLu: | |____________
49+
| /|
50+
| / |
51+
____________|/ |
52+
0 6
53+
54+
Override: ------------| |------------
55+
allow -> <- allow
56+
57+
Our implementations contains one extra complication:
58+
tf.train.Optimizer performs gradient _descent_, so in the update step the
59+
optimizer changes values in the opposite direction of the gradient. Thus, the
60+
sign of the gradient in our overrides has the opposite of the intuitive effect:
61+
negative gradient pushes the input up, positive pushes it down.
62+
Thus, the code below only allows _negative_ gradient when the input is already
63+
negative, and allows _positive_ gradient when the input is already above 6.
64+
65+
66+
[0] That is because many model architectures don't provide easy access
67+
to pre-relu tensors. For example, GoogLeNet's mixed__ layers are passed through
68+
an activation function before being concatenated. We are still interested in the
69+
entire concatenated layer, we would just like to skip the activation function.
70+
"""
71+
72+
import tensorflow as tf
73+
74+
75+
def redirected_relu_grad(op, grad):
76+
assert op.type == "Relu"
77+
x = op.inputs[0]
78+
79+
# Compute ReLu gradient
80+
relu_grad = tf.where(x < 0., tf.zeros_like(grad), grad)
81+
82+
# Compute redirected gradient: where do we need to zero out incoming gradient
83+
# to prevent input going lower if its already negative
84+
neg_pushing_lower = tf.logical_and(x < 0., grad > 0.)
85+
redirected_grad = tf.where(neg_pushing_lower, tf.zeros_like(grad), grad)
86+
87+
# Ensure we have at least a rank 2 tensor, as we expect a batch dimension
88+
assert_op = tf.Assert(tf.greater(tf.rank(relu_grad), 1), [tf.rank(relu_grad)])
89+
with tf.control_dependencies([assert_op]):
90+
# only use redirected gradient where nothing got through original gradient
91+
batch = tf.shape(relu_grad)[0]
92+
reshaped_relu_grad = tf.reshape(relu_grad, [batch, -1])
93+
relu_grad_mag = tf.norm(reshaped_relu_grad, axis=1)
94+
return tf.where(relu_grad_mag > 0., relu_grad, redirected_grad)
95+
96+
97+
def redirected_relu6_grad(op, grad):
98+
assert op.type == "Relu6"
99+
x = op.inputs[0]
100+
101+
# Compute ReLu gradient
102+
relu6_cond = tf.logical_or(x < 0., x > 6.)
103+
relu_grad = tf.where(relu6_cond, tf.zeros_like(grad), grad)
104+
105+
# Compute redirected gradient: where do we need to zero out incoming gradient
106+
# to prevent input going lower if its already negative, or going higher if
107+
# already bigger than 6?
108+
neg_pushing_lower = tf.logical_and(x < 0., grad > 0.)
109+
pos_pushing_higher = tf.logical_and(x > 6., grad < 0.)
110+
dir_filter = tf.logical_or(neg_pushing_lower, pos_pushing_higher)
111+
redirected_grad = tf.where(dir_filter, tf.zeros_like(grad), grad)
112+
113+
# Ensure we have at least a rank 2 tensor, as we expect a batch dimension
114+
assert_op = tf.Assert(tf.greater(tf.rank(relu_grad), 1), [tf.rank(relu_grad)])
115+
with tf.control_dependencies([assert_op]):
116+
# only use redirected gradient where nothing got through original gradient
117+
batch = tf.shape(relu_grad)[0]
118+
reshaped_relu_grad = tf.reshape(relu_grad, [batch, -1])
119+
relu_grad_mag = tf.norm(reshaped_relu_grad, axis=1)
120+
return tf.where(relu_grad_mag > 0., relu_grad, redirected_grad)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import pytest
2+
3+
import tensorflow as tf
4+
from lucid.misc.gradient_override import use_gradient
5+
6+
def test_use_gradient():
7+
def foo_grad(op, grad):
8+
return tf.constant(42), tf.constant(43)
9+
10+
@use_gradient(foo_grad)
11+
def foo(x, y):
12+
return x + y
13+
14+
with tf.Session().as_default() as sess:
15+
x = tf.constant(1.)
16+
y = tf.constant(2.)
17+
z = foo(x, y)
18+
grad_wrt_x = tf.gradients(z, x, [1.])[0]
19+
grad_wrt_y = tf.gradients(z, y, [1.])[0]
20+
assert grad_wrt_x.eval() == 42
21+
assert grad_wrt_y.eval() == 43
22+
23+
24+
from lucid.misc.gradient_override import gradient_override_map
25+
26+
def test_gradient_override_map():
27+
28+
def gradient_override(op, grad):
29+
return tf.constant(42)
30+
31+
with tf.Session().as_default() as sess:
32+
a = tf.constant(1.)
33+
standard_relu = tf.nn.relu(a)
34+
grad_wrt_a = tf.gradients(standard_relu, a, [1.])[0]
35+
with gradient_override_map({"Relu": gradient_override}):
36+
overriden_relu = tf.nn.relu(a)
37+
overriden_grad_wrt_a = tf.gradients(overriden_relu, a, [1.])[0]
38+
assert grad_wrt_a.eval() != overriden_grad_wrt_a.eval()
39+
assert overriden_grad_wrt_a.eval() == 42
40+
41+
42+
from lucid.misc.redirected_relu_grad import redirected_relu_grad, redirected_relu6_grad
43+
44+
relu_examples = [
45+
(1., -1., 0.), (-1., -1., -1.),
46+
(1., 1., 1.), (-1., 1., -1.),
47+
]
48+
relu6_examples = relu_examples + [
49+
(1., 7., 1.), (-1., 7., 0.),
50+
]
51+
nonls = [("Relu", tf.nn.relu, redirected_relu_grad, relu_examples),
52+
("Relu6", tf.nn.relu6, redirected_relu6_grad, relu6_examples)]
53+
54+
@pytest.mark.parametrize("nonl_name,nonl,nonl_grad_override, examples", nonls)
55+
def test_gradient_override_relu6_directionality(nonl_name, nonl,
56+
nonl_grad_override, examples):
57+
for incoming_grad, input, grad in examples:
58+
with tf.Session().as_default() as sess:
59+
batched_shape = [1,1]
60+
incoming_grad_t = tf.constant(incoming_grad, shape=batched_shape)
61+
input_t = tf.constant(input, shape=batched_shape)
62+
with gradient_override_map({nonl_name: nonl_grad_override}):
63+
nonl_t = nonl(input_t)
64+
grad_wrt_input = tf.gradients(nonl_t, input_t, [incoming_grad_t])[0]
65+
assert (grad_wrt_input.eval() == grad).all()

0 commit comments

Comments
 (0)