From 6aa0f1f4e275098948d4b312530119e5d95b8884 Mon Sep 17 00:00:00 2001 From: ievred Date: Tue, 31 Mar 2020 17:12:28 +0200 Subject: v1 jcpot example test --- ot/da.py | 263 ++++++++++++++++++++++++++++++++++++++------------------------- 1 file changed, 157 insertions(+), 106 deletions(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index fd5da4b..a3da8c1 100644 --- a/ot/da.py +++ b/ot/da.py @@ -748,79 +748,58 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, stopThr=1e-6, verbose=False, log=False, **kwargs): - """Joint OT and proportion estimation as proposed in [27] + r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27] The function solves the following optimization problem: .. math:: - \mathbf{h} = \argmin_{\mathbf{h} \in \Delta_C}\quad \sum_{k=1}^K \lambda_k - W_{reg}\left((\mathbf{D}_2^{(k)} \mathbf{h})^T \mathbf{\delta}_{\mathbf{X}^{(k)}}, \mu\right) + \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k + W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) + s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} - s.t. \gamma^T_k \mathbf{1}_n = \mathbf{1}_n/n - - \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} - - \gamma\geq 0 where : - - M is the (ns,nt) squared euclidean cost matrix between samples in - Xs and Xt (scaled by ns) - - :math:`L` is a ns x d linear operator on a kernel matrix that - approximates the barycentric mapping - - a and b are uniform source and target weights - - The problem consist in solving jointly an optimal transport matrix - :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping - :math:`n_s\gamma X_t`. + - :math:`\lambda_k` is the weight of k-th source domain + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes + - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C + - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)` - One can also estimate a mapping with constant bias (see supplementary - material of [8]) using the bias optional argument. - - The algorithm used for solving the problem is the block coordinate - descent that alternates between updates of G (using conditional gradient) - and the update of L using a classical kernel least square solver. + The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. + The algorithm used for solving the problem is the Iterative Bregman projections algorithm + with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution. Parameters ---------- - xs : np.ndarray (ns,d) - samples in the source domain - xt : np.ndarray (nt,d) + Xs : list of K np.ndarray(nsk,d) + features of all source domains' samples + Ys : list of K np.ndarray(nsk,) + labels of all source domains' samples + Xt : np.ndarray (nt,d) samples in the target domain - mu : float,optional - Weight for the linear OT loss (>0) - eta : float, optional - Regularization term for the linear mapping L (>0) - kerneltype : str,optional - kernel used by calling function ot.utils.kernel (gaussian by default) - sigma : float, optional - Gaussian kernel bandwidth. - bias : bool,optional - Estimate linear mapping with constant bias - verbose : bool, optional - Print information along iterations - verbose2 : bool, optional - Print information along iterations + reg : float + Regularization term > 0 + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem numItermax : int, optional - Max number of BCD iterations - numInnerItermax : int, optional - Max number of iterations (inner CG solver) - stopInnerThr : float, optional - Stop threshold on error (inner CG solver) (>0) + Max number of iterations stopThr : float, optional - Stop threshold on relative loss decrease (>0) + Stop threshold on relative change in the barycenter (>0) log : bool, optional record log if True - + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm Returns ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - L : (ns x d) ndarray - Nonlinear mapping matrix (ns+1 x d if bias) + gamma : List of K (nsk x nt) ndarrays + Optimal transportation matrices for the given parameters for each pair of source and target domains + h : (C,) ndarray + proportion estimation in the target domain log : dict log dictionary return only if log==True in parameters @@ -828,62 +807,59 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, References ---------- - .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, - "Mapping estimation for discrete optimal transport", - Neural Information Processing Systems (NIPS), 2016. - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT + .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. - """ + ''' nbclasses = len(np.unique(Ys[0])) nbdomains = len(Xs) - # we then build, for each source domain, specific information + # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 all_domains = [] + + # log dictionary + if log: + log = {'niter': 0, 'err': [], 'all_domains': []} + for d in range(nbdomains): dom = {} - # get number of elements for this domain - nb_elem = Xs[d].shape[0] - dom['nbelem'] = nb_elem - classes = np.unique(Ys[d]) + nsk = Xs[d].shape[0] # get number of elements for this domain + dom['nbelem'] = nsk + classes = np.unique(Ys[d]) # get number of classes for this domain + # format classes to start from 0 for convenience if np.min(classes) != 0: Ys[d] = Ys[d] - np.min(classes) classes = np.unique(Ys[d]) - # build the corresponding D matrix - D1 = np.zeros((nbclasses, nb_elem)) - D2 = np.zeros((nbclasses, nb_elem)) - classes_d = np.zeros(nbclasses) - - classes_d[np.unique(Ys[d]).astype(int)] = 1 - dom['classes'] = classes_d + # build the corresponding D_1 and D_2 matrices + D1 = np.zeros((nbclasses, nsk)) + D2 = np.zeros((nbclasses, nsk)) for c in classes: nbelemperclass = np.sum(Ys[d] == c) if nbelemperclass != 0: D1[int(c), Ys[d] == c] = 1. - D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) # *nbclasses_d) + D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) dom['D1'] = D1 dom['D2'] = D2 - # build the distance matrix + # build the cost matrix and the Gibbs kernel M = dist(Xs[d], Xt, metric=metric) M = M / np.median(M) - dom['K'] = np.exp(-M/reg) + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + dom['K'] = K all_domains.append(dom) - distribT = unif(np.shape(Xt)[0]) - - if log: - log = {'niter': 0, 'err': []} + # uniform target distribution + a = unif(np.shape(Xt)[0]) - cpt = 0 + cpt = 0 # iterations count err = 1 old_bary = np.ones((nbclasses)) @@ -891,13 +867,15 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, bary = np.zeros((nbclasses)) + # update coupling matrices for marginal constraints w.r.t. uniform target distribution for d in range(nbdomains): - all_domains[d]['K'] = projC(all_domains[d]['K'], distribT) + all_domains[d]['K'] = projC(all_domains[d]['K'], a) other = np.sum(all_domains[d]['K'], axis=1) bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains bary = np.exp(bary) + # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] for d in range(nbdomains): new = np.dot(all_domains[d]['D2'].T, bary) all_domains[d]['K'] = projR(all_domains[d]['K'], new) @@ -915,12 +893,14 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, print('{:5d}|{:8e}|'.format(cpt, err)) bary = bary / np.sum(bary) + couplings = [all_domains[d]['K'] for d in range(nbdomains)] if log: log['niter'] = cpt - return bary, log + log['all_domains'] = all_domains + return couplings, bary, log else: - return bary + return couplings, bary def distribution_estimation_uniform(X): @@ -2093,9 +2073,10 @@ class UnbalancedSinkhornTransport(BaseTransport): return self + class JCPOTTransport(BaseTransport): - """Domain Adapatation OT method for target shift based on sinkhorn algorithm. + """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm. Parameters ---------- @@ -2104,8 +2085,6 @@ class JCPOTTransport(BaseTransport): max_iter : int, float, optional (default=10) The minimum number of iteration before stopping the optimization algorithm if no it has not converged - max_inner_iter : int, float, optional (default=200) - The number of iteration in the inner loop tol : float, optional (default=10e-9) Stop threshold on error (inner sinkhorn solver) (>0) verbose : bool, optional (default=False) @@ -2126,21 +2105,20 @@ class JCPOTTransport(BaseTransport): Attributes ---------- - coupling_ : array-like, shape (n_source_samples, n_target_samples) - The optimal coupling + coupling_ : list of array-like objects, shape K x (n_source_samples, n_target_samples) + A set of optimal couplings between each source domain and the target domain + proportions_ : array-like, shape (n_classes,) + Estimated class proportions in the target domain log_ : dictionary The dictionary of log, empty dic if parameter log is not True References ---------- - .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, - "Optimal Transport for Domain Adaptation," in IEEE - Transactions on Pattern Analysis and Machine Intelligence , - vol.PP, no.99, pp.1-1 - .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). - Generalized conditional gradient: analysis of convergence - and applications. arXiv preprint arXiv:1510.06567. + .. [1] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), + vol. 89, p.849-858, 2019. """ @@ -2156,20 +2134,18 @@ class JCPOTTransport(BaseTransport): self.verbose = verbose self.log = log self.metric = metric - self.norm = norm - self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + """Building coupling matrices from a list of source and target sets of samples (Xs, ys) and (Xt, yt) Parameters ---------- - Xs : array-like, shape (n_source_samples, n_features) - The training input samples. - ys : array-like, shape (n_source_samples,) - The class labels + Xs : list of K array-like objects, shape K x (nk_source_samples, n_features) + A list of the training input samples. + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) @@ -2188,15 +2164,90 @@ class JCPOTTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt, ys=ys): - returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg = self.reg_e, - metric=self.metric, numItermax=self.max_iter, stopThr=self.tol, + self.xs_ = Xs + self.xt_ = Xt + + returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e, + metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol, verbose=self.verbose, log=self.log) # deal with the value of log if self.log: - self.coupling_, self.log_ = returned_ + self.coupling_, self.proportions_, self.log_ = returned_ else: - self.coupling_ = returned_ + self.coupling_, self.proportions_ = returned_ self.log_ = dict() return self + + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + """Transports source samples Xs onto target ones Xt + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + """ + + transp_Xs = [] + + # check the necessary inputs parameters are here + if check_params(Xs=Xs): + + if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]): + + # perform standard barycentric mapping for each source domain + + for coupling in self.coupling_: + transp = coupling / np.sum(coupling, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + # compute transported samples + transp_Xs.append(np.dot(transp, self.xt_)) + else: + + # perform out of sample mapping + indices = np.arange(Xs.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size)] + + transp_Xs = [] + + for bi in batch_ind: + transp_Xs_ = [] + + # get the nearest neighbor in the sources domains + xs = np.concatenate(self.xs_, axis=0) + idx = np.argmin(dist(Xs[bi], xs), axis=1) + + # transport the source samples + for coupling in self.coupling_: + transp = coupling / np.sum( + coupling, 1)[:, None] + transp[~ np.isfinite(transp)] = 0 + transp_Xs_.append(np.dot(transp, self.xt_)) + + transp_Xs_ = np.concatenate(transp_Xs_, axis=0) + + # define the transported points + transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :] + transp_Xs.append(transp_Xs_) + + transp_Xs = np.concatenate(transp_Xs, axis=0) + + return transp_Xs -- cgit v1.2.3