From 5992b14bedc1d51b10278474aa2f41f4ce822c38 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 4 Nov 2016 10:18:20 +0100 Subject: add demo mapping --- docs/source/examples.rst | 5 ++ examples/demo_OTDA_mapping.py | 110 ++++++++++++++++++++++++++++++++++++++++++ ot/da.py | 4 +- 3 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 examples/demo_OTDA_mapping.py diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 6093299..f209543 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -32,3 +32,8 @@ Color transfer in images ------------------------ .. literalinclude:: ../../examples/demo_OTDA_color_images.py + +OT mapping estimation for domain adaptation +------------------------------------------- + +.. literalinclude:: ../../examples/demo_OTDA_mapping.py diff --git a/examples/demo_OTDA_mapping.py b/examples/demo_OTDA_mapping.py new file mode 100644 index 0000000..f5da2ff --- /dev/null +++ b/examples/demo_OTDA_mapping.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +""" +Demo of OT mapping estimation for somain adaptation +""" + +import numpy as np +import matplotlib.pylab as pl +import ot + + + +#%% dataset generation + +np.random.seed(0) + +n=100 # nb samples in source and target datasets +theta=2*np.pi/20 +nz=0.1 +xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz) +xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz) + +# one of the target mode changes its variance (no linear mapping) +xt[yt==2]*=3 +xt=xt+4 + + +#%% plot samples + +pl.figure(1,(8,5)) +pl.clf() + +pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples') +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples') + +pl.legend(loc=0) +pl.title('Source and target distributions') + + + +#%% OT linear mapping estimation + +eta=1e-8 # quadratic regularization for regression +mu=1e0 # weight of the OT linear term +bias=True # estimate a bias + +ot_mapping=ot.da.OTDA_mapping_linear() +ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True) + +xst=ot_mapping.predict(xs) # use the estimated mapping +xst0=ot_mapping.interp() # use barycentric mapping + + +pl.figure(2,(10,7)) +pl.clf() +pl.subplot(2,2,1) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping') +pl.title("barycentric mapping") + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3) +pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Learned mapping") + + + +#%% Kernel mapping estimation + +eta=1e-5 # quadratic regularization for regression +mu=1e-1 # weight of the OT linear term +bias=True # estimate a bias +sigma=1 # sigma bandwidth fot gaussian kernel + + +ot_mapping_kernel=ot.da.OTDA_mapping_kernel() +ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True) + +xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping +xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping + + +#%% Plotting the mapped samples + +pl.figure(2,(10,7)) +pl.clf() +pl.subplot(2,2,1) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples') +pl.title("Bary. mapping (linear)") +pl.legend(loc=0) + +pl.subplot(2,2,2) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Estim. mapping (linear)") + +pl.subplot(2,2,3) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping') +pl.title("Bary. mapping (kernel)") + +pl.subplot(2,2,4) +pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2) +pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping') +pl.title("Estim. mapping (kernel)") + + + + + diff --git a/ot/da.py b/ot/da.py index ad2f8b5..4e5fda2 100644 --- a/ot/da.py +++ b/ot/da.py @@ -203,7 +203,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos if it>=numItermax: loop=0 - if abs(vloss[-1]-vloss[-2])=numItermax: loop=0 - if abs(vloss[-1]-vloss[-2])