diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-04 10:18:20 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-04 10:18:20 +0100 |
commit | 5992b14bedc1d51b10278474aa2f41f4ce822c38 (patch) | |
tree | cbc145a8b2d6a50c30d8239212b8c6347861c45e | |
parent | 405f352dc562eb17a2d6d7ca17c2ce14b19f2668 (diff) |
add demo mapping
-rw-r--r-- | docs/source/examples.rst | 5 | ||||
-rw-r--r-- | examples/demo_OTDA_mapping.py | 110 | ||||
-rw-r--r-- | ot/da.py | 4 |
3 files changed, 117 insertions, 2 deletions
diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 6093299..f209543 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -32,3 +32,8 @@ Color transfer in images ------------------------ .. literalinclude:: ../../examples/demo_OTDA_color_images.py + +OT mapping estimation for domain adaptation +------------------------------------------- + +.. literalinclude:: ../../examples/demo_OTDA_mapping.py diff --git a/examples/demo_OTDA_mapping.py b/examples/demo_OTDA_mapping.py new file mode 100644 index 0000000..f5da2ff --- /dev/null +++ b/examples/demo_OTDA_mapping.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +""" +Demo of OT mapping estimation for somain adaptation +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + + + +#%% dataset generation + +np.random.seed(0) + +n=100 # nb samples in source and target datasets +theta=2*np.pi/20 +nz=0.1 +xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz) +xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz) + +# one of the target mode changes its variance (no linear mapping) +xt[yt==2]*=3 +xt=xt+4 + + +#%% plot samples + +pl.figure(1,(8,5)) +pl.clf() + +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +pl.legend(loc=0) +pl.title('Source and target distributions') + + + +#%% OT linear mapping estimation + +eta=1e-8 # quadratic regularization for regression +mu=1e0 # weight of the OT linear term +bias=True # estimate a bias + +ot_mapping=ot.da.OTDA_mapping_linear() +ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) + +xst=ot_mapping.predict(xs) # use the estimated mapping +xst0=ot_mapping.interp() # use barycentric mapping + + +pl.figure(2,(10,7)) +pl.clf() +pl.subplot(2,2,1) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping') +pl.title("barycentric mapping") + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) +pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Learned mapping") + + + +#%% Kernel mapping estimation + +eta=1e-5 # quadratic regularization for regression +mu=1e-1 # weight of the OT linear term +bias=True # estimate a bias +sigma=1 # sigma bandwidth fot gaussian kernel + + +ot_mapping_kernel=ot.da.OTDA_mapping_kernel() +ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) + +xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping +xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping + + +#%% Plotting the mapped samples + +pl.figure(2,(10,7)) +pl.clf() +pl.subplot(2,2,1) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples') +pl.title("Bary. mapping (linear)") +pl.legend(loc=0) + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Estim. mapping (linear)") + +pl.subplot(2,2,3) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping') +pl.title("Bary. mapping (kernel)") + +pl.subplot(2,2,4) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Estim. mapping (kernel)") + + + + + @@ -203,7 +203,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos if it>=numItermax: loop=0 - if abs(vloss[-1]-vloss[-2])<stopThr: + if abs(vloss[-1]-vloss[-2])/abs(vloss[-2])<stopThr: loop=0 if verbose: @@ -323,7 +323,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b if it>=numItermax: loop=0 - if abs(vloss[-1]-vloss[-2])<stopThr: + if abs(vloss[-1]-vloss[-2])/abs(vloss[-2])<stopThr: loop=0 if verbose: |