diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 15 |
1 files changed, 7 insertions, 8 deletions
@@ -748,9 +748,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, return A, b -def emd_laplace(a, b, xs, xt, M, sim, sim_param, reg, eta, alpha, - numItermax, stopThr, numInnerItermax, - stopInnerThr, log=False, verbose=False, **kwargs): +def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5, + numItermax=100, stopThr=1e-9, numInnerItermax=100000, + stopInnerThr=1e-9, log=False, verbose=False): r"""Solve the optimal transport problem (OT) with Laplacian regularization .. math:: @@ -1765,15 +1765,14 @@ class EMDLaplaceTransport(BaseTransport): in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. """ - def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., alpha=0.5, - metric="sqeuclidean", norm=None, similarity="knn", similarity_param=None, max_iter=100, tol=1e-9, + def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean", + norm=None, similarity="knn", similarity_param=None, max_iter=100, tol=1e-9, max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans'): self.reg = reg_type self.reg_lap = reg_lap self.reg_src = reg_src - self.alpha = alpha self.metric = metric self.norm = norm self.similarity = similarity @@ -1815,8 +1814,8 @@ class EMDLaplaceTransport(BaseTransport): super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt) returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_, - xt=self.xt_, M=self.cost_, sim=self.similarity, sim_param=self.sim_param, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src, - numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter, + xt=self.xt_, M=self.cost_, sim=self.similarity, sim_param=self.sim_param, reg=self.reg, eta=self.reg_lap, + alpha=self.reg_src, numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter, stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose) # coupling estimation |