@@ -109,7 +109,7 @@ def proj(X):
109
109
return Popt , proj
110
110
111
111
112
- def wda (X , y , p = 2 , reg = 1 , k = 10 , solver = None , maxiter = 100 , verbose = 0 , P0 = None ):
112
+ def wda (X , y , p = 2 , reg = 1 , k = 10 , solver = None , maxiter = 100 , verbose = 0 , P0 = None , normalize = False ):
113
113
r"""
114
114
Wasserstein Discriminant Analysis [11]_
115
115
@@ -139,6 +139,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
139
139
else should be a pymanopt.solvers
140
140
P0 : ndarray, shape (d, p)
141
141
Initial starting point for projection.
142
+ normalize : bool, optional
143
+ Normalise the Wasserstaiun distane by the average distance on P0 (default : False)
142
144
verbose : int, optional
143
145
Print information along iterations.
144
146
@@ -164,6 +166,18 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
164
166
# compute uniform weighs
165
167
wc = [np .ones ((x .shape [0 ]), dtype = np .float32 ) / x .shape [0 ] for x in xc ]
166
168
169
+ # pre-compute reg_c,c'
170
+ if P0 is not None and normalize :
171
+ regmean = np .zeros ((len (xc ), len (xc )))
172
+ for i , xi in enumerate (xc ):
173
+ xi = np .dot (xi , P0 )
174
+ for j , xj in enumerate (xc [i :]):
175
+ xj = np .dot (xj , P0 )
176
+ M = dist (xi , xj )
177
+ regmean [i , j ] = np .sum (M ) / (len (xi ) * len (xj ))
178
+ else :
179
+ regmean = np .ones ((len (xc ), len (xc )))
180
+
167
181
def cost (P ):
168
182
# wda loss
169
183
loss_b = 0
@@ -174,7 +188,7 @@ def cost(P):
174
188
for j , xj in enumerate (xc [i :]):
175
189
xj = np .dot (xj , P )
176
190
M = dist (xi , xj )
177
- G = sinkhorn (wc [i ], wc [j + i ], M , reg , k )
191
+ G = sinkhorn (wc [i ], wc [j + i ], M , reg * regmean [ i , j ] , k )
178
192
if j == 0 :
179
193
loss_w += np .sum (G * M )
180
194
else :
0 commit comments