diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-02-20 16:11:56 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-02-20 16:11:56 +0100 |
commit | 6d9b281271167d3676538f2ef8518abea82ef9c8 (patch) | |
tree | 7d1fae1d15a0ec70e229819a68b9f3a1ceea8f02 /ot/da.py | |
parent | 806a406e1ca2e9ca0bfdfe0516c75865e8098205 (diff) | |
parent | 5ff8030ce300f3d066e1edba2b36e60709b023b8 (diff) |
Merge branch 'master' of github.com:rflamary/POT
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 26 |
1 files changed, 16 insertions, 10 deletions
@@ -330,17 +330,17 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, if bias: xs1 = np.hstack((xs, np.ones((ns, 1)))) xstxs = xs1.T.dot(xs1) - I = np.eye(d + 1) - I[-1] = 0 - I0 = I[:, :-1] + Id = np.eye(d + 1) + Id[-1] = 0 + I0 = Id[:, :-1] def sel(x): return x[:-1, :] else: xs1 = xs xstxs = xs1.T.dot(xs1) - I = np.eye(d) - I0 = I + Id = np.eye(d) + I0 = Id def sel(x): return x @@ -361,7 +361,7 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, def solve_L(G): """ solve L problem with fixed G (least square)""" xst = ns * G.dot(xt) - return np.linalg.solve(xstxs + eta * I, xs1.T.dot(xst) + eta * I0) + return np.linalg.solve(xstxs + eta * Id, xs1.T.dot(xst) + eta * I0) def solve_G(L, G0): """Update G with CG algorithm""" @@ -520,8 +520,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', K = kernel(xs, xs, method=kerneltype, sigma=sigma) if bias: K1 = np.hstack((K, np.ones((ns, 1)))) - I = np.eye(ns + 1) - I[-1] = 0 + Id = np.eye(ns + 1) + Id[-1] = 0 Kp = np.eye(ns + 1) Kp[:ns, :ns] = K @@ -535,14 +535,14 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', else: K1 = K - I = np.eye(ns) + Id = np.eye(ns) # ls regul # K0 = K1.T.dot(K1)+eta*I # Kreg=I # proper kernel ridge - K0 = K + eta * I + K0 = K + eta * Id Kreg = K if log: @@ -933,6 +933,7 @@ def distribution_estimation_uniform(X): class BaseTransport(BaseEstimator): + """Base class for OTDA objects Notes @@ -1180,6 +1181,7 @@ class BaseTransport(BaseEstimator): class SinkhornTransport(BaseTransport): + """Domain Adapatation OT method based on Sinkhorn Algorithm Parameters @@ -1289,6 +1291,7 @@ class SinkhornTransport(BaseTransport): class EMDTransport(BaseTransport): + """Domain Adapatation OT method based on Earth Mover's Distance Parameters @@ -1377,6 +1380,7 @@ class EMDTransport(BaseTransport): class SinkhornLpl1Transport(BaseTransport): + """Domain Adapatation OT method based on sinkhorn algorithm + LpL1 class regularization. @@ -1486,6 +1490,7 @@ class SinkhornLpl1Transport(BaseTransport): class SinkhornL1l2Transport(BaseTransport): + """Domain Adapatation OT method based on sinkhorn algorithm + l1l2 class regularization. @@ -1608,6 +1613,7 @@ class SinkhornL1l2Transport(BaseTransport): class MappingTransport(BaseEstimator): + """MappingTransport: DA methods that aims at jointly estimating a optimal transport coupling and the associated mapping |