diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-03-31 09:43:15 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-03-31 09:43:15 +0200 |
commit | 171b962cea369aee2513884a1fb3dca8920b77cd (patch) | |
tree | de507a4fb61c72f59b6a20571772e605818b59e6 /ot/da.py | |
parent | fa06bb377d083c61f1ac0b067aeeab0fca2b5e7b (diff) |
added jcpot
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 288 |
1 files changed, 287 insertions, 1 deletions
@@ -13,13 +13,14 @@ Domain adaptation with optimal transport import numpy as np import scipy.linalg as linalg -from .bregman import sinkhorn +from .bregman import sinkhorn, projR, projC from .lp import emd from .utils import unif, dist, kernel, cost_normalization from .utils import check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg from .optim import gcg +from functools import reduce def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, @@ -745,6 +746,183 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, return A, b +def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, + stopThr=1e-6, verbose=False, log=False, **kwargs): + """Joint OT and proportion estimation as proposed in [27] + + The function solves the following optimization problem: + + .. math:: + + \mathbf{h} = \argmin_{\mathbf{h} \in \Delta_C}\quad \sum_{k=1}^K \lambda_k + W_{reg}\left((\mathbf{D}_2^{(k)} \mathbf{h})^T \mathbf{\delta}_{\mathbf{X}^{(k)}}, \mu\right) + + + s.t. \gamma^T_k \mathbf{1}_n = \mathbf{1}_n/n + + \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} + + \gamma\geq 0 + where : + + - M is the (ns,nt) squared euclidean cost matrix between samples in + Xs and Xt (scaled by ns) + - :math:`L` is a ns x d linear operator on a kernel matrix that + approximates the barycentric mapping + - a and b are uniform source and target weights + + The problem consist in solving jointly an optimal transport matrix + :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping + :math:`n_s\gamma X_t`. + + One can also estimate a mapping with constant bias (see supplementary + material of [8]) using the bias optional argument. + + The algorithm used for solving the problem is the block coordinate + descent that alternates between updates of G (using conditional gradient) + and the update of L using a classical kernel least square solver. + + + Parameters + ---------- + xs : np.ndarray (ns,d) + samples in the source domain + xt : np.ndarray (nt,d) + samples in the target domain + mu : float,optional + Weight for the linear OT loss (>0) + eta : float, optional + Regularization term for the linear mapping L (>0) + kerneltype : str,optional + kernel used by calling function ot.utils.kernel (gaussian by default) + sigma : float, optional + Gaussian kernel bandwidth. + bias : bool,optional + Estimate linear mapping with constant bias + verbose : bool, optional + Print information along iterations + verbose2 : bool, optional + Print information along iterations + numItermax : int, optional + Max number of BCD iterations + numInnerItermax : int, optional + Max number of iterations (inner CG solver) + stopInnerThr : float, optional + Stop threshold on error (inner CG solver) (>0) + stopThr : float, optional + Stop threshold on relative loss decrease (>0) + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + L : (ns x d) ndarray + Nonlinear mapping matrix (ns+1 x d if bias) + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, + "Mapping estimation for discrete optimal transport", + Neural Information Processing Systems (NIPS), 2016. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + nbclasses = len(np.unique(Ys[0])) + nbdomains = len(Xs) + + # we then build, for each source domain, specific information + all_domains = [] + for d in range(nbdomains): + dom = {} + # get number of elements for this domain + nb_elem = Xs[d].shape[0] + dom['nbelem'] = nb_elem + classes = np.unique(Ys[d]) + + if np.min(classes) != 0: + Ys[d] = Ys[d] - np.min(classes) + classes = np.unique(Ys[d]) + + # build the corresponding D matrix + D1 = np.zeros((nbclasses, nb_elem)) + D2 = np.zeros((nbclasses, nb_elem)) + classes_d = np.zeros(nbclasses) + + classes_d[np.unique(Ys[d]).astype(int)] = 1 + dom['classes'] = classes_d + + for c in classes: + nbelemperclass = np.sum(Ys[d] == c) + if nbelemperclass != 0: + D1[int(c), Ys[d] == c] = 1. + D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) # *nbclasses_d) + dom['D1'] = D1 + dom['D2'] = D2 + + # build the distance matrix + M = dist(Xs[d], Xt, metric=metric) + M = M / np.median(M) + + dom['K'] = np.exp(-M/reg) + + all_domains.append(dom) + + distribT = unif(np.shape(Xt)[0]) + + if log: + log = {'niter': 0, 'err': []} + + cpt = 0 + err = 1 + old_bary = np.ones((nbclasses)) + + while (err > stopThr and cpt < numItermax): + + bary = np.zeros((nbclasses)) + + for d in range(nbdomains): + all_domains[d]['K'] = projC(all_domains[d]['K'], distribT) + other = np.sum(all_domains[d]['K'], axis=1) + bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains + + bary = np.exp(bary) + + for d in range(nbdomains): + new = np.dot(all_domains[d]['D2'].T, bary) + all_domains[d]['K'] = projR(all_domains[d]['K'], new) + + err = np.linalg.norm(bary - old_bary) + cpt = cpt + 1 + old_bary = bary + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + bary = bary / np.sum(bary) + + if log: + log['niter'] = cpt + return bary, log + else: + return bary + + def distribution_estimation_uniform(X): """estimates a uniform distribution from an array of samples X @@ -1914,3 +2092,111 @@ class UnbalancedSinkhornTransport(BaseTransport): self.log_ = dict() return self + +class JCPOTTransport(BaseTransport): + + """Domain Adapatation OT method for target shift based on sinkhorn algorithm. + + Parameters + ---------- + reg_e : float, optional (default=1) + Entropic regularization parameter + max_iter : int, float, optional (default=10) + The minimum number of iteration before stopping the optimization + algorithm if no it has not converged + max_inner_iter : int, float, optional (default=200) + The number of iteration in the inner loop + tol : float, optional (default=10e-9) + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + log : bool, optional (default=False) + Controls the logs of the optimization algorithm + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True + + References + ---------- + + .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE + Transactions on Pattern Analysis and Machine Intelligence , + vol.PP, no.99, pp.1-1 + .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). + Generalized conditional gradient: analysis of convergence + and applications. arXiv preprint arXiv:1510.06567. + + """ + + def __init__(self, reg_e=.1, max_iter=10, + tol=10e-9, verbose=False, log=False, + metric="sqeuclidean", norm=None, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans'): + + self.reg_e = reg_e + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.log = log + self.metric = metric + self.norm = norm + self.distribution_estimation = distribution_estimation + self.out_of_sample_map = out_of_sample_map + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt, ys=ys): + + returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg = self.reg_e, + metric=self.metric, numItermax=self.max_iter, stopThr=self.tol, + verbose=self.verbose, log=self.log) + + # deal with the value of log + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + + return self |