From 092866815cf906012f9194b87af1e7ae0270f7e7 Mon Sep 17 00:00:00 2001 From: ngayraud Date: Mon, 12 Aug 2019 15:49:25 -0400 Subject: Added Unbalaced transport to domain adaptation methods. Corrected small bug related to warnings in unbalaced.py . Added an error message when user wants to normalize with other than expected cost normalization functions. --- ot/da.py | 121 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ ot/unbalanced.py | 2 +- ot/utils.py | 5 ++- 3 files changed, 126 insertions(+), 2 deletions(-) diff --git a/ot/da.py b/ot/da.py index 83f9027..c1d9849 100644 --- a/ot/da.py +++ b/ot/da.py @@ -6,6 +6,7 @@ Domain adaptation with optimal transport # Author: Remi Flamary # Nicolas Courty # Michael Perrot +# Nathalie Gayraud # # License: MIT License @@ -16,6 +17,7 @@ from .bregman import sinkhorn from .lp import emd from .utils import unif, dist, kernel, cost_normalization from .utils import check_params, BaseEstimator +from .unbalanced import sinkhorn_unbalanced from .optim import cg from .optim import gcg @@ -1793,3 +1795,122 @@ class MappingTransport(BaseEstimator): transp_Xs = K.dot(self.mapping_) return transp_Xs + + +class UnbalancedSinkhornTransport(BaseTransport): + + """Domain Adapatation unbalanced OT method based on sinkhorn algorithm + + Parameters + ---------- + reg_e : float, optional (default=1) + Entropic regularization parameter + reg_m : float, optional (default=0.1) + Mass regularization parameter + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_epsilon_scaling', see those function for specific parameters + max_iter : int, float, optional (default=10) + The minimum number of iteration before stopping the optimization + algorithm if no it has not converged + tol : float, optional (default=10e-9) + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + log : bool, optional (default=False) + Controls the logs of the optimization algorithm + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + limit_max: float, optional (default=10) + Controls the semi supervised mode. Transport between labeled source + and target samples of different classes will exhibit an infinite cost + (10 times the maximum value of the cost matrix) + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True + + References + ---------- + + .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. + + """ + + def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn', + max_iter=10, tol=10e-9, verbose=False, log=False, + metric="sqeuclidean", norm=None, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans', limit_max=10): + + self.reg_e = reg_e + self.reg_m = reg_m + self.method = method + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.log = log + self.metric = metric + self.norm = norm + self.distribution_estimation = distribution_estimation + self.out_of_sample_map = out_of_sample_map + self.limit_max = limit_max + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt): + + super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt) + + returned_ = sinkhorn_unbalanced( + a=self.mu_s, b=self.mu_t, M=self.cost_, + reg=self.reg_e, alpha=self.reg_m, method=self.method, + numItermax=self.max_iter, stopThr=self.tol, + verbose=self.verbose, log=self.log) + + # deal with the value of log + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + + return self diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 467fda2..0f0692e 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -364,7 +364,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration %s' % cpt) u = uprev v = vprev break diff --git a/ot/utils.py b/ot/utils.py index 8419c83..be839f8 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -186,7 +186,10 @@ def cost_normalization(C, norm=None): C = np.log(1 + C) elif norm == "loglog": C = np.log1p(np.log1p(C)) - + else: + raise ValueError(f'Norm {norm} is not a valid option. ' + f'Valid options are:\n' + f'median, max, log, loglog') return C -- cgit v1.2.3