summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-11-03 17:12:26 +0100
committerRémi Flamary <remi.flamary@gmail.com>2016-11-03 17:12:26 +0100
commit0ea30e9ca2398caeb853662012b09148a25cffff (patch)
tree783843a5c0cc0319e00d49528a4ec27e7d626a8a /ot/da.py
parent7e16b7a80f3a2896351262a02af27a60401b6a5e (diff)
add mapping estimation with kernels (smaller bugs)
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/ot/da.py b/ot/da.py
index 49fa79e..8bb3b97 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -229,11 +229,13 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
I=np.eye(ns+1)
I[-1]=0
K0 = K1.T.dot(K1)+eta*I
+ Kreg=I
sel=lambda x : x[:-1,:]
else:
K1=K
I=np.eye(ns)
K0=K+eta*I
+ Kreg=K
sel=lambda x : x
if log:
@@ -247,7 +249,7 @@ def joint_OT_mapping_kernel(xs,xt,mu=1,eta=0.001,kerneltype='gaussian',sigma=1,b
def loss(L,G):
"""Compute full loss"""
- return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(K0).dot(L))
+ return np.sum((K1.dot(L)-ns*G.dot(xt))**2)+mu*np.sum(G*M)+eta*np.trace(L.T.dot(Kreg).dot(L))
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""