summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorTardy Benjamin <tardybenjamin4@gmail.com>2018-04-18 13:26:29 +0200
committerTardy Benjamin <tardybenjamin4@gmail.com>2018-04-18 13:26:29 +0200
commit94b929b37623e45a61b998f031ad7349943a53d8 (patch)
tree933e6e2fc1d5e00a68251fb011b0f46980571a2e /ot/da.py
parent1b5112c22a143e1daec317f27633548a194a32ef (diff)
BUG: Parameter log unusable in sinkhorn classes
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/ot/da.py b/ot/da.py
index aed7006..33480f9 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1310,7 +1310,7 @@ class EMDTransport(BaseTransport):
The kind of distribution estimation to employ
verbose : int, optional (default=0)
Controls the verbosity of the optimization algorithm
- log : bool, optional (default=False)
+ log : int, optional (default=0)
Controls the logs of the optimization algorithm
limit_max: float, optional (default=10)
Controls the semi supervised mode. Transport between labeled source
@@ -1438,7 +1438,7 @@ class SinkhornLpl1Transport(BaseTransport):
"""
def __init__(self, reg_e=1., reg_cl=0.1,
- max_iter=10, max_inner_iter=200,
+ max_iter=10, max_inner_iter=200, lo=False,
tol=10e-9, verbose=False,
metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
@@ -1449,6 +1449,7 @@ class SinkhornLpl1Transport(BaseTransport):
self.max_iter = max_iter
self.max_inner_iter = max_inner_iter
self.tol = tol
+ self.log = log
self.verbose = verbose
self.metric = metric
self.norm = norm
@@ -1486,12 +1487,18 @@ class SinkhornLpl1Transport(BaseTransport):
super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt)
- self.coupling_ = sinkhorn_lpl1_mm(
+ returned_ = sinkhorn_lpl1_mm(
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
- verbose=self.verbose)
+ verbose=self.verbose,log=self.log)
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
return self