summaryrefslogtreecommitdiff
path: root/test/test_utils.py
diff options
context:
space:
mode:
authorClément Bonet <32179275+clbonet@users.noreply.github.com>2023-02-23 08:31:01 +0100
committerGitHub <noreply@github.com>2023-02-23 08:31:01 +0100
commit80e3c23bc968f866fd20344ddc443a3c7fcb3b0d (patch)
treee4c2e938896243842e290d8fcf78879a8f6960bf /test/test_utils.py
parent97feeb32b6c069d7bb44cd995531c2b820d59771 (diff)
[WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434)
* W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn
Diffstat (limited to 'test/test_utils.py')
-rw-r--r--test/test_utils.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 666c157..31b12ef 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -330,3 +330,13 @@ def test_OTResult():
for at in lst_attributes:
with pytest.raises(NotImplementedError):
getattr(res, at)
+
+
+def test_get_coordinate_circle():
+
+ u = np.random.rand(1, 100)
+ x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi))
+ x = np.concatenate([x1, y1]).T
+ x_p = ot.utils.get_coordinate_circle(x)
+
+ np.testing.assert_allclose(u[0], x_p)