From 8fc9fce6c920c646ea7324ac0af54ad53e9aa1bf Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Tue, 20 Mar 2018 16:21:47 +0100 Subject: add class LinearTransport --- examples/plot_otda_linear_mapping.py | 35 ++++- ot/da.py | 239 ++++++++++++++++++++++++++++++++++- 2 files changed, 265 insertions(+), 9 deletions(-) diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py index eff2648..44aa9c5 100644 --- a/examples/plot_otda_linear_mapping.py +++ b/examples/plot_otda_linear_mapping.py @@ -9,7 +9,7 @@ Created on Tue Mar 20 14:31:15 2018 import numpy as np import pylab as pl import ot - +import scipy.linalg as linalg #%% @@ -19,11 +19,13 @@ n=1000 d=2 sigma=.1 +# source samples angles=np.random.rand(n,1)*2*np.pi xs=np.concatenate((np.sin(angles),np.cos(angles)),axis=1)+sigma*np.random.randn(n,2) - xs[:n//2,1]+=2 + +# target samples anglet=np.random.rand(n,1)*2*np.pi xt=np.concatenate((np.sin(anglet),np.cos(anglet)),axis=1)+sigma*np.random.randn(n,2) xt[:n//2,1]+=2 @@ -43,7 +45,33 @@ pl.plot(xt[:,0],xt[:,1],'o') Ae,be=ot.da.OT_mapping_linear(xs,xt) +Ae1=linalg.inv(Ae) +be1=-be.dot(Ae1) + xst=xs.dot(Ae)+be +xts=xt.dot(Ae1)+be1 + +##%% + +pl.figure(1,(5,5)) +pl.clf() +pl.plot(xs[:,0],xs[:,1],'+') +pl.plot(xt[:,0],xt[:,1],'o') +pl.plot(xst[:,0],xst[:,1],'+') +pl.plot(xts[:,0],xts[:,1],'o') + +pl.show() + + +#%% Example class with on images + +mapping=ot.da.LinearTransport() + +mapping.fit(Xs=xs,Xt=xt) + + +xst=mapping.transform(Xs=xs) +xts=mapping.inverse_transform(Xt=xt) ##%% @@ -51,4 +79,5 @@ pl.figure(1,(5,5)) pl.clf() pl.plot(xs[:,0],xs[:,1],'+') pl.plot(xt[:,0],xt[:,1],'o') -pl.plot(xst[:,0],xst[:,1],'+') \ No newline at end of file +pl.plot(xst[:,0],xst[:,1],'+') +pl.plot(xts[:,0],xts[:,1],'o') diff --git a/ot/da.py b/ot/da.py index 63bee5a..ab5f860 100644 --- a/ot/da.py +++ b/ot/da.py @@ -634,13 +634,76 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', return G, L -def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False): - """ return OT linear operator between samples""" +def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False): + """ return OT linear operator between samples + + The function estimate the optimal linear operator that align the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` + and :math:`N(\mu_t,\Sigma_t)` as proposed in [14]. + + The linear operator from source to target :math:`M` + + .. math:: + M(x)=Ax+b + + where : + + .. math:: + A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \Sigma_s^{-1/2} + .. math:: + b=\mu_t-A\mu_s + + Parameters + ---------- + xs : np.ndarray (ns,d) + samples in the source domain + xt : np.ndarray (nt,d) + samples in the target domain + reg : float,optional + regularization added to the daigonals of convariances (>0) + ws : np.ndarray (ns,1), optional + weights for the source samples + wt : np.ndarray (ns,1), optional + weights for the target samples + bias: boolean, optional + estimate bias b else b=0 (default:True) + log : bool, optional + record log if True + + + Returns + ------- + A : (d x d) ndarray + Linear operator + b : (1 x d) ndarray + bias + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + + """ d=xs.shape[1] - mxs=xs.mean(0,keepdims=True) - mxt=xt.mean(0,keepdims=True) + if bias: + mxs=xs.mean(0,keepdims=True) + mxt=xt.mean(0,keepdims=True) + + xs=xs-mxs + xt=xt-mxt + else: + mxs=np.zeros((1,d)) + mxt=np.zeros((1,d)) if ws is None: @@ -658,12 +721,17 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False): M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) - A=Cs_12.dot(M0.dot(Cs_12)).T + A=Cs_12.dot(M0.dot(Cs_12)) b=mxt-mxs.dot(A) if log: - pass + log={} + log['Cs']=Cs + log['Ct']=Ct + log['Cs12']=Cs12 + log['Cs_12']=Cs_12 + return A,b,log else: return A,b @@ -1216,6 +1284,165 @@ class BaseTransport(BaseEstimator): return transp_Xt +class LinearTransport(BaseTransport): + """ OT linear operator between empirical distributions + + The function estimate the optimal linear operator that align the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` + and :math:`N(\mu_t,\Sigma_t)` as proposed in [14]. + + The linear operator from source to target :math:`M` + + .. math:: + M(x)=Ax+b + + where : + + .. math:: + A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \Sigma_s^{-1/2} + .. math:: + b=\mu_t-A\mu_s + + Parameters + ---------- + reg : float,optional + regularization added to the daigonals of convariances (>0) + bias: boolean, optional + estimate bias b else b=0 (default:True) + log : bool, optional + record log if True + + + """ + + def __init__(self, reg=1e-8,bias=True,log=False, + distribution_estimation=distribution_estimation_uniform): + + self.bias=bias + self.log=log + self.reg=reg + self.distribution_estimation=distribution_estimation + + def fit(self, Xs=None, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + self.mu_s = self.distribution_estimation(Xs) + self.mu_t = self.distribution_estimation(Xt) + + + + # coupling estimation + returned_ = OT_mapping_linear(Xs,Xt,reg=self.reg, + ws=self.mu_s.reshape((-1,1)), + wt=self.mu_t.reshape((-1,1)), + bias=self.bias,log=self.log) + + # deal with the value of log + if self.log: + self.A_, self.B_, self.log_ = returned_ + else: + self.A_, self.B_, = returned_ + self.log_ = dict() + + # re compute inverse mapping + self.A1_=linalg.inv(self.A_) + self.B1_=-self.B_.dot(self.A1_) + + return self + + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + """Transports source samples Xs onto target ones Xt + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xs : array-like, shape (n_source_samples, n_features) + The transport source samples. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs): + + transp_Xs= Xs.dot(self.A_)+self.B_ + + return transp_Xs + + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, + batch_size=128): + """Transports target samples Xt onto target samples Xs + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xt : array-like, shape (n_source_samples, n_features) + The transported target samples. + """ + + # check the necessary inputs parameters are here + if check_params(Xt=Xt): + + transp_Xt= Xt.dot(self.A1_)+self.B1_ + + return transp_Xt + + + class SinkhornTransport(BaseTransport): """Domain Adapatation OT method based on Sinkhorn Algorithm -- cgit v1.2.3