summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-09-09 14:55:04 +0200
committerRémi Flamary <remi.flamary@gmail.com>2019-09-09 14:55:04 +0200
commitb2a7afb848a78570d01f35f9b239be8838520edc (patch)
treefc243208d24f5488d5ce06298b2ebb39b76be9bb
parentc698e0aa20d28e36d25f87082855a490283f3c88 (diff)
parentf251b4d080a577c2cee890ca43d8ec3658332021 (diff)
merge new unbalanced
-rw-r--r--.gitignore3
-rw-r--r--.travis.yml5
-rw-r--r--Makefile3
-rw-r--r--docs/source/quickstart.rst148
-rw-r--r--examples/plot_barycenter_lp_vs_entropic.py7
-rw-r--r--examples/plot_free_support_barycenter.py2
-rw-r--r--ot/__init__.py34
-rw-r--r--ot/bregman.py604
-rw-r--r--ot/da.py121
-rw-r--r--ot/datasets.py32
-rw-r--r--ot/dr.py57
-rw-r--r--ot/gromov.py360
-rw-r--r--ot/optim.py32
-rw-r--r--ot/plot.py6
-rw-r--r--ot/stochastic.py199
-rw-r--r--ot/unbalanced.py808
-rw-r--r--ot/utils.py56
-rw-r--r--pytest.ini0
-rw-r--r--setup.cfg18
-rw-r--r--test/test_bregman.py97
-rw-r--r--test/test_da.py65
-rw-r--r--test/test_unbalanced.py163
22 files changed, 1943 insertions, 877 deletions
diff --git a/.gitignore b/.gitignore
index 42a9aad..dadf84c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -59,6 +59,9 @@ coverage.xml
*.mo
*.pot
+# xml
+*.xml
+
# Django stuff:
*.log
local_settings.py
diff --git a/.travis.yml b/.travis.yml
index 0dfb0d0..5b3a26e 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -25,6 +25,9 @@ matrix:
# language: shell # 'language: python' is an error on Travis CI Windows
# before_install: choco install python
# env: PATH=/c/Python37:/c/Python37/Scripts:$PATH
+# before_script: # configure a headless display to test plot generation
+# - "export DISPLAY=:99.0"
+# - sleep 3 # give xvfb some time to start
before_install:
- ./.travis/before_install.sh
# command to install dependencies
@@ -34,6 +37,8 @@ install:
- pip install flake8 pytest "pytest-cov<2.6"
- pip install .
# command to run tests + check syntax style
+services:
+ - xvfb
script:
- python setup.py develop
- flake8 examples/ ot/ test/
diff --git a/Makefile b/Makefile
index 729fd8c..cafda8e 100644
--- a/Makefile
+++ b/Makefile
@@ -85,4 +85,7 @@ dist : wheels
$(PYTHON) setup.py sdist
+pydocstyle :
+ pydocstyle ot
+
FORCE :
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index 1640d6a..978eaff 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -2,13 +2,13 @@
Quick start guide
=================
-In the following we provide some pointers about which functions and classes
+In the following we provide some pointers about which functions and classes
to use for different problems related to optimal transport (OT) and machine
learning. We refer when we can to concrete examples in the documentation that
are also available as notebooks on the POT Github.
This document is not a tutorial on numerical optimal transport. For this we strongly
-recommend to read the very nice book [15]_ .
+recommend to read the very nice book [15]_ .
Optimal transport and Wasserstein distance
@@ -55,8 +55,8 @@ solver is quite efficient and uses sparsity of the solution.
Examples of use for :any:`ot.emd` are available in :
- :any:`auto_examples/plot_OT_2D_samples`
- - :any:`auto_examples/plot_OT_1D`
- - :any:`auto_examples/plot_OT_L1_vs_L2`
+ - :any:`auto_examples/plot_OT_1D`
+ - :any:`auto_examples/plot_OT_L1_vs_L2`
Computing Wasserstein distance
@@ -102,13 +102,13 @@ distance.
An example of use for :any:`ot.emd2` is available in :
- :any:`auto_examples/plot_compute_emd`
-
+
Special cases
^^^^^^^^^^^^^
Note that the OT problem and the corresponding Wasserstein distance can in some
-special cases be computed very efficiently.
+special cases be computed very efficiently.
For instance when the samples are in 1D, then the OT problem can be solved in
:math:`O(n\log(n))` by using a simple sorting. In this case we provide the
@@ -117,13 +117,13 @@ matrix and value. Note that since the solution is very sparse the :code:`sparse`
parameter of :any:`ot.emd_1d` allows for solving and returning the solution for
very large problems. Note that in order to compute directly the :math:`W_p`
Wasserstein distance in 1D we provide the function :any:`ot.wasserstein_1d` that
-takes :code:`p` as a parameter.
+takes :code:`p` as a parameter.
Another special case for estimating OT and Monge mapping is between Gaussian
distributions. In this case there exists a close form solution given in Remark
2.29 in [15]_ and the Monge mapping is an affine function and can be
also computed from the covariances and means of the source and target
-distributions. In the case when the finite sample dataset is supposed gaussian, we provide
+distributions. In the case when the finite sample dataset is supposed gaussian, we provide
:any:`ot.da.OT_mapping_linear` that returns the parameters for the Monge
mapping.
@@ -176,7 +176,7 @@ solution of the resulting optimization problem can be expressed as:
where :math:`u` and :math:`v` are vectors and :math:`K=\exp(-M/\lambda)` where
the :math:`\exp` is taken component-wise. In order to solve the optimization
problem, on can use an alternative projection algorithm called Sinkhorn-Knopp that can be very
-efficient for large values if regularization.
+efficient for large values if regularization.
The Sinkhorn-Knopp algorithm is implemented in :any:`ot.sinkhorn` and
:any:`ot.sinkhorn2` that return respectively the OT matrix and the value of the
@@ -201,12 +201,12 @@ More details about the algorithms used are given in the following note.
+ :code:`method='sinkhorn'` calls :any:`ot.bregman.sinkhorn_knopp` the
classic algorithm [2]_.
+ :code:`method='sinkhorn_stabilized'` calls :any:`ot.bregman.sinkhorn_stabilized` the
- log stabilized version of the algorithm [9]_.
+ log stabilized version of the algorithm [9]_.
+ :code:`method='sinkhorn_epsilon_scaling'` calls
:any:`ot.bregman.sinkhorn_epsilon_scaling` the epsilon scaling version
- of the algorithm [9]_.
+ of the algorithm [9]_.
+ :code:`method='greenkhorn'` calls :any:`ot.bregman.greenkhorn` the
- greedy sinkhorn verison of the algorithm [22]_.
+ greedy sinkhorn verison of the algorithm [22]_.
In addition to all those variants of sinkhorn, we have another
implementation solving the problem in the smooth dual or semi-dual in
@@ -236,7 +236,7 @@ of algorithms in [18]_ [19]_.
Examples of use for :any:`ot.sinkhorn` are available in :
- :any:`auto_examples/plot_OT_2D_samples`
- - :any:`auto_examples/plot_OT_1D`
+ - :any:`auto_examples/plot_OT_1D`
- :any:`auto_examples/plot_OT_1D_smooth`
- :any:`auto_examples/plot_stochastic`
@@ -248,13 +248,13 @@ While entropic OT is the most common and favored in practice, there exist other
kind of regularization. We provide in POT two specific solvers for other
regularization terms, namely quadratic regularization and group lasso
regularization. But we also provide in :any:`ot.optim` two generic solvers that allows solving any
-smooth regularization in practice.
+smooth regularization in practice.
Quadratic regularization
""""""""""""""""""""""""
The first general regularization term we can solve is the quadratic
-regularization of the form
+regularization of the form
.. math::
\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}^2
@@ -264,7 +264,7 @@ densifying the OT matrix but it keeps some sort of sparsity that is lost with
entropic regularization as soon as :math:`\lambda>0` [17]_. This problem can be
solved with POT using solvers from :any:`ot.smooth`, more specifically
functions :any:`ot.smooth.smooth_ot_dual` or
-:any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='l2'` to
+:any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='l2'` to
choose the quadratic regularization.
.. hint::
@@ -300,7 +300,7 @@ gradient algorithm [7]_ in function
.. hint::
Examples of group Lasso regularization are available in :
- - :any:`auto_examples/plot_otda_classes`
+ - :any:`auto_examples/plot_otda_classes`
- :any:`auto_examples/plot_otda_d2`
@@ -311,7 +311,7 @@ Finally we propose in POT generic solvers that can be used to solve any
regularization as long as you can provide a function computing the
regularization and a function computing its gradient (or sub-gradient).
-In order to solve
+In order to solve
.. math::
\gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + \lambda\Omega(\gamma)
@@ -336,12 +336,12 @@ Another generic solver is proposed to solve the problem
where :math:`\Omega_e` is the entropic regularization. In this case we use a
generalized conditional gradient [7]_ implemented in :any:`ot.optim.gcg` that
does not linearize the entropic term but
-relies on :any:`ot.sinkhorn` for its iterations.
+relies on :any:`ot.sinkhorn` for its iterations.
.. hint::
An example of generic solvers are available in :
- - :any:`auto_examples/plot_optim_OTreg`
+ - :any:`auto_examples/plot_optim_OTreg`
Wasserstein Barycenters
@@ -382,7 +382,7 @@ solver :any:`ot.lp.barycenter` that rely on generic LP solvers. By default the
function uses :any:`scipy.optimize.linprog`, but more efficient LP solvers from
cvxopt can be also used by changing parameter :code:`solver`. Note that this problem
requires to solve a very large linear program and can be very slow in
-practice.
+practice.
Similarly to the OT problem, OT barycenters can be computed in the regularized
case. When using entropic regularization is used, the problem can be solved with a
@@ -403,11 +403,11 @@ operators. We provide an implementation of this algorithm in function
Examples of Wasserstein (:any:`ot.lp.barycenter`) and regularized Wasserstein
barycenter (:any:`ot.bregman.barycenter`) computation are available in :
- - :any:`auto_examples/plot_barycenter_1D`
- - :any:`auto_examples/plot_barycenter_lp_vs_entropic`
+ - :any:`auto_examples/plot_barycenter_1D`
+ - :any:`auto_examples/plot_barycenter_lp_vs_entropic`
An example of convolutional barycenter
- (:any:`ot.bregman.convolutional_barycenter2d`) computation
+ (:any:`ot.bregman.convolutional_barycenter2d`) computation
for 2D images is available
in :
@@ -451,13 +451,13 @@ optimal mapping is still an open problem in the general case but has been proven
for smooth distributions by Brenier in his eponym `theorem
<https://who.rocq.inria.fr/Jean-David.Benamou/demiheure.pdf>`__. We provide in
:any:`ot.da` several solvers for smooth Monge mapping estimation and domain
-adaptation from discrete distributions.
+adaptation from discrete distributions.
Monge Mapping estimation
^^^^^^^^^^^^^^^^^^^^^^^^
We now discuss several approaches that are implemented in POT to estimate or
-approximate a Monge mapping from finite distributions.
+approximate a Monge mapping from finite distributions.
First note that when the source and target distributions are supposed to be Gaussian
distributions, there exists a close form solution for the mapping and its an
@@ -513,16 +513,16 @@ A list of the provided implementation is given in the following note.
Here is a list of the OT mapping classes inheriting from
:any:`ot.da.BaseTransport`
-
+
* :any:`ot.da.EMDTransport` : Barycentric mapping with EMD transport
* :any:`ot.da.SinkhornTransport` : Barycentric mapping with Sinkhorn transport
* :any:`ot.da.SinkhornL1l2Transport` : Barycentric mapping with Sinkhorn +
group Lasso regularization [5]_
* :any:`ot.da.SinkhornLpl1Transport` : Barycentric mapping with Sinkhorn +
- non convex group Lasso regularization [5]_
+ non convex group Lasso regularization [5]_
* :any:`ot.da.LinearTransport` : Linear mapping estimation between Gaussians
[14]_
- * :any:`ot.da.MappingTransport` : Nonlinear mapping estimation [8]_
+ * :any:`ot.da.MappingTransport` : Nonlinear mapping estimation [8]_
.. hint::
@@ -550,7 +550,7 @@ consist in finding a linear projector optimizing the following criterion
.. math::
P = \text{arg}\min_P \frac{\sum_i OT_e(\mu_i\#P,\mu_i\#P)}{\sum_{i,j\neq i}
OT_e(\mu_i\#P,\mu_j\#P)}
-
+
where :math:`\#` is the push-forward operator, :math:`OT_e` is the entropic OT
loss and :math:`\mu_i` is the
distribution of samples from class :math:`i`. :math:`P` is also constrained to
@@ -575,12 +575,12 @@ respectively. Note that we also provide the Fisher discriminant estimator in
Unbalanced optimal transport
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-Unbalanced OT is a relaxation of the original OT problem where the violation of
+Unbalanced OT is a relaxation of the entropy regularized OT problem where the violation of
the constraint on the marginals is added to the objective of the optimization
-problem:
-
+problem. The unbalanced OT metric between two unbalanced histograms a and b is defined as [25]_ [10]_:
+
.. math::
- \min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + reg\cdot\Omega(\gamma) + \alpha KL(\gamma 1, a) + \alpha KL(\gamma^T 1, b)
+ W_u(a, b) = \min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t. \quad \gamma\geq 0
@@ -589,18 +589,60 @@ where KL is the Kullback-Leibler divergence. This formulation allows for
computing approximate mapping between distributions that do not have the same
amount of mass. Interestingly the problem can be solved with a generalization of
the Bregman projections algorithm [10]_. We provide a solver for unbalanced OT
-in :any:`ot.unbalanced` and more specifically
-in function :any:`ot.sinkhorn_unbalanced`. A solver for unbalanced OT barycenter
-is available in :any:`ot.barycenter_unbalanced`.
+in :any:`ot.unbalanced`. Computing the optimal transport
+plan or the transport cost is similar to the balanced case. The Sinkhorn-Knopp
+algorithm is implemented in :any:`ot.sinkhorn_unbalanced` and :any:`ot.sinkhorn_unbalanced2`
+that return respectively the OT matrix and the value of the
+linear term.
+
+.. note::
+ The main function to solve entropic regularized UOT is :any:`ot.sinkhorn_unbalanced`.
+ This function is a wrapper and the parameter :code:`method` helps you select
+ the actual algorithm used to solve the problem:
+
+ + :code:`method='sinkhorn'` calls :any:`ot.unbalanced.sinkhorn_knopp_unbalanced`
+ the generalized Sinkhorn algorithm [25]_ [10]_.
+ + :code:`method='sinkhorn_stabilized'` calls :any:`ot.unbalanced.sinkhorn_stabilized_unbalanced`
+ the log stabilized version of the algorithm [10]_.
.. hint::
- Examples of the use of :any:`ot.sinkhorn_unbalanced` and
- :any:`ot.barycenter_unbalanced` are available in :
+ Examples of the use of :any:`ot.sinkhorn_unbalanced` are available in :
- :any:`auto_examples/plot_UOT_1D`
- - :any:`auto_examples/plot_UOT_barycenter_1D`
+
+
+Unbalanced Barycenters
+^^^^^^^^^^^^^^^^^^^^^^
+
+As with balanced distributions, we can define a barycenter of a set of
+histograms with different masses as a Fréchet Mean:
+
+ .. math::
+ \min_{\mu} \quad \sum_{k} w_kW_u(\mu,\mu_k)
+
+Where :math:`W_u` is the unbalanced Wasserstein metric defined above. This problem
+can also be solved using generalized version of Sinkhorn's algorithm and it is
+implemented the main function :any:`ot.barycenter_unbalanced`.
+
+
+.. note::
+ The main function to compute UOT barycenters is :any:`ot.barycenter_unbalanced`.
+ This function is a wrapper and the parameter :code:`method` help you select
+ the actual algorithm used to solve the problem:
+
+ + :code:`method='sinkhorn'` calls :any:`ot.unbalanced.barycenter_unbalanced_sinkhorn_unbalanced`
+ the generalized Sinkhorn algorithm [10]_.
+ + :code:`method='sinkhorn_stabilized'` calls :any:`ot.unbalanced.barycenter_unbalanced_stabilized`
+ the log stabilized version of the algorithm [10]_.
+
+
+.. hint::
+
+ Examples of the use of :any:`ot.barycenter_unbalanced` are available in :
+
+ - :any:`auto_examples/plot_UOT_barycenter_1D`
Gromov-Wasserstein
@@ -636,7 +678,7 @@ barycenters that can be expressed as
where :math:`Ck` is the distance matrix between samples in distribution
:math:`k`. Note that interestingly the barycenter is defined as a symmetric
-positive matrix. We provide a block coordinate optimization procedure in
+positive matrix. We provide a block coordinate optimization procedure in
:any:`ot.gromov.gromov_barycenters` and
:any:`ot.gromov.entropic_gromov_barycenters` for non-regularized and regularized
barycenters respectively.
@@ -654,19 +696,19 @@ The implementations of FGW and FGW barycenter is provided in functions
Examples of computation of GW, regularized G and FGW are available in :
- :any:`auto_examples/plot_gromov`
- - :any:`auto_examples/plot_fgw`
+ - :any:`auto_examples/plot_fgw`
Examples of GW, regularized GW and FGW barycenters are available in :
- :any:`auto_examples/plot_gromov_barycenter`
- - :any:`auto_examples/plot_barycenter_fgw`
+ - :any:`auto_examples/plot_barycenter_fgw`
GPU acceleration
^^^^^^^^^^^^^^^^
We provide several implementation of our OT solvers in :any:`ot.gpu`. Those
-implementations use the :code:`cupy` toolbox that obviously need to be installed.
+implementations use the :code:`cupy` toolbox that obviously need to be installed.
.. note::
@@ -701,7 +743,7 @@ FAQ
1. **How to solve a discrete optimal transport problem ?**
The solver for discrete OT is the function :py:mod:`ot.emd` that returns
- the OT transport matrix. If you want to solve a regularized OT you can
+ the OT transport matrix. If you want to solve a regularized OT you can
use :py:mod:`ot.sinkhorn`.
@@ -714,9 +756,9 @@ FAQ
T=ot.emd(a,b,M) # exact linear program
T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
- More detailed examples can be seen on this example:
+ More detailed examples can be seen on this example:
:doc:`auto_examples/plot_OT_2D_samples`
-
+
2. **pip install POT fails with error : ImportError: No module named Cython.Build**
@@ -726,7 +768,7 @@ FAQ
installing POT.
Note that this problem do not occur when using conda-forge since the packages
- there are pre-compiled.
+ there are pre-compiled.
See `Issue #59 <https://github.com/rflamary/POT/issues/59>`__ for more
details.
@@ -751,7 +793,7 @@ FAQ
In order to limit import time and hard dependencies in POT. we do not import
some sub-modules automatically with :code:`import ot`. In order to use the
acceleration in :any:`ot.gpu` you need first to import is with
- :code:`import ot.gpu`.
+ :code:`import ot.gpu`.
See `Issue #85 <https://github.com/rflamary/POT/issues/85>`__ and :any:`ot.gpu`
for more details.
@@ -763,7 +805,7 @@ References
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
December). `Displacement nterpolation using Lagrangian mass transport
<https://people.csail.mit.edu/sparis/publi/2011/sigasia/Bonneel_11_Displacement_Interpolation.pdf>`__.
- In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
+ In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
.. [2] Cuturi, M. (2013). `Sinkhorn distances: Lightspeed computation of
optimal transport <https://arxiv.org/pdf/1306.0895.pdf>`__. In Advances
@@ -874,4 +916,8 @@ References
.. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N.
(2019). `Optimal Transport for structured data with application on
graphs <http://proceedings.mlr.press/v97/titouan19a.html>`__ Proceedings
- of the 36th International Conference on Machine Learning (ICML). \ No newline at end of file
+ of the 36th International Conference on Machine Learning (ICML).
+
+.. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
diff --git a/examples/plot_barycenter_lp_vs_entropic.py b/examples/plot_barycenter_lp_vs_entropic.py
index b82765e..d7c72d0 100644
--- a/examples/plot_barycenter_lp_vs_entropic.py
+++ b/examples/plot_barycenter_lp_vs_entropic.py
@@ -102,7 +102,7 @@ pl.tight_layout()
problems.append([A, [bary_l2, bary_wass, bary_wass2]])
##############################################################################
-# Dirac Data
+# Stair Data
# ----------
#%% parameters
@@ -168,6 +168,11 @@ pl.legend()
pl.title('Barycenters')
pl.tight_layout()
+
+##############################################################################
+# Dirac Data
+# ----------
+
#%% parameters
a1 = np.zeros(n)
diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py
index b6efc59..64b89e4 100644
--- a/examples/plot_free_support_barycenter.py
+++ b/examples/plot_free_support_barycenter.py
@@ -62,7 +62,7 @@ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init,
pl.figure(1)
for (x_i, b_i) in zip(measures_locations, measures_weights):
color = np.random.randint(low=1, high=10 * N)
- pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure')
+ pl.scatter(x_i[:, 0], x_i[:, 1], s=b_i * 1000, label='input measure')
pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
pl.title('Data measures and their barycenter')
pl.legend(loc=0)
diff --git a/ot/__init__.py b/ot/__init__.py
index 571f235..89c7936 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -1,7 +1,7 @@
"""
-This is the main module of the POT toolbox. It provides easy access to
-a number of sub-modules and functions described below.
+This is the main module of the POT toolbox. It provides easy access to
+a number of sub-modules and functions described below.
.. note::
@@ -14,27 +14,27 @@ a number of sub-modules and functions described below.
- :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems.
- :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT
problems.
- - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov
+ - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov
Wasserstein problems.
- - :any:`ot.optim` contains generic solvers OT based optimization problems
+ - :any:`ot.optim` contains generic solvers OT based optimization problems
- :any:`ot.da` contains classes and function related to Monge mapping
estimation and Domain Adaptation (DA).
- :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers
- - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein
+ - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein
Discriminant Analysis.
- - :any:`ot.utils` contains utility functions such as distance computation and
- timing.
+ - :any:`ot.utils` contains utility functions such as distance computation and
+ timing.
- :any:`ot.datasets` contains toy dataset generation functions.
- :any:`ot.plot` contains visualization functions
- :any:`ot.stochastic` contains stochastic solvers for regularized OT.
- :any:`ot.unbalanced` contains solvers for regularized unbalanced OT.
.. warning::
- The list of automatically imported sub-modules is as follows:
+ The list of automatically imported sub-modules is as follows:
:py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
:py:mod:`ot.utils`, :py:mod:`ot.datasets`,
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
- :py:mod:`ot.stochastic`
+ :py:mod:`ot.stochastic`
The following sub-modules are not imported due to additional dependencies:
@@ -65,17 +65,17 @@ from . import unbalanced
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
from .bregman import sinkhorn, sinkhorn2, barycenter
-from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced
+from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
from .da import sinkhorn_lpl1_mm
# utils functions
from .utils import dist, unif, tic, toc, toq
-__version__ = "1.0.0"
+__version__ = "0.6.0"
-__all__ = ["emd", "emd2", 'emd_1d','emd2_1d', 'wasserstein_1d',
- "sinkhorn", "sinkhorn2", 'barycenter',
- 'sinkhorn_lpl1_mm',
- 'sinkhorn_unbalanced', "barycenter_unbalanced",
- 'dist', 'unif', 'tic', 'toc', 'toq',
- "utils", 'datasets', 'bregman', 'lp', 'gromov', 'da', 'optim']
+__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets',
+ 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d',
+ 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
+ 'sinkhorn_unbalanced', 'barycenter_unbalanced',
+ 'sinkhorn_unbalanced2']
diff --git a/ot/bregman.py b/ot/bregman.py
index 13dfa3b..2cd832b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -7,10 +7,12 @@ Bregman projections for regularized OT
# Nicolas Courty <ncourty@irisa.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+# Hicham Janati <hicham.janati@inria.fr>
#
# License: MIT License
import numpy as np
+import warnings
from .utils import unif, dist
@@ -31,21 +33,21 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -64,7 +66,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -103,30 +105,23 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
"""
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'greenkhorn':
- def sink():
- return greenkhorn(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log)
+ return greenkhorn(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log)
elif method.lower() == 'sinkhorn_stabilized':
- def sink():
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- def sink():
- return sinkhorn_epsilon_scaling(
- a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_epsilon_scaling(a, b, M, reg,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- print('Warning : unknown method using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
-
- return sink()
+ raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -146,21 +141,21 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -176,11 +171,10 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True
-
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
+ W : (n_hists) ndarray or float
+ Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -219,31 +213,23 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
"""
-
+ b = np.asarray(b, dtype=np.float64)
+ if len(b.shape) < 2:
+ b = b[:, None]
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
- def sink():
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- def sink():
- return sinkhorn_epsilon_scaling(
- a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- print('Warning : unknown method using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp(a, b, M, reg, **kwargs)
-
- b = np.asarray(b, dtype=np.float64)
- if len(b.shape) < 2:
- b = b[:, None]
-
- return sink()
+ raise ValueError("Unknown method '%s'." % method)
def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
@@ -263,21 +249,21 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -290,10 +276,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -333,25 +318,25 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
# init data
- Nini = len(a)
- Nfin = len(b)
+ dim_a = len(a)
+ dim_b = len(b)
if len(b.shape) > 1:
- nbb = b.shape[1]
+ n_hists = b.shape[1]
else:
- nbb = 0
+ n_hists = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
- if nbb:
- u = np.ones((Nini, nbb)) / Nini
- v = np.ones((Nfin, nbb)) / Nfin
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
else:
- u = np.ones(Nini) / Nini
- v = np.ones(Nfin) / Nfin
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
# print(reg)
@@ -386,13 +371,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- if nbb:
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ if n_hists:
+ np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
np.einsum('i,ij,j->j', u, K, v, out=tmp2)
- err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
+ err = np.linalg.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
@@ -406,7 +390,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['u'] = u
log['v'] = v
- if nbb: # return only loss
+ if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
@@ -421,7 +405,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
+def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
+ log=False):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -445,20 +430,20 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -469,10 +454,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -512,16 +496,16 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- n = a.shape[0]
- m = b.shape[0]
+ dim_a = a.shape[0]
+ dim_b = b.shape[0]
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty_like(M)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- u = np.full(n, 1. / n)
- v = np.full(m, 1. / m)
+ u = np.full(dim_a, 1. / dim_a)
+ v = np.full(dim_b, 1. / dim_b)
G = u[:, np.newaxis] * K * v[np.newaxis, :]
viol = G.sum(1) - a
@@ -575,7 +559,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
- warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
+ warmstart=None, verbose=False, print_period=20,
+ log=False, **kwargs):
r"""
Solve the entropic regularization OT problem with log stabilization
@@ -591,9 +576,10 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in [2]_ but with the log stabilization
@@ -602,11 +588,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (dim_b,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -623,10 +609,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -638,7 +623,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
- >>> ot.bregman.sinkhorn_stabilized(a,b,M,1)
+ >>> ot.bregman.sinkhorn_stabilized(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
@@ -671,14 +656,14 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# test if multiple target
if len(b.shape) > 1:
- nbb = b.shape[1]
+ n_hists = b.shape[1]
a = a[:, np.newaxis]
else:
- nbb = 0
+ n_hists = 0
# init data
- na = len(a)
- nb = len(b)
+ dim_a = len(a)
+ dim_b = len(b)
cpt = 0
if log:
@@ -687,24 +672,25 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(na), np.zeros(nb)
+ alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
else:
alpha, beta = warmstart
- if nbb:
- u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
else:
- u, v = np.ones(na) / na, np.ones(nb) / nb
+ u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1))
- - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((dim_a, 1))
+ - beta.reshape((1, dim_b))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb)))
- / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
+ return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
+ / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b))))
# print(np.min(K))
@@ -724,26 +710,29 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# remove numerical problems and store them in K
if np.abs(u).max() > tau or np.abs(v).max() > tau:
- if nbb:
+ if n_hists:
alpha, beta = alpha + reg * \
np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
else:
alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
- if nbb:
- u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb
+ if n_hists:
+ u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b
else:
- u, v = np.ones(na) / na, np.ones(nb) / nb
+ u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
K = get_K(alpha, beta)
if cpt % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- if nbb:
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ if n_hists:
+ err_u = abs(u - uprev).max()
+ err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max()
+ err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
else:
transp = get_Gamma(alpha, beta, u, v)
- err = np.linalg.norm((np.sum(transp, axis=0) - b))**2
+ err = np.linalg.norm((np.sum(transp, axis=0) - b))
if log:
log['err'].append(err)
@@ -769,33 +758,39 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
cpt = cpt + 1
- # print('err=',err,' cpt=',cpt)
if log:
- log['logu'] = alpha / reg + np.log(u)
- log['logv'] = beta / reg + np.log(v)
+ if n_hists:
+ alpha = alpha[:, None]
+ beta = beta[:, None]
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ log['logu'] = logu
+ log['logv'] = logv
log['alpha'] = alpha + reg * np.log(u)
log['beta'] = beta + reg * np.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
- if nbb:
- res = np.zeros((nbb))
- for i in range(nbb):
+ if n_hists:
+ res = np.zeros((n_hists))
+ for i in range(n_hists):
res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
return res, log
else:
return get_Gamma(alpha, beta, u, v), log
else:
- if nbb:
- res = np.zeros((nbb))
- for i in range(nbb):
+ if n_hists:
+ res = np.zeros((n_hists))
+ for i in range(n_hists):
res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
return res
else:
return get_Gamma(alpha, beta, u, v)
-def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100,
- tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
+def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
+ numInnerItermax=100, tau=1e3, stopThr=1e-9,
+ warmstart=None, verbose=False, print_period=10,
+ log=False, **kwargs):
r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
@@ -812,9 +807,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in [2]_ but with the log stabilization
@@ -823,19 +819,17 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (dim_b,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
thershold for max value in u or v for log scaling
- tau : float
- thershold for max value in u or v for log scaling
- warmstart : tible of vectors
+ warmstart : tuple of vectors
if given then sarting values for alpha an beta log scalings
numItermax : int, optional
Max number of iterations
@@ -850,10 +844,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -894,8 +887,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
# init data
- na = len(a)
- nb = len(b)
+ dim_a = len(a)
+ dim_b = len(b)
# nrelative umerical precision with 64 bits
numItermin = 35
@@ -908,14 +901,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(na), np.zeros(nb)
+ alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
else:
alpha, beta = warmstart
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1))
- - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((dim_a, 1))
+ - beta.reshape((1, dim_b))) / reg)
# print(np.min(K))
def get_reg(n): # exponential decreasing
@@ -928,8 +921,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
regi = get_reg(cpt)
- G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, warmstart=(
- alpha, beta), verbose=False, print_period=20, tau=tau, log=True)
+ G, logi = sinkhorn_stabilized(a, b, M, regi,
+ numItermax=numInnerItermax, stopThr=1e-9,
+ warmstart=(alpha, beta), verbose=False,
+ print_period=20, tau=tau, log=True)
alpha = logi['alpha']
beta = logi['beta']
@@ -987,8 +982,8 @@ def projC(gamma, q):
return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
-def barycenter(A, M, reg, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
+def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False, **kwargs):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
The function solves the following optimization problem:
@@ -1006,13 +1001,15 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Parameters
----------
- A : np.ndarray (d,n)
- n training distributions a_i of size d
- M : np.ndarray (d,d)
- loss matrix for OT
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
reg : float
- Regularization term >0
- weights : np.ndarray (n,)
+ Regularization term > 0
+ method : str (optional)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized'
+ weights : ndarray, shape (n_hists,)
Weights of each histogram a_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
@@ -1026,7 +1023,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : (dim,) ndarray
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1037,7 +1034,69 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ """
+
+ if method.lower() == 'sinkhorn':
+ return barycenter_sinkhorn(A, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return barycenter_stabilized(A, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ weights : ndarray, shape (n_hists,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
@@ -1083,7 +1142,138 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
-def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
+def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ with stabilization.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ weights : ndarray, shape (n_hists,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+ """
+
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ u = np.ones((dim, n_hists)) / dim
+ v = np.ones((dim, n_hists)) / dim
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim)
+ beta = np.zeros(dim)
+ q = np.ones(dim) / dim
+ while (err > stopThr and cpt < numItermax):
+ qprev = q
+ Kv = K.dot(v)
+ u = A / (Kv + 1e-16)
+ Ktu = K.T.dot(u)
+ q = geometricBar(weights, Ktu)
+ Q = q[:, None]
+ v = Q / (Ktu + 1e-16)
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ q = qprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u * Kv - A).max()
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg`" +
+ "Or a larger absorption threshold `tau`.")
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
+def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-9, stabThr=1e-30, verbose=False,
+ log=False):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images.
@@ -1102,16 +1292,16 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
Parameters
----------
- A : np.ndarray (n,w,h)
- n distributions (2D images) of size w x h
+ A : ndarray, shape (n_hists, width, height)
+ n distributions (2D images) of size width x height
reg : float
Regularization term >0
- weights : np.ndarray (n,)
+ weights : ndarray, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
@@ -1119,15 +1309,13 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
log : bool, optional
record log if True
-
Returns
-------
- a : (w,h) ndarray
+ a : ndarray, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
-
References
----------
@@ -1207,9 +1395,12 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
where :
- :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}` is an observed distribution, :math:`\mathbf{h}_0` is aprior on unmixing
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT data fitting
- - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix for regularization
+ - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)`
+ - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
+ - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
+ - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting
+ - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization
- :math:`\\alpha`weight data fitting and regularization
The optimization problem is solved suing the algorithm described in [4]
@@ -1217,16 +1408,16 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Parameters
----------
- a : np.ndarray (d)
- observed distribution
- D : np.ndarray (d,n)
+ a : ndarray, shape (dim_a)
+ observed distribution (histogram, sums to 1)
+ D : ndarray, shape (dim_a, n_atoms)
dictionary matrix
- M : np.ndarray (d,d)
+ M : ndarray, shape (dim_a, dim_a)
loss matrix
- M0 : np.ndarray (n,n)
+ M0 : ndarray, shape (n_atoms, dim_prior)
loss matrix
- h0 : np.ndarray (n,)
- prior on h
+ h0 : ndarray, shape (n_atoms,)
+ prior on the estimated unmixing h
reg : float
Regularization term >0 (Wasserstein data fitting)
reg0 : float
@@ -1245,7 +1436,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ h : ndarray, shape (n_atoms,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1301,7 +1492,9 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
return np.sum(K0, axis=1)
-def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9, verbose=False,
+ log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -1318,22 +1511,22 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
\gamma\geq 0
where :
- - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (n_samples_a,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1347,7 +1540,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1355,11 +1548,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Examples
--------
- >>> n_s = 2
- >>> n_t = 2
+ >>> n_samples_a = 2
+ >>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
array([[4.99977301e-01, 2.26989344e-05],
[2.26989344e-05, 4.99977301e-01]])
@@ -1408,22 +1601,22 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
\gamma\geq 0
where :
- - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (n_samples_a,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1437,7 +1630,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1445,11 +1638,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Examples
--------
- >>> n_s = 2
- >>> n_t = 2
+ >>> n_samples_a = 2
+ >>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
array([4.53978687e-05])
@@ -1516,22 +1709,22 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
\gamma_b\geq 0
where :
- - :math:`M` (resp. :math:`M_a, M_b`) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt))
+ - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b))
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (n_samples_a,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1542,29 +1735,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
-
- >>> n_s = 2
- >>> n_t = 4
+ >>> n_samples_a = 2
+ >>> n_samples_b = 4
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS
array([1.499...])
References
----------
-
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
'''
if log:
diff --git a/ot/da.py b/ot/da.py
index 83f9027..108a38d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -6,6 +6,7 @@ Domain adaptation with optimal transport
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
# Michael Perrot <michael.perrot@univ-st-etienne.fr>
+# Nathalie Gayraud <nat.gayraud@gmail.com>
#
# License: MIT License
@@ -16,6 +17,7 @@ from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
from .utils import check_params, BaseEstimator
+from .unbalanced import sinkhorn_unbalanced
from .optim import cg
from .optim import gcg
@@ -1793,3 +1795,122 @@ class MappingTransport(BaseEstimator):
transp_Xs = K.dot(self.mapping_)
return transp_Xs
+
+
+class UnbalancedSinkhornTransport(BaseTransport):
+
+ """Domain Adapatation unbalanced OT method based on sinkhorn algorithm
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ reg_m : float, optional (default=0.1)
+ Mass regularization parameter
+ method : str
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
+ 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ max_iter : int, float, optional (default=10)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ tol : float, optional (default=10e-9)
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
+ log : bool, optional (default=False)
+ Controls the logs of the optimization algorithm
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ norm : string, optional (default=None)
+ If given, normalize the ground metric to avoid numerical errors that
+ can occur with large metric values.
+ distribution_estimation : callable, optional (defaults to the uniform)
+ The kind of distribution estimation to employ
+ out_of_sample_map : string, optional (default="ferradans")
+ The kind of out of sample mapping to apply to transport samples
+ from a domain into another one. Currently the only possible option is
+ "ferradans" which uses the method proposed in [6].
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+
+ .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
+ """
+
+ def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
+ max_iter=10, tol=1e-9, verbose=False, log=False,
+ metric="sqeuclidean", norm=None,
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=10):
+
+ self.reg_e = reg_e
+ self.reg_m = reg_m
+ self.method = method
+ self.max_iter = max_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.metric = metric
+ self.norm = norm
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
+ super(UnbalancedSinkhornTransport, self).fit(Xs, ys, Xt, yt)
+
+ returned_ = sinkhorn_unbalanced(
+ a=self.mu_s, b=self.mu_t, M=self.cost_,
+ reg=self.reg_e, reg_m=self.reg_m, method=self.method,
+ numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
+ return self
diff --git a/ot/datasets.py b/ot/datasets.py
index e76e75d..ba0cfd9 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -17,7 +17,6 @@ def make_1D_gauss(n, m, s):
Parameters
----------
-
n : int
number of bins in the histogram
m : float
@@ -25,12 +24,10 @@ def make_1D_gauss(n, m, s):
s : float
standard deviaton of the gaussian distribution
-
Returns
-------
- h : np.array (n,)
- 1D histogram for a gaussian distribution
-
+ h : ndarray (n,)
+ 1D histogram for a gaussian distribution
"""
x = np.arange(n, dtype=np.float64)
h = np.exp(-(x - m)**2 / (2 * s**2))
@@ -44,16 +41,15 @@ def get_1D_gauss(n, m, sigma):
def make_2D_samples_gauss(n, m, sigma, random_state=None):
- """return n samples drawn from 2D gaussian N(m,sigma)
+ """Return n samples drawn from 2D gaussian N(m,sigma)
Parameters
----------
-
n : int
number of samples to make
- m : np.array (2,)
+ m : ndarray, shape (2,)
mean value of the gaussian distribution
- sigma : np.array (2,2)
+ sigma : ndarray, shape (2, 2)
covariance matrix of the gaussian distribution
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
@@ -63,9 +59,8 @@ def make_2D_samples_gauss(n, m, sigma, random_state=None):
Returns
-------
- X : np.array (n,2)
- n samples drawn from N(m,sigma)
-
+ X : ndarray, shape (n, 2)
+ n samples drawn from N(m, sigma).
"""
generator = check_random_state(random_state)
@@ -86,11 +81,10 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None):
def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
- """ dataset generation for classification problems
+ """Dataset generation for classification problems
Parameters
----------
-
dataset : str
type of classification problem (see code)
n : int
@@ -105,13 +99,11 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
Returns
-------
- X : np.array (n,d)
- n observation of size d
- y : np.array (n,)
- labels of the samples
-
+ X : ndarray, shape (n, d)
+ n observation of size d
+ y : ndarray, shape (n,)
+ labels of the samples.
"""
-
generator = check_random_state(random_state)
if dataset.lower() == '3gauss':
diff --git a/ot/dr.py b/ot/dr.py
index d2bf6e2..680dabf 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -49,30 +49,25 @@ def split_classes(X, y):
def fda(X, y, p=2, reg=1e-16):
- """
- Fisher Discriminant Analysis
-
+ """Fisher Discriminant Analysis
Parameters
----------
- X : numpy.ndarray (n,d)
- Training samples
- y : np.ndarray (n,)
- labels for training samples
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
p : int, optional
- size of dimensionnality reduction
+ Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (ridge regularization)
-
Returns
-------
- P : (d x p) ndarray
+ P : ndarray, shape (d, p)
Optimal transportation matrix for the given parameters
- proj : fun
+ proj : callable
projection function including mean centering
-
-
"""
mx = np.mean(X)
@@ -130,37 +125,33 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
Parameters
----------
- X : numpy.ndarray (n,d)
- Training samples
- y : np.ndarray (n,)
- labels for training samples
+ X : ndarray, shape (n, d)
+ Training samples.
+ y : ndarray, shape (n,)
+ Labels for training samples.
p : int, optional
- size of dimensionnality reduction
+ Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (entropic regularization)
- solver : str, optional
- None for steepest decsent or 'TrustRegions' for trust regions algorithm
- else shoudl be a pymanopt.solvers
- P0 : numpy.ndarray (d,p)
- Initial starting point for projection
+ solver : None | str, optional
+ None for steepest descent or 'TrustRegions' for trust regions algorithm
+ else should be a pymanopt.solvers
+ P0 : ndarray, shape (d, p)
+ Initial starting point for projection.
verbose : int, optional
- Print information along iterations
-
-
+ Print information along iterations.
Returns
-------
- P : (d x p) ndarray
+ P : ndarray, shape (d, p)
Optimal transportation matrix for the given parameters
- proj : fun
- projection function including mean centering
-
+ proj : callable
+ Projection function including mean centering.
References
----------
-
- .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
-
+ .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+ Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
""" # noqa
mx = np.mean(X)
diff --git a/ot/gromov.py b/ot/gromov.py
index 3a7e24c..699ae4c 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -1,9 +1,6 @@
-
# -*- coding: utf-8 -*-
"""
Gromov-Wasserstein transport method
-
-
"""
# Author: Erwan Vautier <erwan.vautier@gmail.com>
@@ -22,7 +19,7 @@ from .optim import cg
def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
- """ Return loss matrices and tensors for Gromov-Wasserstein fast computation
+ """Return loss matrices and tensors for Gromov-Wasserstein fast computation
Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
function as the loss function of Gromow-Wasserstein discrepancy.
@@ -51,23 +48,21 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
Parameters
----------
C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
+ Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
+ Metric costfr matrix in the target space
T : ndarray, shape (ns, nt)
- Coupling between source and target spaces
+ Coupling between source and target spaces
p : ndarray, shape (ns,)
-
Returns
-------
-
constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
+ Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ h2(C) matrix in Eq. (6)
References
----------
@@ -114,25 +109,23 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
def tensor_product(constC, hC1, hC2, T):
- """ Return the tensor for Gromov-Wasserstein fast computation
+ """Return the tensor for Gromov-Wasserstein fast computation
The tensor is computed as described in Proposition 1 Eq. (6) in [12].
Parameters
----------
constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
+ Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
-
+ h2(C) matrix in Eq. (6)
Returns
-------
-
tens : ndarray, shape (ns, nt)
- \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
+ \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
References
----------
@@ -148,26 +141,25 @@ def tensor_product(constC, hC1, hC2, T):
def gwloss(constC, hC1, hC2, T):
- """ Return the Loss for Gromov-Wasserstein
+ """Return the Loss for Gromov-Wasserstein
The loss is computed as described in Proposition 1 Eq. (6) in [12].
Parameters
----------
constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
+ Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ h2(C) matrix in Eq. (6)
T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ Current value of transport matrix T
Returns
-------
-
loss : float
- Gromov Wasserstein loss
+ Gromov Wasserstein loss
References
----------
@@ -183,24 +175,23 @@ def gwloss(constC, hC1, hC2, T):
def gwggrad(constC, hC1, hC2, T):
- """ Return the gradient for Gromov-Wasserstein
+ """Return the gradient for Gromov-Wasserstein
The gradient is computed as described in Proposition 2 in [12].
Parameters
----------
constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
+ Constant C matrix in Eq. (6)
hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
+ h1(C1) matrix in Eq. (6)
hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ h2(C) matrix in Eq. (6)
T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ Current value of transport matrix T
Returns
-------
-
grad : ndarray, shape (ns, nt)
Gromov Wasserstein gradient
@@ -222,19 +213,19 @@ def update_square_loss(p, lambdas, T, Cs):
Parameters
----------
- p : ndarray, shape (N,)
- masses in the targeted barycenter
+ p : ndarray, shape (N,)
+ Masses in the targeted barycenter.
lambdas : list of float
- list of the S spaces' weights
- T : list of S np.ndarray(ns,N)
- the S Ts couplings calculated at each iteration
+ List of the S spaces' weights.
+ T : list of S np.ndarray of shape (ns,N)
+ The S Ts couplings calculated at each iteration.
Cs : list of S ndarray, shape(ns,ns)
- Metric cost matrices
+ Metric cost matrices.
Returns
----------
- C : ndarray, shape (nt,nt)
- updated C matrix
+ C : ndarray, shape (nt, nt)
+ Updated C matrix.
"""
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
for s in range(len(T))])
@@ -251,12 +242,12 @@ def update_kl_loss(p, lambdas, T, Cs):
Parameters
----------
p : ndarray, shape (N,)
- weights in the targeted barycenter
+ Weights in the targeted barycenter.
lambdas : list of the S spaces' weights
- T : list of S np.ndarray(ns,N)
- the S Ts couplings calculated at each iteration
+ T : list of S np.ndarray of shape (ns,N)
+ The S Ts couplings calculated at each iteration.
Cs : list of S ndarray, shape(ns,ns)
- Metric cost matrices
+ Metric cost matrices.
Returns
----------
@@ -290,14 +281,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
Parameters
----------
C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
+ Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- p : ndarray, shape (ns,)
- distribution in the source space
- q : ndarray, shape (nt,)
- distribution in the target space
- loss_fun : string
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
max_iter : int, optional
@@ -317,10 +308,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
Returns
-------
T : ndarray, shape (ns, nt)
- coupling between the two spaces that minimizes :
+ Doupling between the two spaces that minimizes:
\sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
log : dict
- convergence information and loss
+ Convergence information and loss.
References
----------
@@ -374,18 +365,18 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
Parameters
----------
- M : ndarray, shape (ns, nt)
- Metric cost matrix between features across domains
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
C1 : ndarray, shape (ns, ns)
- Metric cost matrix representative of the structure in the source space
+ Metric cost matrix representative of the structure in the source space
C2 : ndarray, shape (nt, nt)
- Metric cost matrix representative of the structure in the target space
- p : ndarray, shape (ns,)
- distribution in the source space
- q : ndarray, shape (nt,)
- distribution in the target space
- loss_fun : string,optional
- loss function used for the solver
+ Metric cost matrix representative of the structure in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : str, optional
+ Loss function used for the solver
max_iter : int, optional
Max number of iterations
tol : float, optional
@@ -402,11 +393,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
Returns
-------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
+ gamma : ndarray, shape (ns, nt)
+ Optimal transportation matrix for the given parameters.
log : dict
- log dictionary return only if log==True in parameters
-
+ Log dictionary return only if log==True in parameters.
References
----------
@@ -414,7 +404,6 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
and Courty Nicolas "Optimal Transport for structured data with
application on graphs", International Conference on Machine Learning
(ICML). 2019.
-
"""
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -457,18 +446,18 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
Parameters
----------
- M : ndarray, shape (ns, nt)
- Metric cost matrix between features across domains
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
C1 : ndarray, shape (ns, ns)
- Metric cost matrix respresentative of the structure in the source space
+ Metric cost matrix respresentative of the structure in the source space.
C2 : ndarray, shape (nt, nt)
- Metric cost matrix espresentative of the structure in the target space
+ Metric cost matrix espresentative of the structure in the target space.
p : ndarray, shape (ns,)
- distribution in the source space
+ Distribution in the source space.
q : ndarray, shape (nt,)
- distribution in the target space
- loss_fun : string,optional
- loss function used for the solver
+ Distribution in the target space.
+ loss_fun : str, optional
+ Loss function used for the solver.
max_iter : int, optional
Max number of iterations
tol : float, optional
@@ -476,19 +465,19 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
verbose : bool, optional
Print information along iterations
log : bool, optional
- record log if True
+ Record log if True.
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
+ If True the steps of the line-search is found via an armijo research.
+ Else closed form is used. If there is convergence issues use False.
**kwargs : dict
- parameters can be directly pased to the ot.optim.cg solver
+ Parameters can be directly pased to the ot.optim.cg solver.
Returns
-------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
+ gamma : ndarray, shape (ns, nt)
+ Optimal transportation matrix for the given parameters.
log : dict
- log dictionary return only if log==True in parameters
+ Log dictionary return only if log==True in parameters.
References
----------
@@ -537,16 +526,15 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
Parameters
----------
C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
+ Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
- Metric cost matrix in the target space
- p : ndarray, shape (ns,)
- distribution in the source space
+ Metric cost matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space.
q : ndarray, shape (nt,)
- distribution in the target space
- loss_fun : string
+ Distribution in the target space.
+ loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
-
max_iter : int, optional
Max number of iterations
tol : float, optional
@@ -558,6 +546,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
armijo : bool, optional
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
+
Returns
-------
gw_dist : float
@@ -624,25 +613,25 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
Parameters
----------
C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
+ Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
+ Metric costfr matrix in the target space
p : ndarray, shape (ns,)
- distribution in the source space
+ Distribution in the source space
q : ndarray, shape (nt,)
- distribution in the target space
+ Distribution in the target space
loss_fun : string
- loss function used for the solver either 'square_loss' or 'kl_loss'
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
max_iter : int, optional
- Max number of iterations
+ Max number of iterations
tol : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
- record log if True
+ Record log if True.
Returns
-------
@@ -725,15 +714,15 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
Parameters
----------
C1 : ndarray, shape (ns, ns)
- Metric cost matrix in the source space
+ Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
+ Metric costfr matrix in the target space
p : ndarray, shape (ns,)
- distribution in the source space
+ Distribution in the source space
q : ndarray, shape (nt,)
- distribution in the target space
- loss_fun : string
- loss function used for the solver either 'square_loss' or 'kl_loss'
+ Distribution in the target space
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
epsilon : float
Regularization term >0
max_iter : int, optional
@@ -743,7 +732,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
verbose : bool, optional
Print information along iterations
log : bool, optional
- record log if True
+ Record log if True.
Returns
-------
@@ -757,7 +746,6 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
International Conference on Machine Learning (ICML). 2016.
"""
-
gw, logv = entropic_gromov_wasserstein(
C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True)
@@ -789,19 +777,21 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
Parameters
----------
- N : Integer
- Size of the targeted barycenter
- Cs : list of S np.ndarray(ns,ns)
- Metric cost matrices
- ps : list of S np.ndarray(ns,)
- sample weights in the S spaces
- p : ndarray, shape(N,)
- weights in the targeted barycenter
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray of shape (ns,ns)
+ Metric cost matrices
+ ps : list of S np.ndarray of shape (ns,)
+ Sample weights in the S spaces
+ p : ndarray, shape(N,)
+ Weights in the targeted barycenter
lambdas : list of float
- list of the S spaces' weights
- loss_fun : tensor-matrix multiplication function based on specific loss function
- update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
- with the S Ts couplings calculated at each iteration
+ List of the S spaces' weights.
+ loss_fun : callable
+ Tensor-matrix multiplication function based on specific loss function.
+ update : callable
+ function(p,lambdas,T,Cs) that updates C according to a specific Kernel
+ with the S Ts couplings calculated at each iteration
epsilon : float
Regularization term >0
max_iter : int, optional
@@ -809,11 +799,11 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
tol : float, optional
Stop threshol on error (>0)
verbose : bool, optional
- Print information along iterations
+ Print information along iterations.
log : bool, optional
- record log if True
- init_C : bool, ndarray, shape(N,N)
- random initial value for the C matrix provided by user
+ Record log if True.
+ init_C : bool | ndarray, shape (N, N)
+ Random initial value for the C matrix provided by user.
Returns
-------
@@ -825,7 +815,6 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
-
"""
S = len(Cs)
@@ -835,6 +824,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
+ # XXX use random state
xalea = np.random.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
@@ -846,7 +836,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
error = []
- while(err > tol and cpt < max_iter):
+ while (err > tol) and (cpt < max_iter):
Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
@@ -890,7 +880,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
.. math::
C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
-
Where :
- Cs : metric cost matrix
@@ -898,29 +887,29 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
Parameters
----------
- N : Integer
- Size of the targeted barycenter
- Cs : list of S np.ndarray(ns,ns)
- Metric cost matrices
- ps : list of S np.ndarray(ns,)
- sample weights in the S spaces
- p : ndarray, shape(N,)
- weights in the targeted barycenter
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S np.ndarray of shape (ns, ns)
+ Metric cost matrices
+ ps : list of S np.ndarray of shape (ns,)
+ Sample weights in the S spaces
+ p : ndarray, shape (N,)
+ Weights in the targeted barycenter
lambdas : list of float
- list of the S spaces' weights
+ List of the S spaces' weights
loss_fun : tensor-matrix multiplication function based on specific loss function
update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
with the S Ts couplings calculated at each iteration
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (>0).
verbose : bool, optional
- Print information along iterations
+ Print information along iterations.
log : bool, optional
- record log if True
- init_C : bool, ndarray, shape(N,N)
- random initial value for the C matrix provided by user
+ Record log if True.
+ init_C : bool | ndarray, shape(N,N)
+ Random initial value for the C matrix provided by user.
Returns
-------
@@ -934,7 +923,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
International Conference on Machine Learning (ICML). 2016.
"""
-
S = len(Cs)
Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
@@ -942,6 +930,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
+ # XXX : should use a random state and not use the global seed
xalea = np.random.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
@@ -987,8 +976,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
verbose=False, log=False, init_C=None, init_X=None):
- """
- Compute the fgw barycenter as presented eq (5) in [24].
+ """Compute the fgw barycenter as presented eq (5) in [24].
Parameters
----------
@@ -997,30 +985,32 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Ys: list of ndarray, each element has shape (ns,d)
Features of all samples
Cs : list of ndarray, each element has shape (ns,ns)
- Structure matrices of all samples
+ Structure matrices of all samples
ps : list of ndarray, each element has shape (ns,)
- masses of all samples
+ Masses of all samples.
lambdas : list of float
- list of the S spaces' weights
+ List of the S spaces' weights
alpha : float
- Alpha parameter for the fgw distance
- fixed_structure : bool
- Wether to fix the structure of the barycenter during the updates
- fixed_features : bool
- Wether to fix the feature of the barycenter during the updates
- init_C : ndarray, shape (N,N), optional
- initialization for the barycenters' structure matrix. If not set random init
- init_X : ndarray, shape (N,d), optional
- initialization for the barycenters' features. If not set random init
+ Alpha parameter for the fgw distance
+ fixed_structure : bool
+ Whether to fix the structure of the barycenter during the updates
+ fixed_features : bool
+ Whether to fix the feature of the barycenter during the updates
+ init_C : ndarray, shape (N,N), optional
+ Initialization for the barycenters' structure matrix. If not set
+ a random init is used.
+ init_X : ndarray, shape (N,d), optional
+ Initialization for the barycenters' features. If not set a
+ random init is used.
Returns
-------
- X : ndarray, shape (N,d)
+ X : ndarray, shape (N, d)
Barycenters' features
- C : ndarray, shape (N,N)
+ C : ndarray, shape (N, N)
Barycenters' structure matrix
- log_: dictionary
- Only returned when log=True
+ log_: dict
+ Only returned when log=True. It contains the keys:
T : list of (N,ns) transport matrices
Ms : all distance matrices between the feature of the barycenter and the
other features dist(X,Ys) shape (N,ns)
@@ -1032,7 +1022,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
-
S = len(Cs)
d = Ys[0].shape[1] # dimension on the node features
if p is None:
@@ -1095,7 +1084,8 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
- T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
+ T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
+ numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
@@ -1114,6 +1104,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
print('{:5d}|{:8e}|'.format(cpt, err_feature))
cpt += 1
+
if log:
log_['T'] = T # from target to Ys
log_['p'] = p
@@ -1126,25 +1117,25 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
def update_sructure_matrix(p, lambdas, T, Cs):
- """
- Updates C according to the L2 Loss kernel with the S Ts couplings
- calculated at each iteration
+ """Updates C according to the L2 Loss kernel with the S Ts couplings.
+
+ It is calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
- masses in the targeted barycenter
+ p : ndarray, shape (N,)
+ Masses in the targeted barycenter.
lambdas : list of float
- list of the S spaces' weights
- T : list of S np.ndarray(ns,N)
- the S Ts couplings calculated at each iteration
- Cs : list of S ndarray, shape(ns,ns)
- Metric cost matrices
+ List of the S spaces' weights.
+ T : list of S ndarray of shape (ns, N)
+ The S Ts couplings calculated at each iteration.
+ Cs : list of S ndarray, shape (ns, ns)
+ Metric cost matrices.
Returns
- ----------
- C : ndarray, shape (nt,nt)
- updated C matrix
+ -------
+ C : ndarray, shape (nt, nt)
+ Updated C matrix.
"""
tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
ppt = np.outer(p, p)
@@ -1153,24 +1144,26 @@ def update_sructure_matrix(p, lambdas, T, Cs):
def update_feature_matrix(lambdas, Ys, Ts, p):
- """
- Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [24]
- calculated at each iteration
+ """Updates the feature with respect to the S Ts couplings.
+
+
+ See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
+ in [24] calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
- masses in the targeted barycenter
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
lambdas : list of float
- list of the S spaces' weights
+ List of the S spaces' weights
Ts : list of S np.ndarray(ns,N)
the S Ts couplings calculated at each iteration
Ys : list of S ndarray, shape(d,ns)
- The features
+ The features.
Returns
- ----------
- X : ndarray, shape (d,N)
+ -------
+ X : ndarray, shape (d, N)
References
----------
@@ -1179,9 +1172,8 @@ def update_feature_matrix(lambdas, Ys, Ts, p):
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
+ p = np.array(1. / p).reshape(-1,)
- p = np.diag(np.array(1 / p).reshape(-1,))
-
- tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T).dot(p) for s in range(len(Ts))])
+ tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))])
return tmpsum
diff --git a/ot/optim.py b/ot/optim.py
index f94aceb..0abd9e9 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -26,14 +26,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
Parameters
----------
-
- f : function
+ f : callable
loss function
- xk : np.ndarray
+ xk : ndarray
initial position
- pk : np.ndarray
+ pk : ndarray
descent direction
- gfk : np.ndarray
+ gfk : ndarray
gradient of f at xk
old_fval : float
loss value at xk
@@ -161,15 +160,15 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
- G0 : np.ndarray (ns,nt), optional
+ G0 : ndarray, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
@@ -299,17 +298,17 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarrayv (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg1 : float
Entropic Regularization term >0
reg2 : float
Second Regularization term >0
- G0 : np.ndarray (ns,nt), optional
+ G0 : ndarray, shape (ns, nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
@@ -326,15 +325,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
-
References
----------
-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
@@ -422,13 +419,12 @@ def solve_1d_linesearch_quad(a, b, c):
Parameters
----------
a,b,c : float
- The coefficients of the quadratic function
+ The coefficients of the quadratic function
Returns
-------
x : float
The optimal value which leads to the minimal cost
-
"""
f0 = c
df0 = b
diff --git a/ot/plot.py b/ot/plot.py
index a409d4a..f403e98 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -26,11 +26,11 @@ def plot1D_mat(a, b, M, title=''):
Parameters
----------
- a : np.array, shape (na,)
+ a : ndarray, shape (na,)
Source distribution
- b : np.array, shape (nb,)
+ b : ndarray, shape (nb,)
Target distribution
- M : np.array, shape (na,nb)
+ M : ndarray, shape (na, nb)
Matrix to plot
"""
na, nb = M.shape
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 5754968..13ed9cc 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -38,22 +38,20 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
Parameters
----------
-
- b : np.ndarray(nt,)
- target measure
- M : np.ndarray(ns, nt)
- cost matrix
- reg : float nu
- Regularization term > 0
- v : np.ndarray(nt,)
- dual variable
- i : number int
- picked number i
+ b : ndarray, shape (nt,)
+ Target measure.
+ M : ndarray, shape (ns, nt)
+ Cost matrix.
+ reg : float
+ Regularization term > 0.
+ v : ndarray, shape (nt,)
+ Dual variable.
+ i : int
+ Picked number i.
Returns
-------
-
- coordinate gradient : np.ndarray(nt,)
+ coordinate gradient : ndarray, shape (nt,)
Examples
--------
@@ -78,14 +76,11 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
References
----------
-
[Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
-
+ Stochastic Optimization for Large-scale Optimal Transport,
+ Advances in Neural Information Processing Systems (2016),
+ arXiv preprint arxiv:1605.08527.
'''
-
r = M[i, :] - beta
exp_beta = np.exp(-r / reg) * b
khi = exp_beta / (np.sum(exp_beta))
@@ -121,24 +116,23 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
Parameters
----------
- a : np.ndarray(ns,),
- source measure
- b : np.ndarray(nt,),
- target measure
- M : np.ndarray(ns, nt),
- cost matrix
- reg : float number,
+ a : ndarray, shape (ns,),
+ Source measure.
+ b : ndarray, shape (nt,),
+ Target measure.
+ M : ndarray, shape (ns, nt),
+ Cost matrix.
+ reg : float
Regularization term > 0
- numItermax : int number
- number of iteration
- lr : float number
- learning rate
+ numItermax : int
+ Number of iteration.
+ lr : float
+ Learning rate.
Returns
-------
-
- v : np.ndarray(nt,)
- dual variable
+ v : ndarray, shape (nt,)
+ Dual variable.
Examples
--------
@@ -213,23 +207,20 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
Parameters
----------
-
- b : np.ndarray(nt,)
+ b : ndarray, shape (nt,)
target measure
- M : np.ndarray(ns, nt)
+ M : ndarray, shape (ns, nt)
cost matrix
- reg : float number
+ reg : float
Regularization term > 0
- numItermax : int number
- number of iteration
- lr : float number
- learning rate
-
+ numItermax : int
+ Number of iteration.
+ lr : float
+ Learning rate.
Returns
-------
-
- ave_v : np.ndarray(nt,)
+ ave_v : ndarray, shape (nt,)
dual variable
Examples
@@ -256,9 +247,9 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
----------
[Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ Stochastic Optimization for Large-scale Optimal Transport,
+ Advances in Neural Information Processing Systems (2016),
+ arXiv preprint arxiv:1605.08527.
'''
if lr is None:
@@ -298,21 +289,19 @@ def c_transform_entropic(b, M, reg, beta):
Parameters
----------
-
- b : np.ndarray(nt,)
- target measure
- M : np.ndarray(ns, nt)
- cost matrix
+ b : ndarray, shape (nt,)
+ Target measure
+ M : ndarray, shape (ns, nt)
+ Cost matrix
reg : float
- regularization term > 0
- v : np.ndarray(nt,)
- dual variable
+ Regularization term > 0
+ v : ndarray, shape (nt,)
+ Dual variable.
Returns
-------
-
- u : np.ndarray(ns,)
- dual variable
+ u : ndarray, shape (ns,)
+ Dual variable.
Examples
--------
@@ -338,9 +327,9 @@ def c_transform_entropic(b, M, reg, beta):
----------
[Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ Stochastic Optimization for Large-scale Optimal Transport,
+ Advances in Neural Information Processing Systems (2016),
+ arXiv preprint arxiv:1605.08527.
'''
n_source = np.shape(M)[0]
@@ -382,31 +371,30 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
Parameters
----------
- a : np.ndarray(ns,)
+ a : ndarray, shape (ns,)
source measure
- b : np.ndarray(nt,)
+ b : ndarray, shape (nt,)
target measure
- M : np.ndarray(ns, nt)
+ M : ndarray, shape (ns, nt)
cost matrix
- reg : float number
+ reg : float
Regularization term > 0
methode : str
used method (SAG or ASGD)
- numItermax : int number
+ numItermax : int
number of iteration
- lr : float number
+ lr : float
learning rate
- n_source : int number
+ n_source : int
size of the source measure
- n_target : int number
+ n_target : int
size of the target measure
log : bool, optional
record log if True
Returns
-------
-
- pi : np.ndarray(ns, nt)
+ pi : ndarray, shape (ns, nt)
transportation matrix
log : dict
log dictionary return only if log==True in parameters
@@ -495,30 +483,28 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
Parameters
----------
-
- a : np.ndarray(ns,)
+ a : ndarray, shape (ns,)
source measure
- b : np.ndarray(nt,)
+ b : ndarray, shape (nt,)
target measure
- M : np.ndarray(ns, nt)
+ M : ndarray, shape (ns, nt)
cost matrix
- reg : float number
+ reg : float
Regularization term > 0
- alpha : np.ndarray(ns,)
+ alpha : ndarray, shape (ns,)
dual variable
- beta : np.ndarray(nt,)
+ beta : ndarray, shape (nt,)
dual variable
- batch_size : int number
+ batch_size : int
size of the batch
- batch_alpha : np.ndarray(bs,)
+ batch_alpha : ndarray, shape (bs,)
batch of index of alpha
- batch_beta : np.ndarray(bs,)
+ batch_beta : ndarray, shape (bs,)
batch of index of beta
Returns
-------
-
- grad : np.ndarray(ns,)
+ grad : ndarray, shape (ns,)
partial grad F
Examples
@@ -591,28 +577,26 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
Parameters
----------
-
- a : np.ndarray(ns,)
+ a : ndarray, shape (ns,)
source measure
- b : np.ndarray(nt,)
+ b : ndarray, shape (nt,)
target measure
- M : np.ndarray(ns, nt)
+ M : ndarray, shape (ns, nt)
cost matrix
- reg : float number
+ reg : float
Regularization term > 0
- batch_size : int number
+ batch_size : int
size of the batch
- numItermax : int number
+ numItermax : int
number of iteration
- lr : float number
+ lr : float
learning rate
Returns
-------
-
- alpha : np.ndarray(ns,)
+ alpha : ndarray, shape (ns,)
dual variable
- beta : np.ndarray(nt,)
+ beta : ndarray, shape (nt,)
dual variable
Examples
@@ -648,10 +632,9 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
References
----------
-
[Seguy et al., 2018] :
- International Conference on Learning Representation (2018),
- arXiv preprint arxiv:1711.02283.
+ International Conference on Learning Representation (2018),
+ arXiv preprint arxiv:1711.02283.
'''
n_source = np.shape(M)[0]
@@ -696,28 +679,26 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
Parameters
----------
-
- a : np.ndarray(ns,)
+ a : ndarray, shape (ns,)
source measure
- b : np.ndarray(nt,)
+ b : ndarray, shape (nt,)
target measure
- M : np.ndarray(ns, nt)
+ M : ndarray, shape (ns, nt)
cost matrix
- reg : float number
+ reg : float
Regularization term > 0
- batch_size : int number
+ batch_size : int
size of the batch
- numItermax : int number
+ numItermax : int
number of iteration
- lr : float number
+ lr : float
learning rate
log : bool, optional
record log if True
Returns
-------
-
- pi : np.ndarray(ns, nt)
+ pi : ndarray, shape (ns, nt)
transportation matrix
log : dict
log dictionary return only if log==True in parameters
@@ -757,8 +738,8 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
----------
[Seguy et al., 2018] :
- International Conference on Learning Representation (2018),
- arXiv preprint arxiv:1711.02283.
+ International Conference on Learning Representation (2018),
+ arXiv preprint arxiv:1711.02283.
'''
opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 50ec03c..d516dfc 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -9,51 +9,56 @@ Regularized Unbalanced OT
from __future__ import division
import warnings
import numpy as np
+from scipy.special import logsumexp
+
# from .utils import unif, dist
-def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
- Solve the unbalanced entropic regularization optimal transport problem and return the loss
+ Solve the unbalanced entropic regularization optimal transport problem
+ and return the OT plan
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns, nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -62,10 +67,16 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -82,83 +93,96 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
- ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_stabilized_unbalanced:
+ Unbalanced Stabilized sinkhorn [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling_unbalanced:
+ Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
-
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
+ raise ValueError("Unknown method '%s'." % method)
- return sink()
-
-def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
- numItermax=1000, stopThr=1e-9, verbose=False,
+def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
+ numItermax=1000, stopThr=1e-6, verbose=False,
log=False, **kwargs):
r"""
- Solve the entropic regularization unbalanced optimal transport problem and return the loss
+ Solve the entropic regularization unbalanced optimal transport problem and
+ return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt, n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -171,10 +195,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
log : dict
- log dictionary return only if log==True in parameters
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -191,64 +215,70 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
-
- if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
-
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
- warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
- else:
- raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
-
b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
b = b[:, None]
-
- return sink()
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError('Unknown method %s.' % method)
-def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
@@ -256,21 +286,21 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt, n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -279,11 +309,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
Returns
-------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
-
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -298,9 +333,13 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
References
----------
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
@@ -313,12 +352,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- n_a, n_b = M.shape
+ dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(n_a, dtype=np.float64) / n_a
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
if len(b) == 0:
- b = np.ones(n_b, dtype=np.float64) / n_b
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -331,21 +370,19 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((n_a, 1)) / n_a
- v = np.ones((n_b, n_hists)) / n_b
- a = a.reshape(n_a, 1)
+ u = np.ones((dim_a, 1)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
else:
- u = np.ones(n_a) / n_a
- v = np.ones(n_b) / n_b
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
- # print(reg)
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- # print(np.min(K))
- fi = alpha / (alpha + reg)
+ fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
@@ -364,15 +401,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %s' % cpt)
u = uprev
v = vprev
break
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -380,10 +418,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt + 1
+ cpt += 1
+
if log:
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
@@ -400,9 +439,224 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
return u[:, None] * K * v[None, :]
-def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A
+def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False,
+ **kwargs):
+ r"""
+ Solve the entropic regularization unbalanced optimal transport
+ problem and return the loss
+
+ The function solves the following optimization problem using log-domain
+ stabilization as proposed in [10]:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+
+ s.t.
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
+ - KL is the Kullback-Leibler divergence
+
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.],[1., 0.]]
+ >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
+
+ References
+ ----------
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ dim_a, dim_b = M.shape
+
+ if len(a) == 0:
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
+ if len(b) == 0:
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
+
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
+ else:
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim_a)
+ beta = np.zeros(dim_b)
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+
+ if n_hists:
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((a / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ v = ((b / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ if n_hists:
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ else:
+ alpha = alpha + reg * np.log(np.max(u))
+ beta = beta + reg * np.log(np.max(v))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ u = uprev
+ v = vprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(),
+ 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
+
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if n_hists:
+ logu = alpha[:, None] / reg + np.log(u)
+ logv = beta[:, None] / reg + np.log(v)
+ else:
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ if log:
+ log['logu'] = logu
+ log['logv'] = logv
+ if n_hists: # return only loss
+ res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] +
+ logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1))
+ res = np.exp(res)
+ if log:
+ return res, log
+ else:
+ return res
+
+ else: # return OT matrix
+ ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg)
+ if log:
+ return ot_matrix, log
+ else:
+ return ot_matrix
+
+
+def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
The function solves the following optimization problem:
@@ -411,28 +665,35 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
where :
- - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- - alpha is the marginal relaxation hyperparameter
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of
+ matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
Parameters
----------
- A : np.ndarray (d,n)
- n training distributions a_i of size d
- M : np.ndarray (d,d)
- loss matrix for OT
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m : float
Marginal relaxation term > 0
- weights : np.ndarray (n,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ tau : float
+ Stabilization threshold for log domain absorption.
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -441,7 +702,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : (dim,) ndarray
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -450,12 +711,165 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré,
+ G. (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
"""
- p, n_hists = A.shape
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ fi = reg_m / (reg_m + reg)
+
+ u = np.ones((dim, n_hists)) / dim
+ v = np.ones((dim, n_hists)) / dim
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim)
+ beta = np.zeros(dim)
+ q = np.ones(dim) / dim
+ while (err > stopThr and cpt < numItermax):
+ qprev = q
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ q = (Ktu ** (1 - fi)) * f_beta
+ q = q.dot(weights) ** (1 / (1 - fi))
+ Q = q[:, None]
+ v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ q = qprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(q - qprev).max() / max(abs(q).max(),
+ abs(qprev).max(), 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
+def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
+
+
+ """
+ dim, n_hists = A.shape
if weights is None:
weights = np.ones(n_hists) / n_hists
else:
@@ -466,10 +880,10 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
K = np.exp(- M / reg)
- fi = alpha / (alpha + reg)
+ fi = reg_m / (reg_m + reg)
- v = np.ones((p, n_hists)) / p
- u = np.ones((p, 1)) / p
+ v = np.ones((dim, n_hists)) / dim
+ u = np.ones((dim, 1)) / dim
cpt = 0
err = 1.
@@ -498,8 +912,11 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \
- np.sum((v - vprev) ** 2) / np.sum((v) ** 2)
+ err_u = abs(u - uprev).max()
+ err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max()
+ err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -511,8 +928,95 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
cpt += 1
if log:
log['niter'] = cpt
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
return q, log
else:
return q
+
+
+def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False, **kwargs):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
+
+ """
+
+ if method.lower() == 'sinkhorn':
+ return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return barycenter_unbalanced_stabilized(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return barycenter_unbalanced(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
diff --git a/ot/utils.py b/ot/utils.py
index 5707d9b..b71458b 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -111,12 +111,12 @@ def dist(x1, x2=None, metric='sqeuclidean'):
Parameters
----------
- x1 : np.array (n1,d)
+ x1 : ndarray, shape (n1,d)
matrix with n1 samples of size d
- x2 : np.array (n2,d), optional
+ x2 : array, shape (n2,d), optional
matrix with n2 samples of size d (if None then x2=x1)
- metric : str, fun, optional
- name of the metric to be computed (full list in the doc of scipy), If a string,
+ metric : str | callable, optional
+ Name of the metric to be computed (full list in the doc of scipy), If a string,
the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
@@ -138,26 +138,21 @@ def dist(x1, x2=None, metric='sqeuclidean'):
def dist0(n, method='lin_square'):
- """Compute standard cost matrices of size (n,n) for OT problems
+ """Compute standard cost matrices of size (n, n) for OT problems
Parameters
----------
-
n : int
- size of the cost matrix
+ Size of the cost matrix.
method : str, optional
Type of loss matrix chosen from:
* 'lin_square' : linear sampling between 0 and n-1, quadratic loss
-
Returns
-------
-
- M : np.array (n1,n2)
- distance matrix computed with given metric
-
-
+ M : ndarray, shape (n1,n2)
+ Distance matrix computed with given metric.
"""
res = 0
if method == 'lin_square':
@@ -169,33 +164,34 @@ def dist0(n, method='lin_square'):
def cost_normalization(C, norm=None):
""" Apply normalization to the loss matrix
-
Parameters
----------
- C : np.array (n1, n2)
+ C : ndarray, shape (n1, n2)
The cost matrix to normalize.
norm : str
- type of normalization from 'median','max','log','loglog'. Any other
- value do not normalize.
-
+ Type of normalization from 'median', 'max', 'log', 'loglog'. Any
+ other value do not normalize.
Returns
-------
-
- C : np.array (n1, n2)
+ C : ndarray, shape (n1, n2)
The input cost matrix normalized according to given norm.
-
"""
- if norm == "median":
+ if norm is None:
+ pass
+ elif norm == "median":
C /= float(np.median(C))
elif norm == "max":
C /= float(np.max(C))
elif norm == "log":
C = np.log(1 + C)
elif norm == "loglog":
- C = np.log(1 + np.log(1 + C))
-
+ C = np.log1p(np.log1p(C))
+ else:
+ raise ValueError('Norm %s is not a valid option.\n'
+ 'Valid options are:\n'
+ 'median, max, log, loglog' % norm)
return C
@@ -261,6 +257,7 @@ def check_params(**kwargs):
def check_random_state(seed):
"""Turn seed into a np.random.RandomState instance
+
Parameters
----------
seed : None | int | instance of RandomState
@@ -280,7 +277,6 @@ def check_random_state(seed):
class deprecated(object):
-
"""Decorator to mark a function or class as deprecated.
deprecated class from scikit-learn package
@@ -296,8 +292,8 @@ class deprecated(object):
Parameters
----------
- extra : string
- to be added to the deprecation messages
+ extra : str
+ To be added to the deprecation messages.
"""
# Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
@@ -378,9 +374,9 @@ def _is_deprecated(func):
class BaseEstimator(object):
-
"""Base class for most objects in POT
- adapted from sklearn BaseEstimator class
+
+ Code adapted from sklearn BaseEstimator class
Notes
-----
@@ -422,7 +418,7 @@ class BaseEstimator(object):
Parameters
----------
- deep : boolean, optional
+ deep : bool, optional
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/pytest.ini
diff --git a/setup.cfg b/setup.cfg
index aa0ff62..6be91fe 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -4,3 +4,21 @@ description-file = README.md
[flake8]
exclude = __init__.py
ignore = E265,E501,W605,W503,W504
+
+[tool:pytest]
+addopts =
+ --showlocals --durations=20 --doctest-modules -ra --cov-report= --cov=ot
+ --doctest-ignore-import-errors --junit-xml=junit-results.xml
+ --ignore=docs --ignore=examples --ignore=notebooks
+
+[pycodestyle]
+exclude = __init__.py,*externals*,constants.py,fixes.py
+ignore = E241,E305,W504
+
+[pydocstyle]
+convention = pep257
+match_dir = ^(?!\.|docs|examples).*$
+match = (?!tests/__init__\.py|fixes).*\.py
+add-ignore = D100,D104,D107,D413
+add-select = D214,D215,D404,D405,D406,D407,D408,D409,D410,D411
+ignore-decorators = ^(copy_.*_doc_to_|on_trait_change|cached_property|deprecated|property|.*setter).*
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 7f4972c..f70df10 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -7,6 +7,7 @@
import numpy as np
import ot
+import pytest
def test_sinkhorn():
@@ -71,13 +72,11 @@ def test_sinkhorn_variants():
Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
Ges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
- Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
- np.testing.assert_allclose(G0, Gerr)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
print(G0, G_green)
@@ -96,18 +95,17 @@ def test_sinkhorn_variants_log():
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
- Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True)
G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
- np.testing.assert_allclose(G0, Gerr)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
print(G0, G_green)
-def test_bary():
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_barycenter(method):
n_bins = 100 # nb bins
@@ -126,14 +124,42 @@ def test_bary():
weights = np.array([1 - alpha, alpha])
# wasserstein
- reg = 1e-3
- bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ reg = 1e-2
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method)
np.testing.assert_allclose(1, np.sum(bary_wass))
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+def test_barycenter_stabilization():
+
+ n_bins = 100 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # wasserstein
+ reg = 1e-2
+ bar_stable = ot.bregman.barycenter(A, M, reg, weights,
+ method="sinkhorn_stabilized",
+ stopThr=1e-8)
+ bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
+ stopThr=1e-8)
+ np.testing.assert_allclose(bar, bar_stable)
+
+
def test_wasserstein_bary_2d():
size = 100 # size of a square image
@@ -254,3 +280,60 @@ def test_empirical_sinkhorn_divergence():
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
np.testing.assert_allclose(
emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
+
+
+def test_stabilized_vs_sinkhorn_multidim():
+ # test if stable version matches sinkhorn
+ # for multidimensional inputs
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ log=True)
+ G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2)
+
+
+def test_implemented_methods():
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling']
+ NOT_VALID_TOKENS = ['foo']
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 3
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n)
+ A = rng.rand(n, 2)
+ M = ot.dist(x, x)
+ epsilon = 1.
+
+ for method in IMPLEMENTED_METHODS:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ with pytest.raises(ValueError):
+ for method in set(NOT_VALID_TOKENS):
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ for method in ONLY_1D_methods:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ with pytest.raises(ValueError):
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
diff --git a/test/test_da.py b/test/test_da.py
index f7f3a9d..2a5e50e 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -245,6 +245,71 @@ def test_sinkhorn_transport_class():
assert len(otda.log_.keys()) != 0
+def test_unbalanced_sinkhorn_transport_class():
+ """test_sinkhorn_transport
+ """
+
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ otda = ot.da.UnbalancedSinkhornTransport()
+
+ # test its computed
+ otda.fit(Xs=Xs, Xt=Xt)
+ assert hasattr(otda, "cost_")
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ assert_equal(otda.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)
+
+ # test inverse transform
+ transp_Xt = otda.inverse_transform(Xt=Xt)
+ assert_equal(transp_Xt.shape, Xt.shape)
+
+ Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xt_new.shape, Xt_new.shape)
+
+ # test fit_transform
+ transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
+ assert_equal(transp_Xs.shape, Xs.shape)
+
+ # test unsupervised vs semi-supervised mode
+ otda_unsup = ot.da.SinkhornTransport()
+ otda_unsup.fit(Xs=Xs, Xt=Xt)
+ n_unsup = np.sum(otda_unsup.cost_)
+
+ otda_semi = ot.da.SinkhornTransport()
+ otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
+ assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
+ n_semisup = np.sum(otda_semi.cost_)
+
+ # check that the cost matrix norms are indeed different
+ assert n_unsup != n_semisup, "semisupervised mode not working"
+
+ # check everything runs well with log=True
+ otda = ot.da.SinkhornTransport(log=True)
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ assert len(otda.log_.keys()) != 0
+
+
def test_emd_transport_class():
"""test_sinkhorn_transport
"""
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index 1395fe1..ca1efba 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -7,9 +7,12 @@
import numpy as np
import ot
import pytest
+from ot.unbalanced import barycenter_unbalanced
+from scipy.special import logsumexp
-@pytest.mark.parametrize("method", ["sinkhorn"])
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_unbalanced_convergence(method):
# test generalized sinkhorn for unbalanced OT
n = 100
@@ -23,29 +26,35 @@ def test_unbalanced_convergence(method):
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
- G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
- stopThr=1e-10, method=method,
+ G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method=method,
log=True)
- loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (b / K.T.dot(log["u"])) ** fi
- u_final = (a / K.dot(log["v"])) ** fi
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)
+ logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
+ logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
+
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
# check if sinkhorn_unbalanced2 returns the correct loss
np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
-@pytest.mark.parametrize("method", ["sinkhorn"])
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_unbalanced_multiple_inputs(method):
# test generalized sinkhorn for unbalanced OT
n = 100
@@ -59,28 +68,59 @@ def test_unbalanced_multiple_inputs(method):
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
- alpha=alpha,
- stopThr=1e-10, method=method,
+ reg_m=reg_m,
+ method=method,
log=True)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (b / K.T.dot(log["u"])) ** fi
-
- u_final = (a[:, None] / K.dot(log["v"])) ** fi
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
assert len(loss) == b.shape[1]
-def test_unbalanced_barycenter():
+def test_stabilized_vs_sinkhorn():
+ # test if stable version matches sinkhorn
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ reg_m = 1.
+ G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ reg_m=reg_m,
+ log=True)
+ G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2, atol=1e-5)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_unbalanced_barycenter(method):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -92,27 +132,56 @@ def test_unbalanced_barycenter():
A = A * np.array([1, 2])[None, :]
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
- q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha,
- stopThr=1e-10,
- log=True)
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method, log=True)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (q[:, None] / K.T.dot(log["u"])) ** fi
- u_final = (A / K.dot(log["v"])) ** fi
+ fi = reg_m / (reg_m + epsilon)
+ logA = np.log(A + 1e-16)
+ logq = np.log(q + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logq - logKtu)
+ u_final = fi * (logA - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
+
+
+def test_barycenter_stabilized_vs_sinkhorn():
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ A = rng.rand(n, 2)
+
+ # make dists unbalanced
+ A = A * np.array([1, 4])[None, :]
+ M = ot.dist(x, x)
+ epsilon = 0.5
+ reg_m = 10
+
+ qstable, log = barycenter_unbalanced(A, M, reg=epsilon,
+ reg_m=reg_m, log=True,
+ tau=100,
+ method="sinkhorn_stabilized",
+ )
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method="sinkhorn",
+ log=True)
+
+ np.testing.assert_allclose(
+ q, qstable, atol=1e-05)
def test_implemented_methods():
- IMPLEMENTED_METHODS = ['sinkhorn']
- TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized',
- 'sinkhorn_epsilon_scaling']
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
NOT_VALID_TOKENS = ['foo']
# test generalized sinkhorn for unbalanced OT barycenter
n = 3
@@ -123,24 +192,30 @@ def test_implemented_methods():
# make dists unbalanced
b = ot.utils.unif(n) * 1.5
-
+ A = rng.rand(n, 2)
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
+ reg_m = 1.
for method in IMPLEMENTED_METHODS:
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
with pytest.warns(UserWarning, match='not implemented'):
for method in set(TO_BE_IMPLEMENTED_METHODS):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
with pytest.raises(ValueError):
for method in set(NOT_VALID_TOKENS):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)