summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-04 13:56:51 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-04 13:56:51 +0200
commitd793f1f73e6f816458d8b307762675aa9fa84d22 (patch)
tree23e5975343c6622cc2a359258d4e15424bbfe3ea /test
parent0b005906f9d78adbf4d52d2ea9610eb3fde96a7c (diff)
correction of semi supervised mode
Diffstat (limited to 'test')
-rw-r--r--test/test_da.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 497a8ee..ecd2a3a 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -63,12 +63,12 @@ def test_sinkhorn_lpl1_transport_class():
assert_equal(transp_Xs.shape, Xs.shape)
# test semi supervised mode
- clf = ot.da.SinkhornTransport(mode="semisupervised")
- clf.fit(Xs=Xs, Xt=Xt)
+ clf = ot.da.SinkhornLpl1Transport()
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt)
n_unsup = np.sum(clf.Cost)
# test semi supervised mode
- clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf = ot.da.SinkhornLpl1Transport()
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = np.sum(clf.Cost)
@@ -126,12 +126,12 @@ def test_sinkhorn_l1l2_transport_class():
assert_equal(transp_Xs.shape, Xs.shape)
# test semi supervised mode
- clf = ot.da.SinkhornTransport(mode="semisupervised")
- clf.fit(Xs=Xs, Xt=Xt)
+ clf = ot.da.SinkhornL1l2Transport()
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt)
n_unsup = np.sum(clf.Cost)
# test semi supervised mode
- clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf = ot.da.SinkhornL1l2Transport()
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = np.sum(clf.Cost)
@@ -189,12 +189,12 @@ def test_sinkhorn_transport_class():
assert_equal(transp_Xs.shape, Xs.shape)
# test semi supervised mode
- clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf = ot.da.SinkhornTransport()
clf.fit(Xs=Xs, Xt=Xt)
n_unsup = np.sum(clf.Cost)
# test semi supervised mode
- clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf = ot.da.SinkhornTransport()
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = np.sum(clf.Cost)
@@ -252,12 +252,12 @@ def test_emd_transport_class():
assert_equal(transp_Xs.shape, Xs.shape)
# test semi supervised mode
- clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf = ot.da.EMDTransport()
clf.fit(Xs=Xs, Xt=Xt)
n_unsup = np.sum(clf.Cost)
# test semi supervised mode
- clf = ot.da.SinkhornTransport(mode="semisupervised")
+ clf = ot.da.EMDTransport()
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
n_semisup = np.sum(clf.Cost)