From 80e3c23bc968f866fd20344ddc443a3c7fcb3b0d Mon Sep 17 00:00:00 2001 From: Clément Bonet <32179275+clbonet@users.noreply.github.com> Date: Thu, 23 Feb 2023 08:31:01 +0100 Subject: [WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434) * W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn --- test/test_backend.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) (limited to 'test/test_backend.py') diff --git a/test/test_backend.py b/test/test_backend.py index 3628f61..fd9a761 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -282,6 +282,20 @@ def test_empty_backend(): nx.array_equal(M, M) with pytest.raises(NotImplementedError): nx.is_floating_point(M) + with pytest.raises(NotImplementedError): + nx.tile(M, (10, 1)) + with pytest.raises(NotImplementedError): + nx.floor(M) + with pytest.raises(NotImplementedError): + nx.prod(M) + with pytest.raises(NotImplementedError): + nx.sort2(M) + with pytest.raises(NotImplementedError): + nx.qr(M) + with pytest.raises(NotImplementedError): + nx.atan2(v, v) + with pytest.raises(NotImplementedError): + nx.transpose(M) def test_func_backends(nx): @@ -603,6 +617,38 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("isfinite") + A = nx.tile(vb, (10, 1)) + lst_b.append(nx.to_numpy(A)) + lst_name.append("tile") + + A = nx.floor(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("floor") + + A = nx.prod(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("prod") + + A, B = nx.sort2(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("sort2 sort") + lst_b.append(nx.to_numpy(B)) + lst_name.append("sort2 argsort") + + A, B = nx.qr(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("QR Q") + lst_b.append(nx.to_numpy(B)) + lst_name.append("QR R") + + A = nx.atan2(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("atan2") + + A = nx.transpose(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("transpose") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( -- cgit v1.2.3