diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2022-12-06 18:02:44 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-06 18:02:44 +0100 |
commit | 8490196dcc982c492b7565e1ec4de5f75f006acf (patch) | |
tree | 7f330e9c15dd4dff8d36c2308fa8d416b552ce23 | |
parent | ac830dd2b85cfd39f4fadd879a721b36ded033ea (diff) |
[MRG] Fix bug in regularized OTDA l1lp with log (#413)
* correct bug in DA l1lp with log
* better tests and speedup with smaller dataset size
* remove jax for log test
* remove trndorflow for log test
* pep8!
-rw-r--r-- | RELEASES.md | 3 | ||||
-rw-r--r-- | ot/da.py | 13 | ||||
-rw-r--r-- | test/test_da.py | 62 |
3 files changed, 53 insertions, 25 deletions
diff --git a/RELEASES.md b/RELEASES.md index 68487e8..3bd84c1 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -27,7 +27,8 @@ roughly 2^31) (PR #381) - Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402) - Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409) - Fixed weak optimal transport docstring (Issue #404, PR #410) - +- Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412, +PR #413) ## 0.8.2 @@ -126,8 +126,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, W = nx.zeros(M.shape, type_as=M) for cpt in range(numItermax): Mreg = M + eta * W - transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, - stopThr=stopInnerThr) + if log: + transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, + stopThr=stopInnerThr, log=True) + else: + transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, + stopThr=stopInnerThr) # the transport has been computed. Check if classes are really # separated W = nx.ones(M.shape, type_as=M) @@ -136,7 +140,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, majs = p * ((majs + epsilon) ** (p - 1)) W[indices_labels[i]] = majs - return transp + if log: + return transp, log + else: + return transp def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, diff --git a/test/test_da.py b/test/test_da.py index 4bf0ab1..138936f 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -44,12 +44,32 @@ def test_class_jax_tf(): @pytest.skip_backend("jax") @pytest.skip_backend("tf") +@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport]) +def test_log_da(nx, class_to_test): + + ns = 50 + nt = 50 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + + otda = class_to_test(log=True) + + # test its computed + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + assert hasattr(otda, "log_") + + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") def test_sinkhorn_lpl1_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx): """ ns = 50 - nt = 100 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -402,8 +422,8 @@ def test_emd_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -558,8 +578,8 @@ def test_mapping_transport_class_specific_seed(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_linear_mapping(nx): - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -579,8 +599,8 @@ def test_linear_mapping(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_linear_mapping_class(nx): - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -609,9 +629,9 @@ def test_jcpot_transport_class(nx): """test_jcpot_transport """ - ns1 = 150 - ns2 = 150 - nt = 200 + ns1 = 50 + ns2 = 50 + nt = 50 Xs1, ys1 = make_data_classif('3gauss', ns1) Xs2, ys2 = make_data_classif('3gauss', ns2) @@ -681,9 +701,9 @@ def test_jcpot_barycenter(nx): """test_jcpot_barycenter """ - ns1 = 150 - ns2 = 150 - nt = 200 + ns1 = 50 + ns2 = 50 + nt = 50 sigma = 0.1 np.random.seed(1985) @@ -713,8 +733,8 @@ def test_jcpot_barycenter(nx): def test_emd_laplace_class(nx): """test_emd_laplace_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) |