diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 17:12:26 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2016-11-03 17:12:26 +0100 |
commit | 0ea30e9ca2398caeb853662012b09148a25cffff (patch) | |
tree | 783843a5c0cc0319e00d49528a4ec27e7d626a8a /ot/da.py | |
parent | 7e16b7a80f3a2896351262a02af27a60401b6a5e (diff) |
add mapping estimation with kernels (smaller bugs)
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 4 |
1 files changed, 3 insertions, 1 deletions
@@ -229,11 +229,13 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b I=np.eye(ns+1) I[-1]=0 K0 = K1.T.dot(K1)+eta*I + Kreg=I sel=lambda x : x[:-1,:] else: K1=K I=np.eye(ns) K0=K+eta*I + Kreg=K sel=lambda x : x if log: @@ -247,7 +249,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b def loss(L,G): """Compute full loss""" - return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(K0).dot(L)) + return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(Kreg).dot(L)) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" |