summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-01 09:49:24 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-01 09:49:24 +0200
commit6b8477d1c08696a08a1b71642712d83e560f9623 (patch)
tree702f920a75bf3f9c9b316d8c21cc71211ff27f28
parentb1f87363b160735b6e2df59380f9de56b7934b53 (diff)
pep8
-rw-r--r--examples/plot_otda_jcpot.py20
-rw-r--r--test/test_da.py7
2 files changed, 15 insertions, 12 deletions
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py
index 579ad2a..ce6b88f 100644
--- a/examples/plot_otda_jcpot.py
+++ b/examples/plot_otda_jcpot.py
@@ -34,15 +34,17 @@ dec2 = [0, -2]
pt = .4
dect = [4, 0]
-xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p = p1, bias = dec1)
-xs2, ys2 = make_data_classif('2gauss_prop', n+1, nz=sigma, p = p2, bias = dec2)
-xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p = pt, bias = dect)
+xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1)
+xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2)
+xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect)
all_Xr = [xs1, xs2]
all_Yr = [ys1, ys2]
# %%
da = 1.5
+
+
def plot_ax(dec, name):
pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5)
pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5)
@@ -58,21 +60,24 @@ pl.clf()
plot_ax(dec1, 'Source 1')
plot_ax(dec2, 'Source 2')
plot_ax(dect, 'Target')
-pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, label='Source 1 ({:1.2f}, {:1.2f})'.format(1-p1, p1))
-pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, label='Source 2 ({:1.2f}, {:1.2f})'.format(1-p2, p2))
-pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, label='Target ({:1.2f}, {:1.2f})'.format(1-pt, pt))
+pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9,
+ label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1))
+pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9,
+ label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2))
+pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9,
+ label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt))
pl.title('Data')
pl.legend()
pl.axis('equal')
pl.axis('off')
-
##############################################################################
# Instantiate Sinkhorn transport algorithm and fit them for all source domains
# ----------------------------------------------------------------------------
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean')
+
def print_G(G, xs, ys, xt):
for i in range(G.shape[0]):
for j in range(G.shape[1]):
@@ -107,7 +112,6 @@ pl.legend()
pl.axis('equal')
pl.axis('off')
-
##############################################################################
# Instantiate JCPOT adaptation algorithm and fit it
# ----------------------------------------------------------------------------
diff --git a/test/test_da.py b/test/test_da.py
index a13550c..f700df9 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -511,7 +511,6 @@ def test_mapping_transport_class():
def test_linear_mapping():
-
ns = 150
nt = 200
@@ -529,7 +528,6 @@ def test_linear_mapping():
def test_linear_mapping_class():
-
ns = 150
nt = 200
@@ -568,7 +566,7 @@ def test_jcpot_transport_class():
Xs = [Xs1, Xs2]
ys = [ys1, ys2]
- otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log = True)
+ otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log=True)
# test its computed
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
@@ -592,7 +590,8 @@ def test_jcpot_transport_class():
# test margin constraints w.r.t. modified source weights for each source domain
assert_allclose(
- np.dot(otda.log_['all_domains'][i]['D1'], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3)
+ np.dot(otda.log_['all_domains'][i]['D1'], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3,
+ atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)