3131
3232from lucid .optvis import objectives , param , transform
3333from lucid .misc .io import show
34+ from lucid .misc .redirected_relu_grad import redirected_relu_grad , redirected_relu6_grad
35+ from lucid .misc .gradient_override import gradient_override_map
3436
3537# pylint: disable=invalid-name
3638
4042
4143
4244def render_vis (model , objective_f , param_f = None , optimizer = None ,
43- transforms = None , thresholds = (512 ,),
44- print_objectives = None , verbose = True ,):
45+ transforms = None , thresholds = (512 ,), print_objectives = None ,
46+ verbose = True , relu_gradient_override = True , use_fixed_seed = False ):
4547 """Flexible optimization-base feature vis.
4648
4749 There's a lot of ways one might wish to customize otpimization-based
@@ -72,6 +74,11 @@ def render_vis(model, objective_f, param_f=None, optimizer=None,
7274 whose values get logged during the optimization.
7375 verbose: Should we display the visualization when we hit a threshold?
7476 This should only be used in IPython.
77+ relu_gradient_override: Whether to use the gradient override scheme
78+ described in lucid/misc/redirected_relu_grad.py. On by default!
79+ use_fixed_seed: Seed the RNG with a fixed value so results are reproducible.
80+ Off by default. As of tf 1.8 this does not work as intended, see:
81+ https://github.com/tensorflow/tensorflow/issues/9171
7582 Returns:
7683 2D array of optimization results containing of evaluations of supplied
7784 param_f snapshotted at specified thresholds. Usually that will mean one or
@@ -80,7 +87,11 @@ def render_vis(model, objective_f, param_f=None, optimizer=None,
8087
8188 with tf .Graph ().as_default () as graph , tf .Session () as sess :
8289
83- T = make_vis_T (model , objective_f , param_f , optimizer , transforms )
90+ if use_fixed_seed : # does not mean results are reproducible, see Args doc
91+ tf .set_random_seed (0 )
92+
93+ T = make_vis_T (model , objective_f , param_f , optimizer , transforms ,
94+ relu_gradient_override )
8495 print_objective_func = make_print_objective_func (print_objectives , T )
8596 loss , vis_op , t_image = T ("loss" ), T ("vis_op" ), T ("input" )
8697 tf .global_variables_initializer ().run ()
@@ -105,7 +116,7 @@ def render_vis(model, objective_f, param_f=None, optimizer=None,
105116
106117
107118def make_vis_T (model , objective_f , param_f = None , optimizer = None ,
108- transforms = None ):
119+ transforms = None , relu_gradient_override = False ):
109120 """Even more flexible optimization-base feature vis.
110121
111122 This function is the inner core of render_vis(), and can be used
@@ -155,10 +166,19 @@ def make_vis_T(model, objective_f, param_f=None, optimizer=None,
155166 transform_f = make_transform_f (transforms )
156167 optimizer = make_optimizer (optimizer , [])
157168
158- T = import_model (model , transform_f (t_image ), t_image )
169+ global_step = tf .train .get_or_create_global_step ()
170+ init_global_step = tf .variables_initializer ([global_step ])
171+ init_global_step .run ()
172+
173+ if relu_gradient_override :
174+ with gradient_override_map ({'Relu' : redirected_relu_grad ,
175+ 'Relu6' : redirected_relu6_grad }):
176+ T = import_model (model , transform_f (t_image ), t_image )
177+ else :
178+ T = import_model (model , transform_f (t_image ), t_image )
159179 loss = objective_f (T )
160180
161- global_step = tf . Variable ( 0 , trainable = False , name = "global_step" )
181+
162182 vis_op = optimizer .minimize (- loss , global_step = global_step )
163183
164184 local_vars = locals ()
0 commit comments