summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorngayraud <nat.gayraud@gmail.com>2019-08-12 15:49:25 -0400
committerngayraud <nat.gayraud@gmail.com>2019-08-12 15:49:25 -0400
commit092866815cf906012f9194b87af1e7ae0270f7e7 (patch)
treeb98dfb4a0ac9165d22f8276f6f3a481e8b316f97
parentb2157e9b3458388571f6ae87d80f47f500dfa166 (diff)
Added Unbalaced transport to domain adaptation methods. Corrected small bug related to warnings in unbalaced.py . Added an error message when user wants to normalize with other than expected cost normalization functions.
-rw-r--r--ot/da.py121
-rw-r--r--ot/unbalanced.py2
-rw-r--r--ot/utils.py5
3 files changed, 126 insertions, 2 deletions
diff --git a/ot/da.py b/ot/da.py
index 83f9027..c1d9849 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -6,6 +6,7 @@ Domain adaptation with optimal transport
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
# Michael Perrot <michael.perrot@univ-st-etienne.fr>
+# Nathalie Gayraud <nat.gayraud@gmail.com>
#
# License: MIT License
@@ -16,6 +17,7 @@ from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
from .utils import check_params, BaseEstimator
+from .unbalanced import sinkhorn_unbalanced
from .optim import cg
from .optim import gcg
@@ -1793,3 +1795,122 @@ class MappingTransport(BaseEstimator):
transp_Xs = K.dot(self.mapping_)
return transp_Xs
+
+
+class UnbalancedSinkhornTransport(BaseTransport):
+
+ """Domain Adapatation unbalanced OT method based on sinkhorn algorithm
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ reg_m : float, optional (default=0.1)
+ Mass regularization parameter
+ method : str
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ max_iter : int, float, optional (default=10)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ tol : float, optional (default=10e-9)
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+ log : bool, optional (default=False)
+ Controls the logs of the optimization algorithm
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ distribution_estimation : callable, optional (defaults to the uniform)
+ The kind of distribution estimation to employ
+ out_of_sample_map : string, optional (default="ferradans")
+ The kind of out of sample mapping to apply to transport samples
+ from a domain into another one. Currently the only possible option is
+ "ferradans" which uses the method proposed in [6].
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+
+ .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
+ """
+
+ def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
+ max_iter=10, tol=10e-9, verbose=False, log=False,
+ metric="sqeuclidean", norm=None,
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=10):
+
+ self.reg_e = reg_e
+ self.reg_m = reg_m
+ self.method = method
+ self.max_iter = max_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.metric = metric
+ self.norm = norm
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
+ super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt)
+
+ returned_ = sinkhorn_unbalanced(
+ a=self.mu_s, b=self.mu_t, M=self.cost_,
+ reg=self.reg_e, alpha=self.reg_m, method=self.method,
+ numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
+ return self
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 467fda2..0f0692e 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -364,7 +364,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %s' % cpt)
u = uprev
v = vprev
break
diff --git a/ot/utils.py b/ot/utils.py
index 8419c83..be839f8 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -186,7 +186,10 @@ def cost_normalization(C, norm=None):
C = np.log(1 + C)
elif norm == "loglog":
C = np.log1p(np.log1p(C))
-
+ else:
+ raise ValueError(f'Norm {norm} is not a valid option. '
+ f'Valid options are:\n'
+ f'median, max, log, loglog')
return C