summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-09 13:26:33 +0200
committerGitHub <noreply@github.com>2018-05-09 13:26:33 +0200
commit27032b6fa0f2f68af3fe4f90e5dcbb68f130a962 (patch)
treebef4b8a49cf410652e9c100274d010519acaf0f8 /ot/da.py
parent1ff35860db2d612748270299d7ce0037b8d40702 (diff)
parent0496e2b1b2c2f4ea2d7f313ccf58c612efaa70bf (diff)
Merge pull request #42 from rflamary/linear_mapping
Linear mapping + tests
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py277
1 files changed, 275 insertions, 2 deletions
diff --git a/ot/da.py b/ot/da.py
index 532dcd2..48b418f 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -10,6 +10,7 @@ Domain adaptation with optimal transport
# License: MIT License
import numpy as np
+import scipy.linalg as linalg
from .bregman import sinkhorn
from .lp import emd
@@ -356,7 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
def loss(L, G):
"""Compute full loss"""
- return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.sum(sel(L - I0)**2)
+ return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \
+ np.sum(G * M) + eta * np.sum(sel(L - I0)**2)
def solve_L(G):
""" solve L problem with fixed G (least square)"""
@@ -556,7 +558,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
def loss(L, G):
"""Compute full loss"""
- return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
+ return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \
+ np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""
@@ -633,6 +636,110 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
+def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
+ wt=None, bias=True, log=False):
+ """ return OT linear operator between samples
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)`
+ and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark 2.29 in [15].
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(x)=Ax+b
+
+ where :
+
+ .. math::
+ A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \Sigma_s^{-1/2}
+ .. math::
+ b=\mu_t-A\mu_s
+
+ Parameters
+ ----------
+ xs : np.ndarray (ns,d)
+ samples in the source domain
+ xt : np.ndarray (nt,d)
+ samples in the target domain
+ reg : float,optional
+ regularization added to the diagonals of convariances (>0)
+ ws : np.ndarray (ns,1), optional
+ weights for the source samples
+ wt : np.ndarray (ns,1), optional
+ weights for the target samples
+ bias: boolean, optional
+ estimate bias b else b=0 (default:True)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ A : (d x d) ndarray
+ Linear operator
+ b : (1 x d) ndarray
+ bias
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+
+ """
+
+ d = xs.shape[1]
+
+ if bias:
+ mxs = xs.mean(0, keepdims=True)
+ mxt = xt.mean(0, keepdims=True)
+
+ xs = xs - mxs
+ xt = xt - mxt
+ else:
+ mxs = np.zeros((1, d))
+ mxt = np.zeros((1, d))
+
+ if ws is None:
+ ws = np.ones((xs.shape[0], 1)) / xs.shape[0]
+
+ if wt is None:
+ wt = np.ones((xt.shape[0], 1)) / xt.shape[0]
+
+ Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d)
+ Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d)
+
+ Cs12 = linalg.sqrtm(Cs)
+ Cs_12 = linalg.inv(Cs12)
+
+ M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
+
+ A = Cs_12.dot(M0.dot(Cs_12))
+
+ b = mxt - mxs.dot(A)
+
+ if log:
+ log = {}
+ log['Cs'] = Cs
+ log['Ct'] = Ct
+ log['Cs12'] = Cs12
+ log['Cs_12'] = Cs_12
+ return A, b, log
+ else:
+ return A, b
+
+
@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
"removed in 0.5"
"\n\tfor standard transport use class EMDTransport instead.")
@@ -1180,6 +1287,172 @@ class BaseTransport(BaseEstimator):
return transp_Xt
+class LinearTransport(BaseTransport):
+ """ OT linear operator between empirical distributions
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)`
+ and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in
+ remark 2.29 in [15].
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(x)=Ax+b
+
+ where :
+
+ .. math::
+ A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \Sigma_s^{-1/2}
+ .. math::
+ b=\mu_t-A\mu_s
+
+ Parameters
+ ----------
+ reg : float,optional
+ regularization added to the daigonals of convariances (>0)
+ bias: boolean, optional
+ estimate bias b else b=0 (default:True)
+ log : bool, optional
+ record log if True
+
+ References
+ ----------
+
+ .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ """
+
+ def __init__(self, reg=1e-8, bias=True, log=False,
+ distribution_estimation=distribution_estimation_uniform):
+
+ self.bias = bias
+ self.log = log
+ self.reg = reg
+ self.distribution_estimation = distribution_estimation
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ self.mu_s = self.distribution_estimation(Xs)
+ self.mu_t = self.distribution_estimation(Xt)
+
+ # coupling estimation
+ returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
+ ws=self.mu_s.reshape((-1, 1)),
+ wt=self.mu_t.reshape((-1, 1)),
+ bias=self.bias, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.A_, self.B_, self.log_ = returned_
+ else:
+ self.A_, self.B_, = returned_
+ self.log_ = dict()
+
+ # re compute inverse mapping
+ self.A1_ = linalg.inv(self.A_)
+ self.B1_ = -self.B_.dot(self.A1_)
+
+ return self
+
+ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
+ """Transports source samples Xs onto target ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xs : array-like, shape (n_source_samples, n_features)
+ The transport source samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ transp_Xs = Xs.dot(self.A_) + self.B_
+
+ return transp_Xs
+
+ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
+ batch_size=128):
+ """Transports target samples Xt onto target samples Xs
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xt : array-like, shape (n_source_samples, n_features)
+ The transported target samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xt=Xt):
+
+ transp_Xt = Xt.dot(self.A1_) + self.B1_
+
+ return transp_Xt
+
+
class SinkhornTransport(BaseTransport):
"""Domain Adapatation OT method based on Sinkhorn Algorithm