diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-01 09:42:09 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-01 09:42:09 +0200 |
commit | b1f87363b160735b6e2df59380f9de56b7934b53 (patch) | |
tree | 7e36f3cc31dad1ffeefc53c0f41d4802dd8ec54c /examples | |
parent | 547a03ef87e4aa92edc1e89ee2db04114e1a8ad5 (diff) |
add dataset clean plot
Diffstat (limited to 'examples')
-rw-r--r-- | examples/plot_otda_jcpot.py | 28 |
1 files changed, 9 insertions, 19 deletions
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py index 1641fb0..579ad2a 100644 --- a/examples/plot_otda_jcpot.py +++ b/examples/plot_otda_jcpot.py @@ -16,6 +16,7 @@ This example introduces a target shift problem with two 2D source and 1 target d import pylab as pl import numpy as np import ot +from ot.datasets import make_data_classif ############################################################################## # Generate data @@ -24,17 +25,6 @@ n = 50 sigma = 0.3 np.random.seed(1985) - -def get_data(n, p, dec): - y = np.concatenate((np.ones(int(p * n)), np.zeros(int((1 - p) * n)))) - x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + sigma * np.random.randn(len(y), 2) - - x[:, 0] += dec[0] - x[:, 1] += dec[1] - - return x, y - - p1 = .2 dec1 = [0, 2] @@ -44,15 +34,15 @@ dec2 = [0, -2] pt = .4 dect = [4, 0] -xs1, ys1 = get_data(n, p1, dec1) -xs2, ys2 = get_data(n + 1, p2, dec2) -xt, yt = get_data(n, pt, dect) +xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p = p1, bias = dec1) +xs2, ys2 = make_data_classif('2gauss_prop', n+1, nz=sigma, p = p2, bias = dec2) +xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p = pt, bias = dect) + all_Xr = [xs1, xs2] all_Yr = [ys1, ys2] # %% -da = 1.5 - +da = 1.5 def plot_ax(dec, name): pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5) pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5) @@ -68,9 +58,9 @@ pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, label='Source 1 (0.8,0.2)') -pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, label='Source 2 (0.1,0.9)') -pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, label='Target (0.6,0.4)') +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, label='Source 1 ({:1.2f}, {:1.2f})'.format(1-p1, p1)) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, label='Source 2 ({:1.2f}, {:1.2f})'.format(1-p2, p2)) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, label='Target ({:1.2f}, {:1.2f})'.format(1-pt, pt)) pl.title('Data') pl.legend() |