diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-01 09:49:24 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-01 09:49:24 +0200 |
commit | 6b8477d1c08696a08a1b71642712d83e560f9623 (patch) | |
tree | 702f920a75bf3f9c9b316d8c21cc71211ff27f28 /examples | |
parent | b1f87363b160735b6e2df59380f9de56b7934b53 (diff) |
pep8
Diffstat (limited to 'examples')
-rw-r--r-- | examples/plot_otda_jcpot.py | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py index 579ad2a..ce6b88f 100644 --- a/examples/plot_otda_jcpot.py +++ b/examples/plot_otda_jcpot.py @@ -34,15 +34,17 @@ dec2 = [0, -2] pt = .4 dect = [4, 0] -xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p = p1, bias = dec1) -xs2, ys2 = make_data_classif('2gauss_prop', n+1, nz=sigma, p = p2, bias = dec2) -xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p = pt, bias = dect) +xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1) +xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2) +xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect) all_Xr = [xs1, xs2] all_Yr = [ys1, ys2] # %% da = 1.5 + + def plot_ax(dec, name): pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5) pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5) @@ -58,21 +60,24 @@ pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, label='Source 1 ({:1.2f}, {:1.2f})'.format(1-p1, p1)) -pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, label='Source 2 ({:1.2f}, {:1.2f})'.format(1-p2, p2)) -pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, label='Target ({:1.2f}, {:1.2f})'.format(1-pt, pt)) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, + label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1)) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, + label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2)) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, + label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt)) pl.title('Data') pl.legend() pl.axis('equal') pl.axis('off') - ############################################################################## # Instantiate Sinkhorn transport algorithm and fit them for all source domains # ---------------------------------------------------------------------------- ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean') + def print_G(G, xs, ys, xt): for i in range(G.shape[0]): for j in range(G.shape[1]): @@ -107,7 +112,6 @@ pl.legend() pl.axis('equal') pl.axis('off') - ############################################################################## # Instantiate JCPOT adaptation algorithm and fit it # ---------------------------------------------------------------------------- |