Skip to content
This repository was archived by the owner on Apr 10, 2024. It is now read-only.

Commit 2471083

Browse files
committed
Fixed bug in objectives, enabled direct access to for objective tensors, changed description in objectives
1 parent 6438448 commit 2471083

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

lucid/optvis/objectives.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -61,41 +61,30 @@ class Objective(object):
6161

6262
def __init__(self, objective_func, name="", description=""):
6363
self.objective_func = objective_func
64-
self.name = name
6564
self.description = description
65+
self.value = None # This value is populated after a call
6666

6767
def __add__(self, other):
6868
if isinstance(other, (int, float)):
6969
objective_func = lambda T: other + self(T)
70-
name = self.name
71-
description = self.description
7270
else:
7371
objective_func = lambda T: self(T) + other(T)
74-
name = ", ".join([self.name, other.name])
75-
description = "Sum(" + " +\n".join([self.description, other.description]) + ")"
76-
return Objective(objective_func, name=name, description=description)
72+
description = "(" + " + ".join([str(self), str(other)]) + ")"
73+
return Objective(objective_func, description=description)
7774

7875
def __neg__(self):
7976
return -1 * self
8077

8178
def __sub__(self, other):
8279
return self + (-1 * other)
8380

84-
@staticmethod
85-
def sum(objs):
86-
objective_func = lambda T: sum([obj(T) for obj in objs])
87-
descriptions = [obj.description for obj in objs]
88-
description = "Sum(" + " +\n".join(descriptions) + ")"
89-
names = [obj.name for obj in objs]
90-
name = ", ".join(names)
91-
return Objective(objective_func, name=name, description=description)
92-
9381
def __mul__(self, other):
9482
if isinstance(other, (int, float)):
9583
objective_func = lambda T: other * self(T)
9684
else:
9785
objective_func = lambda T: self(T) * other(T)
98-
return Objective(objective_func, name=self.name, description=self.description)
86+
description = str(self) + "·" + str(other)
87+
return Objective(objective_func, description=description)
9988

10089
def __rmul__(self, other):
10190
return self.__mul__(other)
@@ -104,7 +93,14 @@ def __radd__(self, other):
10493
return self.__add__(other)
10594

10695
def __call__(self, T):
107-
return self.objective_func(T)
96+
self.value = self.objective_func(T)
97+
return self.value
98+
99+
def __str__(self):
100+
return self.description
101+
102+
def __repr__(self):
103+
return self.description
108104

109105

110106
def _make_arg_str(arg):
@@ -124,7 +120,7 @@ def wrap_objective(f, *args, **kwds):
124120
"""
125121
objective_func = f(*args, **kwds)
126122
objective_name = f.__name__
127-
args_str = " [" + ", ".join([_make_arg_str(arg) for arg in args]) + "]"
123+
args_str = "(" + ", ".join([_make_arg_str(arg) for arg in args]) + ")"
128124
description = objective_name.title() + args_str
129125
return Objective(objective_func, objective_name, description)
130126

@@ -190,10 +186,10 @@ def direction(layer, vec, batch=None, cossim_pow=0):
190186
"""Visualize a direction"""
191187
if batch is None:
192188
vec = vec[None, None, None]
193-
return lambda T: _dot_cossim(T(layer), vec)
189+
return lambda T: _dot_cossim(T(layer), vec, cossim_pow = cossim_pow)
194190
else:
195191
vec = vec[None, None]
196-
return lambda T: _dot_cossim(T(layer)[batch], vec)
192+
return lambda T: _dot_cossim(T(layer)[batch], vec, cossim_pow = cossim_pow)
197193

198194

199195
@wrap_objective

tests/optvis/test_objectives.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,37 @@ def test_neuron(inceptionv1):
4141
objective = objectives.neuron("mixed4a_pre_relu", 42)
4242
assert_gradient_ascent(objective, inceptionv1)
4343

44+
def test_composition():
45+
@wrap_objective
46+
def f(a):
47+
return lambda T: a
48+
49+
a = f(1)
50+
b = f(2)
51+
c = f(3)
52+
ab = a - 2*b
53+
cab = c*(ab - 1)
54+
55+
assert str(cab) == "F(3)·((F(1) + F(2)·2·-1) + -1)"
56+
assert cab(None) == 3*(1 - 2*2 - 1)
57+
assert a.value == 1
58+
assert b.value == 2
59+
assert c.value == 3
60+
assert ab.value == (a.value - 2*b.value)
61+
assert cab.value == c.value*(ab.value - 1)
62+
63+
64+
@pytest.mark.parametrize("cossim_pow", [0, 1, 2])
65+
def test_cossim():
66+
x = np.array([1,1], dtype = np.float32)
67+
y = np.array([1,0], dtype = np.float32)
68+
T = lambda _: tf.constant(x[None, None, None, :])
69+
objective = objectives.direction("dummy", y, cossim_pow=cossim_pow)
70+
obj = objective(T)
71+
sess = tf.Session()
72+
trueval = np.dot(x,y)*(np.dot(x,y)/(np.linalg.norm(x)*np.linalg.norm(y)))**cossim_pow
73+
assert abs(sess.run(obj) - trueval) < 1e-3
74+
4475

4576
def test_channel(inceptionv1):
4677
objective = objectives.channel("mixed4a_pre_relu", 42)

0 commit comments

Comments
 (0)