summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorSlasnista <stan.chambon@gmail.com>2017-08-25 10:29:41 +0200
committerSlasnista <stan.chambon@gmail.com>2017-08-25 10:29:41 +0200
commit2d4d0b46f88c66ebc5502c840703ba6ce8910376 (patch)
tree5f0ead47c2d0bb769edca3e0f9e8f155ea73ffd0 /ot/da.py
parent09302239b3e4e1a90c1a4e2d7a85b0af86b01365 (diff)
solving log issues to avoid errors and adding further tests
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py57
1 files changed, 42 insertions, 15 deletions
diff --git a/ot/da.py b/ot/da.py
index 8fa1895..5a34979 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1315,7 +1315,10 @@ class SinkhornTransport(BaseTransport):
Attributes
----------
- coupling_ : the optimal coupling
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
References
----------
@@ -1367,11 +1370,18 @@ class SinkhornTransport(BaseTransport):
super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
# coupling estimation
- self.coupling_ = sinkhorn(
+ returned_ = sinkhorn(
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
numItermax=self.max_iter, stopThr=self.tol,
verbose=self.verbose, log=self.log)
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
return self
@@ -1400,7 +1410,8 @@ class EMDTransport(BaseTransport):
Attributes
----------
- coupling_ : the optimal coupling
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
References
----------
@@ -1475,15 +1486,14 @@ class SinkhornLpl1Transport(BaseTransport):
The number of iteration in the inner loop
verbose : int, optional (default=0)
Controls the verbosity of the optimization algorithm
- log : int, optional (default=0)
- Controls the logs of the optimization algorithm
limit_max: float, optional (defaul=np.infty)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an infinite cost
Attributes
----------
- coupling_ : the optimal coupling
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
References
----------
@@ -1500,7 +1510,7 @@ class SinkhornLpl1Transport(BaseTransport):
def __init__(self, reg_e=1., reg_cl=0.1,
max_iter=10, max_inner_iter=200,
- tol=10e-9, verbose=False, log=False,
+ tol=10e-9, verbose=False,
metric="sqeuclidean",
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
@@ -1511,7 +1521,6 @@ class SinkhornLpl1Transport(BaseTransport):
self.max_inner_iter = max_inner_iter
self.tol = tol
self.verbose = verbose
- self.log = log
self.metric = metric
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
@@ -1544,7 +1553,7 @@ class SinkhornLpl1Transport(BaseTransport):
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
- verbose=self.verbose, log=self.log)
+ verbose=self.verbose)
return self
@@ -1584,7 +1593,10 @@ class SinkhornL1l2Transport(BaseTransport):
Attributes
----------
- coupling_ : the optimal coupling
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
References
----------
@@ -1641,12 +1653,19 @@ class SinkhornL1l2Transport(BaseTransport):
super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt)
- self.coupling_ = sinkhorn_l1l2_gl(
+ returned_ = sinkhorn_l1l2_gl(
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
verbose=self.verbose, log=self.log)
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
return self
@@ -1683,14 +1702,15 @@ class MappingTransport(BaseEstimator):
Attributes
----------
- coupling_ : array-like, shape (n_source_samples, n_features)
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
mapping_ : array-like, shape (n_features (+ 1), n_features)
(if bias) for kernel == linear
The associated mapping
-
array-like, shape (n_source_samples (+ 1), n_features)
(if bias) for kernel == gaussian
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
References
----------
@@ -1745,19 +1765,26 @@ class MappingTransport(BaseEstimator):
self.Xt = Xt
if self.kernel == "linear":
- self.coupling_, self.mapping_ = joint_OT_mapping_linear(
+ returned_ = joint_OT_mapping_linear(
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
verbose=self.verbose, verbose2=self.verbose2,
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
stopThr=self.tol, stopInnerThr=self.inner_tol, log=self.log)
elif self.kernel == "gaussian":
- self.coupling_, self.mapping_ = joint_OT_mapping_kernel(
+ returned_ = joint_OT_mapping_kernel(
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
sigma=self.sigma, verbose=self.verbose, verbose2=self.verbose,
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
stopInnerThr=self.inner_tol, stopThr=self.tol, log=self.log)
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.mapping_, self.log_ = returned_
+ else:
+ self.coupling_, self.mapping_ = returned_
+ self.log_ = dict()
+
return self
def transform(self, Xs):