@@ -61,41 +61,30 @@ class Objective(object):
61
61
62
62
def __init__ (self , objective_func , name = "" , description = "" ):
63
63
self .objective_func = objective_func
64
- self .name = name
65
64
self .description = description
65
+ self .value = None # This value is populated after a call
66
66
67
67
def __add__ (self , other ):
68
68
if isinstance (other , (int , float )):
69
69
objective_func = lambda T : other + self (T )
70
- name = self .name
71
- description = self .description
72
70
else :
73
71
objective_func = lambda T : self (T ) + other (T )
74
- name = ", " .join ([self .name , other .name ])
75
- description = "Sum(" + " +\n " .join ([self .description , other .description ]) + ")"
76
- return Objective (objective_func , name = name , description = description )
72
+ description = "(" + " + " .join ([str (self ), str (other )]) + ")"
73
+ return Objective (objective_func , description = description )
77
74
78
75
def __neg__ (self ):
79
76
return - 1 * self
80
77
81
78
def __sub__ (self , other ):
82
79
return self + (- 1 * other )
83
80
84
- @staticmethod
85
- def sum (objs ):
86
- objective_func = lambda T : sum ([obj (T ) for obj in objs ])
87
- descriptions = [obj .description for obj in objs ]
88
- description = "Sum(" + " +\n " .join (descriptions ) + ")"
89
- names = [obj .name for obj in objs ]
90
- name = ", " .join (names )
91
- return Objective (objective_func , name = name , description = description )
92
-
93
81
def __mul__ (self , other ):
94
82
if isinstance (other , (int , float )):
95
83
objective_func = lambda T : other * self (T )
96
84
else :
97
85
objective_func = lambda T : self (T ) * other (T )
98
- return Objective (objective_func , name = self .name , description = self .description )
86
+ description = str (self ) + "·" + str (other )
87
+ return Objective (objective_func , description = description )
99
88
100
89
def __rmul__ (self , other ):
101
90
return self .__mul__ (other )
@@ -104,7 +93,14 @@ def __radd__(self, other):
104
93
return self .__add__ (other )
105
94
106
95
def __call__ (self , T ):
107
- return self .objective_func (T )
96
+ self .value = self .objective_func (T )
97
+ return self .value
98
+
99
+ def __str__ (self ):
100
+ return self .description
101
+
102
+ def __repr__ (self ):
103
+ return self .description
108
104
109
105
110
106
def _make_arg_str (arg ):
@@ -124,7 +120,7 @@ def wrap_objective(f, *args, **kwds):
124
120
"""
125
121
objective_func = f (* args , ** kwds )
126
122
objective_name = f .__name__
127
- args_str = " [ " + ", " .join ([_make_arg_str (arg ) for arg in args ]) + "] "
123
+ args_str = "( " + ", " .join ([_make_arg_str (arg ) for arg in args ]) + ") "
128
124
description = objective_name .title () + args_str
129
125
return Objective (objective_func , objective_name , description )
130
126
@@ -190,10 +186,10 @@ def direction(layer, vec, batch=None, cossim_pow=0):
190
186
"""Visualize a direction"""
191
187
if batch is None :
192
188
vec = vec [None , None , None ]
193
- return lambda T : _dot_cossim (T (layer ), vec )
189
+ return lambda T : _dot_cossim (T (layer ), vec , cossim_pow = cossim_pow )
194
190
else :
195
191
vec = vec [None , None ]
196
- return lambda T : _dot_cossim (T (layer )[batch ], vec )
192
+ return lambda T : _dot_cossim (T (layer )[batch ], vec , cossim_pow = cossim_pow )
197
193
198
194
199
195
@wrap_objective
0 commit comments