summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-20 10:39:13 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-20 10:39:13 +0200
commit07463285317ed5c989040edcefb5c0e8cd3ac034 (patch)
treee6b4d5d17495fe2704d10aedfa448259f2fec9a8 /ot/da.py
parent126903381374a1d2f934190b208d134a0495dc65 (diff)
added kwargs to sim + doc
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py48
1 files changed, 32 insertions, 16 deletions
diff --git a/ot/da.py b/ot/da.py
index 9e00dce..8e26e31 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
return A, b
-def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha,
+def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
numItermax, stopThr, numInnerItermax,
stopInnerThr, log=False, verbose=False, **kwargs):
r"""Solve the optimal transport problem (OT) with Laplacian regularization
@@ -803,7 +803,11 @@ def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha,
Print information along iterations
log : bool, optional
record log if True
-
+ kwargs : dict
+ Dictionary with attributes 'sim' ('knn' or 'gauss') and
+ 'param' (int, float or None) for similarity type and its parameter to be used.
+ If 'param' is None, it is computed as mean pairwise Euclidean distance over the data set
+ or set to 3 when sim is 'gauss' or 'knn', respectively.
Returns
-------
@@ -830,24 +834,28 @@ def emd_laplace(a, b, xs, xt, M, sim, reg, eta, alpha,
ot.optim.cg : General regularized OT
"""
- if sim == 'gauss':
- if 'rbfparam' not in kwargs:
- kwargs['rbfparam'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
- sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['rbfparam'])
- sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['rbfparam'])
+ if not isinstance(kwargs['param'], (int, float, type(None))):
+ raise ValueError(
+ 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(kwargs['param'])))
+
+ if kwargs['sim'] == 'gauss':
+ if kwargs['param'] is None:
+ kwargs['param'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
+ sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['param'])
+ sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['param'])
- elif sim == 'knn':
- if 'nn' not in kwargs:
- kwargs['nn'] = 5
+ elif kwargs['sim'] == 'knn':
+ if kwargs['param'] is None:
+ kwargs['param'] = 3
from sklearn.neighbors import kneighbors_graph
- sS = kneighbors_graph(xs, kwargs['nn']).toarray()
+ sS = kneighbors_graph(X=xs, n_neighbors=int(kwargs['param'])).toarray()
sS = (sS + sS.T) / 2
- sT = kneighbors_graph(xt, kwargs['nn']).toarray()
+ sT = kneighbors_graph(xt, n_neighbors=int(kwargs['param'])).toarray()
sT = (sT + sT.T) / 2
else:
- raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim))
+ raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=kwargs['sim']))
lS = laplacian(sS)
lT = laplacian(sT)
@@ -1721,6 +1729,9 @@ class EMDLaplaceTransport(BaseTransport):
can occur with large metric values.
similarity : string, optional (default="knn")
The similarity to use either knn or gaussian
+ similarity_param : int or float, optional (default=3)
+ Parameter for the similarity: number of nearest neighbors or bandwidth
+ if similarity="knn" or "gaussian", respectively.
max_iter : int, optional (default=100)
Max number of BCD iterations
tol : float, optional (default=1e-5)
@@ -1754,7 +1765,7 @@ class EMDLaplaceTransport(BaseTransport):
"""
def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., alpha=0.5,
- metric="sqeuclidean", norm=None, similarity="knn", max_iter=100, tol=1e-9,
+ metric="sqeuclidean", norm=None, similarity="knn", similarity_param=None, max_iter=100, tol=1e-9,
max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans'):
@@ -1765,6 +1776,7 @@ class EMDLaplaceTransport(BaseTransport):
self.metric = metric
self.norm = norm
self.similarity = similarity
+ self.sim_param = similarity_param
self.max_iter = max_iter
self.tol = tol
self.max_inner_iter = max_inner_iter
@@ -1801,10 +1813,14 @@ class EMDLaplaceTransport(BaseTransport):
super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt)
+ kwargs = dict()
+ kwargs["sim"] = self.similarity
+ kwargs["param"] = self.sim_param
+
returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_,
- xt=self.xt_, M=self.cost_, reg=self.reg, sim=self.similarity, eta=self.reg_lap, alpha=self.reg_src,
+ xt=self.xt_, M=self.cost_, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src,
numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter,
- stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose)
+ stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose, **kwargs)
# coupling estimation
if self.log: