diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-04 10:10:44 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-04 10:10:44 +0100 |
commit | 405f352dc562eb17a2d6d7ca17c2ce14b19f2668 (patch) | |
tree | fff693f369f0e65eaa800a4b98a3c1da90db8f9e /ot/da.py | |
parent | 0ea30e9ca2398caeb853662012b09148a25cffff (diff) |
add mapping estimation with kernels works!
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 32 |
1 files changed, 24 insertions, 8 deletions
@@ -123,7 +123,7 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter return transp -def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs): +def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs): """Joint Ot and mapping estimation (uniform weights and ) """ @@ -209,7 +209,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos if verbose: if it%20==0: print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) - print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2]))) + print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],(vloss[-1]-vloss[-2])/abs(vloss[-2]))) if log: log['loss']=vloss return G,L,log @@ -217,7 +217,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos return G,L -def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs): +def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs): """Joint Ot and mapping estimation (uniform weights and ) """ @@ -228,15 +228,31 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b K1=np.hstack((K,np.ones((ns,1)))) I=np.eye(ns+1) I[-1]=0 - K0 = K1.T.dot(K1)+eta*I - Kreg=I - sel=lambda x : x[:-1,:] + Kp=np.eye(ns+1) + Kp[:ns,:ns]=K + + # ls regu + #K0 = K1.T.dot(K1)+eta*I + #Kreg=I + + # RKHS regul + K0 = K1.T.dot(K1)+eta*Kp + Kreg=Kp + else: K1=K I=np.eye(ns) + + # ls regul + #K0 = K1.T.dot(K1)+eta*I + #Kreg=I + + # proper kernel ridge K0=K+eta*I Kreg=K - sel=lambda x : x + + + if log: log={'err':[]} @@ -313,7 +329,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b if verbose: if it%20==0: print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32) - print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2]))) + print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],(vloss[-1]-vloss[-2])/abs(vloss[-2]))) if log: log['loss']=vloss return G,L,log |