diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-05-09 13:26:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-05-09 13:26:33 +0200 |
commit | 27032b6fa0f2f68af3fe4f90e5dcbb68f130a962 (patch) | |
tree | bef4b8a49cf410652e9c100274d010519acaf0f8 /ot/da.py | |
parent | 1ff35860db2d612748270299d7ce0037b8d40702 (diff) | |
parent | 0496e2b1b2c2f4ea2d7f313ccf58c612efaa70bf (diff) |
Merge pull request #42 from rflamary/linear_mapping
Linear mapping + tests
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 277 |
1 files changed, 275 insertions, 2 deletions
@@ -10,6 +10,7 @@ Domain adaptation with optimal transport # License: MIT License import numpy as np +import scipy.linalg as linalg from .bregman import sinkhorn from .lp import emd @@ -356,7 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, def loss(L, G): """Compute full loss""" - return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.sum(sel(L - I0)**2) + return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \ + np.sum(G * M) + eta * np.sum(sel(L - I0)**2) def solve_L(G): """ solve L problem with fixed G (least square)""" @@ -556,7 +558,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def loss(L, G): """Compute full loss""" - return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) + return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \ + np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" @@ -633,6 +636,110 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', return G, L +def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, + wt=None, bias=True, log=False): + """ return OT linear operator between samples + + The function estimates the optimal linear operator that aligns the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)` + and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark 2.29 in [15]. + + The linear operator from source to target :math:`M` + + .. math:: + M(x)=Ax+b + + where : + + .. math:: + A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \Sigma_s^{-1/2} + .. math:: + b=\mu_t-A\mu_s + + Parameters + ---------- + xs : np.ndarray (ns,d) + samples in the source domain + xt : np.ndarray (nt,d) + samples in the target domain + reg : float,optional + regularization added to the diagonals of convariances (>0) + ws : np.ndarray (ns,1), optional + weights for the source samples + wt : np.ndarray (ns,1), optional + weights for the target samples + bias: boolean, optional + estimate bias b else b=0 (default:True) + log : bool, optional + record log if True + + + Returns + ------- + A : (d x d) ndarray + Linear operator + b : (1 x d) ndarray + bias + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + + """ + + d = xs.shape[1] + + if bias: + mxs = xs.mean(0, keepdims=True) + mxt = xt.mean(0, keepdims=True) + + xs = xs - mxs + xt = xt - mxt + else: + mxs = np.zeros((1, d)) + mxt = np.zeros((1, d)) + + if ws is None: + ws = np.ones((xs.shape[0], 1)) / xs.shape[0] + + if wt is None: + wt = np.ones((xt.shape[0], 1)) / xt.shape[0] + + Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d) + Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d) + + Cs12 = linalg.sqrtm(Cs) + Cs_12 = linalg.inv(Cs12) + + M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) + + A = Cs_12.dot(M0.dot(Cs_12)) + + b = mxt - mxs.dot(A) + + if log: + log = {} + log['Cs'] = Cs + log['Ct'] = Ct + log['Cs12'] = Cs12 + log['Cs_12'] = Cs_12 + return A, b, log + else: + return A, b + + @deprecated("The class OTDA is deprecated in 0.3.1 and will be " "removed in 0.5" "\n\tfor standard transport use class EMDTransport instead.") @@ -1180,6 +1287,172 @@ class BaseTransport(BaseEstimator): return transp_Xt +class LinearTransport(BaseTransport): + """ OT linear operator between empirical distributions + + The function estimates the optimal linear operator that aligns the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)` + and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in + remark 2.29 in [15]. + + The linear operator from source to target :math:`M` + + .. math:: + M(x)=Ax+b + + where : + + .. math:: + A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \Sigma_s^{-1/2} + .. math:: + b=\mu_t-A\mu_s + + Parameters + ---------- + reg : float,optional + regularization added to the daigonals of convariances (>0) + bias: boolean, optional + estimate bias b else b=0 (default:True) + log : bool, optional + record log if True + + References + ---------- + + .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + """ + + def __init__(self, reg=1e-8, bias=True, log=False, + distribution_estimation=distribution_estimation_uniform): + + self.bias = bias + self.log = log + self.reg = reg + self.distribution_estimation = distribution_estimation + + def fit(self, Xs=None, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + self.mu_s = self.distribution_estimation(Xs) + self.mu_t = self.distribution_estimation(Xt) + + # coupling estimation + returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, + ws=self.mu_s.reshape((-1, 1)), + wt=self.mu_t.reshape((-1, 1)), + bias=self.bias, log=self.log) + + # deal with the value of log + if self.log: + self.A_, self.B_, self.log_ = returned_ + else: + self.A_, self.B_, = returned_ + self.log_ = dict() + + # re compute inverse mapping + self.A1_ = linalg.inv(self.A_) + self.B1_ = -self.B_.dot(self.A1_) + + return self + + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + """Transports source samples Xs onto target ones Xt + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xs : array-like, shape (n_source_samples, n_features) + The transport source samples. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs): + + transp_Xs = Xs.dot(self.A_) + self.B_ + + return transp_Xs + + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, + batch_size=128): + """Transports target samples Xt onto target samples Xs + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xt : array-like, shape (n_source_samples, n_features) + The transported target samples. + """ + + # check the necessary inputs parameters are here + if check_params(Xt=Xt): + + transp_Xt = Xt.dot(self.A1_) + self.B1_ + + return transp_Xt + + class SinkhornTransport(BaseTransport): """Domain Adapatation OT method based on Sinkhorn Algorithm |