summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2017-08-30 15:47:16 +0200
committerGitHub <noreply@github.com>2017-08-30 15:47:16 +0200
commit16697047eff9326a0ecb483317c13a854a3d3a71 (patch)
treeb9a8659370286820563a1fd1a9ea09ed0a9003a3 /ot/da.py
parenta2ec6e55e458c719484e86a4e6a6e764c2e38dc8 (diff)
parentfadaf2ab3c3844d281b22f8d5c3404c3c4cf7d97 (diff)
Merge pull request #25 from aje/master
Add iter_max to lp solver and fixes #24
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py85
1 files changed, 45 insertions, 40 deletions
diff --git a/ot/da.py b/ot/da.py
index 78dc150..564c7b7 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -13,7 +13,7 @@ import numpy as np
from .bregman import sinkhorn
from .lp import emd
-from .utils import unif, dist, kernel
+from .utils import unif, dist, kernel, cost_normalization
from .utils import check_params, deprecated, BaseEstimator
from .optim import cg
from .optim import gcg
@@ -650,15 +650,16 @@ class OTDA(object):
"""
- def __init__(self, metric='sqeuclidean'):
+ def __init__(self, metric='sqeuclidean', norm=None):
""" Class initialization"""
self.xs = 0
self.xt = 0
self.G = 0
self.metric = metric
+ self.norm = norm
self.computed = False
- def fit(self, xs, xt, ws=None, wt=None, norm=None):
+ def fit(self, xs, xt, ws=None, wt=None, max_iter=100000):
"""Fit domain adaptation between samples is xs and xt
(with optional weights)"""
self.xs = xs
@@ -673,8 +674,8 @@ class OTDA(object):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.normalizeM(norm)
- self.G = emd(ws, wt, self.M)
+ self.M = cost_normalization(self.M, self.norm)
+ self.G = emd(ws, wt, self.M, max_iter)
self.computed = True
def interp(self, direction=1):
@@ -741,26 +742,6 @@ class OTDA(object):
# aply the delta to the interpolation
return xf[idx, :] + x - x0[idx, :]
- def normalizeM(self, norm):
- """ Apply normalization to the loss matrix
-
-
- Parameters
- ----------
- norm : str
- type of normalization from 'median','max','log','loglog'
-
- """
-
- if norm == "median":
- self.M /= float(np.median(self.M))
- elif norm == "max":
- self.M /= float(np.max(self.M))
- elif norm == "log":
- self.M = np.log(1 + self.M)
- elif norm == "loglog":
- self.M = np.log(1 + np.log(1 + self.M))
-
@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
" removed in 0.5 \nUse class SinkhornTransport instead.")
@@ -772,7 +753,7 @@ class OTDA_sinkhorn(OTDA):
"""
- def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
+ def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs):
"""Fit regularized domain adaptation between samples is xs and xt
(with optional weights)"""
self.xs = xs
@@ -787,7 +768,7 @@ class OTDA_sinkhorn(OTDA):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.normalizeM(norm)
+ self.M = cost_normalization(self.M, self.norm)
self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)
self.computed = True
@@ -799,8 +780,7 @@ class OTDA_lpl1(OTDA):
"""Class for domain adaptation with optimal transport with entropic and
group regularization"""
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
- **kwargs):
+ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
"""Fit regularized domain adaptation between samples is xs and xt
(with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
parameters"""
@@ -816,7 +796,7 @@ class OTDA_lpl1(OTDA):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.normalizeM(norm)
+ self.M = cost_normalization(self.M, self.norm)
self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)
self.computed = True
@@ -828,8 +808,7 @@ class OTDA_l1l2(OTDA):
"""Class for domain adaptation with optimal transport with entropic
and group lasso regularization"""
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
- **kwargs):
+ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
"""Fit regularized domain adaptation between samples is xs and xt
(with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
parameters"""
@@ -845,7 +824,7 @@ class OTDA_l1l2(OTDA):
self.wt = wt
self.M = dist(xs, xt, metric=self.metric)
- self.normalizeM(norm)
+ self.M = cost_normalization(self.M, self.norm)
self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
self.computed = True
@@ -1001,6 +980,7 @@ class BaseTransport(BaseEstimator):
# pairwise distance
self.cost_ = dist(Xs, Xt, metric=self.metric)
+ self.cost_ = cost_normalization(self.cost_, self.norm)
if (ys is not None) and (yt is not None):
@@ -1202,6 +1182,9 @@ class SinkhornTransport(BaseTransport):
be transported from a domain to another one.
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
distribution : string, optional (default="uniform")
The kind of distribution estimation to employ
verbose : int, optional (default=0)
@@ -1231,7 +1214,7 @@ class SinkhornTransport(BaseTransport):
def __init__(self, reg_e=1., max_iter=1000,
tol=10e-9, verbose=False, log=False,
- metric="sqeuclidean",
+ metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
@@ -1241,6 +1224,7 @@ class SinkhornTransport(BaseTransport):
self.verbose = verbose
self.log = log
self.metric = metric
+ self.norm = norm
self.limit_max = limit_max
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
@@ -1296,6 +1280,9 @@ class EMDTransport(BaseTransport):
be transported from a domain to another one.
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
distribution : string, optional (default="uniform")
The kind of distribution estimation to employ
verbose : int, optional (default=0)
@@ -1306,6 +1293,9 @@ class EMDTransport(BaseTransport):
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an infinite cost
(10 times the maximum value of the cost matrix)
+ max_iter : int, optional (default=100000)
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
Attributes
----------
@@ -1319,14 +1309,17 @@ class EMDTransport(BaseTransport):
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
"""
- def __init__(self, metric="sqeuclidean",
+ def __init__(self, metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
- out_of_sample_map='ferradans', limit_max=10):
+ out_of_sample_map='ferradans', limit_max=10,
+ max_iter=100000):
self.metric = metric
+ self.norm = norm
self.limit_max = limit_max
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
+ self.max_iter = max_iter
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
@@ -1353,7 +1346,7 @@ class EMDTransport(BaseTransport):
# coupling estimation
self.coupling_ = emd(
- a=self.mu_s, b=self.mu_t, M=self.cost_,
+ a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter
)
return self
@@ -1376,6 +1369,9 @@ class SinkhornLpl1Transport(BaseTransport):
be transported from a domain to another one.
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
distribution : string, optional (default="uniform")
The kind of distribution estimation to employ
max_iter : int, float, optional (default=10)
@@ -1410,7 +1406,7 @@ class SinkhornLpl1Transport(BaseTransport):
def __init__(self, reg_e=1., reg_cl=0.1,
max_iter=10, max_inner_iter=200,
tol=10e-9, verbose=False,
- metric="sqeuclidean",
+ metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
@@ -1421,6 +1417,7 @@ class SinkhornLpl1Transport(BaseTransport):
self.tol = tol
self.verbose = verbose
self.metric = metric
+ self.norm = norm
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
self.limit_max = limit_max
@@ -1477,6 +1474,9 @@ class SinkhornL1l2Transport(BaseTransport):
be transported from a domain to another one.
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
distribution : string, optional (default="uniform")
The kind of distribution estimation to employ
max_iter : int, float, optional (default=10)
@@ -1516,7 +1516,7 @@ class SinkhornL1l2Transport(BaseTransport):
def __init__(self, reg_e=1., reg_cl=0.1,
max_iter=10, max_inner_iter=200,
tol=10e-9, verbose=False, log=False,
- metric="sqeuclidean",
+ metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=10):
@@ -1528,6 +1528,7 @@ class SinkhornL1l2Transport(BaseTransport):
self.verbose = verbose
self.log = log
self.metric = metric
+ self.norm = norm
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
self.limit_max = limit_max
@@ -1588,6 +1589,9 @@ class MappingTransport(BaseEstimator):
Estimate linear mapping with constant bias
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
kernel : string, optional (default="linear")
The kernel to use either linear or gaussian
sigma : float, optional (default=1)
@@ -1627,11 +1631,12 @@ class MappingTransport(BaseEstimator):
"""
def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean",
- kernel="linear", sigma=1, max_iter=100, tol=1e-5,
+ norm=None, kernel="linear", sigma=1, max_iter=100, tol=1e-5,
max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False,
verbose2=False):
self.metric = metric
+ self.norm = norm
self.mu = mu
self.eta = eta
self.bias = bias