summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-08-21 14:43:58 +0200
committerGitHub <noreply@github.com>2019-08-21 14:43:58 +0200
commitabfe183a49caaf74a07e595ac40920dae05a3c22 (patch)
tree4d5fe2d98d249a252cda18db29ff47edc472a2ab /ot
parentb2157e9b3458388571f6ae87d80f47f500dfa166 (diff)
parentce86d1476b32771d32b7e55566e7cab45bb57b3a (diff)
Merge pull request #100 from ngayraud/add_unbalanced_da
[MRG] Adds Unbalaced transport to domain adaptation methods + bugfixes
Diffstat (limited to 'ot')
-rw-r--r--ot/da.py121
-rw-r--r--ot/unbalanced.py2
-rw-r--r--ot/utils.py9
3 files changed, 129 insertions, 3 deletions
diff --git a/ot/da.py b/ot/da.py
index 83f9027..2af855d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -6,6 +6,7 @@ Domain adaptation with optimal transport
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
# Michael Perrot <michael.perrot@univ-st-etienne.fr>
+# Nathalie Gayraud <nat.gayraud@gmail.com>
#
# License: MIT License
@@ -16,6 +17,7 @@ from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
from .utils import check_params, BaseEstimator
+from .unbalanced import sinkhorn_unbalanced
from .optim import cg
from .optim import gcg
@@ -1793,3 +1795,122 @@ class MappingTransport(BaseEstimator):
transp_Xs = K.dot(self.mapping_)
return transp_Xs
+
+
+class UnbalancedSinkhornTransport(BaseTransport):
+
+ """Domain Adapatation unbalanced OT method based on sinkhorn algorithm
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ reg_m : float, optional (default=0.1)
+ Mass regularization parameter
+ method : str
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ max_iter : int, float, optional (default=10)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ tol : float, optional (default=10e-9)
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+ log : bool, optional (default=False)
+ Controls the logs of the optimization algorithm
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ distribution_estimation : callable, optional (defaults to the uniform)
+ The kind of distribution estimation to employ
+ out_of_sample_map : string, optional (default="ferradans")
+ The kind of out of sample mapping to apply to transport samples
+ from a domain into another one. Currently the only possible option is
+ "ferradans" which uses the method proposed in [6].
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+
+ .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
+ """
+
+ def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
+ max_iter=10, tol=1e-9, verbose=False, log=False,
+ metric="sqeuclidean", norm=None,
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=10):
+
+ self.reg_e = reg_e
+ self.reg_m = reg_m
+ self.method = method
+ self.max_iter = max_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.metric = metric
+ self.norm = norm
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
+ super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt)
+
+ returned_ = sinkhorn_unbalanced(
+ a=self.mu_s, b=self.mu_t, M=self.cost_,
+ reg=self.reg_e, alpha=self.reg_m, method=self.method,
+ numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
+ return self
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 467fda2..0f0692e 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -364,7 +364,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %s' % cpt)
u = uprev
v = vprev
break
diff --git a/ot/utils.py b/ot/utils.py
index 8419c83..d4127e3 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -178,7 +178,9 @@ def cost_normalization(C, norm=None):
The input cost matrix normalized according to given norm.
"""
- if norm == "median":
+ if norm is None:
+ pass
+ elif norm == "median":
C /= float(np.median(C))
elif norm == "max":
C /= float(np.max(C))
@@ -186,7 +188,10 @@ def cost_normalization(C, norm=None):
C = np.log(1 + C)
elif norm == "loglog":
C = np.log1p(np.log1p(C))
-
+ else:
+ raise ValueError('Norm %s is not a valid option.\n'
+ 'Valid options are:\n'
+ 'median, max, log, loglog' % norm)
return C