From af9f1e3cd75ec3453e53a4cd8940a5542fc9e117 Mon Sep 17 00:00:00 2001 From: Tardy Benjamin Date: Wed, 18 Apr 2018 11:34:46 +0200 Subject: BUG: EMDTransport parameter log unusable --- ot/da.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index c688654..a92c7f9 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1332,13 +1332,14 @@ class EMDTransport(BaseTransport): on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 """ - def __init__(self, metric="sqeuclidean", norm=None, + def __init__(self, metric="sqeuclidean", norm=None,log=False, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10, max_iter=100000): - + self.metric = metric self.norm = norm + self.log = log self.limit_max = limit_max self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map @@ -1371,11 +1372,16 @@ class EMDTransport(BaseTransport): super(EMDTransport, self).fit(Xs, ys, Xt, yt) - # coupling estimation - self.coupling_ = emd( - a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter - ) + returned_ = emd( + a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter, + log=self.log) + # coupling estimation + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() return self -- cgit v1.2.3