summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-07 13:50:11 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-07 13:50:11 +0200
commit2c9f992157844d6253a302905417e86580ac6b12 (patch)
treed7314b752b8d50f7bcc8e34eab11e2a109ddb222
parent34e13d467e376e9bfee2eb15771d9308518c2adb (diff)
upd
-rw-r--r--examples/plot_otda_classes.py1
-rw-r--r--examples/plot_otda_jcpot.py16
-rw-r--r--ot/bregman.py2
-rw-r--r--test/test_da.py2
4 files changed, 10 insertions, 11 deletions
diff --git a/examples/plot_otda_classes.py b/examples/plot_otda_classes.py
index c311fbd..f028022 100644
--- a/examples/plot_otda_classes.py
+++ b/examples/plot_otda_classes.py
@@ -17,7 +17,6 @@ approaches currently supported in POT.
import matplotlib.pylab as pl
import ot
-
##############################################################################
# Generate data
# -------------
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py
index ce6b88f..316fa8b 100644
--- a/examples/plot_otda_jcpot.py
+++ b/examples/plot_otda_jcpot.py
@@ -118,16 +118,16 @@ pl.axis('off')
otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
otda.fit(all_Xr, all_Yr, xt)
-ws1 = otda.proportions_.dot(otda.log_['all_domains'][0]['D2'])
-ws2 = otda.proportions_.dot(otda.log_['all_domains'][1]['D2'])
+ws1 = otda.proportions_.dot(otda.log_['D2'][0])
+ws2 = otda.proportions_.dot(otda.log_['D2'][1])
pl.figure(3)
pl.clf()
plot_ax(dec1, 'Source 1')
plot_ax(dec2, 'Source 2')
plot_ax(dect, 'Target')
-print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt)
-print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt)
+print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
@@ -146,16 +146,16 @@ pl.axis('off')
# ----------------------------------------------------------------------------
h_res = np.array([1 - pt, pt])
-ws1 = h_res.dot(otda.log_['all_domains'][0]['D2'])
-ws2 = h_res.dot(otda.log_['all_domains'][1]['D2'])
+ws1 = h_res.dot(otda.log_['D2'][0])
+ws2 = h_res.dot(otda.log_['D2'][1])
pl.figure(4)
pl.clf()
plot_ax(dec1, 'Source 1')
plot_ax(dec2, 'Source 2')
plot_ax(dect, 'Target')
-print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt)
-print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt)
+print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
diff --git a/ot/bregman.py b/ot/bregman.py
index ec81924..61dfa52 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1608,7 +1608,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
# build the cost matrix and the Gibbs kernel
Mtmp = dist(Xs[d], Xt, metric=metric)
Mtmp = Mtmp / np.median(Mtmp)
- M.append(M)
+ M.append(Mtmp)
Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)
np.divide(Mtmp, -reg, out=Ktmp)
diff --git a/test/test_da.py b/test/test_da.py
index 372ebd4..4eaf193 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -589,7 +589,7 @@ def test_jcpot_transport_class():
# test margin constraints w.r.t. modified source weights for each source domain
assert_allclose(
- np.dot(otda.log_['all_domains'][i]['D1'], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3,
+ np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3,
atol=1e-3)
# test transform