summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-03-20 16:21:47 +0100
committerRémi Flamary <remi.flamary@gmail.com>2018-03-20 16:21:47 +0100
commit8fc9fce6c920c646ea7324ac0af54ad53e9aa1bf (patch)
treeec05a571277c01cbfcb62162a06bf34b098ddb16
parentcb739f625921e7fc19113d6d758e27ac69eac24b (diff)
add class LinearTransport
-rw-r--r--examples/plot_otda_linear_mapping.py35
-rw-r--r--ot/da.py239
2 files changed, 265 insertions, 9 deletions
diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py
index eff2648..44aa9c5 100644
--- a/examples/plot_otda_linear_mapping.py
+++ b/examples/plot_otda_linear_mapping.py
@@ -9,7 +9,7 @@ Created on Tue Mar 20 14:31:15 2018
import numpy as np
import pylab as pl
import ot
-
+import scipy.linalg as linalg
#%%
@@ -19,11 +19,13 @@ n=1000
d=2
sigma=.1
+# source samples
angles=np.random.rand(n,1)*2*np.pi
xs=np.concatenate((np.sin(angles),np.cos(angles)),axis=1)+sigma*np.random.randn(n,2)
-
xs[:n//2,1]+=2
+
+# target samples
anglet=np.random.rand(n,1)*2*np.pi
xt=np.concatenate((np.sin(anglet),np.cos(anglet)),axis=1)+sigma*np.random.randn(n,2)
xt[:n//2,1]+=2
@@ -43,7 +45,33 @@ pl.plot(xt[:,0],xt[:,1],'o')
Ae,be=ot.da.OT_mapping_linear(xs,xt)
+Ae1=linalg.inv(Ae)
+be1=-be.dot(Ae1)
+
xst=xs.dot(Ae)+be
+xts=xt.dot(Ae1)+be1
+
+##%%
+
+pl.figure(1,(5,5))
+pl.clf()
+pl.plot(xs[:,0],xs[:,1],'+')
+pl.plot(xt[:,0],xt[:,1],'o')
+pl.plot(xst[:,0],xst[:,1],'+')
+pl.plot(xts[:,0],xts[:,1],'o')
+
+pl.show()
+
+
+#%% Example class with on images
+
+mapping=ot.da.LinearTransport()
+
+mapping.fit(Xs=xs,Xt=xt)
+
+
+xst=mapping.transform(Xs=xs)
+xts=mapping.inverse_transform(Xt=xt)
##%%
@@ -51,4 +79,5 @@ pl.figure(1,(5,5))
pl.clf()
pl.plot(xs[:,0],xs[:,1],'+')
pl.plot(xt[:,0],xt[:,1],'o')
-pl.plot(xst[:,0],xst[:,1],'+') \ No newline at end of file
+pl.plot(xst[:,0],xst[:,1],'+')
+pl.plot(xts[:,0],xts[:,1],'o')
diff --git a/ot/da.py b/ot/da.py
index 63bee5a..ab5f860 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -634,13 +634,76 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
-def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False):
- """ return OT linear operator between samples"""
+def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False):
+ """ return OT linear operator between samples
+
+ The function estimate the optimal linear operator that align the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
+ and :math:`N(\mu_t,\Sigma_t)` as proposed in [14].
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(x)=Ax+b
+
+ where :
+
+ .. math::
+ A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \Sigma_s^{-1/2}
+ .. math::
+ b=\mu_t-A\mu_s
+
+ Parameters
+ ----------
+ xs : np.ndarray (ns,d)
+ samples in the source domain
+ xt : np.ndarray (nt,d)
+ samples in the target domain
+ reg : float,optional
+ regularization added to the daigonals of convariances (>0)
+ ws : np.ndarray (ns,1), optional
+ weights for the source samples
+ wt : np.ndarray (ns,1), optional
+ weights for the target samples
+ bias: boolean, optional
+ estimate bias b else b=0 (default:True)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ A : (d x d) ndarray
+ Linear operator
+ b : (1 x d) ndarray
+ bias
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+
+ """
d=xs.shape[1]
- mxs=xs.mean(0,keepdims=True)
- mxt=xt.mean(0,keepdims=True)
+ if bias:
+ mxs=xs.mean(0,keepdims=True)
+ mxt=xt.mean(0,keepdims=True)
+
+ xs=xs-mxs
+ xt=xt-mxt
+ else:
+ mxs=np.zeros((1,d))
+ mxt=np.zeros((1,d))
if ws is None:
@@ -658,12 +721,17 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False):
M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
- A=Cs_12.dot(M0.dot(Cs_12)).T
+ A=Cs_12.dot(M0.dot(Cs_12))
b=mxt-mxs.dot(A)
if log:
- pass
+ log={}
+ log['Cs']=Cs
+ log['Ct']=Ct
+ log['Cs12']=Cs12
+ log['Cs_12']=Cs_12
+ return A,b,log
else:
return A,b
@@ -1216,6 +1284,165 @@ class BaseTransport(BaseEstimator):
return transp_Xt
+class LinearTransport(BaseTransport):
+ """ OT linear operator between empirical distributions
+
+ The function estimate the optimal linear operator that align the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)`
+ and :math:`N(\mu_t,\Sigma_t)` as proposed in [14].
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(x)=Ax+b
+
+ where :
+
+ .. math::
+ A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \Sigma_s^{-1/2}
+ .. math::
+ b=\mu_t-A\mu_s
+
+ Parameters
+ ----------
+ reg : float,optional
+ regularization added to the daigonals of convariances (>0)
+ bias: boolean, optional
+ estimate bias b else b=0 (default:True)
+ log : bool, optional
+ record log if True
+
+
+ """
+
+ def __init__(self, reg=1e-8,bias=True,log=False,
+ distribution_estimation=distribution_estimation_uniform):
+
+ self.bias=bias
+ self.log=log
+ self.reg=reg
+ self.distribution_estimation=distribution_estimation
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ self.mu_s = self.distribution_estimation(Xs)
+ self.mu_t = self.distribution_estimation(Xt)
+
+
+
+ # coupling estimation
+ returned_ = OT_mapping_linear(Xs,Xt,reg=self.reg,
+ ws=self.mu_s.reshape((-1,1)),
+ wt=self.mu_t.reshape((-1,1)),
+ bias=self.bias,log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.A_, self.B_, self.log_ = returned_
+ else:
+ self.A_, self.B_, = returned_
+ self.log_ = dict()
+
+ # re compute inverse mapping
+ self.A1_=linalg.inv(self.A_)
+ self.B1_=-self.B_.dot(self.A1_)
+
+ return self
+
+ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
+ """Transports source samples Xs onto target ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xs : array-like, shape (n_source_samples, n_features)
+ The transport source samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ transp_Xs= Xs.dot(self.A_)+self.B_
+
+ return transp_Xs
+
+ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
+ batch_size=128):
+ """Transports target samples Xt onto target samples Xs
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xt : array-like, shape (n_source_samples, n_features)
+ The transported target samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xt=Xt):
+
+ transp_Xt= Xt.dot(self.A1_)+self.B1_
+
+ return transp_Xt
+
+
+
class SinkhornTransport(BaseTransport):
"""Domain Adapatation OT method based on Sinkhorn Algorithm