summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py30
-rw-r--r--ot/bregman.py7
-rw-r--r--ot/da.py54
-rw-r--r--ot/datasets.py6
-rw-r--r--ot/dr.py4
-rw-r--r--ot/gpu/__init__.py5
-rw-r--r--ot/gromov.py2
-rw-r--r--ot/lp/__init__.py6
-rw-r--r--ot/optim.py2
-rwxr-xr-xot/partial.py2
-rw-r--r--ot/smooth.py4
-rw-r--r--ot/unbalanced.py2
12 files changed, 58 insertions, 66 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 1e57b78..2d23610 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -1,35 +1,5 @@
"""
-This is the main module of the POT toolbox. It provides easy access to
-a number of sub-modules and functions described below.
-
-.. note::
-
-
- Here is a list of the submodules and short description of what they contain.
-
- - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems.
- - :any:`ot.bregman` contains OT solvers for the entropic OT problems using
- Bregman projections.
- - :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems.
- - :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT
- problems.
- - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov
- Wasserstein problems.
- - :any:`ot.optim` contains generic solvers OT based optimization problems
- - :any:`ot.da` contains classes and function related to Monge mapping
- estimation and Domain Adaptation (DA).
- - :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers
- - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein
- Discriminant Analysis.
- - :any:`ot.utils` contains utility functions such as distance computation and
- timing.
- - :any:`ot.datasets` contains toy dataset generation functions.
- - :any:`ot.plot` contains visualization functions
- - :any:`ot.stochastic` contains stochastic solvers for regularized OT.
- - :any:`ot.unbalanced` contains solvers for regularized unbalanced OT.
- - :any:`ot.partial` contains solvers for partial OT.
-
.. warning::
The list of automatically imported sub-modules is as follows:
:py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
diff --git a/ot/bregman.py b/ot/bregman.py
index 543dbaa..f1f8437 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Bregman projections for regularized OT
+Bregman projections solvers for entropic regularized OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -909,11 +909,6 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
else:
alpha, beta = warmstart
- def get_K(alpha, beta):
- """log space computation"""
- return np.exp(-(M - alpha.reshape((dim_a, 1))
- - beta.reshape((1, dim_b))) / reg)
-
# print(np.min(K))
def get_reg(n): # exponential decreasing
return (epsilon0 - reg) * np.exp(-n) + reg
diff --git a/ot/da.py b/ot/da.py
index 6249f08..b881a8b 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -908,7 +908,8 @@ class BaseTransport(BaseEstimator):
at the class level in their ``__init__`` as explicit keyword
arguments (no ``*args`` or ``**kwargs``).
- fit method should:
+ the fit method should:
+
- estimate a cost matrix and store it in a `cost_` attribute
- estimate a coupling matrix and store it in a `coupling_`
attribute
@@ -933,7 +934,7 @@ class BaseTransport(BaseEstimator):
Xs : array-like, shape (n_source_samples, n_features)
The training input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The training class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
@@ -994,7 +995,7 @@ class BaseTransport(BaseEstimator):
Xs : array-like, shape (n_source_samples, n_features)
The training input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The class labels for training samples
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
@@ -1018,13 +1019,13 @@ class BaseTransport(BaseEstimator):
Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
+ The source input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The class labels for source samples
Xt : array-like, shape (n_target_samples, n_features)
- The training input samples.
+ The target input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
+ The class labels for target. If some target samples are unlabeled, fill the
yt's elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
@@ -1085,7 +1086,7 @@ class BaseTransport(BaseEstimator):
Parameters
----------
ys : array-like, shape (n_source_samples,)
- The class labels
+ The source class labels
Returns
-------
@@ -1125,18 +1126,18 @@ class BaseTransport(BaseEstimator):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples Xt onto target samples Xs
+ """Transports target samples Xt onto source samples Xs
Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
+ The source input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The source class labels
Xt : array-like, shape (n_target_samples, n_features)
- The training input samples.
+ The target input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
+ The target class labels. If some target samples are unlabeled, fill the
yt's elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
@@ -1227,7 +1228,6 @@ class BaseTransport(BaseEstimator):
class LinearTransport(BaseTransport):
-
""" OT linear operator between empirical distributions
The function estimates the optimal linear operator that aligns the two
@@ -1438,6 +1438,9 @@ class SinkhornTransport(BaseTransport):
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems (NIPS)
26, 2013
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_e=1., max_iter=1000,
@@ -1536,6 +1539,9 @@ class EMDTransport(BaseTransport):
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE Transactions
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, metric="sqeuclidean", norm=None, log=False,
@@ -1643,7 +1649,9 @@ class SinkhornLpl1Transport(BaseTransport):
.. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence
and applications. arXiv preprint arXiv:1510.06567.
-
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_e=1., reg_cl=0.1,
@@ -1763,6 +1771,9 @@ class EMDLaplaceTransport(BaseTransport):
.. [2] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
"Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching,"
in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean",
@@ -1882,7 +1893,9 @@ class SinkhornL1l2Transport(BaseTransport):
.. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence
and applications. arXiv preprint arXiv:1510.06567.
-
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_e=1., reg_cl=0.1,
@@ -2174,7 +2187,9 @@ class UnbalancedSinkhornTransport(BaseTransport):
.. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint
arXiv:1607.05816.
-
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
@@ -2287,6 +2302,11 @@ class JCPOTTransport(BaseTransport):
International Conference on Artificial Intelligence and Statistics (AISTATS),
vol. 89, p.849-858, 2019.
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
+
+
"""
def __init__(self, reg_e=.1, max_iter=10,
diff --git a/ot/datasets.py b/ot/datasets.py
index a1ca7b6..b86ef3b 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -1,5 +1,5 @@
"""
-Simple example datasets for OT
+Simple example datasets
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -147,8 +147,8 @@ def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwa
n2 = np.sum(y == 2)
x = np.zeros((n, 2))
- x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator)
- x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator)
+ x[y == 1, :] = make_2D_samples_gauss(n1, m1, nz, random_state=generator)
+ x[y == 2, :] = make_2D_samples_gauss(n2, m2, nz, random_state=generator)
x = x.dot(rot)
diff --git a/ot/dr.py b/ot/dr.py
index 680dabf..11d2e10 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""
-Dimension reduction with optimal transport
+Dimension reduction with OT
.. warning::
- Note that by default the module is not import in :mod:`ot`. In order to
+ Note that by default the module is not imported in :mod:`ot`. In order to
use it you need to explicitely import :mod:`ot.dr`
"""
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index 1ab95bb..7478fb9 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
"""
+GPU implementation for several OT solvers and utility
+functions.
-This module provides GPU implementation for several OT solvers and utility
-functions. The GPU backend in handled by `cupy
+The GPU backend in handled by `cupy
<https://cupy.chainer.org/>`_.
.. warning::
diff --git a/ot/gromov.py b/ot/gromov.py
index 43780a4..4427a96 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Gromov-Wasserstein transport method
+Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers
"""
# Author: Erwan Vautier <erwan.vautier@gmail.com>
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index ad390c5..50003ed 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -180,7 +180,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
\gamma = arg\min_\gamma <\gamma,M>_F
s.t. \gamma 1 = a
+
\gamma^T 1= b
+
\gamma\geq 0
where :
@@ -289,10 +291,12 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
r"""Solves the Earth Movers distance problem and returns the loss
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ \min_\gamma <\gamma,M>_F
s.t. \gamma 1 = a
+
\gamma^T 1= b
+
\gamma\geq 0
where :
diff --git a/ot/optim.py b/ot/optim.py
index 4012e0d..b9ca891 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Optimization algorithms for OT
+Generic solvers for regularized OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
diff --git a/ot/partial.py b/ot/partial.py
index c03ec25..eb707d8 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Partial OT
+Partial OT solvers
"""
# Author: Laetitia Chapel <laetitia.chapel@irisa.fr>
diff --git a/ot/smooth.py b/ot/smooth.py
index 5a8e4b5..81f6a3e 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -26,7 +26,9 @@
# Remi Flamary <remi.flamary@unice.fr>
"""
-Implementation of
+Smooth and Sparse Optimal Transport solvers (KL an L2 reg.)
+
+Implementation of :
Smooth and Sparse Optimal Transport.
Mathieu Blondel, Vivien Seguy, Antoine Rolet.
In Proc. of AISTATS 2018.
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 23f6607..e37f10c 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Regularized Unbalanced OT
+Regularized Unbalanced OT solvers
"""
# Author: Hicham Janati <hicham.janati@inria.fr>