summaryrefslogtreecommitdiff
path: root/ot/lp/__init__.py
diff options
context:
space:
mode:
authorClément Bonet <32179275+clbonet@users.noreply.github.com>2023-02-23 08:31:01 +0100
committerGitHub <noreply@github.com>2023-02-23 08:31:01 +0100
commit80e3c23bc968f866fd20344ddc443a3c7fcb3b0d (patch)
treee4c2e938896243842e290d8fcf78879a8f6960bf /ot/lp/__init__.py
parent97feeb32b6c069d7bb44cd995531c2b820d59771 (diff)
[WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434)
* W circle + SSW * Tests + Example SSW_1 * Example Wasserstein Circle + Tests * Wasserstein on the circle wrt Unif * Example SSW unif * pep8 * np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests * np qr * rm test python 3.11 * update names, tests, backend transpose * Comment error batchs * semidiscrete_wasserstein2_unif_circle example * torch permute method instead of torch.permute for previous versions * update comments and doc * doc wasserstein circle model as [0,1[ * Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn
Diffstat (limited to 'ot/lp/__init__.py')
-rw-r--r--ot/lp/__init__.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 17411d0..7d0640f 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -20,14 +20,17 @@ from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from .solver_1d import emd_1d, emd2_1d, wasserstein_1d
+from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d,
+ binary_search_circle, wasserstein_circle,
+ semidiscrete_wasserstein2_unif_circle)
from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter']
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter',
+ 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle']
def check_number_threads(numThreads):