summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-01 09:42:09 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-01 09:42:09 +0200
commitb1f87363b160735b6e2df59380f9de56b7934b53 (patch)
tree7e36f3cc31dad1ffeefc53c0f41d4802dd8ec54c /examples
parent547a03ef87e4aa92edc1e89ee2db04114e1a8ad5 (diff)
add dataset clean plot
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_otda_jcpot.py28
1 files changed, 9 insertions, 19 deletions
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py
index 1641fb0..579ad2a 100644
--- a/examples/plot_otda_jcpot.py
+++ b/examples/plot_otda_jcpot.py
@@ -16,6 +16,7 @@ This example introduces a target shift problem with two 2D source and 1 target d
import pylab as pl
import numpy as np
import ot
+from ot.datasets import make_data_classif
##############################################################################
# Generate data
@@ -24,17 +25,6 @@ n = 50
sigma = 0.3
np.random.seed(1985)
-
-def get_data(n, p, dec):
- y = np.concatenate((np.ones(int(p * n)), np.zeros(int((1 - p) * n))))
- x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + sigma * np.random.randn(len(y), 2)
-
- x[:, 0] += dec[0]
- x[:, 1] += dec[1]
-
- return x, y
-
-
p1 = .2
dec1 = [0, 2]
@@ -44,15 +34,15 @@ dec2 = [0, -2]
pt = .4
dect = [4, 0]
-xs1, ys1 = get_data(n, p1, dec1)
-xs2, ys2 = get_data(n + 1, p2, dec2)
-xt, yt = get_data(n, pt, dect)
+xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p = p1, bias = dec1)
+xs2, ys2 = make_data_classif('2gauss_prop', n+1, nz=sigma, p = p2, bias = dec2)
+xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p = pt, bias = dect)
+
all_Xr = [xs1, xs2]
all_Yr = [ys1, ys2]
# %%
-da = 1.5
-
+da = 1.5
def plot_ax(dec, name):
pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5)
pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5)
@@ -68,9 +58,9 @@ pl.clf()
plot_ax(dec1, 'Source 1')
plot_ax(dec2, 'Source 2')
plot_ax(dect, 'Target')
-pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, label='Source 1 (0.8,0.2)')
-pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, label='Source 2 (0.1,0.9)')
-pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, label='Target (0.6,0.4)')
+pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, label='Source 1 ({:1.2f}, {:1.2f})'.format(1-p1, p1))
+pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, label='Source 2 ({:1.2f}, {:1.2f})'.format(1-p2, p2))
+pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, label='Target ({:1.2f}, {:1.2f})'.format(1-pt, pt))
pl.title('Data')
pl.legend()