summaryrefslogtreecommitdiff
path: root/ot/__init__.py
diff options
context:
space:
mode:
authortgnassou <66993815+tgnassou@users.noreply.github.com>2023-01-16 18:09:44 +0100
committerGitHub <noreply@github.com>2023-01-16 18:09:44 +0100
commit97feeb32b6c069d7bb44cd995531c2b820d59771 (patch)
tree18f28e89a925534884c6ed97bfd986bbb61d1279 /ot/__init__.py
parent058d275565f0f65c23e06853812d5eb3a6ebdcef (diff)
[MRG] OT for Gaussian distributions (#428)
* add gaussian modules * add gaussian modules * add PR to release.md * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/gaussian.py * Update ot/gaussian.py * add empirical bures wassertsein distance, fix docstring and test * update to fit with new networkx API * add test for jax et tf" * fix test * fix test? * add empirical_bures_wasserstein_mapping * fix docs * fix doc * fix docstring * add tgnassou to contributors * add more coverage for gaussian.py * add deprecated function * fix doc math" " * fix doc math" " * add remi flamary to authors of gaussiansmodule * fix equation Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/__init__.py')
-rw-r--r--ot/__init__.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 51eb726..0b55e0c 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -35,6 +35,7 @@ from . import regpath
from . import weak
from . import factored
from . import solvers
+from . import gaussian
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -56,7 +57,7 @@ __version__ = "0.8.3dev"
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
- 'emd2_1d', 'wasserstein_1d', 'backend',
+ 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',