summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-04 10:18:20 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-04 10:18:20 +0100
commit5992b14bedc1d51b10278474aa2f41f4ce822c38 (patch)
treecbc145a8b2d6a50c30d8239212b8c6347861c45e /examples
parent405f352dc562eb17a2d6d7ca17c2ce14b19f2668 (diff)
add demo mapping
Diffstat (limited to 'examples')
-rw-r--r--examples/demo_OTDA_mapping.py110
1 files changed, 110 insertions, 0 deletions
diff --git a/examples/demo_OTDA_mapping.py b/examples/demo_OTDA_mapping.py
new file mode 100644
index 0000000..f5da2ff
--- /dev/null
+++ b/examples/demo_OTDA_mapping.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+"""
+Demo of OT mapping estimation for somain adaptation
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+
+#%% dataset generation
+
+np.random.seed(0)
+
+n=100 # nb samples in source and target datasets
+theta=2*np.pi/20
+nz=0.1
+xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz)
+xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz)
+
+# one of the target mode changes its variance (no linear mapping)
+xt[yt==2]*=3
+xt=xt+4
+
+
+#%% plot samples
+
+pl.figure(1,(8,5))
+pl.clf()
+
+pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
+
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+
+
+#%% OT linear mapping estimation
+
+eta=1e-8 # quadratic regularization for regression
+mu=1e0 # weight of the OT linear term
+bias=True # estimate a bias
+
+ot_mapping=ot.da.OTDA_mapping_linear()
+ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)
+
+xst=ot_mapping.predict(xs) # use the estimated mapping
+xst0=ot_mapping.interp() # use barycentric mapping
+
+
+pl.figure(2,(10,7))
+pl.clf()
+pl.subplot(2,2,1)
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
+pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping')
+pl.title("barycentric mapping")
+
+pl.subplot(2,2,2)
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
+pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
+pl.title("Learned mapping")
+
+
+
+#%% Kernel mapping estimation
+
+eta=1e-5 # quadratic regularization for regression
+mu=1e-1 # weight of the OT linear term
+bias=True # estimate a bias
+sigma=1 # sigma bandwidth fot gaussian kernel
+
+
+ot_mapping_kernel=ot.da.OTDA_mapping_kernel()
+ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)
+
+xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping
+xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping
+
+
+#%% Plotting the mapped samples
+
+pl.figure(2,(10,7))
+pl.clf()
+pl.subplot(2,2,1)
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
+pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples')
+pl.title("Bary. mapping (linear)")
+pl.legend(loc=0)
+
+pl.subplot(2,2,2)
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
+pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
+pl.title("Estim. mapping (linear)")
+
+pl.subplot(2,2,3)
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
+pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping')
+pl.title("Bary. mapping (kernel)")
+
+pl.subplot(2,2,4)
+pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
+pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping')
+pl.title("Estim. mapping (kernel)")
+
+
+
+
+