diff options
author | Clément Bonet <32179275+clbonet@users.noreply.github.com> | 2023-02-23 08:31:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-23 08:31:01 +0100 |
commit | 80e3c23bc968f866fd20344ddc443a3c7fcb3b0d (patch) | |
tree | e4c2e938896243842e290d8fcf78879a8f6960bf /ot/utils.py | |
parent | 97feeb32b6c069d7bb44cd995531c2b820d59771 (diff) |
[WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434)
* W circle + SSW
* Tests + Example SSW_1
* Example Wasserstein Circle + Tests
* Wasserstein on the circle wrt Unif
* Example SSW unif
* pep8
* np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests
* np qr
* rm test python 3.11
* update names, tests, backend transpose
* Comment error batchs
* semidiscrete_wasserstein2_unif_circle example
* torch permute method instead of torch.permute for previous versions
* update comments and doc
* doc wasserstein circle model as [0,1[
* Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/ot/utils.py b/ot/utils.py index 9093f09..3423a7e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -375,6 +375,36 @@ def check_random_state(seed): ' instance'.format(seed)) +def get_coordinate_circle(x): + r"""For :math:`x\in S^1 \subset \mathbb{R}^2`, returns the coordinates in + turn (in [0,1[). + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + Parameters + ---------- + x: ndarray, shape (n, 2) + Samples on the circle with ambient coordinates + + Returns + ------- + x_t: ndarray, shape (n,) + Coordinates on [0,1[ + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]]) * (2 * np.pi) + >>> x1, y1 = np.cos(u), np.sin(u) + >>> x = np.concatenate([x1, y1]).T + >>> get_coordinate_circle(x) + array([0.2, 0.5, 0.8]) + """ + nx = get_backend(x) + x_t = (nx.atan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + return x_t + + class deprecated(object): r"""Decorator to mark a function or class as deprecated. |