summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/ot/da.py b/ot/da.py
index 49fa79e..8bb3b97 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -229,11 +229,13 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
I=np.eye(ns+1)
I[-1]=0
K0 = K1.T.dot(K1)+eta*I
+ Kreg=I
sel=lambda x : x[:-1,:]
else:
K1=K
I=np.eye(ns)
K0=K+eta*I
+ Kreg=K
sel=lambda x : x
if log:
@@ -247,7 +249,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
def loss(L,G):
"""Compute full loss"""
- return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(K0).dot(L))
+ return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(Kreg).dot(L))
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""