summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-03 16:28:46 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-03 16:28:46 +0100
commit86b1c88eb0c2c43853bca38de96d2278cc90ceba (patch)
tree3bf323b3cd422595ccb105725754ac89f0e36a5c /ot
parent566645ad184e1205f7f666ea2f19021254c33d74 (diff)
add mapping estimation with kernels (still debugging)
Diffstat (limited to 'ot')
-rw-r--r--ot/da.py35
1 files changed, 21 insertions, 14 deletions
diff --git a/ot/da.py b/ot/da.py
index 66680cd..e4aa0be 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -217,25 +217,23 @@ def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbos
return G,L
-def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
+def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 20,stopInnerThr=1e-9,stopThr=1e-6,log=False,**kwargs):
"""Joint Ot and mapping estimation (uniform weights and )
"""
ns,nt,d=xs.shape[0],xt.shape[0],xt.shape[1]
+ K=kernel(xs,xs,method=kerneltype,sigma=sigma)
if bias:
- K=
- xs1=np.hstack((xs,np.ones((ns,1))))
- xstxs=xs1.T.dot(xs1)
- I=np.eye(d+1)
+ K1=np.hstack((K,np.ones((ns,1))))
+ I=np.eye(ns+1)
I[-1]=0
- I0=I[:,:-1]
+ K0 = K1.T.dot(K1)+eta*I
sel=lambda x : x[:-1,:]
else:
- xs1=xs
- xstxs=xs1.T.dot(xs1)
- I=np.eye(d)
- I0=I
+ K1=K
+ I=np.eye(ns)
+ K0=K+eta*I
sel=lambda x : x
if log:
@@ -249,16 +247,21 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=
def loss(L,G):
"""Compute full loss"""
- return np.sum((xs1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L-I0)**2)
+ return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.sum(sel(L)**2)
- def solve_L(G):
+ def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""
xst=ns*G.dot(xt)
- return np.linalg.solve(xstxs+eta*I,xs1.T.dot(xst)+eta*I0)
+ return np.linalg.solve(K0,xst)
+
+ def solve_L_bias(G):
+ """ solve L problem with fixed G (least square)"""
+ xst=ns*G.dot(xt)
+ return np.linalg.solve(K0,K1.T.dot(xst))
def solve_G(L,G0):
"""Update G with CG algorithm"""
- xsi=xs1.dot(L)
+ xsi=K1.dot(L)
def f(G):
return np.sum((xsi-ns*G.dot(xt))**2)
def df(G):
@@ -266,6 +269,10 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kernel='gaussian',sigma=1,bias=
G=cg(a,b,M,1.0/mu,f,df,G0=G0,numItermax=numInnerItermax,stopThr=stopInnerThr)
return G
+ if bias:
+ solve_L=solve_L_bias
+ else:
+ solve_L=solve_L_nobias
L=solve_L(G)