summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-04 10:10:44 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-04 10:10:44 +0100
commit405f352dc562eb17a2d6d7ca17c2ce14b19f2668 (patch)
treefff693f369f0e65eaa800a4b98a3c1da90db8f9e /ot
parent0ea30e9ca2398caeb853662012b09148a25cffff (diff)
add mapping estimation with kernels works!
Diffstat (limited to 'ot')
-rw-r--r--ot/da.py32
-rw-r--r--ot/optim.py2
-rw-r--r--ot/utils.py2
3 files changed, 27 insertions, 9 deletions
diff --git a/ot/da.py b/ot/da.py
index 8bb3b97..ad2f8b5 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -123,7 +123,7 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
return transp
-def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
+def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
"""Joint Ot and mapping estimation (uniform weights and )
"""
@@ -209,7 +209,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
if verbose:
if it%20==0:
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
- print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
+ print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],(vloss[-1]-vloss[-2])/abs(vloss[-2])))
if log:
log['loss']=vloss
return G,L,log
@@ -217,7 +217,7 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
return G,L
-def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
+def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
"""Joint Ot and mapping estimation (uniform weights and )
"""
@@ -228,15 +228,31 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
K1=np.hstack((K,np.ones((ns,1))))
I=np.eye(ns+1)
I[-1]=0
- K0 = K1.T.dot(K1)+eta*I
- Kreg=I
- sel=lambda x : x[:-1,:]
+ Kp=np.eye(ns+1)
+ Kp[:ns,:ns]=K
+
+ # ls regu
+ #K0 = K1.T.dot(K1)+eta*I
+ #Kreg=I
+
+ # RKHS regul
+ K0 = K1.T.dot(K1)+eta*Kp
+ Kreg=Kp
+
else:
K1=K
I=np.eye(ns)
+
+ # ls regul
+ #K0 = K1.T.dot(K1)+eta*I
+ #Kreg=I
+
+ # proper kernel ridge
K0=K+eta*I
Kreg=K
- sel=lambda x : x
+
+
+
if log:
log={'err':[]}
@@ -313,7 +329,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
if verbose:
if it%20==0:
print('{:5s}|{:12s}|{:8s}'.format('It.','Loss','Delta loss')+'\n'+'-'*32)
- print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],abs(vloss[-1]-vloss[-2])/abs(vloss[-2])))
+ print('{:5d}|{:8e}|{:8e}'.format(it,vloss[-1],(vloss[-1]-vloss[-2])/abs(vloss[-2])))
if log:
log['loss']=vloss
return G,L,log
diff --git a/ot/optim.py b/ot/optim.py
index dcefd24..2b8f565 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -159,6 +159,8 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
# problem linearization
Mi=M+reg*df(G)
+ # set M positive
+ Mi+=Mi.min()
# solve linear program
Gc=emd(a,b,Mi)
diff --git a/ot/utils.py b/ot/utils.py
index 47fe77f..d3df8fa 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -9,7 +9,7 @@ from scipy.spatial.distance import cdist
def kernel(x1,x2,method='gaussian',sigma=1,**kwargs):
"""Compute kernel matrix"""
if method.lower() in ['gaussian','gauss','rbf']:
- K=np.exp(dist(x1,x2)/(2*sigma**2))
+ K=np.exp(-dist(x1,x2)/(2*sigma**2))
return K
def unif(n):