summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-20 14:04:32 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-20 14:04:32 +0200
commit36b2e92d9ad5148208cc1bec9bc9133999bcdb1c (patch)
tree87424b5f23cc67c4f1f241c3d350803a9f89ce10 /ot
parentfd115a538deb8fa9dcf3169fcfa6b85aebd36b07 (diff)
added defaults for emd_laplace
Diffstat (limited to 'ot')
-rw-r--r--ot/da.py15
1 files changed, 7 insertions, 8 deletions
diff --git a/ot/da.py b/ot/da.py
index 300af30..e615993 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -748,9 +748,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
return A, b
-def emd_laplace(a, b, xs, xt, M, sim, sim_param, reg, eta, alpha,
- numItermax, stopThr, numInnerItermax,
- stopInnerThr, log=False, verbose=False, **kwargs):
+def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5,
+ numItermax=100, stopThr=1e-9, numInnerItermax=100000,
+ stopInnerThr=1e-9, log=False, verbose=False):
r"""Solve the optimal transport problem (OT) with Laplacian regularization
.. math::
@@ -1765,15 +1765,14 @@ class EMDLaplaceTransport(BaseTransport):
in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
"""
- def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., alpha=0.5,
- metric="sqeuclidean", norm=None, similarity="knn", similarity_param=None, max_iter=100, tol=1e-9,
+ def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean",
+ norm=None, similarity="knn", similarity_param=None, max_iter=100, tol=1e-9,
max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans'):
self.reg = reg_type
self.reg_lap = reg_lap
self.reg_src = reg_src
- self.alpha = alpha
self.metric = metric
self.norm = norm
self.similarity = similarity
@@ -1815,8 +1814,8 @@ class EMDLaplaceTransport(BaseTransport):
super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt)
returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_,
- xt=self.xt_, M=self.cost_, sim=self.similarity, sim_param=self.sim_param, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src,
- numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter,
+ xt=self.xt_, M=self.cost_, sim=self.similarity, sim_param=self.sim_param, reg=self.reg, eta=self.reg_lap,
+ alpha=self.reg_src, numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter,
stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose)
# coupling estimation