summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-02-20 16:11:56 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-02-20 16:11:56 +0100
commit6d9b281271167d3676538f2ef8518abea82ef9c8 (patch)
tree7d1fae1d15a0ec70e229819a68b9f3a1ceea8f02 /ot/da.py
parent806a406e1ca2e9ca0bfdfe0516c75865e8098205 (diff)
parent5ff8030ce300f3d066e1edba2b36e60709b023b8 (diff)
Merge branch 'master' of github.com:rflamary/POT
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py26
1 files changed, 16 insertions, 10 deletions
diff --git a/ot/da.py b/ot/da.py
index 1d3d0ba..c688654 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -330,17 +330,17 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
if bias:
xs1 = np.hstack((xs, np.ones((ns, 1))))
xstxs = xs1.T.dot(xs1)
- I = np.eye(d + 1)
- I[-1] = 0
- I0 = I[:, :-1]
+ Id = np.eye(d + 1)
+ Id[-1] = 0
+ I0 = Id[:, :-1]
def sel(x):
return x[:-1, :]
else:
xs1 = xs
xstxs = xs1.T.dot(xs1)
- I = np.eye(d)
- I0 = I
+ Id = np.eye(d)
+ I0 = Id
def sel(x):
return x
@@ -361,7 +361,7 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
def solve_L(G):
""" solve L problem with fixed G (least square)"""
xst = ns * G.dot(xt)
- return np.linalg.solve(xstxs + eta * I, xs1.T.dot(xst) + eta * I0)
+ return np.linalg.solve(xstxs + eta * Id, xs1.T.dot(xst) + eta * I0)
def solve_G(L, G0):
"""Update G with CG algorithm"""
@@ -520,8 +520,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
K = kernel(xs, xs, method=kerneltype, sigma=sigma)
if bias:
K1 = np.hstack((K, np.ones((ns, 1))))
- I = np.eye(ns + 1)
- I[-1] = 0
+ Id = np.eye(ns + 1)
+ Id[-1] = 0
Kp = np.eye(ns + 1)
Kp[:ns, :ns] = K
@@ -535,14 +535,14 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
else:
K1 = K
- I = np.eye(ns)
+ Id = np.eye(ns)
# ls regul
# K0 = K1.T.dot(K1)+eta*I
# Kreg=I
# proper kernel ridge
- K0 = K + eta * I
+ K0 = K + eta * Id
Kreg = K
if log:
@@ -933,6 +933,7 @@ def distribution_estimation_uniform(X):
class BaseTransport(BaseEstimator):
+
"""Base class for OTDA objects
Notes
@@ -1180,6 +1181,7 @@ class BaseTransport(BaseEstimator):
class SinkhornTransport(BaseTransport):
+
"""Domain Adapatation OT method based on Sinkhorn Algorithm
Parameters
@@ -1289,6 +1291,7 @@ class SinkhornTransport(BaseTransport):
class EMDTransport(BaseTransport):
+
"""Domain Adapatation OT method based on Earth Mover's Distance
Parameters
@@ -1377,6 +1380,7 @@ class EMDTransport(BaseTransport):
class SinkhornLpl1Transport(BaseTransport):
+
"""Domain Adapatation OT method based on sinkhorn algorithm +
LpL1 class regularization.
@@ -1486,6 +1490,7 @@ class SinkhornLpl1Transport(BaseTransport):
class SinkhornL1l2Transport(BaseTransport):
+
"""Domain Adapatation OT method based on sinkhorn algorithm +
l1l2 class regularization.
@@ -1608,6 +1613,7 @@ class SinkhornL1l2Transport(BaseTransport):
class MappingTransport(BaseEstimator):
+
"""MappingTransport: DA methods that aims at jointly estimating a optimal
transport coupling and the associated mapping