Skip to content

Commit 0cb2b2e

Browse files
authored
[MRG] Add tesing on wda (#296)
1 parent 1b5c35b commit 0cb2b2e

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

test/test_dr.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,31 @@ def test_wda():
6060
np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
6161

6262

63+
@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
64+
def test_wda_normalized():
65+
66+
n_samples = 100 # nb samples in source and target datasets
67+
np.random.seed(0)
68+
69+
# generate gaussian dataset
70+
xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples)
71+
72+
n_features_noise = 8
73+
74+
xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise)))
75+
76+
p = 2
77+
78+
P0 = np.random.randn(10, p)
79+
P0 /= P0.sum(0, keepdims=True)
80+
81+
Pwda, projwda = ot.dr.wda(xs, ys, p, maxiter=10, P0=P0, normalize=True)
82+
83+
projwda(xs)
84+
85+
np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
86+
87+
6388
@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
6489
def test_prw():
6590
d = 100 # Dimension

0 commit comments

Comments
 (0)