summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2018-06-11 11:24:57 +0200
committerGitHub <noreply@github.com>2018-06-11 11:24:57 +0200
commit47730fc727c0f54e8459964d9208ad824e3f91da (patch)
tree48952166e88e602f1843bd15c0187c7d5ffb6cac /test/test_da.py
parent4641ec5f2ddbff1a468afaf65741aecae44738cc (diff)
parent530dc93a60e9b81fb8d1b44680deea77dacf660b (diff)
Merge branch 'master' into remove_otda_v05
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py50
1 files changed, 25 insertions, 25 deletions
diff --git a/test/test_da.py b/test/test_da.py
index dc8bc5f..b429315 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -8,7 +8,7 @@ import numpy as np
from numpy.testing.utils import assert_allclose, assert_equal
import ot
-from ot.datasets import get_data_classif
+from ot.datasets import make_data_classif
from ot.utils import unif
@@ -19,8 +19,8 @@ def test_sinkhorn_lpl1_transport_class():
ns = 150
nt = 200
- Xs, ys = get_data_classif('3gauss', ns)
- Xt, yt = get_data_classif('3gauss2', nt)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
otda = ot.da.SinkhornLpl1Transport()
@@ -45,7 +45,7 @@ def test_sinkhorn_lpl1_transport_class():
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = get_data_classif('3gauss', ns + 1)
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -55,7 +55,7 @@ def test_sinkhorn_lpl1_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
- Xt_new, _ = get_data_classif('3gauss2', nt + 1)
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -92,8 +92,8 @@ def test_sinkhorn_l1l2_transport_class():
ns = 150
nt = 200
- Xs, ys = get_data_classif('3gauss', ns)
- Xt, yt = get_data_classif('3gauss2', nt)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
otda = ot.da.SinkhornL1l2Transport()
@@ -119,7 +119,7 @@ def test_sinkhorn_l1l2_transport_class():
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = get_data_classif('3gauss', ns + 1)
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -129,7 +129,7 @@ def test_sinkhorn_l1l2_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
- Xt_new, _ = get_data_classif('3gauss2', nt + 1)
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -173,8 +173,8 @@ def test_sinkhorn_transport_class():
ns = 150
nt = 200
- Xs, ys = get_data_classif('3gauss', ns)
- Xt, yt = get_data_classif('3gauss2', nt)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
otda = ot.da.SinkhornTransport()
@@ -200,7 +200,7 @@ def test_sinkhorn_transport_class():
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = get_data_classif('3gauss', ns + 1)
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -210,7 +210,7 @@ def test_sinkhorn_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
- Xt_new, _ = get_data_classif('3gauss2', nt + 1)
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -252,8 +252,8 @@ def test_emd_transport_class():
ns = 150
nt = 200
- Xs, ys = get_data_classif('3gauss', ns)
- Xt, yt = get_data_classif('3gauss2', nt)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
otda = ot.da.EMDTransport()
@@ -278,7 +278,7 @@ def test_emd_transport_class():
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = get_data_classif('3gauss', ns + 1)
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -288,7 +288,7 @@ def test_emd_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
- Xt_new, _ = get_data_classif('3gauss2', nt + 1)
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -329,9 +329,9 @@ def test_mapping_transport_class():
ns = 60
nt = 120
- Xs, ys = get_data_classif('3gauss', ns)
- Xt, yt = get_data_classif('3gauss2', nt)
- Xs_new, _ = get_data_classif('3gauss', ns + 1)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
##########################################################################
# kernel == linear mapping tests
@@ -449,8 +449,8 @@ def test_linear_mapping():
ns = 150
nt = 200
- Xs, ys = get_data_classif('3gauss', ns)
- Xt, yt = get_data_classif('3gauss2', nt)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
A, b = ot.da.OT_mapping_linear(Xs, Xt)
@@ -467,8 +467,8 @@ def test_linear_mapping_class():
ns = 150
nt = 200
- Xs, ys = get_data_classif('3gauss', ns)
- Xt, yt = get_data_classif('3gauss2', nt)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
otmap = ot.da.LinearTransport()
@@ -483,4 +483,4 @@ def test_linear_mapping_class():
Ct = np.cov(Xt.T)
Cst = np.cov(Xst.T)
- np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+ np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) \ No newline at end of file