summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py57
-rw-r--r--test/test_da.py39
2 files changed, 75 insertions, 21 deletions
diff --git a/ot/da.py b/ot/da.py
index 8fa1895..5a34979 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1315,7 +1315,10 @@ class SinkhornTransport(BaseTransport):
Attributes
----------
- coupling_ : the optimal coupling
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
References
----------
@@ -1367,11 +1370,18 @@ class SinkhornTransport(BaseTransport):
super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
# coupling estimation
- self.coupling_ = sinkhorn(
+ returned_ = sinkhorn(
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
numItermax=self.max_iter, stopThr=self.tol,
verbose=self.verbose, log=self.log)
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
return self
@@ -1400,7 +1410,8 @@ class EMDTransport(BaseTransport):
Attributes
----------
- coupling_ : the optimal coupling
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
References
----------
@@ -1475,15 +1486,14 @@ class SinkhornLpl1Transport(BaseTransport):
The number of iteration in the inner loop
verbose : int, optional (default=0)
Controls the verbosity of the optimization algorithm
- log : int, optional (default=0)
- Controls the logs of the optimization algorithm
limit_max: float, optional (defaul=np.infty)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an infinite cost
Attributes
----------
- coupling_ : the optimal coupling
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
References
----------
@@ -1500,7 +1510,7 @@ class SinkhornLpl1Transport(BaseTransport):
def __init__(self, reg_e=1., reg_cl=0.1,
max_iter=10, max_inner_iter=200,
- tol=10e-9, verbose=False, log=False,
+ tol=10e-9, verbose=False,
metric="sqeuclidean",
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
@@ -1511,7 +1521,6 @@ class SinkhornLpl1Transport(BaseTransport):
self.max_inner_iter = max_inner_iter
self.tol = tol
self.verbose = verbose
- self.log = log
self.metric = metric
self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
@@ -1544,7 +1553,7 @@ class SinkhornLpl1Transport(BaseTransport):
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
- verbose=self.verbose, log=self.log)
+ verbose=self.verbose)
return self
@@ -1584,7 +1593,10 @@ class SinkhornL1l2Transport(BaseTransport):
Attributes
----------
- coupling_ : the optimal coupling
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
References
----------
@@ -1641,12 +1653,19 @@ class SinkhornL1l2Transport(BaseTransport):
super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt)
- self.coupling_ = sinkhorn_l1l2_gl(
+ returned_ = sinkhorn_l1l2_gl(
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
verbose=self.verbose, log=self.log)
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
return self
@@ -1683,14 +1702,15 @@ class MappingTransport(BaseEstimator):
Attributes
----------
- coupling_ : array-like, shape (n_source_samples, n_features)
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
mapping_ : array-like, shape (n_features (+ 1), n_features)
(if bias) for kernel == linear
The associated mapping
-
array-like, shape (n_source_samples (+ 1), n_features)
(if bias) for kernel == gaussian
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
References
----------
@@ -1745,19 +1765,26 @@ class MappingTransport(BaseEstimator):
self.Xt = Xt
if self.kernel == "linear":
- self.coupling_, self.mapping_ = joint_OT_mapping_linear(
+ returned_ = joint_OT_mapping_linear(
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
verbose=self.verbose, verbose2=self.verbose2,
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
stopThr=self.tol, stopInnerThr=self.inner_tol, log=self.log)
elif self.kernel == "gaussian":
- self.coupling_, self.mapping_ = joint_OT_mapping_kernel(
+ returned_ = joint_OT_mapping_kernel(
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
sigma=self.sigma, verbose=self.verbose, verbose2=self.verbose,
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
stopInnerThr=self.inner_tol, stopThr=self.tol, log=self.log)
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.mapping_, self.log_ = returned_
+ else:
+ self.coupling_, self.mapping_ = returned_
+ self.log_ = dict()
+
return self
def transform(self, Xs):
diff --git a/test/test_da.py b/test/test_da.py
index 9578b3d..104a798 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -26,6 +26,8 @@ def test_sinkhorn_lpl1_transport_class():
# test its computed
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -89,6 +91,9 @@ def test_sinkhorn_l1l2_transport_class():
# test its computed
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
+ assert hasattr(clf, "log_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -137,6 +142,11 @@ def test_sinkhorn_l1l2_transport_class():
assert n_unsup != n_semisup, "semisupervised mode not working"
+ # check everything runs well with log=True
+ clf = ot.da.SinkhornL1l2Transport(log=True)
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(clf.log_.keys()) != 0
+
def test_sinkhorn_transport_class():
"""test_sinkhorn_transport
@@ -152,6 +162,9 @@ def test_sinkhorn_transport_class():
# test its computed
clf.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
+ assert hasattr(clf, "log_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -200,6 +213,11 @@ def test_sinkhorn_transport_class():
assert n_unsup != n_semisup, "semisupervised mode not working"
+ # check everything runs well with log=True
+ clf = ot.da.SinkhornTransport(log=True)
+ clf.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(clf.log_.keys()) != 0
+
def test_emd_transport_class():
"""test_sinkhorn_transport
@@ -215,6 +233,8 @@ def test_emd_transport_class():
# test its computed
clf.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(clf, "cost_")
+ assert hasattr(clf, "coupling_")
# test dimensions of coupling
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
@@ -282,6 +302,9 @@ def test_mapping_transport_class():
# check computation and dimensions if bias == False
clf = ot.da.MappingTransport(kernel="linear", bias=False)
clf.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(clf, "coupling_")
+ assert hasattr(clf, "mapping_")
+ assert hasattr(clf, "log_")
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
assert_equal(clf.mapping_.shape, ((Xs.shape[1], Xt.shape[1])))
@@ -369,6 +392,11 @@ def test_mapping_transport_class():
# check that the oos method is working
assert_equal(transp_Xs_new.shape, Xs_new.shape)
+ # check everything runs well with log=True
+ clf = ot.da.MappingTransport(kernel="gaussian", log=True)
+ clf.fit(Xs=Xs, Xt=Xt)
+ assert len(clf.log_.keys()) != 0
+
def test_otda():
@@ -434,9 +462,8 @@ def test_otda():
# if __name__ == "__main__":
- # test_otda()
- # test_sinkhorn_transport_class()
- # test_emd_transport_class()
- # test_sinkhorn_l1l2_transport_class()
- # test_sinkhorn_lpl1_transport_class()
- # test_mapping_transport_class()
+# test_sinkhorn_transport_class()
+# test_emd_transport_class()
+# test_sinkhorn_l1l2_transport_class()
+# test_sinkhorn_lpl1_transport_class()
+# test_mapping_transport_class()