summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-21 08:29:50 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-21 08:29:50 +0100
commit6fdf5de8fa27fa16d6b8910fe96eb67b7761aa0e (patch)
tree55b8d6e463b131134b4aa9d4e3013bbb77811da6 /test/test_da.py
parent287c659ad35f5036ba2687caf73009ef455c7239 (diff)
add linear mapping test + autopep8
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 593dc53..7b63daf 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -444,6 +444,24 @@ def test_mapping_transport_class():
assert len(otda.log_.keys()) != 0
+def test_linear_mapping():
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = get_data_classif('3gauss', ns)
+ Xt, yt = get_data_classif('3gauss2', nt)
+
+ A, b = ot.da.OT_mapping_linear(Xs, Xt)
+
+ Xst = Xs.dot(A) + b
+
+ Ct = np.cov(Xt.T)
+ Cst = np.cov(Xst.T)
+
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
def test_otda():
n_samples = 150 # nb samples