diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 136 |
1 files changed, 20 insertions, 116 deletions
@@ -17,8 +17,9 @@ from .backend import get_backend from .bregman import sinkhorn, jcpot_barycenter from .lp import emd from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots -from .utils import list_to_array, check_params, BaseEstimator +from .utils import list_to_array, check_params, BaseEstimator, deprecated from .unbalanced import sinkhorn_unbalanced +from .gaussian import empirical_bures_wasserstein_mapping from .optim import cg from .optim import gcg @@ -126,8 +127,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, W = nx.zeros(M.shape, type_as=M) for cpt in range(numItermax): Mreg = M + eta * W - transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, - stopThr=stopInnerThr) + if log: + transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, + stopThr=stopInnerThr, log=True) + else: + transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, + stopThr=stopInnerThr) # the transport has been computed. Check if classes are really # separated W = nx.ones(M.shape, type_as=M) @@ -136,7 +141,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, majs = p * ((majs + epsilon) ** (p - 1)) W[indices_labels[i]] = majs - return transp + if log: + return transp, log + else: + return transp def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, @@ -672,112 +680,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', return G, L -def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, - wt=None, bias=True, log=False): - r"""Return OT linear operator between samples. - - The function estimates the optimal linear operator that aligns the two - empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` - and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in - :ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in - :ref:`[15] <references-OT-mapping-linear>`. - - The linear operator from source to target :math:`M` - - .. math:: - M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} - - where : - - .. math:: - \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} - \Sigma_s^{-1/2} - - \mathbf{b} &= \mu_t - \mathbf{A} \mu_s - - Parameters - ---------- - xs : array-like (ns,d) - samples in the source domain - xt : array-like (nt,d) - samples in the target domain - reg : float,optional - regularization added to the diagonals of covariances (>0) - ws : array-like (ns,1), optional - weights for the source samples - wt : array-like (ns,1), optional - weights for the target samples - bias: boolean, optional - estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) - log : bool, optional - record log if True - - - Returns - ------- - A : (d, d) array-like - Linear operator - b : (1, d) array-like - bias - log : dict - log dictionary return only if log==True in parameters - - - .. _references-OT-mapping-linear: - References - ---------- - .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of - distributions", Journal of Optimization Theory and Applications - Vol 43, 1984 - - .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - - """ - xs, xt = list_to_array(xs, xt) - nx = get_backend(xs, xt) - - d = xs.shape[1] - - if bias: - mxs = nx.mean(xs, axis=0)[None, :] - mxt = nx.mean(xt, axis=0)[None, :] - - xs = xs - mxs - xt = xt - mxt - else: - mxs = nx.zeros((1, d), type_as=xs) - mxt = nx.zeros((1, d), type_as=xs) - - if ws is None: - ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] - - if wt is None: - wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] - - Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) - Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) - - Cs12 = nx.sqrtm(Cs) - Cs_12 = nx.inv(Cs12) - - M0 = nx.sqrtm(dots(Cs12, Ct, Cs12)) - - A = dots(Cs_12, M0, Cs_12) - - b = mxt - nx.dot(mxs, A) - - if log: - log = {} - log['Cs'] = Cs - log['Ct'] = Ct - log['Cs12'] = Cs12 - log['Cs_12'] = Cs_12 - return A, b, log - else: - return A, b +OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping) def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5, @@ -1371,10 +1274,10 @@ class LinearTransport(BaseTransport): self.mu_t = self.distribution_estimation(Xt) # coupling estimation - returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, - ws=nx.reshape(self.mu_s, (-1, 1)), - wt=nx.reshape(self.mu_t, (-1, 1)), - bias=self.bias, log=self.log) + returned_ = empirical_bures_wasserstein_mapping(Xs, Xt, reg=self.reg, + ws=nx.reshape(self.mu_s, (-1, 1)), + wt=nx.reshape(self.mu_t, (-1, 1)), + bias=self.bias, log=self.log) # deal with the value of log if self.log: @@ -1514,12 +1417,13 @@ class SinkhornTransport(BaseTransport): Sciences, 7(3), 1853-1882. """ - def __init__(self, reg_e=1., max_iter=1000, + def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=np.infty): self.reg_e = reg_e + self.method = method self.max_iter = max_iter self.tol = tol self.verbose = verbose @@ -1560,7 +1464,7 @@ class SinkhornTransport(BaseTransport): # coupling estimation returned_ = sinkhorn( a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e, - numItermax=self.max_iter, stopThr=self.tol, + method=self.method, numItermax=self.max_iter, stopThr=self.tol, verbose=self.verbose, log=self.log) # deal with the value of log |