summaryrefslogtreecommitdiff
path: root/test/test_da.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_da.py')
-rw-r--r--test/test_da.py13
1 files changed, 2 insertions, 11 deletions
diff --git a/test/test_da.py b/test/test_da.py
index 7526f30..a13550c 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -568,7 +568,7 @@ def test_jcpot_transport_class():
Xs = [Xs1, Xs2]
ys = [ys1, ys2]
- otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True)
+ otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log = True)
# test its computed
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
@@ -591,14 +591,8 @@ def test_jcpot_transport_class():
# test margin constraints w.r.t. modified source weights for each source domain
- D1 = np.zeros((len(np.unique(ys[i])), len(ys[i])))
- for c in np.unique(ys[i]):
- nbelemperclass = np.sum(ys[i] == c)
- if nbelemperclass != 0:
- D1[int(c), ys[i] == c] = 1.
-
assert_allclose(
- np.dot(D1, np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3)
+ np.dot(otda.log_['all_domains'][i]['D1'], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
@@ -609,6 +603,3 @@ def test_jcpot_transport_class():
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
-
-
-test_jcpot_transport_class() \ No newline at end of file