From 2d4d0b46f88c66ebc5502c840703ba6ce8910376 Mon Sep 17 00:00:00 2001 From: Slasnista Date: Fri, 25 Aug 2017 10:29:41 +0200 Subject: solving log issues to avoid errors and adding further tests --- ot/da.py | 57 ++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 15 deletions(-) (limited to 'ot') diff --git a/ot/da.py b/ot/da.py index 8fa1895..5a34979 100644 --- a/ot/da.py +++ b/ot/da.py @@ -1315,7 +1315,10 @@ class SinkhornTransport(BaseTransport): Attributes ---------- - coupling_ : the optimal coupling + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True References ---------- @@ -1367,11 +1370,18 @@ class SinkhornTransport(BaseTransport): super(SinkhornTransport, self).fit(Xs, ys, Xt, yt) # coupling estimation - self.coupling_ = sinkhorn( + returned_ = sinkhorn( a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e, numItermax=self.max_iter, stopThr=self.tol, verbose=self.verbose, log=self.log) + # deal with the value of log + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + return self @@ -1400,7 +1410,8 @@ class EMDTransport(BaseTransport): Attributes ---------- - coupling_ : the optimal coupling + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling References ---------- @@ -1475,15 +1486,14 @@ class SinkhornLpl1Transport(BaseTransport): The number of iteration in the inner loop verbose : int, optional (default=0) Controls the verbosity of the optimization algorithm - log : int, optional (default=0) - Controls the logs of the optimization algorithm limit_max: float, optional (defaul=np.infty) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an infinite cost Attributes ---------- - coupling_ : the optimal coupling + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling References ---------- @@ -1500,7 +1510,7 @@ class SinkhornLpl1Transport(BaseTransport): def __init__(self, reg_e=1., reg_cl=0.1, max_iter=10, max_inner_iter=200, - tol=10e-9, verbose=False, log=False, + tol=10e-9, verbose=False, metric="sqeuclidean", distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=np.infty): @@ -1511,7 +1521,6 @@ class SinkhornLpl1Transport(BaseTransport): self.max_inner_iter = max_inner_iter self.tol = tol self.verbose = verbose - self.log = log self.metric = metric self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map @@ -1544,7 +1553,7 @@ class SinkhornLpl1Transport(BaseTransport): a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_, reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter, numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol, - verbose=self.verbose, log=self.log) + verbose=self.verbose) return self @@ -1584,7 +1593,10 @@ class SinkhornL1l2Transport(BaseTransport): Attributes ---------- - coupling_ : the optimal coupling + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True References ---------- @@ -1641,12 +1653,19 @@ class SinkhornL1l2Transport(BaseTransport): super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt) - self.coupling_ = sinkhorn_l1l2_gl( + returned_ = sinkhorn_l1l2_gl( a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_, reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter, numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol, verbose=self.verbose, log=self.log) + # deal with the value of log + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + return self @@ -1683,14 +1702,15 @@ class MappingTransport(BaseEstimator): Attributes ---------- - coupling_ : array-like, shape (n_source_samples, n_features) + coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling mapping_ : array-like, shape (n_features (+ 1), n_features) (if bias) for kernel == linear The associated mapping - array-like, shape (n_source_samples (+ 1), n_features) (if bias) for kernel == gaussian + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True References ---------- @@ -1745,19 +1765,26 @@ class MappingTransport(BaseEstimator): self.Xt = Xt if self.kernel == "linear": - self.coupling_, self.mapping_ = joint_OT_mapping_linear( + returned_ = joint_OT_mapping_linear( Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias, verbose=self.verbose, verbose2=self.verbose2, numItermax=self.max_iter, numInnerItermax=self.max_inner_iter, stopThr=self.tol, stopInnerThr=self.inner_tol, log=self.log) elif self.kernel == "gaussian": - self.coupling_, self.mapping_ = joint_OT_mapping_kernel( + returned_ = joint_OT_mapping_kernel( Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias, sigma=self.sigma, verbose=self.verbose, verbose2=self.verbose, numItermax=self.max_iter, numInnerItermax=self.max_inner_iter, stopInnerThr=self.inner_tol, stopThr=self.tol, log=self.log) + # deal with the value of log + if self.log: + self.coupling_, self.mapping_, self.log_ = returned_ + else: + self.coupling_, self.mapping_ = returned_ + self.log_ = dict() + return self def transform(self, Xs): -- cgit v1.2.3