summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-08-21 14:43:58 +0200
committerGitHub <noreply@github.com>2019-08-21 14:43:58 +0200
commitabfe183a49caaf74a07e595ac40920dae05a3c22 (patch)
tree4d5fe2d98d249a252cda18db29ff47edc472a2ab
parentb2157e9b3458388571f6ae87d80f47f500dfa166 (diff)
parentce86d1476b32771d32b7e55566e7cab45bb57b3a (diff)
Merge pull request #100 from ngayraud/add_unbalanced_da
[MRG] Adds Unbalaced transport to domain adaptation methods + bugfixes
-rw-r--r--.travis.yml5
-rw-r--r--ot/da.py121
-rw-r--r--ot/unbalanced.py2
-rw-r--r--ot/utils.py9
-rw-r--r--test/test_da.py65
5 files changed, 197 insertions, 5 deletions
diff --git a/.travis.yml b/.travis.yml
index 5e5694b..72fd29a 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -13,7 +13,7 @@ matrix:
python: 3.5
- os: linux
sudo: required
- python: 3.6
+ python: 3.6
- os: linux
sudo: required
python: 2.7
@@ -21,7 +21,6 @@ before_install:
- ./.travis/before_install.sh
before_script: # configure a headless display to test plot generation
- "export DISPLAY=:99.0"
- - "sh -e /etc/init.d/xvfb start"
- sleep 3 # give xvfb some time to start
# command to install dependencies
install:
@@ -30,6 +29,8 @@ install:
- pip install flake8 pytest "pytest-cov<2.6"
- pip install .
# command to run tests + check syntax style
+services:
+ - xvfb
script:
- python setup.py develop
- flake8 examples/ ot/ test/
diff --git a/ot/da.py b/ot/da.py
index 83f9027..2af855d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -6,6 +6,7 @@ Domain adaptation with optimal transport
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
# Michael Perrot <michael.perrot@univ-st-etienne.fr>
+# Nathalie Gayraud <nat.gayraud@gmail.com>
#
# License: MIT License
@@ -16,6 +17,7 @@ from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
from .utils import check_params, BaseEstimator
+from .unbalanced import sinkhorn_unbalanced
from .optim import cg
from .optim import gcg
@@ -1793,3 +1795,122 @@ class MappingTransport(BaseEstimator):
transp_Xs = K.dot(self.mapping_)
return transp_Xs
+
+
+class UnbalancedSinkhornTransport(BaseTransport):
+
+ """Domain Adapatation unbalanced OT method based on sinkhorn algorithm
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ reg_m : float, optional (default=0.1)
+ Mass regularization parameter
+ method : str
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ max_iter : int, float, optional (default=10)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ tol : float, optional (default=10e-9)
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+ log : bool, optional (default=False)
+ Controls the logs of the optimization algorithm
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ distribution_estimation : callable, optional (defaults to the uniform)
+ The kind of distribution estimation to employ
+ out_of_sample_map : string, optional (default="ferradans")
+ The kind of out of sample mapping to apply to transport samples
+ from a domain into another one. Currently the only possible option is
+ "ferradans" which uses the method proposed in [6].
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+
+ .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
+ """
+
+ def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
+ max_iter=10, tol=1e-9, verbose=False, log=False,
+ metric="sqeuclidean", norm=None,
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=10):
+
+ self.reg_e = reg_e
+ self.reg_m = reg_m
+ self.method = method
+ self.max_iter = max_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.metric = metric
+ self.norm = norm
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
+ super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt)
+
+ returned_ = sinkhorn_unbalanced(
+ a=self.mu_s, b=self.mu_t, M=self.cost_,
+ reg=self.reg_e, alpha=self.reg_m, method=self.method,
+ numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
+ return self
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 467fda2..0f0692e 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -364,7 +364,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %s' % cpt)
u = uprev
v = vprev
break
diff --git a/ot/utils.py b/ot/utils.py
index 8419c83..d4127e3 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -178,7 +178,9 @@ def cost_normalization(C, norm=None):
The input cost matrix normalized according to given norm.
"""
- if norm == "median":
+ if norm is None:
+ pass
+ elif norm == "median":
C /= float(np.median(C))
elif norm == "max":
C /= float(np.max(C))
@@ -186,7 +188,10 @@ def cost_normalization(C, norm=None):
C = np.log(1 + C)
elif norm == "loglog":
C = np.log1p(np.log1p(C))
-
+ else:
+ raise ValueError('Norm %s is not a valid option.\n'
+ 'Valid options are:\n'
+ 'median, max, log, loglog' % norm)
return C
diff --git a/test/test_da.py b/test/test_da.py
index f7f3a9d..2a5e50e 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -245,6 +245,71 @@ def test_sinkhorn_transport_class():
assert len(otda.log_.keys()) != 0
+def test_unbalanced_sinkhorn_transport_class():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.UnbalancedSinkhornTransport()
+
+ # test its computed
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornTransport()
+ otda_unsup.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
+
+ otda_semi = ot.da.SinkhornTransport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+
+ # check that the cost matrix norms are indeed different
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
+ # check everything runs well with log=True
+ otda = ot.da.SinkhornTransport(log=True)
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
+
+
def test_emd_transport_class():
"""test_sinkhorn_transport
"""