summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-09-09 13:13:53 +0200
committerGitHub <noreply@github.com>2019-09-09 13:13:53 +0200
commitf251b4d080a577c2cee890ca43d8ec3658332021 (patch)
tree42d0bcd9b93a771a9ab227ebb75d654b7c262244
parentabfe183a49caaf74a07e595ac40920dae05a3c22 (diff)
parente0c935a865a57bc4603144b27f1b58cbfba87760 (diff)
Merge pull request #99 from hichamjanati/unbalanced-bar
[MRG] Add stabilization for unbalanced OT and barycenter computation
-rw-r--r--.gitignore3
-rw-r--r--docs/source/quickstart.rst148
-rw-r--r--ot/__init__.py23
-rw-r--r--ot/bregman.py584
-rw-r--r--ot/da.py2
-rw-r--r--ot/unbalanced.py803
-rw-r--r--pytest.ini0
-rw-r--r--test/test_bregman.py72
-rw-r--r--test/test_unbalanced.py163
9 files changed, 1340 insertions, 458 deletions
diff --git a/.gitignore b/.gitignore
index 42a9aad..dadf84c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -59,6 +59,9 @@ coverage.xml
*.mo
*.pot
+# xml
+*.xml
+
# Django stuff:
*.log
local_settings.py
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index 1640d6a..978eaff 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -2,13 +2,13 @@
Quick start guide
=================
-In the following we provide some pointers about which functions and classes
+In the following we provide some pointers about which functions and classes
to use for different problems related to optimal transport (OT) and machine
learning. We refer when we can to concrete examples in the documentation that
are also available as notebooks on the POT Github.
This document is not a tutorial on numerical optimal transport. For this we strongly
-recommend to read the very nice book [15]_ .
+recommend to read the very nice book [15]_ .
Optimal transport and Wasserstein distance
@@ -55,8 +55,8 @@ solver is quite efficient and uses sparsity of the solution.
Examples of use for :any:`ot.emd` are available in :
- :any:`auto_examples/plot_OT_2D_samples`
- - :any:`auto_examples/plot_OT_1D`
- - :any:`auto_examples/plot_OT_L1_vs_L2`
+ - :any:`auto_examples/plot_OT_1D`
+ - :any:`auto_examples/plot_OT_L1_vs_L2`
Computing Wasserstein distance
@@ -102,13 +102,13 @@ distance.
An example of use for :any:`ot.emd2` is available in :
- :any:`auto_examples/plot_compute_emd`
-
+
Special cases
^^^^^^^^^^^^^
Note that the OT problem and the corresponding Wasserstein distance can in some
-special cases be computed very efficiently.
+special cases be computed very efficiently.
For instance when the samples are in 1D, then the OT problem can be solved in
:math:`O(n\log(n))` by using a simple sorting. In this case we provide the
@@ -117,13 +117,13 @@ matrix and value. Note that since the solution is very sparse the :code:`sparse`
parameter of :any:`ot.emd_1d` allows for solving and returning the solution for
very large problems. Note that in order to compute directly the :math:`W_p`
Wasserstein distance in 1D we provide the function :any:`ot.wasserstein_1d` that
-takes :code:`p` as a parameter.
+takes :code:`p` as a parameter.
Another special case for estimating OT and Monge mapping is between Gaussian
distributions. In this case there exists a close form solution given in Remark
2.29 in [15]_ and the Monge mapping is an affine function and can be
also computed from the covariances and means of the source and target
-distributions. In the case when the finite sample dataset is supposed gaussian, we provide
+distributions. In the case when the finite sample dataset is supposed gaussian, we provide
:any:`ot.da.OT_mapping_linear` that returns the parameters for the Monge
mapping.
@@ -176,7 +176,7 @@ solution of the resulting optimization problem can be expressed as:
where :math:`u` and :math:`v` are vectors and :math:`K=\exp(-M/\lambda)` where
the :math:`\exp` is taken component-wise. In order to solve the optimization
problem, on can use an alternative projection algorithm called Sinkhorn-Knopp that can be very
-efficient for large values if regularization.
+efficient for large values if regularization.
The Sinkhorn-Knopp algorithm is implemented in :any:`ot.sinkhorn` and
:any:`ot.sinkhorn2` that return respectively the OT matrix and the value of the
@@ -201,12 +201,12 @@ More details about the algorithms used are given in the following note.
+ :code:`method='sinkhorn'` calls :any:`ot.bregman.sinkhorn_knopp` the
classic algorithm [2]_.
+ :code:`method='sinkhorn_stabilized'` calls :any:`ot.bregman.sinkhorn_stabilized` the
- log stabilized version of the algorithm [9]_.
+ log stabilized version of the algorithm [9]_.
+ :code:`method='sinkhorn_epsilon_scaling'` calls
:any:`ot.bregman.sinkhorn_epsilon_scaling` the epsilon scaling version
- of the algorithm [9]_.
+ of the algorithm [9]_.
+ :code:`method='greenkhorn'` calls :any:`ot.bregman.greenkhorn` the
- greedy sinkhorn verison of the algorithm [22]_.
+ greedy sinkhorn verison of the algorithm [22]_.
In addition to all those variants of sinkhorn, we have another
implementation solving the problem in the smooth dual or semi-dual in
@@ -236,7 +236,7 @@ of algorithms in [18]_ [19]_.
Examples of use for :any:`ot.sinkhorn` are available in :
- :any:`auto_examples/plot_OT_2D_samples`
- - :any:`auto_examples/plot_OT_1D`
+ - :any:`auto_examples/plot_OT_1D`
- :any:`auto_examples/plot_OT_1D_smooth`
- :any:`auto_examples/plot_stochastic`
@@ -248,13 +248,13 @@ While entropic OT is the most common and favored in practice, there exist other
kind of regularization. We provide in POT two specific solvers for other
regularization terms, namely quadratic regularization and group lasso
regularization. But we also provide in :any:`ot.optim` two generic solvers that allows solving any
-smooth regularization in practice.
+smooth regularization in practice.
Quadratic regularization
""""""""""""""""""""""""
The first general regularization term we can solve is the quadratic
-regularization of the form
+regularization of the form
.. math::
\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}^2
@@ -264,7 +264,7 @@ densifying the OT matrix but it keeps some sort of sparsity that is lost with
entropic regularization as soon as :math:`\lambda>0` [17]_. This problem can be
solved with POT using solvers from :any:`ot.smooth`, more specifically
functions :any:`ot.smooth.smooth_ot_dual` or
-:any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='l2'` to
+:any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='l2'` to
choose the quadratic regularization.
.. hint::
@@ -300,7 +300,7 @@ gradient algorithm [7]_ in function
.. hint::
Examples of group Lasso regularization are available in :
- - :any:`auto_examples/plot_otda_classes`
+ - :any:`auto_examples/plot_otda_classes`
- :any:`auto_examples/plot_otda_d2`
@@ -311,7 +311,7 @@ Finally we propose in POT generic solvers that can be used to solve any
regularization as long as you can provide a function computing the
regularization and a function computing its gradient (or sub-gradient).
-In order to solve
+In order to solve
.. math::
\gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + \lambda\Omega(\gamma)
@@ -336,12 +336,12 @@ Another generic solver is proposed to solve the problem
where :math:`\Omega_e` is the entropic regularization. In this case we use a
generalized conditional gradient [7]_ implemented in :any:`ot.optim.gcg` that
does not linearize the entropic term but
-relies on :any:`ot.sinkhorn` for its iterations.
+relies on :any:`ot.sinkhorn` for its iterations.
.. hint::
An example of generic solvers are available in :
- - :any:`auto_examples/plot_optim_OTreg`
+ - :any:`auto_examples/plot_optim_OTreg`
Wasserstein Barycenters
@@ -382,7 +382,7 @@ solver :any:`ot.lp.barycenter` that rely on generic LP solvers. By default the
function uses :any:`scipy.optimize.linprog`, but more efficient LP solvers from
cvxopt can be also used by changing parameter :code:`solver`. Note that this problem
requires to solve a very large linear program and can be very slow in
-practice.
+practice.
Similarly to the OT problem, OT barycenters can be computed in the regularized
case. When using entropic regularization is used, the problem can be solved with a
@@ -403,11 +403,11 @@ operators. We provide an implementation of this algorithm in function
Examples of Wasserstein (:any:`ot.lp.barycenter`) and regularized Wasserstein
barycenter (:any:`ot.bregman.barycenter`) computation are available in :
- - :any:`auto_examples/plot_barycenter_1D`
- - :any:`auto_examples/plot_barycenter_lp_vs_entropic`
+ - :any:`auto_examples/plot_barycenter_1D`
+ - :any:`auto_examples/plot_barycenter_lp_vs_entropic`
An example of convolutional barycenter
- (:any:`ot.bregman.convolutional_barycenter2d`) computation
+ (:any:`ot.bregman.convolutional_barycenter2d`) computation
for 2D images is available
in :
@@ -451,13 +451,13 @@ optimal mapping is still an open problem in the general case but has been proven
for smooth distributions by Brenier in his eponym `theorem
<https://who.rocq.inria.fr/Jean-David.Benamou/demiheure.pdf>`__. We provide in
:any:`ot.da` several solvers for smooth Monge mapping estimation and domain
-adaptation from discrete distributions.
+adaptation from discrete distributions.
Monge Mapping estimation
^^^^^^^^^^^^^^^^^^^^^^^^
We now discuss several approaches that are implemented in POT to estimate or
-approximate a Monge mapping from finite distributions.
+approximate a Monge mapping from finite distributions.
First note that when the source and target distributions are supposed to be Gaussian
distributions, there exists a close form solution for the mapping and its an
@@ -513,16 +513,16 @@ A list of the provided implementation is given in the following note.
Here is a list of the OT mapping classes inheriting from
:any:`ot.da.BaseTransport`
-
+
* :any:`ot.da.EMDTransport` : Barycentric mapping with EMD transport
* :any:`ot.da.SinkhornTransport` : Barycentric mapping with Sinkhorn transport
* :any:`ot.da.SinkhornL1l2Transport` : Barycentric mapping with Sinkhorn +
group Lasso regularization [5]_
* :any:`ot.da.SinkhornLpl1Transport` : Barycentric mapping with Sinkhorn +
- non convex group Lasso regularization [5]_
+ non convex group Lasso regularization [5]_
* :any:`ot.da.LinearTransport` : Linear mapping estimation between Gaussians
[14]_
- * :any:`ot.da.MappingTransport` : Nonlinear mapping estimation [8]_
+ * :any:`ot.da.MappingTransport` : Nonlinear mapping estimation [8]_
.. hint::
@@ -550,7 +550,7 @@ consist in finding a linear projector optimizing the following criterion
.. math::
P = \text{arg}\min_P \frac{\sum_i OT_e(\mu_i\#P,\mu_i\#P)}{\sum_{i,j\neq i}
OT_e(\mu_i\#P,\mu_j\#P)}
-
+
where :math:`\#` is the push-forward operator, :math:`OT_e` is the entropic OT
loss and :math:`\mu_i` is the
distribution of samples from class :math:`i`. :math:`P` is also constrained to
@@ -575,12 +575,12 @@ respectively. Note that we also provide the Fisher discriminant estimator in
Unbalanced optimal transport
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-Unbalanced OT is a relaxation of the original OT problem where the violation of
+Unbalanced OT is a relaxation of the entropy regularized OT problem where the violation of
the constraint on the marginals is added to the objective of the optimization
-problem:
-
+problem. The unbalanced OT metric between two unbalanced histograms a and b is defined as [25]_ [10]_:
+
.. math::
- \min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + reg\cdot\Omega(\gamma) + \alpha KL(\gamma 1, a) + \alpha KL(\gamma^T 1, b)
+ W_u(a, b) = \min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t. \quad \gamma\geq 0
@@ -589,18 +589,60 @@ where KL is the Kullback-Leibler divergence. This formulation allows for
computing approximate mapping between distributions that do not have the same
amount of mass. Interestingly the problem can be solved with a generalization of
the Bregman projections algorithm [10]_. We provide a solver for unbalanced OT
-in :any:`ot.unbalanced` and more specifically
-in function :any:`ot.sinkhorn_unbalanced`. A solver for unbalanced OT barycenter
-is available in :any:`ot.barycenter_unbalanced`.
+in :any:`ot.unbalanced`. Computing the optimal transport
+plan or the transport cost is similar to the balanced case. The Sinkhorn-Knopp
+algorithm is implemented in :any:`ot.sinkhorn_unbalanced` and :any:`ot.sinkhorn_unbalanced2`
+that return respectively the OT matrix and the value of the
+linear term.
+
+.. note::
+ The main function to solve entropic regularized UOT is :any:`ot.sinkhorn_unbalanced`.
+ This function is a wrapper and the parameter :code:`method` helps you select
+ the actual algorithm used to solve the problem:
+
+ + :code:`method='sinkhorn'` calls :any:`ot.unbalanced.sinkhorn_knopp_unbalanced`
+ the generalized Sinkhorn algorithm [25]_ [10]_.
+ + :code:`method='sinkhorn_stabilized'` calls :any:`ot.unbalanced.sinkhorn_stabilized_unbalanced`
+ the log stabilized version of the algorithm [10]_.
.. hint::
- Examples of the use of :any:`ot.sinkhorn_unbalanced` and
- :any:`ot.barycenter_unbalanced` are available in :
+ Examples of the use of :any:`ot.sinkhorn_unbalanced` are available in :
- :any:`auto_examples/plot_UOT_1D`
- - :any:`auto_examples/plot_UOT_barycenter_1D`
+
+
+Unbalanced Barycenters
+^^^^^^^^^^^^^^^^^^^^^^
+
+As with balanced distributions, we can define a barycenter of a set of
+histograms with different masses as a Fréchet Mean:
+
+ .. math::
+ \min_{\mu} \quad \sum_{k} w_kW_u(\mu,\mu_k)
+
+Where :math:`W_u` is the unbalanced Wasserstein metric defined above. This problem
+can also be solved using generalized version of Sinkhorn's algorithm and it is
+implemented the main function :any:`ot.barycenter_unbalanced`.
+
+
+.. note::
+ The main function to compute UOT barycenters is :any:`ot.barycenter_unbalanced`.
+ This function is a wrapper and the parameter :code:`method` help you select
+ the actual algorithm used to solve the problem:
+
+ + :code:`method='sinkhorn'` calls :any:`ot.unbalanced.barycenter_unbalanced_sinkhorn_unbalanced`
+ the generalized Sinkhorn algorithm [10]_.
+ + :code:`method='sinkhorn_stabilized'` calls :any:`ot.unbalanced.barycenter_unbalanced_stabilized`
+ the log stabilized version of the algorithm [10]_.
+
+
+.. hint::
+
+ Examples of the use of :any:`ot.barycenter_unbalanced` are available in :
+
+ - :any:`auto_examples/plot_UOT_barycenter_1D`
Gromov-Wasserstein
@@ -636,7 +678,7 @@ barycenters that can be expressed as
where :math:`Ck` is the distance matrix between samples in distribution
:math:`k`. Note that interestingly the barycenter is defined as a symmetric
-positive matrix. We provide a block coordinate optimization procedure in
+positive matrix. We provide a block coordinate optimization procedure in
:any:`ot.gromov.gromov_barycenters` and
:any:`ot.gromov.entropic_gromov_barycenters` for non-regularized and regularized
barycenters respectively.
@@ -654,19 +696,19 @@ The implementations of FGW and FGW barycenter is provided in functions
Examples of computation of GW, regularized G and FGW are available in :
- :any:`auto_examples/plot_gromov`
- - :any:`auto_examples/plot_fgw`
+ - :any:`auto_examples/plot_fgw`
Examples of GW, regularized GW and FGW barycenters are available in :
- :any:`auto_examples/plot_gromov_barycenter`
- - :any:`auto_examples/plot_barycenter_fgw`
+ - :any:`auto_examples/plot_barycenter_fgw`
GPU acceleration
^^^^^^^^^^^^^^^^
We provide several implementation of our OT solvers in :any:`ot.gpu`. Those
-implementations use the :code:`cupy` toolbox that obviously need to be installed.
+implementations use the :code:`cupy` toolbox that obviously need to be installed.
.. note::
@@ -701,7 +743,7 @@ FAQ
1. **How to solve a discrete optimal transport problem ?**
The solver for discrete OT is the function :py:mod:`ot.emd` that returns
- the OT transport matrix. If you want to solve a regularized OT you can
+ the OT transport matrix. If you want to solve a regularized OT you can
use :py:mod:`ot.sinkhorn`.
@@ -714,9 +756,9 @@ FAQ
T=ot.emd(a,b,M) # exact linear program
T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
- More detailed examples can be seen on this example:
+ More detailed examples can be seen on this example:
:doc:`auto_examples/plot_OT_2D_samples`
-
+
2. **pip install POT fails with error : ImportError: No module named Cython.Build**
@@ -726,7 +768,7 @@ FAQ
installing POT.
Note that this problem do not occur when using conda-forge since the packages
- there are pre-compiled.
+ there are pre-compiled.
See `Issue #59 <https://github.com/rflamary/POT/issues/59>`__ for more
details.
@@ -751,7 +793,7 @@ FAQ
In order to limit import time and hard dependencies in POT. we do not import
some sub-modules automatically with :code:`import ot`. In order to use the
acceleration in :any:`ot.gpu` you need first to import is with
- :code:`import ot.gpu`.
+ :code:`import ot.gpu`.
See `Issue #85 <https://github.com/rflamary/POT/issues/85>`__ and :any:`ot.gpu`
for more details.
@@ -763,7 +805,7 @@ References
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
December). `Displacement nterpolation using Lagrangian mass transport
<https://people.csail.mit.edu/sparis/publi/2011/sigasia/Bonneel_11_Displacement_Interpolation.pdf>`__.
- In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
+ In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
.. [2] Cuturi, M. (2013). `Sinkhorn distances: Lightspeed computation of
optimal transport <https://arxiv.org/pdf/1306.0895.pdf>`__. In Advances
@@ -874,4 +916,8 @@ References
.. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N.
(2019). `Optimal Transport for structured data with application on
graphs <http://proceedings.mlr.press/v97/titouan19a.html>`__ Proceedings
- of the 36th International Conference on Machine Learning (ICML). \ No newline at end of file
+ of the 36th International Conference on Machine Learning (ICML).
+
+.. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
diff --git a/ot/__init__.py b/ot/__init__.py
index 35ae6fc..df0ef8a 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -1,7 +1,7 @@
"""
-This is the main module of the POT toolbox. It provides easy access to
-a number of sub-modules and functions described below.
+This is the main module of the POT toolbox. It provides easy access to
+a number of sub-modules and functions described below.
.. note::
@@ -14,27 +14,27 @@ a number of sub-modules and functions described below.
- :any:`ot.lp` contains OT solvers for the exact (Linear Program) OT problems.
- :any:`ot.smooth` contains OT solvers for the regularized (l2 and kl) smooth OT
problems.
- - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov
+ - :any:`ot.gromov` contains solvers for Gromov-Wasserstein and Fused Gromov
Wasserstein problems.
- - :any:`ot.optim` contains generic solvers OT based optimization problems
+ - :any:`ot.optim` contains generic solvers OT based optimization problems
- :any:`ot.da` contains classes and function related to Monge mapping
estimation and Domain Adaptation (DA).
- :any:`ot.gpu` contains GPU (cupy) implementation of some OT solvers
- - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein
+ - :any:`ot.dr` contains Dimension Reduction (DR) methods such as Wasserstein
Discriminant Analysis.
- - :any:`ot.utils` contains utility functions such as distance computation and
- timing.
+ - :any:`ot.utils` contains utility functions such as distance computation and
+ timing.
- :any:`ot.datasets` contains toy dataset generation functions.
- :any:`ot.plot` contains visualization functions
- :any:`ot.stochastic` contains stochastic solvers for regularized OT.
- :any:`ot.unbalanced` contains solvers for regularized unbalanced OT.
.. warning::
- The list of automatically imported sub-modules is as follows:
+ The list of automatically imported sub-modules is as follows:
:py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
:py:mod:`ot.utils`, :py:mod:`ot.datasets`,
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
- :py:mod:`ot.stochastic`
+ :py:mod:`ot.stochastic`
The following sub-modules are not imported due to additional dependencies:
@@ -65,7 +65,7 @@ from . import unbalanced
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
from .bregman import sinkhorn, sinkhorn2, barycenter
-from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced
+from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
from .da import sinkhorn_lpl1_mm
# utils functions
@@ -77,4 +77,5 @@ __all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets'
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
'emd_1d', 'emd2_1d', 'wasserstein_1d',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
- 'sinkhorn_unbalanced', "barycenter_unbalanced"]
+ 'sinkhorn_unbalanced', 'barycenter_unbalanced',
+ 'sinkhorn_unbalanced2']
diff --git a/ot/bregman.py b/ot/bregman.py
index 70e4208..2cd832b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -7,10 +7,12 @@ Bregman projections for regularized OT
# Nicolas Courty <ncourty@irisa.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+# Hicham Janati <hicham.janati@inria.fr>
#
# License: MIT License
import numpy as np
+import warnings
from .utils import unif, dist
@@ -31,21 +33,21 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -64,7 +66,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -103,30 +105,23 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
"""
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'greenkhorn':
- def sink():
- return greenkhorn(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log)
+ return greenkhorn(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log)
elif method.lower() == 'sinkhorn_stabilized':
- def sink():
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- def sink():
- return sinkhorn_epsilon_scaling(
- a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_epsilon_scaling(a, b, M, reg,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- print('Warning : unknown method using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
-
- return sink()
+ raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -146,21 +141,21 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -178,8 +173,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
+ W : (n_hists) ndarray or float
+ Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -218,31 +213,23 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
"""
-
+ b = np.asarray(b, dtype=np.float64)
+ if len(b.shape) < 2:
+ b = b[:, None]
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
- def sink():
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- def sink():
- return sinkhorn_epsilon_scaling(
- a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- print('Warning : unknown method using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp(a, b, M, reg, **kwargs)
-
- b = np.asarray(b, dtype=np.float64)
- if len(b.shape) < 2:
- b = b[:, None]
-
- return sink()
+ raise ValueError("Unknown method '%s'." % method)
def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
@@ -262,21 +249,21 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -291,7 +278,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -331,25 +318,25 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
# init data
- Nini = len(a)
- Nfin = len(b)
+ dim_a = len(a)
+ dim_b = len(b)
if len(b.shape) > 1:
- nbb = b.shape[1]
+ n_hists = b.shape[1]
else:
- nbb = 0
+ n_hists = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
- if nbb:
- u = np.ones((Nini, nbb)) / Nini
- v = np.ones((Nfin, nbb)) / Nfin
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
else:
- u = np.ones(Nini) / Nini
- v = np.ones(Nfin) / Nfin
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
# print(reg)
@@ -384,13 +371,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- if nbb:
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ if n_hists:
+ np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
np.einsum('i,ij,j->j', u, K, v, out=tmp2)
- err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
+ err = np.linalg.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
@@ -404,7 +390,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['u'] = u
log['v'] = v
- if nbb: # return only loss
+ if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
@@ -419,7 +405,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
+def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
+ log=False):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -443,20 +430,20 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -469,7 +456,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -509,16 +496,16 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- n = a.shape[0]
- m = b.shape[0]
+ dim_a = a.shape[0]
+ dim_b = b.shape[0]
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty_like(M)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- u = np.full(n, 1. / n)
- v = np.full(m, 1. / m)
+ u = np.full(dim_a, 1. / dim_a)
+ v = np.full(dim_b, 1. / dim_b)
G = u[:, np.newaxis] * K * v[np.newaxis, :]
viol = G.sum(1) - a
@@ -572,7 +559,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
- warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
+ warmstart=None, verbose=False, print_period=20,
+ log=False, **kwargs):
r"""
Solve the entropic regularization OT problem with log stabilization
@@ -588,9 +576,10 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in [2]_ but with the log stabilization
@@ -599,11 +588,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : ndarray, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -622,7 +611,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -634,7 +623,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
- >>> ot.bregman.sinkhorn_stabilized(a,b,M,1)
+ >>> ot.bregman.sinkhorn_stabilized(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
@@ -667,14 +656,14 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# test if multiple target
if len(b.shape) > 1:
- nbb = b.shape[1]
+ n_hists = b.shape[1]
a = a[:, np.newaxis]
else:
- nbb = 0
+ n_hists = 0
# init data
- na = len(a)
- nb = len(b)
+ dim_a = len(a)
+ dim_b = len(b)
cpt = 0
if log:
@@ -683,24 +672,25 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(na), np.zeros(nb)
+ alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
else:
alpha, beta = warmstart
- if nbb:
- u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
else:
- u, v = np.ones(na) / na, np.ones(nb) / nb
+ u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1))
- - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((dim_a, 1))
+ - beta.reshape((1, dim_b))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb)))
- / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
+ return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
+ / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b))))
# print(np.min(K))
@@ -720,26 +710,29 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# remove numerical problems and store them in K
if np.abs(u).max() > tau or np.abs(v).max() > tau:
- if nbb:
+ if n_hists:
alpha, beta = alpha + reg * \
np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
else:
alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
- if nbb:
- u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb
+ if n_hists:
+ u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b
else:
- u, v = np.ones(na) / na, np.ones(nb) / nb
+ u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
K = get_K(alpha, beta)
if cpt % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- if nbb:
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ if n_hists:
+ err_u = abs(u - uprev).max()
+ err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max()
+ err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
else:
transp = get_Gamma(alpha, beta, u, v)
- err = np.linalg.norm((np.sum(transp, axis=0) - b))**2
+ err = np.linalg.norm((np.sum(transp, axis=0) - b))
if log:
log['err'].append(err)
@@ -766,7 +759,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
cpt = cpt + 1
if log:
- if nbb:
+ if n_hists:
alpha = alpha[:, None]
beta = beta[:, None]
logu = alpha / reg + np.log(u)
@@ -776,26 +769,28 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
log['alpha'] = alpha + reg * np.log(u)
log['beta'] = beta + reg * np.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
- if nbb:
- res = np.zeros((nbb))
- for i in range(nbb):
+ if n_hists:
+ res = np.zeros((n_hists))
+ for i in range(n_hists):
res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
return res, log
else:
return get_Gamma(alpha, beta, u, v), log
else:
- if nbb:
- res = np.zeros((nbb))
- for i in range(nbb):
+ if n_hists:
+ res = np.zeros((n_hists))
+ for i in range(n_hists):
res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
return res
else:
return get_Gamma(alpha, beta, u, v)
-def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100,
- tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
+def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
+ numInnerItermax=100, tau=1e3, stopThr=1e-9,
+ warmstart=None, verbose=False, print_period=10,
+ log=False, **kwargs):
r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
@@ -812,9 +807,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - a and b are source and target weights (histograms, both sum to 1)
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in [2]_ but with the log stabilization
@@ -823,18 +819,16 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : ndarray, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
thershold for max value in u or v for log scaling
- tau : float
- thershold for max value in u or v for log scaling
warmstart : tuple of vectors
if given then sarting values for alpha an beta log scalings
numItermax : int, optional
@@ -852,7 +846,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -893,8 +887,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
# init data
- na = len(a)
- nb = len(b)
+ dim_a = len(a)
+ dim_b = len(b)
# nrelative umerical precision with 64 bits
numItermin = 35
@@ -907,14 +901,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(na), np.zeros(nb)
+ alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
else:
alpha, beta = warmstart
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1))
- - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((dim_a, 1))
+ - beta.reshape((1, dim_b))) / reg)
# print(np.min(K))
def get_reg(n): # exponential decreasing
@@ -927,8 +921,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
regi = get_reg(cpt)
- G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, warmstart=(
- alpha, beta), verbose=False, print_period=20, tau=tau, log=True)
+ G, logi = sinkhorn_stabilized(a, b, M, regi,
+ numItermax=numInnerItermax, stopThr=1e-9,
+ warmstart=(alpha, beta), verbose=False,
+ print_period=20, tau=tau, log=True)
alpha = logi['alpha']
beta = logi['beta']
@@ -986,8 +982,8 @@ def projC(gamma, q):
return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
-def barycenter(A, M, reg, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
+def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False, **kwargs):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
The function solves the following optimization problem:
@@ -1005,13 +1001,15 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Parameters
----------
- A : ndarray, shape (d,n)
- n training distributions a_i of size d
- M : ndarray, shape (d,d)
- loss matrix for OT
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
reg : float
- Regularization term >0
- weights : ndarray, shape (n,)
+ Regularization term > 0
+ method : str (optional)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized'
+ weights : ndarray, shape (n_hists,)
Weights of each histogram a_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
@@ -1025,7 +1023,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : (dim,) ndarray
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1036,7 +1034,69 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ """
+ if method.lower() == 'sinkhorn':
+ return barycenter_sinkhorn(A, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return barycenter_stabilized(A, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ weights : ndarray, shape (n_hists,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
@@ -1082,7 +1142,138 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
-def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
+def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ with stabilization.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ weights : ndarray, shape (n_hists,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+ """
+
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ u = np.ones((dim, n_hists)) / dim
+ v = np.ones((dim, n_hists)) / dim
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim)
+ beta = np.zeros(dim)
+ q = np.ones(dim) / dim
+ while (err > stopThr and cpt < numItermax):
+ qprev = q
+ Kv = K.dot(v)
+ u = A / (Kv + 1e-16)
+ Ktu = K.T.dot(u)
+ q = geometricBar(weights, Ktu)
+ Q = q[:, None]
+ v = Q / (Ktu + 1e-16)
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ q = qprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u * Kv - A).max()
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg`" +
+ "Or a larger absorption threshold `tau`.")
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
+def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-9, stabThr=1e-30, verbose=False,
+ log=False):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images.
@@ -1101,16 +1292,16 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
Parameters
----------
- A : ndarray, shape (n, w, h)
- n distributions (2D images) of size w x h
+ A : ndarray, shape (n_hists, width, height)
+ n distributions (2D images) of size width x height
reg : float
Regularization term >0
- weights : ndarray, shape (n,)
+ weights : ndarray, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
@@ -1120,7 +1311,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
Returns
-------
- a : ndarray, shape (w, h)
+ a : ndarray, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1204,9 +1395,12 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
where :
- :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}` is an observed distribution, :math:`\mathbf{h}_0` is aprior on unmixing
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT data fitting
- - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix for regularization
+ - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)`
+ - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
+ - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
+ - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting
+ - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization
- :math:`\\alpha`weight data fitting and regularization
The optimization problem is solved suing the algorithm described in [4]
@@ -1214,16 +1408,16 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Parameters
----------
- a : ndarray, shape (d)
- observed distribution
- D : ndarray, shape (d, n)
+ a : ndarray, shape (dim_a)
+ observed distribution (histogram, sums to 1)
+ D : ndarray, shape (dim_a, n_atoms)
dictionary matrix
- M : ndarray, shape (d, d)
+ M : ndarray, shape (dim_a, dim_a)
loss matrix
- M0 : ndarray, shape (n, n)
+ M0 : ndarray, shape (n_atoms, dim_prior)
loss matrix
- h0 : ndarray, shape (n,)
- prior on h
+ h0 : ndarray, shape (n_atoms,)
+ prior on the estimated unmixing h
reg : float
Regularization term >0 (Wasserstein data fitting)
reg0 : float
@@ -1242,7 +1436,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Returns
-------
- a : ndarray, shape (d,)
+ h : ndarray, shape (n_atoms,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1298,7 +1492,9 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
return np.sum(K0, axis=1)
-def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
+def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9, verbose=False,
+ log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -1315,22 +1511,22 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
\gamma\geq 0
where :
- - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (ns, d)
+ X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (nt, d)
+ X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (ns,)
+ a : ndarray, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : ndarray, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1344,7 +1540,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1352,11 +1548,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Examples
--------
- >>> n_s = 2
- >>> n_t = 2
+ >>> n_samples_a = 2
+ >>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
array([[4.99977301e-01, 2.26989344e-05],
[2.26989344e-05, 4.99977301e-01]])
@@ -1405,22 +1601,22 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
\gamma\geq 0
where :
- - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (ns, d)
+ X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (nt, d)
+ X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (ns,)
+ a : ndarray, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : ndarray, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1434,7 +1630,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1442,11 +1638,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Examples
--------
- >>> n_s = 2
- >>> n_t = 2
+ >>> n_samples_a = 2
+ >>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
array([4.53978687e-05])
@@ -1513,22 +1709,22 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
\gamma_b\geq 0
where :
- - :math:`M` (resp. :math:`M_a, M_b`) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt))
+ - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b))
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (ns, d)
+ X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (nt, d)
+ X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (ns,)
+ a : ndarray, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : ndarray, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1541,18 +1737,18 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
- >>> n_s = 2
- >>> n_t = 4
+ >>> n_samples_a = 2
+ >>> n_samples_b = 4
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS
array([1.499...])
diff --git a/ot/da.py b/ot/da.py
index 2af855d..108a38d 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1902,7 +1902,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
returned_ = sinkhorn_unbalanced(
a=self.mu_s, b=self.mu_t, M=self.cost_,
- reg=self.reg_e, alpha=self.reg_m, method=self.method,
+ reg=self.reg_e, reg_m=self.reg_m, method=self.method,
numItermax=self.max_iter, stopThr=self.tol,
verbose=self.verbose, log=self.log)
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 0f0692e..d516dfc 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -9,51 +9,56 @@ Regularized Unbalanced OT
from __future__ import division
import warnings
import numpy as np
+from scipy.special import logsumexp
+
# from .utils import unif, dist
-def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
- Solve the unbalanced entropic regularization optimal transport problem and return the loss
+ Solve the unbalanced entropic regularization optimal transport problem
+ and return the OT plan
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns, nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -62,10 +67,16 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -82,83 +93,96 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
- ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_stabilized_unbalanced:
+ Unbalanced Stabilized sinkhorn [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling_unbalanced:
+ Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
-
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
-
- return sink()
+ raise ValueError("Unknown method '%s'." % method)
-def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
- numItermax=1000, stopThr=1e-9, verbose=False,
+def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
+ numItermax=1000, stopThr=1e-6, verbose=False,
log=False, **kwargs):
r"""
- Solve the entropic regularization unbalanced optimal transport problem and return the loss
+ Solve the entropic regularization unbalanced optimal transport problem and
+ return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt, n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ 'sinkhorn_reg_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -171,10 +195,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
Returns
-------
- W : (nt) ndarray or float
- Optimal transportation matrix for the given parameters
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
log : dict
- log dictionary return only if log==True in parameters
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -191,64 +215,70 @@ def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems
+ (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
"""
-
- if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
-
- elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
- warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
- else:
- raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
-
b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
b = b[:, None]
-
- return sink()
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError('Unknown method %s.' % method)
-def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \\alpha KL(\gamma 1, a) + \\alpha KL(\gamma^T 1, b)
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b)
s.t.
\gamma\geq 0
where :
- - M is the (ns, nt) metric cost matrix
+ - M is the (dim_a, dim_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights
+ - a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
@@ -256,21 +286,21 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
- samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt, n_hists)
- samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m: float
Marginal relaxation term > 0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -279,11 +309,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
Returns
-------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- log : dict
- log dictionary return only if log==True in parameters
-
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
Examples
--------
@@ -298,9 +333,13 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
References
----------
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
- .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
See Also
--------
@@ -313,12 +352,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
- n_a, n_b = M.shape
+ dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(n_a, dtype=np.float64) / n_a
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
if len(b) == 0:
- b = np.ones(n_b, dtype=np.float64) / n_b
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -331,21 +370,19 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((n_a, 1)) / n_a
- v = np.ones((n_b, n_hists)) / n_b
- a = a.reshape(n_a, 1)
+ u = np.ones((dim_a, 1)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
else:
- u = np.ones(n_a) / n_a
- v = np.ones(n_b) / n_b
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
- # print(reg)
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- # print(np.min(K))
- fi = alpha / (alpha + reg)
+ fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
@@ -371,8 +408,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -383,8 +421,8 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
cpt += 1
if log:
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
@@ -401,9 +439,224 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
return u[:, None] * K * v[None, :]
-def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A
+def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
+ stopThr=1e-6, verbose=False, log=False,
+ **kwargs):
+ r"""
+ Solve the entropic regularization unbalanced optimal transport
+ problem and return the loss
+
+ The function solves the following optimization problem using log-domain
+ stabilization as proposed in [10]:
+
+ .. math::
+ W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+
+ s.t.
+ \gamma\geq 0
+ where :
+
+ - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\Omega` is the entropic regularization
+ term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - a and b are source and target unbalanced distributions
+ - KL is the Kullback-Leibler divergence
+
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Unnormalized histogram of dimension dim_a
+ b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ One or multiple unnormalized histograms of dimension dim_b
+ If many, compute all the OT distances (a, b_i)
+ M : np.ndarray (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ if n_hists == 1:
+ gamma : (dim_a x dim_b) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary returned only if `log` is `True`
+ else:
+ ot_distance : (n_hists,) ndarray
+ the OT distance between `a` and each of the histograms `b_i`
+ log : dict
+ log dictionary returned only if `log` is `True`
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.],[1., 0.]]
+ >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.)
+ array([[0.51122823, 0.18807035],
+ [0.18807035, 0.51122823]])
+
+ References
+ ----------
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+ .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64)
+
+ dim_a, dim_b = M.shape
+
+ if len(a) == 0:
+ a = np.ones(dim_a, dtype=np.float64) / dim_a
+ if len(b) == 0:
+ b = np.ones(dim_b, dtype=np.float64) / dim_b
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
+
+ if log:
+ log = {'err': []}
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
+ a = a.reshape(dim_a, 1)
+ else:
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim_a)
+ beta = np.zeros(dim_b)
+ while (err > stopThr and cpt < numItermax):
+ uprev = u
+ vprev = v
+
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+
+ if n_hists:
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((a / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ v = ((b / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ if n_hists:
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ else:
+ alpha = alpha + reg * np.log(np.max(u))
+ beta = beta + reg * np.log(np.max(v))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ u = uprev
+ v = vprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(),
+ 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+ cpt = cpt + 1
+
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if n_hists:
+ logu = alpha[:, None] / reg + np.log(u)
+ logv = beta[:, None] / reg + np.log(v)
+ else:
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ if log:
+ log['logu'] = logu
+ log['logv'] = logv
+ if n_hists: # return only loss
+ res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] +
+ logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1))
+ res = np.exp(res)
+ if log:
+ return res, log
+ else:
+ return res
+
+ else: # return OT matrix
+ ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg)
+ if log:
+ return ot_matrix, log
+ else:
+ return ot_matrix
+
+
+def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
The function solves the following optimization problem:
@@ -412,28 +665,35 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
where :
- - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- - alpha is the marginal relaxation hyperparameter
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of
+ matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
Parameters
----------
- A : np.ndarray (d,n)
- n training distributions a_i of size d
- M : np.ndarray (d,d)
- loss matrix for OT
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
reg : float
Entropy regularization term > 0
- alpha : float
+ reg_m : float
Marginal relaxation term > 0
- weights : np.ndarray (n,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ tau : float
+ Stabilization threshold for log domain absorption.
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -442,7 +702,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : (dim,) ndarray
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -451,12 +711,165 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré,
+ G. (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
"""
- p, n_hists = A.shape
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ fi = reg_m / (reg_m + reg)
+
+ u = np.ones((dim, n_hists)) / dim
+ v = np.ones((dim, n_hists)) / dim
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ fi = reg_m / (reg_m + reg)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim)
+ beta = np.zeros(dim)
+ q = np.ones(dim) / dim
+ while (err > stopThr and cpt < numItermax):
+ qprev = q
+ Kv = K.dot(v)
+ f_alpha = np.exp(- alpha / (reg + reg_m))
+ f_beta = np.exp(- beta / (reg + reg_m))
+ f_alpha = f_alpha[:, None]
+ f_beta = f_beta[:, None]
+ u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
+ Ktu = K.T.dot(u)
+ q = (Ktu ** (1 - fi)) * f_beta
+ q = q.dot(weights) ** (1 / (1 - fi))
+ Q = q[:, None]
+ v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ q = qprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(q - qprev).max() / max(abs(q).max(),
+ abs(qprev).max(), 1.)
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg` or a lower mass `reg_m`." +
+ "Or a larger absorption threshold `tau`.")
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
+def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
+
+
+ """
+ dim, n_hists = A.shape
if weights is None:
weights = np.ones(n_hists) / n_hists
else:
@@ -467,10 +880,10 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
K = np.exp(- M / reg)
- fi = alpha / (alpha + reg)
+ fi = reg_m / (reg_m + reg)
- v = np.ones((p, n_hists)) / p
- u = np.ones((p, 1)) / p
+ v = np.ones((dim, n_hists)) / dim
+ u = np.ones((dim, 1)) / dim
cpt = 0
err = 1.
@@ -499,8 +912,11 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \
- np.sum((v - vprev) ** 2) / np.sum((v) ** 2)
+ err_u = abs(u - uprev).max()
+ err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max()
+ err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
@@ -512,8 +928,95 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
cpt += 1
if log:
log['niter'] = cpt
- log['u'] = u
- log['v'] = v
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
return q, log
else:
return q
+
+
+def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False, **kwargs):
+ r"""Compute the entropic unbalanced wasserstein barycenter of A.
+
+ The function solves the following optimization problem with a
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
+ Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+ - reg_mis the marginal relaxation hyperparameter
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ Parameters
+ ----------
+ A : np.ndarray (dim, n_hists)
+ `n_hists` training distributions a_i of dimension dim
+ M : np.ndarray (dim, dim)
+ ground metric matrix for OT.
+ reg : float
+ Entropy regularization term > 0
+ reg_m: float
+ Marginal relaxation term > 0
+ weights : np.ndarray (n_hists,) optional
+ Weight of each distribution (barycentric coodinates)
+ If None, uniform weights are used.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Unbalanced Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). Iterative Bregman projections for regularized transportation
+ problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprin
+ arXiv:1607.05816.
+
+ """
+
+ if method.lower() == 'sinkhorn':
+ return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+
+ elif method.lower() == 'sinkhorn_stabilized':
+ return barycenter_unbalanced_stabilized(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr,
+ verbose=verbose,
+ log=log, **kwargs)
+ elif method.lower() in ['sinkhorn_reg_scaling']:
+ warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
+ return barycenter_unbalanced(A, M, reg, reg_m,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/pytest.ini
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 83ebba8..f70df10 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -7,6 +7,7 @@
import numpy as np
import ot
+import pytest
def test_sinkhorn():
@@ -71,13 +72,11 @@ def test_sinkhorn_variants():
Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
Ges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
- Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
- np.testing.assert_allclose(G0, Gerr)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
print(G0, G_green)
@@ -96,18 +95,17 @@ def test_sinkhorn_variants_log():
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
- Gerr, logerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10, log=True)
G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
- np.testing.assert_allclose(G0, Gerr)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
print(G0, G_green)
-def test_bary():
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_barycenter(method):
n_bins = 100 # nb bins
@@ -126,14 +124,42 @@ def test_bary():
weights = np.array([1 - alpha, alpha])
# wasserstein
- reg = 1e-3
- bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ reg = 1e-2
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights, method=method)
np.testing.assert_allclose(1, np.sum(bary_wass))
ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+def test_barycenter_stabilization():
+
+ n_bins = 100 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # wasserstein
+ reg = 1e-2
+ bar_stable = ot.bregman.barycenter(A, M, reg, weights,
+ method="sinkhorn_stabilized",
+ stopThr=1e-8)
+ bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
+ stopThr=1e-8)
+ np.testing.assert_allclose(bar, bar_stable)
+
+
def test_wasserstein_bary_2d():
size = 100 # size of a square image
@@ -279,3 +305,35 @@ def test_stabilized_vs_sinkhorn_multidim():
method="sinkhorn", log=True)
np.testing.assert_allclose(G, G2)
+
+
+def test_implemented_methods():
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ ONLY_1D_methods = ['greenkhorn', 'sinkhorn_epsilon_scaling']
+ NOT_VALID_TOKENS = ['foo']
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 3
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n)
+ A = rng.rand(n, 2)
+ M = ot.dist(x, x)
+ epsilon = 1.
+
+ for method in IMPLEMENTED_METHODS:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ with pytest.raises(ValueError):
+ for method in set(NOT_VALID_TOKENS):
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+ ot.bregman.barycenter(A, M, reg=epsilon, method=method)
+ for method in ONLY_1D_methods:
+ ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
+ with pytest.raises(ValueError):
+ ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index 1395fe1..ca1efba 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -7,9 +7,12 @@
import numpy as np
import ot
import pytest
+from ot.unbalanced import barycenter_unbalanced
+from scipy.special import logsumexp
-@pytest.mark.parametrize("method", ["sinkhorn"])
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_unbalanced_convergence(method):
# test generalized sinkhorn for unbalanced OT
n = 100
@@ -23,29 +26,35 @@ def test_unbalanced_convergence(method):
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
- G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
- stopThr=1e-10, method=method,
+ G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method=method,
log=True)
- loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (b / K.T.dot(log["u"])) ** fi
- u_final = (a / K.dot(log["v"])) ** fi
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)
+ logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
+ logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
+
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
# check if sinkhorn_unbalanced2 returns the correct loss
np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
-@pytest.mark.parametrize("method", ["sinkhorn"])
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
def test_unbalanced_multiple_inputs(method):
# test generalized sinkhorn for unbalanced OT
n = 100
@@ -59,28 +68,59 @@ def test_unbalanced_multiple_inputs(method):
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
- alpha=alpha,
- stopThr=1e-10, method=method,
+ reg_m=reg_m,
+ method=method,
log=True)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (b / K.T.dot(log["u"])) ** fi
-
- u_final = (a[:, None] / K.dot(log["v"])) ** fi
+ # in log-domain
+ fi = reg_m / (reg_m + epsilon)
+ logb = np.log(b + 1e-16)
+ loga = np.log(a + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logb - logKtu)
+ u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
assert len(loss) == b.shape[1]
-def test_unbalanced_barycenter():
+def test_stabilized_vs_sinkhorn():
+ # test if stable version matches sinkhorn
+ n = 100
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b1 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+ b2 = ot.datasets.make_1D_gauss(n, m=30, s=4)
+
+ # creating matrix A containing all distributions
+ b = np.vstack((b1, b2)).T
+
+ M = ot.utils.dist0(n)
+ M /= np.median(M)
+ epsilon = 0.1
+ reg_m = 1.
+ G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
+ method="sinkhorn_stabilized",
+ reg_m=reg_m,
+ log=True)
+ G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method="sinkhorn", log=True)
+
+ np.testing.assert_allclose(G, G2, atol=1e-5)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_unbalanced_barycenter(method):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -92,27 +132,56 @@ def test_unbalanced_barycenter():
A = A * np.array([1, 2])[None, :]
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
- K = np.exp(- M / epsilon)
+ reg_m = 1.
- q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha,
- stopThr=1e-10,
- log=True)
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method, log=True)
# check fixed point equations
- fi = alpha / (alpha + epsilon)
- v_final = (q[:, None] / K.T.dot(log["u"])) ** fi
- u_final = (A / K.dot(log["v"])) ** fi
+ fi = reg_m / (reg_m + epsilon)
+ logA = np.log(A + 1e-16)
+ logq = np.log(q + 1e-16)[:, None]
+ logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
+ axis=0)
+ logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ v_final = fi * (logq - logKtu)
+ u_final = fi * (logA - logKv)
np.testing.assert_allclose(
- u_final, log["u"], atol=1e-05)
+ u_final, log["logu"], atol=1e-05)
np.testing.assert_allclose(
- v_final, log["v"], atol=1e-05)
+ v_final, log["logv"], atol=1e-05)
+
+
+def test_barycenter_stabilized_vs_sinkhorn():
+ # test generalized sinkhorn for unbalanced OT barycenter
+ n = 100
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ A = rng.rand(n, 2)
+
+ # make dists unbalanced
+ A = A * np.array([1, 4])[None, :]
+ M = ot.dist(x, x)
+ epsilon = 0.5
+ reg_m = 10
+
+ qstable, log = barycenter_unbalanced(A, M, reg=epsilon,
+ reg_m=reg_m, log=True,
+ tau=100,
+ method="sinkhorn_stabilized",
+ )
+ q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method="sinkhorn",
+ log=True)
+
+ np.testing.assert_allclose(
+ q, qstable, atol=1e-05)
def test_implemented_methods():
- IMPLEMENTED_METHODS = ['sinkhorn']
- TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized',
- 'sinkhorn_epsilon_scaling']
+ IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
NOT_VALID_TOKENS = ['foo']
# test generalized sinkhorn for unbalanced OT barycenter
n = 3
@@ -123,24 +192,30 @@ def test_implemented_methods():
# make dists unbalanced
b = ot.utils.unif(n) * 1.5
-
+ A = rng.rand(n, 2)
M = ot.dist(x, x)
epsilon = 1.
- alpha = 1.
+ reg_m = 1.
for method in IMPLEMENTED_METHODS:
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
with pytest.warns(UserWarning, match='not implemented'):
for method in set(TO_BE_IMPLEMENTED_METHODS):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)
with pytest.raises(ValueError):
for method in set(NOT_VALID_TOKENS):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method=method)
+ barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
+ method=method)