summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorClément Bonet <32179275+clbonet@users.noreply.github.com>2023-02-23 08:31:01 +0100
committerGitHub <noreply@github.com>2023-02-23 08:31:01 +0100
commit80e3c23bc968f866fd20344ddc443a3c7fcb3b0d (patch)
treee4c2e938896243842e290d8fcf78879a8f6960bf /ot/utils.py
parent97feeb32b6c069d7bb44cd995531c2b820d59771 (diff)
[WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434)
* W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 9093f09..3423a7e 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -375,6 +375,36 @@ def check_random_state(seed):
' instance'.format(seed))
+def get_coordinate_circle(x):
+ r"""For :math:`x\in S^1 \subset \mathbb{R}^2`, returns the coordinates in
+ turn (in [0,1[).
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ Parameters
+ ----------
+ x: ndarray, shape (n, 2)
+ Samples on the circle with ambient coordinates
+
+ Returns
+ -------
+ x_t: ndarray, shape (n,)
+ Coordinates on [0,1[
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]]) * (2 * np.pi)
+ >>> x1, y1 = np.cos(u), np.sin(u)
+ >>> x = np.concatenate([x1, y1]).T
+ >>> get_coordinate_circle(x)
+ array([0.2, 0.5, 0.8])
+ """
+ nx = get_backend(x)
+ x_t = (nx.atan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
+ return x_t
+
+
class deprecated(object):
r"""Decorator to mark a function or class as deprecated.