From 092866815cf906012f9194b87af1e7ae0270f7e7 Mon Sep 17 00:00:00 2001 From: ngayraud Date: Mon, 12 Aug 2019 15:49:25 -0400 Subject: Added Unbalaced transport to domain adaptation methods. Corrected small bug related to warnings in unbalaced.py . Added an error message when user wants to normalize with other than expected cost normalization functions. --- ot/da.py | 121 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ ot/unbalanced.py | 2 +- ot/utils.py | 5 ++- 3 files changed, 126 insertions(+), 2 deletions(-) diff --git a/ot/da.py b/ot/da.py index 83f9027..c1d9849 100644 --- a/ot/da.py +++ b/ot/da.py @@ -6,6 +6,7 @@ Domain adaptation with optimal transport # Author: Remi Flamary # Nicolas Courty # Michael Perrot +# Nathalie Gayraud # # License: MIT License @@ -16,6 +17,7 @@ from .bregman import sinkhorn from .lp import emd from .utils import unif, dist, kernel, cost_normalization from .utils import check_params, BaseEstimator +from .unbalanced import sinkhorn_unbalanced from .optim import cg from .optim import gcg @@ -1793,3 +1795,122 @@ class MappingTransport(BaseEstimator): transp_Xs = K.dot(self.mapping_) return transp_Xs + + +class UnbalancedSinkhornTransport(BaseTransport): + + """Domain Adapatation unbalanced OT method based on sinkhorn algorithm + + Parameters + ---------- + reg_e : float, optional (default=1) + Entropic regularization parameter + reg_m : float, optional (default=0.1) + Mass regularization parameter + method : str + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or + 'sinkhorn_epsilon_scaling', see those function for specific parameters + max_iter : int, float, optional (default=10) + The minimum number of iteration before stopping the optimization + algorithm if no it has not converged + tol : float, optional (default=10e-9) + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + log : bool, optional (default=False) + Controls the logs of the optimization algorithm + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + limit_max: float, optional (default=10) + Controls the semi supervised mode. Transport between labeled source + and target samples of different classes will exhibit an infinite cost + (10 times the maximum value of the cost matrix) + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True + + References + ---------- + + .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. + + """ + + def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn', + max_iter=10, tol=10e-9, verbose=False, log=False, + metric="sqeuclidean", norm=None, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans', limit_max=10): + + self.reg_e = reg_e + self.reg_m = reg_m + self.method = method + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.log = log + self.metric = metric + self.norm = norm + self.distribution_estimation = distribution_estimation + self.out_of_sample_map = out_of_sample_map + self.limit_max = limit_max + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt): + + super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt) + + returned_ = sinkhorn_unbalanced( + a=self.mu_s, b=self.mu_t, M=self.cost_, + reg=self.reg_e, alpha=self.reg_m, method=self.method, + numItermax=self.max_iter, stopThr=self.tol, + verbose=self.verbose, log=self.log) + + # deal with the value of log + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + + return self diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 467fda2..0f0692e 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -364,7 +364,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000, or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration %s' % cpt) u = uprev v = vprev break diff --git a/ot/utils.py b/ot/utils.py index 8419c83..be839f8 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -186,7 +186,10 @@ def cost_normalization(C, norm=None): C = np.log(1 + C) elif norm == "loglog": C = np.log1p(np.log1p(C)) - + else: + raise ValueError(f'Norm {norm} is not a valid option. ' + f'Valid options are:\n' + f'median, max, log, loglog') return C -- cgit v1.2.3 From 9d4b786a036ac95989825beec819521089fb4feb Mon Sep 17 00:00:00 2001 From: ngayraud Date: Mon, 12 Aug 2019 16:37:58 -0400 Subject: fixes for travis, added test, minor nits --- .travis.yml | 5 ++-- ot/da.py | 2 +- ot/utils.py | 4 +++- test/test_da.py | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 5e5694b..72fd29a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,7 +13,7 @@ matrix: python: 3.5 - os: linux sudo: required - python: 3.6 + python: 3.6 - os: linux sudo: required python: 2.7 @@ -21,7 +21,6 @@ before_install: - ./.travis/before_install.sh before_script: # configure a headless display to test plot generation - "export DISPLAY=:99.0" - - "sh -e /etc/init.d/xvfb start" - sleep 3 # give xvfb some time to start # command to install dependencies install: @@ -30,6 +29,8 @@ install: - pip install flake8 pytest "pytest-cov<2.6" - pip install . # command to run tests + check syntax style +services: + - xvfb script: - python setup.py develop - flake8 examples/ ot/ test/ diff --git a/ot/da.py b/ot/da.py index c1d9849..2af855d 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1852,7 +1852,7 @@ class UnbalancedSinkhornTransport(BaseTransport): """ def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn', - max_iter=10, tol=10e-9, verbose=False, log=False, + max_iter=10, tol=1e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10): diff --git a/ot/utils.py b/ot/utils.py index be839f8..a334fea 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -178,7 +178,9 @@ def cost_normalization(C, norm=None): The input cost matrix normalized according to given norm. """ - if norm == "median": + if norm is None: + pass + elif norm == "median": C /= float(np.median(C)) elif norm == "max": C /= float(np.max(C)) diff --git a/test/test_da.py b/test/test_da.py index f7f3a9d..9efd2d9 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -245,6 +245,79 @@ def test_sinkhorn_transport_class(): assert len(otda.log_.keys()) != 0 +def test_unbalanced_sinkhorn_transport_class(): + """test_sinkhorn_transport + """ + + ns = 150 + nt = 200 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + otda = ot.da.UnbalancedSinkhornTransport() + + # test its computed + otda.fit(Xs=Xs, Xt=Xt) + assert hasattr(otda, "cost_") + assert hasattr(otda, "coupling_") + assert hasattr(otda, "log_") + + # test dimensions of coupling + assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) + assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + + # test margin constraints + mu_s = unif(ns) + mu_t = unif(nt) + assert_allclose( + np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + + # test transform + transp_Xs = otda.transform(Xs=Xs) + assert_equal(transp_Xs.shape, Xs.shape) + + Xs_new, _ = make_data_classif('3gauss', ns + 1) + transp_Xs_new = otda.transform(Xs_new) + + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) + + # test inverse transform + transp_Xt = otda.inverse_transform(Xt=Xt) + assert_equal(transp_Xt.shape, Xt.shape) + + Xt_new, _ = make_data_classif('3gauss2', nt + 1) + transp_Xt_new = otda.inverse_transform(Xt=Xt_new) + + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) + + # test fit_transform + transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) + assert_equal(transp_Xs.shape, Xs.shape) + + # test unsupervised vs semi-supervised mode + otda_unsup = ot.da.SinkhornTransport() + otda_unsup.fit(Xs=Xs, Xt=Xt) + n_unsup = np.sum(otda_unsup.cost_) + + otda_semi = ot.da.SinkhornTransport() + otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) + assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) + n_semisup = np.sum(otda_semi.cost_) + + # check that the cost matrix norms are indeed different + assert n_unsup != n_semisup, "semisupervised mode not working" + + # check everything runs well with log=True + otda = ot.da.SinkhornTransport(log=True) + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + assert len(otda.log_.keys()) != 0 + + def test_emd_transport_class(): """test_sinkhorn_transport """ -- cgit v1.2.3 From b536be73326e20fd3959ba4fe28cc45a344f47d3 Mon Sep 17 00:00:00 2001 From: ngayraud Date: Mon, 12 Aug 2019 16:51:51 -0400 Subject: Attempting to fix docstyle --- ot/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/utils.py b/ot/utils.py index a334fea..b0d95f9 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -189,7 +189,7 @@ def cost_normalization(C, norm=None): elif norm == "loglog": C = np.log1p(np.log1p(C)) else: - raise ValueError(f'Norm {norm} is not a valid option. ' + raise ValueError(f'Norm {norm} is not a valid option.\n' f'Valid options are:\n' f'median, max, log, loglog') return C -- cgit v1.2.3 From 2633116175a09c468d953489c3fc7bab6aa69057 Mon Sep 17 00:00:00 2001 From: ngayraud Date: Mon, 12 Aug 2019 17:01:14 -0400 Subject: Attempting to fix docstyle --- ot/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ot/utils.py b/ot/utils.py index b0d95f9..d4127e3 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -189,9 +189,9 @@ def cost_normalization(C, norm=None): elif norm == "loglog": C = np.log1p(np.log1p(C)) else: - raise ValueError(f'Norm {norm} is not a valid option.\n' - f'Valid options are:\n' - f'median, max, log, loglog') + raise ValueError('Norm %s is not a valid option.\n' + 'Valid options are:\n' + 'median, max, log, loglog' % norm) return C -- cgit v1.2.3 From ce86d1476b32771d32b7e55566e7cab45bb57b3a Mon Sep 17 00:00:00 2001 From: ngayraud Date: Mon, 12 Aug 2019 17:03:08 -0400 Subject: Fix in test: no margin constraints here --- test/test_da.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/test_da.py b/test/test_da.py index 9efd2d9..2a5e50e 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -267,14 +267,6 @@ def test_unbalanced_sinkhorn_transport_class(): assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - # test margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) -- cgit v1.2.3