summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/demo_OTDA_2D.py117
-rw-r--r--ot/da.py3
-rw-r--r--ot/datasets.py10
3 files changed, 124 insertions, 6 deletions
diff --git a/examples/demo_OTDA_2D.py b/examples/demo_OTDA_2D.py
new file mode 100644
index 0000000..fbaf56d
--- /dev/null
+++ b/examples/demo_OTDA_2D.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+"""
+demo of Optimal transport for domain adaptation
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+
+#%% parameters
+
+n=150 # nb bins
+
+xs,ys=ot.datasets.get_data_classif('3gauss',n)
+xt,yt=ot.datasets.get_data_classif('3gauss2',n)
+
+a,b = ot.unif(n),ot.unif(n)
+# loss matrix
+M=ot.dist(xs,xt)
+#M/=M.max()
+
+#%% plot samples
+
+pl.figure(1)
+
+pl.subplot(2,2,1)
+pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
+pl.legend(loc=0)
+pl.title('Source distributions')
+
+pl.subplot(2,2,2)
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
+pl.legend(loc=0)
+pl.title('target distributions')
+
+pl.figure(2)
+pl.imshow(M,interpolation='nearest')
+pl.title('Cost matrix M')
+
+
+#%% OT estimation
+
+# EMD
+G0=ot.emd(a,b,M)
+
+# sinkhorn
+lambd=1e-1
+Gs=ot.sinkhorn(a,b,M,lambd)
+
+
+# Group lasso regularization
+reg=1e-1
+eta=1e0
+Gg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta)
+
+
+#%% visu matrices
+
+pl.figure(3)
+
+pl.subplot(2,3,1)
+pl.imshow(G0,interpolation='nearest')
+pl.title('OT matrix ')
+
+pl.subplot(2,3,2)
+pl.imshow(Gs,interpolation='nearest')
+pl.title('OT matrix Sinkhorn')
+
+pl.subplot(2,3,3)
+pl.imshow(Gg,interpolation='nearest')
+pl.title('OT matrix Group lasso')
+
+pl.subplot(2,3,4)
+ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
+pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
+
+
+pl.subplot(2,3,5)
+ot.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1])
+pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
+
+pl.subplot(2,3,6)
+ot.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1])
+pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
+
+#%% sample interpolation
+
+xst0=n*G0.dot(xt)
+xsts=n*Gs.dot(xt)
+xstg=n*Gg.dot(xt)
+
+pl.figure(4)
+pl.subplot(2,3,1)
+
+
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
+pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
+pl.title('Interp samples')
+pl.legend(loc=0)
+
+pl.subplot(2,3,2)
+
+
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
+pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
+pl.title('Interp samples Sinkhorn')
+
+pl.subplot(2,3,3)
+
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
+pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
+pl.title('Interp samples Grouplasso') \ No newline at end of file
diff --git a/ot/da.py b/ot/da.py
index 8ecd952..3ef9fc7 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -47,4 +47,5 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1):
if idx_begin==-1:
W[indices_labels[0],t]=np.min(all_maj)
- return transp \ No newline at end of file
+ return transp
+ \ No newline at end of file
diff --git a/ot/datasets.py b/ot/datasets.py
index edc29a9..f22e345 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -37,17 +37,17 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
x[y==2,0]=-1.; x[y==2,1]=1.
x[y==3,0]=1. ; x[y==3,1]=0
- x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
+ x[y!=3,:]+=1.5*nz*np.random.randn(sum(y!=3),2)
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
elif dataset.lower()=='3gauss2':
- y=np.floor((np.arange(n)*1.0/n*4))+1
+ y=np.floor((np.arange(n)*1.0/n*3))+1
x=np.zeros((n,2))
y[y==4]=3
# class 1
- x[y==1,0]=-1.; x[y==1,1]=-1.
- x[y==2,0]=-1.; x[y==2,1]=1.
- x[y==3,0]=1. ; x[y==3,1]=0
+ x[y==1,0]=-2.; x[y==1,1]=-2.
+ x[y==2,0]=-2.; x[y==2,1]=2.
+ x[y==3,0]=2. ; x[y==3,1]=0
x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)