summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_dr.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/test/test_dr.py b/test/test_dr.py
index 741f2ad..6d7fc9a 100644
--- a/test/test_dr.py
+++ b/test/test_dr.py
@@ -61,6 +61,28 @@ def test_wda():
@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
+def test_wda_low_reg():
+
+ n_samples = 100 # nb samples in source and target datasets
+ np.random.seed(0)
+
+ # generate gaussian dataset
+ xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples)
+
+ n_features_noise = 8
+
+ xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise)))
+
+ p = 2
+
+ Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method='sinkhorn_log')
+
+ projwda(xs)
+
+ np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
+
+
+@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
def test_wda_normalized():
n_samples = 100 # nb samples in source and target datasets