From 547a03ef87e4aa92edc1e89ee2db04114e1a8ad5 Mon Sep 17 00:00:00 2001 From: ievred Date: Wed, 1 Apr 2020 09:13:58 +0200 Subject: fix test example add M to log --- examples/plot_otda_jcpot.py | 20 ++++++-------------- ot/bregman.py | 1 + test/test_da.py | 13 ++----------- 3 files changed, 9 insertions(+), 25 deletions(-) diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py index 5e5fff8..1641fb0 100644 --- a/examples/plot_otda_jcpot.py +++ b/examples/plot_otda_jcpot.py @@ -81,11 +81,7 @@ pl.axis('off') ############################################################################## # Instantiate Sinkhorn transport algorithm and fit them for all source domains # ---------------------------------------------------------------------------- -ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-2, metric='euclidean') - -M1 = ot.dist(xs1, xt, 'euclidean') -M2 = ot.dist(xs2, xt, 'euclidean') - +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean') def print_G(G, xs, ys, xt): for i in range(G.shape[0]): @@ -125,7 +121,7 @@ pl.axis('off') ############################################################################## # Instantiate JCPOT adaptation algorithm and fit it # ---------------------------------------------------------------------------- -otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, tol=1e-9, verbose=True, log=True) +otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) otda.fit(all_Xr, all_Yr, xt) ws1 = otda.proportions_.dot(otda.log_['all_domains'][0]['D2']) @@ -136,8 +132,8 @@ pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], M1, reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], M2, reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) @@ -154,10 +150,6 @@ pl.axis('off') ############################################################################## # Run oracle transport algorithm with known proportions # ---------------------------------------------------------------------------- - -otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log=True) -otda.fit(all_Xr, all_Yr, xt) - h_res = np.array([1 - pt, pt]) ws1 = h_res.dot(otda.log_['all_domains'][0]['D2']) @@ -168,8 +160,8 @@ pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], M1, reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], M2, reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) diff --git a/ot/bregman.py b/ot/bregman.py index d17aaf0..fb959e9 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1603,6 +1603,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # build the cost matrix and the Gibbs kernel M = dist(Xs[d], Xt, metric=metric) M = M / np.median(M) + dom['M'] = M K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) diff --git a/test/test_da.py b/test/test_da.py index 7526f30..a13550c 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -568,7 +568,7 @@ def test_jcpot_transport_class(): Xs = [Xs1, Xs2] ys = [ys1, ys2] - otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True) + otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log = True) # test its computed otda.fit(Xs=Xs, ys=ys, Xt=Xt) @@ -591,14 +591,8 @@ def test_jcpot_transport_class(): # test margin constraints w.r.t. modified source weights for each source domain - D1 = np.zeros((len(np.unique(ys[i])), len(ys[i]))) - for c in np.unique(ys[i]): - nbelemperclass = np.sum(ys[i] == c) - if nbelemperclass != 0: - D1[int(c), ys[i] == c] = 1. - assert_allclose( - np.dot(D1, np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3) + np.dot(otda.log_['all_domains'][i]['D1'], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -609,6 +603,3 @@ def test_jcpot_transport_class(): # check that the oos method is working assert_equal(transp_Xs_new.shape, Xs_new.shape) - - -test_jcpot_transport_class() \ No newline at end of file -- cgit v1.2.3