summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authoraje <leo_g_autheron@hotmail.fr>2017-08-30 09:56:37 +0200
committeraje <leo_g_autheron@hotmail.fr>2017-08-30 09:56:37 +0200
commit982f36cb0a5f3a6a14454238a26053de7251b0f0 (patch)
treee7ad62954515430a418b9825829f7d563bf469ff /ot/da.py
parent308ce24b705dfad9d058138d058da8b18002e081 (diff)
Changes:
- Rename numItermax to max_iter - Default value to 100000 instead of 10000 - Add max_iter to class SinkhornTransport(BaseTransport) - Add norm to all BaseTransport
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py64
1 files changed, 55 insertions, 9 deletions
diff --git a/ot/da.py b/ot/da.py
index 0dfd02f..5871aba 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -658,7 +658,7 @@ class OTDA(object):
self.metric = metric
self.computed = False
- def fit(self, xs, xt, ws=None, wt=None, norm=None, numItermax=10000):
+ def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000):
"""Fit domain adaptation between samples is xs and xt
(with optional weights)"""
self.xs = xs
@@ -674,7 +674,7 @@ class OTDA(object):
self.M = dist(xs, xt, metric=self.metric)
self.normalizeM(norm)
- self.G = emd(ws, wt, self.M, numItermax)
+ self.G = emd(ws, wt, self.M, max_iter)
self.computed = True
def interp(self, direction=1):
@@ -1001,6 +1001,7 @@ class BaseTransport(BaseEstimator):
# pairwise distance
self.cost_ = dist(Xs, Xt, metric=self.metric)
+ self.normalizeCost_(self.norm)
if (ys is not None) and (yt is not None):
@@ -1182,6 +1183,26 @@ class BaseTransport(BaseEstimator):
return transp_Xt
+ def normalizeCost_(self, norm):
+ """ Apply normalization to the loss matrix
+
+
+ Parameters
+ ----------
+ norm : str
+ type of normalization from 'median','max','log','loglog'
+
+ """
+
+ if norm == "median":
+ self.cost_ /= float(np.median(self.cost_))
+ elif norm == "max":
+ self.cost_ /= float(np.max(self.cost_))
+ elif norm == "log":
+ self.cost_ = np.log(1 + self.cost_)
+ elif norm == "loglog":
+ self.cost_ = np.log(1 + np.log(1 + self.cost_))
+
class SinkhornTransport(BaseTransport):
"""Domain Adapatation OT method based on Sinkhorn Algorithm
@@ -1202,6 +1223,9 @@ class SinkhornTransport(BaseTransport):
be transported from a domain to another one.
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
distribution : string, optional (default="uniform")
The kind of distribution estimation to employ
verbose : int, optional (default=0)
@@ -1231,7 +1255,7 @@ class SinkhornTransport(BaseTransport):
def __init__(self, reg_e=1., max_iter=1000,
tol=10e-9, verbose=False, log=False,
- metric="sqeuclidean",
+ metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
@@ -1241,6 +1265,7 @@ class SinkhornTransport(BaseTransport):
self.verbose = verbose
self.log = log
self.metric = metric
+ self.norm = norm
self.limit_max = limit_max
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
@@ -1296,6 +1321,9 @@ class EMDTransport(BaseTransport):
be transported from a domain to another one.
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
distribution : string, optional (default="uniform")
The kind of distribution estimation to employ
verbose : int, optional (default=0)
@@ -1306,6 +1334,9 @@ class EMDTransport(BaseTransport):
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an infinite cost
(10 times the maximum value of the cost matrix)
+ max_iter : int, optional (default=100000)
+ The maximum number of iterations before stopping the optimization
+ algorithm if it has not converged.
Attributes
----------
@@ -1319,14 +1350,17 @@ class EMDTransport(BaseTransport):
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
"""
- def __init__(self, metric="sqeuclidean",
+ def __init__(self, metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
- out_of_sample_map='ferradans', limit_max=10):
+ out_of_sample_map='ferradans', limit_max=10,
+ max_iter=100000):
self.metric = metric
+ self.norm = norm
self.limit_max = limit_max
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
+ self.max_iter = max_iter
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
@@ -1353,7 +1387,7 @@ class EMDTransport(BaseTransport):
# coupling estimation
self.coupling_ = emd(
- a=self.mu_s, b=self.mu_t, M=self.cost_,
+ a=self.mu_s, b=self.mu_t, M=self.cost_, max_iter=self.max_iter
)
return self
@@ -1376,6 +1410,9 @@ class SinkhornLpl1Transport(BaseTransport):
be transported from a domain to another one.
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
distribution : string, optional (default="uniform")
The kind of distribution estimation to employ
max_iter : int, float, optional (default=10)
@@ -1410,7 +1447,7 @@ class SinkhornLpl1Transport(BaseTransport):
def __init__(self, reg_e=1., reg_cl=0.1,
max_iter=10, max_inner_iter=200,
tol=10e-9, verbose=False,
- metric="sqeuclidean",
+ metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
@@ -1421,6 +1458,7 @@ class SinkhornLpl1Transport(BaseTransport):
self.tol = tol
self.verbose = verbose
self.metric = metric
+ self.norm = norm
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
self.limit_max = limit_max
@@ -1477,6 +1515,9 @@ class SinkhornL1l2Transport(BaseTransport):
be transported from a domain to another one.
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
distribution : string, optional (default="uniform")
The kind of distribution estimation to employ
max_iter : int, float, optional (default=10)
@@ -1516,7 +1557,7 @@ class SinkhornL1l2Transport(BaseTransport):
def __init__(self, reg_e=1., reg_cl=0.1,
max_iter=10, max_inner_iter=200,
tol=10e-9, verbose=False, log=False,
- metric="sqeuclidean",
+ metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=10):
@@ -1528,6 +1569,7 @@ class SinkhornL1l2Transport(BaseTransport):
self.verbose = verbose
self.log = log
self.metric = metric
+ self.norm = norm
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
self.limit_max = limit_max
@@ -1588,6 +1630,9 @@ class MappingTransport(BaseEstimator):
Estimate linear mapping with constant bias
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
kernel : string, optional (default="linear")
The kernel to use either linear or gaussian
sigma : float, optional (default=1)
@@ -1627,11 +1672,12 @@ class MappingTransport(BaseEstimator):
"""
def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean",
- kernel="linear", sigma=1, max_iter=100, tol=1e-5,
+ norm=None, kernel="linear", sigma=1, max_iter=100, tol=1e-5,
max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False,
verbose2=False):
self.metric = metric
+ self.norm = norm
self.mu = mu
self.eta = eta
self.bias = bias