summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-12-06 18:02:44 +0100
committerGitHub <noreply@github.com>2022-12-06 18:02:44 +0100
commit8490196dcc982c492b7565e1ec4de5f75f006acf (patch)
tree7f330e9c15dd4dff8d36c2308fa8d416b552ce23
parentac830dd2b85cfd39f4fadd879a721b36ded033ea (diff)
[MRG] Fix bug in regularized OTDA l1lp with log (#413)
* correct bug in DA l1lp with log * better tests and speedup with smaller dataset size * remove jax for log test * remove trndorflow for log test * pep8!
-rw-r--r--RELEASES.md3
-rw-r--r--ot/da.py13
-rw-r--r--test/test_da.py62
3 files changed, 53 insertions, 25 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 68487e8..3bd84c1 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -27,7 +27,8 @@ roughly 2^31) (PR #381)
- Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402)
- Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409)
- Fixed weak optimal transport docstring (Issue #404, PR #410)
-
+- Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
+PR #413)
## 0.8.2
diff --git a/ot/da.py b/ot/da.py
index 0b9737e..083663c 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -126,8 +126,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
W = nx.zeros(M.shape, type_as=M)
for cpt in range(numItermax):
Mreg = M + eta * W
- transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
- stopThr=stopInnerThr)
+ if log:
+ transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr, log=True)
+ else:
+ transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr)
# the transport has been computed. Check if classes are really
# separated
W = nx.ones(M.shape, type_as=M)
@@ -136,7 +140,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs
- return transp
+ if log:
+ return transp, log
+ else:
+ return transp
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
diff --git a/test/test_da.py b/test/test_da.py
index 4bf0ab1..138936f 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -44,12 +44,32 @@ def test_class_jax_tf():
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
+@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport])
+def test_log_da(nx, class_to_test):
+
+ ns = 50
+ nt = 50
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
+ otda = class_to_test(log=True)
+
+ # test its computed
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert hasattr(otda, "log_")
+
+
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
def test_sinkhorn_lpl1_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
"""
ns = 50
- nt = 100
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -402,8 +422,8 @@ def test_emd_transport_class(nx):
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -558,8 +578,8 @@ def test_mapping_transport_class_specific_seed(nx):
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_linear_mapping(nx):
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -579,8 +599,8 @@ def test_linear_mapping(nx):
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_linear_mapping_class(nx):
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -609,9 +629,9 @@ def test_jcpot_transport_class(nx):
"""test_jcpot_transport
"""
- ns1 = 150
- ns2 = 150
- nt = 200
+ ns1 = 50
+ ns2 = 50
+ nt = 50
Xs1, ys1 = make_data_classif('3gauss', ns1)
Xs2, ys2 = make_data_classif('3gauss', ns2)
@@ -681,9 +701,9 @@ def test_jcpot_barycenter(nx):
"""test_jcpot_barycenter
"""
- ns1 = 150
- ns2 = 150
- nt = 200
+ ns1 = 50
+ ns2 = 50
+ nt = 50
sigma = 0.1
np.random.seed(1985)
@@ -713,8 +733,8 @@ def test_jcpot_barycenter(nx):
def test_emd_laplace_class(nx):
"""test_emd_laplace_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 50
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)