diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-03-31 17:36:00 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-03-31 17:36:00 +0200 |
commit | ba493aa5488507937b7f9707faa17128c9aa1872 (patch) | |
tree | a99d7afcc2ca0988fc5c9f3c94dc240c1ead2cff /ot | |
parent | 6aa0f1f4e275098948d4b312530119e5d95b8884 (diff) |
readme move to bregman
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 157 | ||||
-rw-r--r-- | ot/da.py | 190 |
2 files changed, 167 insertions, 180 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index d5e3563..d17aaf0 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -10,6 +10,7 @@ Bregman projections for regularized OT # Hicham Janati <hicham.janati@inria.fr> # Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> # Alexander Tong <alexander.tong@yale.edu> +# Ievgen Redko <ievgen.redko@univ-st-etienne.fr> # # License: MIT License @@ -18,7 +19,6 @@ import warnings from .utils import unif, dist from scipy.optimize import fmin_l_bfgs_b - def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" @@ -1501,6 +1501,161 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, else: return np.sum(K0, axis=1) +def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, + stopThr=1e-6, verbose=False, log=False, **kwargs): + r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27] + + The function solves the following optimization problem: + + .. math:: + + \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k + W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) + + s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} + + where : + + - :math:`\lambda_k` is the weight of k-th source domain + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes + - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C + - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)` + + The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. + + The algorithm used for solving the problem is the Iterative Bregman projections algorithm + with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution. + + Parameters + ---------- + Xs : list of K np.ndarray(nsk,d) + features of all source domains' samples + Ys : list of K np.ndarray(nsk,) + labels of all source domains' samples + Xt : np.ndarray (nt,d) + samples in the target domain + reg : float + Regularization term > 0 + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on relative change in the barycenter (>0) + log : bool, optional + record log if True + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + + Returns + ------- + gamma : List of K (nsk x nt) ndarrays + Optimal transportation matrices for the given parameters for each pair of source and target domains + h : (C,) ndarray + proportion estimation in the target domain + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. + + ''' + nbclasses = len(np.unique(Ys[0])) + nbdomains = len(Xs) + + # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 + all_domains = [] + + # log dictionary + if log: + log = {'niter': 0, 'err': [], 'all_domains': []} + + for d in range(nbdomains): + dom = {} + nsk = Xs[d].shape[0] # get number of elements for this domain + dom['nbelem'] = nsk + classes = np.unique(Ys[d]) # get number of classes for this domain + + # format classes to start from 0 for convenience + if np.min(classes) != 0: + Ys[d] = Ys[d] - np.min(classes) + classes = np.unique(Ys[d]) + + # build the corresponding D_1 and D_2 matrices + D1 = np.zeros((nbclasses, nsk)) + D2 = np.zeros((nbclasses, nsk)) + + for c in classes: + nbelemperclass = np.sum(Ys[d] == c) + if nbelemperclass != 0: + D1[int(c), Ys[d] == c] = 1. + D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) + dom['D1'] = D1 + dom['D2'] = D2 + + # build the cost matrix and the Gibbs kernel + M = dist(Xs[d], Xt, metric=metric) + M = M / np.median(M) + + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + dom['K'] = K + + all_domains.append(dom) + + # uniform target distribution + a = unif(np.shape(Xt)[0]) + + cpt = 0 # iterations count + err = 1 + old_bary = np.ones((nbclasses)) + + while (err > stopThr and cpt < numItermax): + + bary = np.zeros((nbclasses)) + + # update coupling matrices for marginal constraints w.r.t. uniform target distribution + for d in range(nbdomains): + all_domains[d]['K'] = projC(all_domains[d]['K'], a) + other = np.sum(all_domains[d]['K'], axis=1) + bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains + + bary = np.exp(bary) + + # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] + for d in range(nbdomains): + new = np.dot(all_domains[d]['D2'].T, bary) + all_domains[d]['K'] = projR(all_domains[d]['K'], new) + + err = np.linalg.norm(bary - old_bary) + cpt = cpt + 1 + old_bary = bary + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + bary = bary / np.sum(bary) + couplings = [all_domains[d]['K'] for d in range(nbdomains)] + + if log: + log['niter'] = cpt + log['all_domains'] = all_domains + return couplings, bary, log + else: + return couplings, bary def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, @@ -7,20 +7,20 @@ Domain adaptation with optimal transport # Nicolas Courty <ncourty@irisa.fr> # Michael Perrot <michael.perrot@univ-st-etienne.fr> # Nathalie Gayraud <nat.gayraud@gmail.com> +# Ievgen Redko <ievgen.redko@univ-st-etienne.fr> # # License: MIT License import numpy as np import scipy.linalg as linalg -from .bregman import sinkhorn, projR, projC +from .bregman import sinkhorn from .lp import emd from .utils import unif, dist, kernel, cost_normalization from .utils import check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg from .optim import gcg -from functools import reduce def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, @@ -128,7 +128,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, W = np.ones(M.shape) for (i, c) in enumerate(classes): majs = np.sum(transp[indices_labels[i]], axis=0) - majs = p * ((majs + epsilon)**(p - 1)) + majs = p * ((majs + epsilon) ** (p - 1)) W[indices_labels[i]] = majs return transp @@ -360,8 +360,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, def loss(L, G): """Compute full loss""" - return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \ - np.sum(G * M) + eta * np.sum(sel(L - I0)**2) + return np.sum((xs1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ + np.sum(G * M) + eta * np.sum(sel(L - I0) ** 2) def solve_L(G): """ solve L problem with fixed G (least square)""" @@ -373,10 +373,11 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, xsi = xs1.dot(L) def f(G): - return np.sum((xsi - ns * G.dot(xt))**2) + return np.sum((xsi - ns * G.dot(xt)) ** 2) def df(G): return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) return G @@ -563,8 +564,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def loss(L, G): """Compute full loss""" - return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \ - np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) + return np.sum((K1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ + np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" @@ -581,10 +582,11 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', xsi = K1.dot(L) def f(G): - return np.sum((xsi - ns * G.dot(xt))**2) + return np.sum((xsi - ns * G.dot(xt)) ** 2) def df(G): return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) return G @@ -746,163 +748,6 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, return A, b -def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, - stopThr=1e-6, verbose=False, log=False, **kwargs): - r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27] - - The function solves the following optimization problem: - - .. math:: - - \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k - W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) - - s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} - - where : - - - :math:`\lambda_k` is the weight of k-th source domain - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) - - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes - - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C - - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` - - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)` - - The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. - - The algorithm used for solving the problem is the Iterative Bregman projections algorithm - with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution. - - Parameters - ---------- - Xs : list of K np.ndarray(nsk,d) - features of all source domains' samples - Ys : list of K np.ndarray(nsk,) - labels of all source domains' samples - Xt : np.ndarray (nt,d) - samples in the target domain - reg : float - Regularization term > 0 - metric : string, optional (default="sqeuclidean") - The ground metric for the Wasserstein problem - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on relative change in the barycenter (>0) - log : bool, optional - record log if True - verbose : bool, optional (default=False) - Controls the verbosity of the optimization algorithm - - Returns - ------- - gamma : List of K (nsk x nt) ndarrays - Optimal transportation matrices for the given parameters for each pair of source and target domains - h : (C,) ndarray - proportion estimation in the target domain - log : dict - log dictionary return only if log==True in parameters - - - References - ---------- - - .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia - "Optimal transport for multi-source domain adaptation under target shift", - International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. - - ''' - nbclasses = len(np.unique(Ys[0])) - nbdomains = len(Xs) - - # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 - all_domains = [] - - # log dictionary - if log: - log = {'niter': 0, 'err': [], 'all_domains': []} - - for d in range(nbdomains): - dom = {} - nsk = Xs[d].shape[0] # get number of elements for this domain - dom['nbelem'] = nsk - classes = np.unique(Ys[d]) # get number of classes for this domain - - # format classes to start from 0 for convenience - if np.min(classes) != 0: - Ys[d] = Ys[d] - np.min(classes) - classes = np.unique(Ys[d]) - - # build the corresponding D_1 and D_2 matrices - D1 = np.zeros((nbclasses, nsk)) - D2 = np.zeros((nbclasses, nsk)) - - for c in classes: - nbelemperclass = np.sum(Ys[d] == c) - if nbelemperclass != 0: - D1[int(c), Ys[d] == c] = 1. - D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) - dom['D1'] = D1 - dom['D2'] = D2 - - # build the cost matrix and the Gibbs kernel - M = dist(Xs[d], Xt, metric=metric) - M = M / np.median(M) - - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - dom['K'] = K - - all_domains.append(dom) - - # uniform target distribution - a = unif(np.shape(Xt)[0]) - - cpt = 0 # iterations count - err = 1 - old_bary = np.ones((nbclasses)) - - while (err > stopThr and cpt < numItermax): - - bary = np.zeros((nbclasses)) - - # update coupling matrices for marginal constraints w.r.t. uniform target distribution - for d in range(nbdomains): - all_domains[d]['K'] = projC(all_domains[d]['K'], a) - other = np.sum(all_domains[d]['K'], axis=1) - bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains - - bary = np.exp(bary) - - # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] - for d in range(nbdomains): - new = np.dot(all_domains[d]['D2'].T, bary) - all_domains[d]['K'] = projR(all_domains[d]['K'], new) - - err = np.linalg.norm(bary - old_bary) - cpt = cpt + 1 - old_bary = bary - - if log: - log['err'].append(err) - - if verbose: - if cpt % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - bary = bary / np.sum(bary) - couplings = [all_domains[d]['K'] for d in range(nbdomains)] - - if log: - log['niter'] = cpt - log['all_domains'] = all_domains - return couplings, bary, log - else: - return couplings, bary - - def distribution_estimation_uniform(X): """estimates a uniform distribution from an array of samples X @@ -921,7 +766,6 @@ def distribution_estimation_uniform(X): class BaseTransport(BaseEstimator): - """Base class for OTDA objects Notes @@ -1079,7 +923,6 @@ class BaseTransport(BaseEstimator): transp_Xs = [] for bi in batch_ind: - # get the nearest neighbor in the source domain D0 = dist(Xs[bi], self.xs_) idx = np.argmin(D0, axis=1) @@ -1148,7 +991,6 @@ class BaseTransport(BaseEstimator): transp_Xt = [] for bi in batch_ind: - D0 = dist(Xt[bi], self.xt_) idx = np.argmin(D0, axis=1) @@ -1294,7 +1136,6 @@ class LinearTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(Xs=Xs): - transp_Xs = Xs.dot(self.A_) + self.B_ return transp_Xs @@ -1328,14 +1169,12 @@ class LinearTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(Xt=Xt): - transp_Xt = Xt.dot(self.A1_) + self.B1_ return transp_Xt class SinkhornTransport(BaseTransport): - """Domain Adapatation OT method based on Sinkhorn Algorithm Parameters @@ -1445,7 +1284,6 @@ class SinkhornTransport(BaseTransport): class EMDTransport(BaseTransport): - """Domain Adapatation OT method based on Earth Mover's Distance Parameters @@ -1537,7 +1375,6 @@ class EMDTransport(BaseTransport): class SinkhornLpl1Transport(BaseTransport): - """Domain Adapatation OT method based on sinkhorn algorithm + LpL1 class regularization. @@ -1639,7 +1476,6 @@ class SinkhornLpl1Transport(BaseTransport): # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt, ys=ys): - super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt) returned_ = sinkhorn_lpl1_mm( @@ -1658,7 +1494,6 @@ class SinkhornLpl1Transport(BaseTransport): class SinkhornL1l2Transport(BaseTransport): - """Domain Adapatation OT method based on sinkhorn algorithm + l1l2 class regularization. @@ -1782,7 +1617,6 @@ class SinkhornL1l2Transport(BaseTransport): class MappingTransport(BaseEstimator): - """MappingTransport: DA methods that aims at jointly estimating a optimal transport coupling and the associated mapping @@ -1956,7 +1790,6 @@ class MappingTransport(BaseEstimator): class UnbalancedSinkhornTransport(BaseTransport): - """Domain Adapatation unbalanced OT method based on sinkhorn algorithm Parameters @@ -2075,7 +1908,6 @@ class UnbalancedSinkhornTransport(BaseTransport): class JCPOTTransport(BaseTransport): - """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm. Parameters |