From af9f1e3cd75ec3453e53a4cd8940a5542fc9e117 Mon Sep 17 00:00:00 2001 From: Tardy Benjamin Date: Wed, 18 Apr 2018 11:34:46 +0200 Subject: BUG: EMDTransport parameter log unusable --- ot/da.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ot/da.py b/ot/da.py index c688654..a92c7f9 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1332,13 +1332,14 @@ class EMDTransport(BaseTransport): on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 """ - def __init__(self, metric="sqeuclidean", norm=None, + def __init__(self, metric="sqeuclidean", norm=None,log=False, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10, max_iter=100000): - + self.metric = metric self.norm = norm + self.log = log self.limit_max = limit_max self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map @@ -1371,11 +1372,16 @@ class EMDTransport(BaseTransport): super(EMDTransport, self).fit(Xs, ys, Xt, yt) - # coupling estimation - self.coupling_ = emd( - a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter - ) + returned_ = emd( + a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter, + log=self.log) + # coupling estimation + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() return self -- cgit v1.2.3 From 1b5112c22a143e1daec317f27633548a194a32ef Mon Sep 17 00:00:00 2001 From: Tardy Benjamin Date: Wed, 18 Apr 2018 11:36:13 +0200 Subject: ENH: Change the parameter type to bool in class EMDTransport documentation --- ot/da.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/da.py b/ot/da.py index a92c7f9..aed7006 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1310,7 +1310,7 @@ class EMDTransport(BaseTransport): The kind of distribution estimation to employ verbose : int, optional (default=0) Controls the verbosity of the optimization algorithm - log : int, optional (default=0) + log : bool, optional (default=False) Controls the logs of the optimization algorithm limit_max: float, optional (default=10) Controls the semi supervised mode. Transport between labeled source -- cgit v1.2.3 From 94b929b37623e45a61b998f031ad7349943a53d8 Mon Sep 17 00:00:00 2001 From: Tardy Benjamin Date: Wed, 18 Apr 2018 13:26:29 +0200 Subject: BUG: Parameter log unusable in sinkhorn classes --- ot/da.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ot/da.py b/ot/da.py index aed7006..33480f9 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1310,7 +1310,7 @@ class EMDTransport(BaseTransport): The kind of distribution estimation to employ verbose : int, optional (default=0) Controls the verbosity of the optimization algorithm - log : bool, optional (default=False) + log : int, optional (default=0) Controls the logs of the optimization algorithm limit_max: float, optional (default=10) Controls the semi supervised mode. Transport between labeled source @@ -1438,7 +1438,7 @@ class SinkhornLpl1Transport(BaseTransport): """ def __init__(self, reg_e=1., reg_cl=0.1, - max_iter=10, max_inner_iter=200, + max_iter=10, max_inner_iter=200, lo=False, tol=10e-9, verbose=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, @@ -1449,6 +1449,7 @@ class SinkhornLpl1Transport(BaseTransport): self.max_iter = max_iter self.max_inner_iter = max_inner_iter self.tol = tol + self.log = log self.verbose = verbose self.metric = metric self.norm = norm @@ -1486,12 +1487,18 @@ class SinkhornLpl1Transport(BaseTransport): super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt) - self.coupling_ = sinkhorn_lpl1_mm( + returned_ = sinkhorn_lpl1_mm( a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_, reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter, numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol, - verbose=self.verbose) + verbose=self.verbose,log=self.log) + # deal with the value of log + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() return self -- cgit v1.2.3 From b6687b5a2963f217d11f75e93c408b390a5dc53d Mon Sep 17 00:00:00 2001 From: Tardy Benjamin Date: Wed, 18 Apr 2018 13:31:49 +0200 Subject: BUG: typo error lo->log --- ot/da.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/da.py b/ot/da.py index 33480f9..b3dd61b 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1438,7 +1438,7 @@ class SinkhornLpl1Transport(BaseTransport): """ def __init__(self, reg_e=1., reg_cl=0.1, - max_iter=10, max_inner_iter=200, lo=False, + max_iter=10, max_inner_iter=200, log=False, tol=10e-9, verbose=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, -- cgit v1.2.3 From 54e16a40c5980290c03b654696db51e0aee583bb Mon Sep 17 00:00:00 2001 From: Tardy Benjamin Date: Wed, 18 Apr 2018 13:49:02 +0200 Subject: BUG: correct typo problems --- ot/da.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ot/da.py b/ot/da.py index b3dd61b..532dcd2 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1332,11 +1332,11 @@ class EMDTransport(BaseTransport): on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 """ - def __init__(self, metric="sqeuclidean", norm=None,log=False, + def __init__(self, metric="sqeuclidean", norm=None, log=False, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10, max_iter=100000): - + self.metric = metric self.norm = norm self.log = log @@ -1491,7 +1491,7 @@ class SinkhornLpl1Transport(BaseTransport): a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_, reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter, numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol, - verbose=self.verbose,log=self.log) + verbose=self.verbose, log=self.log) # deal with the value of log if self.log: -- cgit v1.2.3