summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-04 10:10:44 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-04 10:10:44 +0100
commit405f352dc562eb17a2d6d7ca17c2ce14b19f2668 (patch)
treefff693f369f0e65eaa800a4b98a3c1da90db8f9e /ot/da.py
parent0ea30e9ca2398caeb853662012b09148a25cffff (diff)
add mapping estimation with kernels works!
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py32
1 files changed, 24 insertions, 8 deletions
diff --git a/ot/da.py b/ot/da.py
index 8bb3b97..ad2f8b5 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -123,7 +123,7 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
return transp
-def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
+def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
"""Joint Ot and mapping estimation (uniform weights and )
"""
@@ -209,7 +209,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
if verbose:
if it%20==0:
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
- print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
+ print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],(vloss[-1]-vloss[-2])/abs(vloss[-2])))
if log:
log['loss']=vloss
return G,L,log
@@ -217,7 +217,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
return G,L
-def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
+def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
"""Joint Ot and mapping estimation (uniform weights and )
"""
@@ -228,15 +228,31 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
K1=np.hstack((K,np.ones((ns,1))))
I=np.eye(ns+1)
I[-1]=0
- K0 = K1.T.dot(K1)+eta*I
- Kreg=I
- sel=lambda x : x[:-1,:]
+ Kp=np.eye(ns+1)
+ Kp[:ns,:ns]=K
+
+ # ls regu
+ #K0 = K1.T.dot(K1)+eta*I
+ #Kreg=I
+
+ # RKHS regul
+ K0 = K1.T.dot(K1)+eta*Kp
+ Kreg=Kp
+
else:
K1=K
I=np.eye(ns)
+
+ # ls regul
+ #K0 = K1.T.dot(K1)+eta*I
+ #Kreg=I
+
+ # proper kernel ridge
K0=K+eta*I
Kreg=K
- sel=lambda x : x
+
+
+
if log:
log={'err':[]}
@@ -313,7 +329,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
if verbose:
if it%20==0:
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
- print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
+ print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],(vloss[-1]-vloss[-2])/abs(vloss[-2])))
if log:
log['loss']=vloss
return G,L,log