diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 4 |
1 files changed, 3 insertions, 1 deletions
@@ -229,11 +229,13 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b I=np.eye(ns+1) I[-1]=0 K0 = K1.T.dot(K1)+eta*I + Kreg=I sel=lambda x : x[:-1,:] else: K1=K I=np.eye(ns) K0=K+eta*I + Kreg=K sel=lambda x : x if log: @@ -247,7 +249,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b def loss(L,G): """Compute full loss""" - return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(K0).dot(L)) + return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(Kreg).dot(L)) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" |