diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-20 16:27:49 +0100 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2018-03-20 16:27:49 +0100 |
commit | c1046238d826fe9cf1294f8ea60b8d44743fac78 (patch) | |
tree | 505a707c8458c74901952fd84d9054f7922a86c6 /ot/da.py | |
parent | 8fc9fce6c920c646ea7324ac0af54ad53e9aa1bf (diff) |
passing tests
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 147 |
1 files changed, 72 insertions, 75 deletions
@@ -357,7 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, def loss(L, G): """Compute full loss""" - return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.sum(sel(L - I0)**2) + return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \ + np.sum(G * M) + eta * np.sum(sel(L - I0)**2) def solve_L(G): """ solve L problem with fixed G (least square)""" @@ -557,7 +558,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def loss(L, G): """Compute full loss""" - return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) + return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \ + np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" @@ -634,25 +636,26 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', return G, L -def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False): +def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, + wt=None, bias=True, log=False): """ return OT linear operator between samples The function estimate the optimal linear operator that align the two - empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` and :math:`N(\mu_t,\Sigma_t)` as proposed in [14]. - + The linear operator from source to target :math:`M` .. math:: M(x)=Ax+b - + where : - + .. math:: A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} \Sigma_s^{-1/2} - .. math:: + .. math:: b=\mu_t-A\mu_s Parameters @@ -666,7 +669,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False): ws : np.ndarray (ns,1), optional weights for the source samples wt : np.ndarray (ns,1), optional - weights for the target samples + weights for the target samples bias: boolean, optional estimate bias b else b=0 (default:True) log : bool, optional @@ -686,55 +689,52 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False): References ---------- - .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of + .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of distributions", Journal of Optimization Theory and Applications Vol 43, 1984 """ - d=xs.shape[1] - + d = xs.shape[1] + if bias: - mxs=xs.mean(0,keepdims=True) - mxt=xt.mean(0,keepdims=True) - - xs=xs-mxs - xt=xt-mxt + mxs = xs.mean(0, keepdims=True) + mxt = xt.mean(0, keepdims=True) + + xs = xs - mxs + xt = xt - mxt else: - mxs=np.zeros((1,d)) - mxt=np.zeros((1,d)) + mxs = np.zeros((1, d)) + mxt = np.zeros((1, d)) - if ws is None: - ws=np.ones((xs.shape[0],1))/xs.shape[0] - + ws = np.ones((xs.shape[0], 1)) / xs.shape[0] + if wt is None: - wt=np.ones((xt.shape[0],1))/xt.shape[0] - - Cs=(xs*ws).T.dot(xs)/ws.sum()+reg*np.eye(d) - Ct=(xt*wt).T.dot(xt)/wt.sum()+reg*np.eye(d) - - - Cs12=linalg.sqrtm(Cs) - Cs_12=linalg.inv(Cs12) - - M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) - - A=Cs_12.dot(M0.dot(Cs_12)) - - b=mxt-mxs.dot(A) - + wt = np.ones((xt.shape[0], 1)) / xt.shape[0] + + Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d) + Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d) + + Cs12 = linalg.sqrtm(Cs) + Cs_12 = linalg.inv(Cs12) + + M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) + + A = Cs_12.dot(M0.dot(Cs_12)) + + b = mxt - mxs.dot(A) + if log: - log={} - log['Cs']=Cs - log['Ct']=Ct - log['Cs12']=Cs12 - log['Cs_12']=Cs_12 - return A,b,log + log = {} + log['Cs'] = Cs + log['Ct'] = Ct + log['Cs12'] = Cs12 + log['Cs_12'] = Cs_12 + return A, b, log else: - return A,b - + return A, b @deprecated("The class OTDA is deprecated in 0.3.1 and will be " @@ -1288,42 +1288,42 @@ class LinearTransport(BaseTransport): """ OT linear operator between empirical distributions The function estimate the optimal linear operator that align the two - empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` and :math:`N(\mu_t,\Sigma_t)` as proposed in [14]. - + The linear operator from source to target :math:`M` .. math:: M(x)=Ax+b - + where : - + .. math:: A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} \Sigma_s^{-1/2} - .. math:: + .. math:: b=\mu_t-A\mu_s Parameters ---------- reg : float,optional - regularization added to the daigonals of convariances (>0) + regularization added to the daigonals of convariances (>0) bias: boolean, optional estimate bias b else b=0 (default:True) log : bool, optional record log if True - - + + """ - - def __init__(self, reg=1e-8,bias=True,log=False, + + def __init__(self, reg=1e-8, bias=True, log=False, distribution_estimation=distribution_estimation_uniform): - - self.bias=bias - self.log=log - self.reg=reg - self.distribution_estimation=distribution_estimation + + self.bias = bias + self.log = log + self.reg = reg + self.distribution_estimation = distribution_estimation def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples @@ -1349,17 +1349,15 @@ class LinearTransport(BaseTransport): self : object Returns self. """ - + self.mu_s = self.distribution_estimation(Xs) self.mu_t = self.distribution_estimation(Xt) - - # coupling estimation - returned_ = OT_mapping_linear(Xs,Xt,reg=self.reg, - ws=self.mu_s.reshape((-1,1)), - wt=self.mu_t.reshape((-1,1)), - bias=self.bias,log=self.log) + returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, + ws=self.mu_s.reshape((-1, 1)), + wt=self.mu_t.reshape((-1, 1)), + bias=self.bias, log=self.log) # deal with the value of log if self.log: @@ -1367,10 +1365,10 @@ class LinearTransport(BaseTransport): else: self.A_, self.B_, = returned_ self.log_ = dict() - + # re compute inverse mapping - self.A1_=linalg.inv(self.A_) - self.B1_=-self.B_.dot(self.A1_) + self.A1_ = linalg.inv(self.A_) + self.B1_ = -self.B_.dot(self.A1_) return self @@ -1403,7 +1401,7 @@ class LinearTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(Xs=Xs): - transp_Xs= Xs.dot(self.A_)+self.B_ + transp_Xs = Xs.dot(self.A_) + self.B_ return transp_Xs @@ -1437,12 +1435,11 @@ class LinearTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(Xt=Xt): - transp_Xt= Xt.dot(self.A1_)+self.B1_ + transp_Xt = Xt.dot(self.A1_) + self.B1_ return transp_Xt - class SinkhornTransport(BaseTransport): """Domain Adapatation OT method based on Sinkhorn Algorithm |