summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-03-31 09:43:15 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-03-31 09:43:15 +0200
commit171b962cea369aee2513884a1fb3dca8920b77cd (patch)
treede507a4fb61c72f59b6a20571772e605818b59e6 /ot/da.py
parentfa06bb377d083c61f1ac0b067aeeab0fca2b5e7b (diff)
added jcpot
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py288
1 files changed, 287 insertions, 1 deletions
diff --git a/ot/da.py b/ot/da.py
index 108a38d..fd5da4b 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -13,13 +13,14 @@ Domain adaptation with optimal transport
import numpy as np
import scipy.linalg as linalg
-from .bregman import sinkhorn
+from .bregman import sinkhorn, projR, projC
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
from .utils import check_params, BaseEstimator
from .unbalanced import sinkhorn_unbalanced
from .optim import cg
from .optim import gcg
+from functools import reduce
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -745,6 +746,183 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
return A, b
+def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
+ """Joint OT and proportion estimation as proposed in [27]
+
+ The function solves the following optimization problem:
+
+ .. math::
+
+ \mathbf{h} = \argmin_{\mathbf{h} \in \Delta_C}\quad \sum_{k=1}^K \lambda_k
+ W_{reg}\left((\mathbf{D}_2^{(k)} \mathbf{h})^T \mathbf{\delta}_{\mathbf{X}^{(k)}}, \mu\right)
+
+
+ s.t. \gamma^T_k \mathbf{1}_n = \mathbf{1}_n/n
+
+ \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) squared euclidean cost matrix between samples in
+ Xs and Xt (scaled by ns)
+ - :math:`L` is a ns x d linear operator on a kernel matrix that
+ approximates the barycentric mapping
+ - a and b are uniform source and target weights
+
+ The problem consist in solving jointly an optimal transport matrix
+ :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping
+ :math:`n_s\gamma X_t`.
+
+ One can also estimate a mapping with constant bias (see supplementary
+ material of [8]) using the bias optional argument.
+
+ The algorithm used for solving the problem is the block coordinate
+ descent that alternates between updates of G (using conditional gradient)
+ and the update of L using a classical kernel least square solver.
+
+
+ Parameters
+ ----------
+ xs : np.ndarray (ns,d)
+ samples in the source domain
+ xt : np.ndarray (nt,d)
+ samples in the target domain
+ mu : float,optional
+ Weight for the linear OT loss (>0)
+ eta : float, optional
+ Regularization term for the linear mapping L (>0)
+ kerneltype : str,optional
+ kernel used by calling function ot.utils.kernel (gaussian by default)
+ sigma : float, optional
+ Gaussian kernel bandwidth.
+ bias : bool,optional
+ Estimate linear mapping with constant bias
+ verbose : bool, optional
+ Print information along iterations
+ verbose2 : bool, optional
+ Print information along iterations
+ numItermax : int, optional
+ Max number of BCD iterations
+ numInnerItermax : int, optional
+ Max number of iterations (inner CG solver)
+ stopInnerThr : float, optional
+ Stop threshold on error (inner CG solver) (>0)
+ stopThr : float, optional
+ Stop threshold on relative loss decrease (>0)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ L : (ns x d) ndarray
+ Nonlinear mapping matrix (ns+1 x d if bias)
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+ nbclasses = len(np.unique(Ys[0]))
+ nbdomains = len(Xs)
+
+ # we then build, for each source domain, specific information
+ all_domains = []
+ for d in range(nbdomains):
+ dom = {}
+ # get number of elements for this domain
+ nb_elem = Xs[d].shape[0]
+ dom['nbelem'] = nb_elem
+ classes = np.unique(Ys[d])
+
+ if np.min(classes) != 0:
+ Ys[d] = Ys[d] - np.min(classes)
+ classes = np.unique(Ys[d])
+
+ # build the corresponding D matrix
+ D1 = np.zeros((nbclasses, nb_elem))
+ D2 = np.zeros((nbclasses, nb_elem))
+ classes_d = np.zeros(nbclasses)
+
+ classes_d[np.unique(Ys[d]).astype(int)] = 1
+ dom['classes'] = classes_d
+
+ for c in classes:
+ nbelemperclass = np.sum(Ys[d] == c)
+ if nbelemperclass != 0:
+ D1[int(c), Ys[d] == c] = 1.
+ D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) # *nbclasses_d)
+ dom['D1'] = D1
+ dom['D2'] = D2
+
+ # build the distance matrix
+ M = dist(Xs[d], Xt, metric=metric)
+ M = M / np.median(M)
+
+ dom['K'] = np.exp(-M/reg)
+
+ all_domains.append(dom)
+
+ distribT = unif(np.shape(Xt)[0])
+
+ if log:
+ log = {'niter': 0, 'err': []}
+
+ cpt = 0
+ err = 1
+ old_bary = np.ones((nbclasses))
+
+ while (err > stopThr and cpt < numItermax):
+
+ bary = np.zeros((nbclasses))
+
+ for d in range(nbdomains):
+ all_domains[d]['K'] = projC(all_domains[d]['K'], distribT)
+ other = np.sum(all_domains[d]['K'], axis=1)
+ bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains
+
+ bary = np.exp(bary)
+
+ for d in range(nbdomains):
+ new = np.dot(all_domains[d]['D2'].T, bary)
+ all_domains[d]['K'] = projR(all_domains[d]['K'], new)
+
+ err = np.linalg.norm(bary - old_bary)
+ cpt = cpt + 1
+ old_bary = bary
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ bary = bary / np.sum(bary)
+
+ if log:
+ log['niter'] = cpt
+ return bary, log
+ else:
+ return bary
+
+
def distribution_estimation_uniform(X):
"""estimates a uniform distribution from an array of samples X
@@ -1914,3 +2092,111 @@ class UnbalancedSinkhornTransport(BaseTransport):
self.log_ = dict()
return self
+
+class JCPOTTransport(BaseTransport):
+
+ """Domain Adapatation OT method for target shift based on sinkhorn algorithm.
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ max_iter : int, float, optional (default=10)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ max_inner_iter : int, float, optional (default=200)
+ The number of iteration in the inner loop
+ tol : float, optional (default=10e-9)
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+ log : bool, optional (default=False)
+ Controls the logs of the optimization algorithm
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ distribution_estimation : callable, optional (defaults to the uniform)
+ The kind of distribution estimation to employ
+ out_of_sample_map : string, optional (default="ferradans")
+ The kind of out of sample mapping to apply to transport samples
+ from a domain into another one. Currently the only possible option is
+ "ferradans" which uses the method proposed in [6].
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+
+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE
+ Transactions on Pattern Analysis and Machine Intelligence ,
+ vol.PP, no.99, pp.1-1
+ .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence
+ and applications. arXiv preprint arXiv:1510.06567.
+
+ """
+
+ def __init__(self, reg_e=.1, max_iter=10,
+ tol=10e-9, verbose=False, log=False,
+ metric="sqeuclidean", norm=None,
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans'):
+
+ self.reg_e = reg_e
+ self.max_iter = max_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.metric = metric
+ self.norm = norm
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt, ys=ys):
+
+ returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg = self.reg_e,
+ metric=self.metric, numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
+ return self