summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-20 16:21:47 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-20 16:21:47 +0100
commit8fc9fce6c920c646ea7324ac0af54ad53e9aa1bf (patch)
treeec05a571277c01cbfcb62162a06bf34b098ddb16 /ot/da.py
parentcb739f625921e7fc19113d6d758e27ac69eac24b (diff)
add class LinearTransport
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py239
1 files changed, 233 insertions, 6 deletions
diff --git a/ot/da.py b/ot/da.py
index 63bee5a..ab5f860 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -634,13 +634,76 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
-def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False):
- """ return OT linear operator between samples"""
+def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False):
+ """ return OT linear operator between samples
+
+ The function estimate the optimal linear operator that align the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
+ and :math:`N(\mu_t,\Sigma_t)` as proposed in [14].
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(x)=Ax+b
+
+ where :
+
+ .. math::
+ A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \Sigma_s^{-1/2}
+ .. math::
+ b=\mu_t-A\mu_s
+
+ Parameters
+ ----------
+ xs : np.ndarray (ns,d)
+ samples in the source domain
+ xt : np.ndarray (nt,d)
+ samples in the target domain
+ reg : float,optional
+ regularization added to the daigonals of convariances (>0)
+ ws : np.ndarray (ns,1), optional
+ weights for the source samples
+ wt : np.ndarray (ns,1), optional
+ weights for the target samples
+ bias: boolean, optional
+ estimate bias b else b=0 (default:True)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ A : (d x d) ndarray
+ Linear operator
+ b : (1 x d) ndarray
+ bias
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+
+ """
d=xs.shape[1]
- mxs=xs.mean(0,keepdims=True)
- mxt=xt.mean(0,keepdims=True)
+ if bias:
+ mxs=xs.mean(0,keepdims=True)
+ mxt=xt.mean(0,keepdims=True)
+
+ xs=xs-mxs
+ xt=xt-mxt
+ else:
+ mxs=np.zeros((1,d))
+ mxt=np.zeros((1,d))
if ws is None:
@@ -658,12 +721,17 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False):
M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
- A=Cs_12.dot(M0.dot(Cs_12)).T
+ A=Cs_12.dot(M0.dot(Cs_12))
b=mxt-mxs.dot(A)
if log:
- pass
+ log={}
+ log['Cs']=Cs
+ log['Ct']=Ct
+ log['Cs12']=Cs12
+ log['Cs_12']=Cs_12
+ return A,b,log
else:
return A,b
@@ -1216,6 +1284,165 @@ class BaseTransport(BaseEstimator):
return transp_Xt
+class LinearTransport(BaseTransport):
+ """ OT linear operator between empirical distributions
+
+ The function estimate the optimal linear operator that align the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
+ and :math:`N(\mu_t,\Sigma_t)` as proposed in [14].
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(x)=Ax+b
+
+ where :
+
+ .. math::
+ A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \Sigma_s^{-1/2}
+ .. math::
+ b=\mu_t-A\mu_s
+
+ Parameters
+ ----------
+ reg : float,optional
+ regularization added to the daigonals of convariances (>0)
+ bias: boolean, optional
+ estimate bias b else b=0 (default:True)
+ log : bool, optional
+ record log if True
+
+
+ """
+
+ def __init__(self, reg=1e-8,bias=True,log=False,
+ distribution_estimation=distribution_estimation_uniform):
+
+ self.bias=bias
+ self.log=log
+ self.reg=reg
+ self.distribution_estimation=distribution_estimation
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ self.mu_s = self.distribution_estimation(Xs)
+ self.mu_t = self.distribution_estimation(Xt)
+
+
+
+ # coupling estimation
+ returned_ = OT_mapping_linear(Xs,Xt,reg=self.reg,
+ ws=self.mu_s.reshape((-1,1)),
+ wt=self.mu_t.reshape((-1,1)),
+ bias=self.bias,log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.A_, self.B_, self.log_ = returned_
+ else:
+ self.A_, self.B_, = returned_
+ self.log_ = dict()
+
+ # re compute inverse mapping
+ self.A1_=linalg.inv(self.A_)
+ self.B1_=-self.B_.dot(self.A1_)
+
+ return self
+
+ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
+ """Transports source samples Xs onto target ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xs : array-like, shape (n_source_samples, n_features)
+ The transport source samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ transp_Xs= Xs.dot(self.A_)+self.B_
+
+ return transp_Xs
+
+ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
+ batch_size=128):
+ """Transports target samples Xt onto target samples Xs
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xt : array-like, shape (n_source_samples, n_features)
+ The transported target samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xt=Xt):
+
+ transp_Xt= Xt.dot(self.A1_)+self.B1_
+
+ return transp_Xt
+
+
+
class SinkhornTransport(BaseTransport):
"""Domain Adapatation OT method based on Sinkhorn Algorithm