summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHuy Tran <huytran82125@gmail.com>2023-02-23 15:31:20 +0100
committerGitHub <noreply@github.com>2023-02-23 15:31:20 +0100
commita313e21f223af16cf21d3b7dd01bd0c6345d574c (patch)
tree2231393f6f3b15e8237f20f50bab1e9d02b96cf2
parent80e3c23bc968f866fd20344ddc443a3c7fcb3b0d (diff)
[MRG] Added argument for warmstart of dual vectors in Sinkhorn-based methods in `ot.bregman` (#437)
* Allow warmstart in sinkhorn and sinkhorn_log * Added argument for warmstart of dual vectors in Sinkhorn-based methods in * Add the number of the PR * [WIP] CO-Optimal Transport * Revert "[WIP] CO-Optimal Transport" This reverts commit f3d36b2705013409ac69b346585e311bc25fcfb7. * reformat with PEP8 * Fix W291 trailing whitespace error in pep8 test * Rearange position of warmstart argument and edit its description --------- Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md88
-rw-r--r--ot/bregman.py228
-rw-r--r--test/test_bregman.py231
3 files changed, 372 insertions, 175 deletions
diff --git a/RELEASES.md b/RELEASES.md
index f8ef653..292d1df 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -12,8 +12,8 @@
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
- Added Free Support Sinkhorn Barycenter + example (PR #387)
- New API for OT solver using function `ot.solve` (PR #388)
-- Backend version of `ot.partial` and `ot.smooth` (PR #388)
-
+- Backend version of `ot.partial` and `ot.smooth` (PR #388)
+- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
#### Closed issues
@@ -35,10 +35,10 @@ roughly 2^31) (PR #381)
- Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402)
- Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409)
- Fixed weak optimal transport docstring (Issue #404, PR #410)
-- Fixed error with parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
+- Fixed error with parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
PR #413)
- Fixed an issue about `warn` parameter in `sinkhorn2` (PR #417)
-- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls
+- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls
that explicitly specified `stopThr=1e-9` (Issue #421, PR #422).
- Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425)
@@ -88,7 +88,7 @@ and [Factored coupling OT](https://pythonot.github.io/auto_examples/others/plot_
- Remove deprecated `ot.gpu` submodule (PR #361)
- Update examples in the gallery (PR #359)
-- Add stochastic loss and OT plan computation for regularized OT and
+- Add stochastic loss and OT plan computation for regularized OT and
backend examples(PR #360)
- Implementation of factored OT with emd and sinkhorn (PR #358)
- A brand new logo for POT (PR #357)
@@ -104,9 +104,9 @@ and [Factored coupling OT](https://pythonot.github.io/auto_examples/others/plot_
#### Closed issues
-- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are
+- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are
centered (Issue #364, PR #363)
-- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
+- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
PR #338)
- Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349)
- Warning when feeding integer cost matrix to EMD solver resulting in an integer transport plan (Issue #345, PR #343)
@@ -156,21 +156,21 @@ As always we want to that the contributors who helped make POT better (and bug f
- Fix bug in older Numpy ABI (<1.20) (Issue #308, PR #326)
- Fix bug in `ot.dist` function when non euclidean distance (Issue #305, PR #306)
-- Fix gradient scaling for functions using `nx.set_gradients` (Issue #309,
+- Fix gradient scaling for functions using `nx.set_gradients` (Issue #309,
PR #310)
-- Fix bug in generalized Conditional gradient solver and SinkhornL1L2
+- Fix bug in generalized Conditional gradient solver and SinkhornL1L2
(Issue #311, PR #313)
- Fix log error in `gromov_barycenters` (Issue #317, PR #3018)
## 0.8.0
*November 2021*
-This new stable release introduces several important features.
+This new stable release introduces several important features.
First we now have
an OpenMP compatible exact ot solver in `ot.emd`. The OpenMP version is used
when the parameter `numThreads` is greater than one and can lead to nice
-speedups on multi-core machines.
+speedups on multi-core machines.
Second we have introduced a backend mechanism that allows to use standard POT
function seamlessly on Numpy, Pytorch and Jax arrays. Other backends are coming
@@ -189,7 +189,7 @@ for a [sliced Wasserstein gradient
flow](https://PythonOT.github.io/auto_examples/backends/plot_sliced_wass_grad_flow_pytorch.html)
and [optimizing the Gromov-Wassersein distance](https://PythonOT.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html). Note that the Jax backend is still in early development and quite
slow at the moment, we strongly recommend for Jax users to use the [OTT
-toolbox](https://github.com/google-research/ott) when possible.
+toolbox](https://github.com/google-research/ott) when possible.
As a result of this new feature,
the old `ot.gpu` submodule is now deprecated since GPU
implementations can be done using GPU arrays on the torch backends.
@@ -212,7 +212,7 @@ Finally POT was accepted for publication in the Journal of Machine Learning
Research (JMLR) open source software track and we ask the POT users to cite [this
paper](https://www.jmlr.org/papers/v22/20-451.html) from now on. The documentation has been improved in particular by adding a
"Why OT?" section to the quick start guide and several new examples illustrating
-the new features. The documentation now has two version : the stable version
+the new features. The documentation now has two version : the stable version
[https://pythonot.github.io/](https://pythonot.github.io/)
corresponding to the last release and the master version [https://pythonot.github.io/master](https://pythonot.github.io/master) that corresponds to the
current master branch on GitHub.
@@ -222,7 +222,7 @@ As usual, we want to thank all the POT contributors (now 37 people have
contributed to the toolbox). But for this release we thank in particular Nathan
Cassereau and Kamel Guerda from the AI support team at
[IDRIS](http://www.idris.fr/) for their support to the development of the
-backend and OpenMP implementations.
+backend and OpenMP implementations.
#### New features
@@ -289,7 +289,7 @@ repository for the new documentation is now hosted at
This is the first release where the Python 2.7 tests have been removed. Most of
the toolbox should still work but we do not offer support for Python 2.7 and
-will close related Issues.
+will close related Issues.
A lot of changes have been done to the documentation that is now hosted on
[https://PythonOT.github.io/](https://PythonOT.github.io/) instead of
@@ -322,7 +322,7 @@ problems.
This release is also the moment to thank all the POT contributors (old and new)
for helping making POT such a nice toolbox. A lot of changes (also in the API)
-are coming for the next versions.
+are coming for the next versions.
#### Features
@@ -351,14 +351,14 @@ are coming for the next versions.
- Log bugs for Gromov-Wassertein solver (Issue #107, fixed in PR #108)
- Weight issues in barycenter function (PR #106)
-## 0.6.0
+## 0.6.0
*July 2019*
-This is the first official stable release of POT and this means a jump to 0.6!
+This is the first official stable release of POT and this means a jump to 0.6!
The library has been used in
the wild for a while now and we have reached a state where a lot of fundamental
OT solvers are available and tested. It has been quite stable in the last months
-but kept the beta flag in its Pypi classifiers until now.
+but kept the beta flag in its Pypi classifiers until now.
Note that this release will be the last one supporting officially Python 2.7 (See
https://python3statement.org/ for more reasons). For next release we will keep
@@ -387,7 +387,7 @@ graphs](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_fg
A lot of work has been done on the documentation with several new
examples corresponding to the new features and a lot of corrections for the
-docstrings. But the most visible change is a new
+docstrings. But the most visible change is a new
[quick start guide](https://pot.readthedocs.io/en/latest/quickstart.html) for
POT that gives several pointers about which function or classes allow to solve which
specific OT problem. When possible a link is provided to relevant examples.
@@ -425,29 +425,29 @@ bring new features and solvers to the library.
- Issue #72 Macosx build problem
-## 0.5.0
+## 0.5.0
*Sep 2018*
-POT is 2 years old! This release brings numerous new features to the
+POT is 2 years old! This release brings numerous new features to the
toolbox as listed below but also several bug correction.
-Among the new features, we can highlight a [non-regularized Gromov-Wasserstein
-solver](https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov.ipynb),
-a new [greedy variant of sinkhorn](https://pot.readthedocs.io/en/latest/all.html#ot.bregman.greenkhorn),
-[non-regularized](https://pot.readthedocs.io/en/latest/all.html#ot.lp.barycenter),
+Among the new features, we can highlight a [non-regularized Gromov-Wasserstein
+solver](https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov.ipynb),
+a new [greedy variant of sinkhorn](https://pot.readthedocs.io/en/latest/all.html#ot.bregman.greenkhorn),
+[non-regularized](https://pot.readthedocs.io/en/latest/all.html#ot.lp.barycenter),
[convolutional (2D)](https://github.com/rflamary/POT/blob/master/notebooks/plot_convolutional_barycenter.ipynb)
and [free support](https://github.com/rflamary/POT/blob/master/notebooks/plot_free_support_barycenter.ipynb)
- Wasserstein barycenters and [smooth](https://github.com/rflamary/POT/blob/prV0.5/notebooks/plot_OT_1D_smooth.ipynb)
- and [stochastic](https://pot.readthedocs.io/en/latest/all.html#ot.stochastic.sgd_entropic_regularization)
+ Wasserstein barycenters and [smooth](https://github.com/rflamary/POT/blob/prV0.5/notebooks/plot_OT_1D_smooth.ipynb)
+ and [stochastic](https://pot.readthedocs.io/en/latest/all.html#ot.stochastic.sgd_entropic_regularization)
implementation of entropic OT.
-POT 0.5 also comes with a rewriting of ot.gpu using the cupy framework instead of
-the unmaintained cudamat. Note that while we tried to keed changes to the
-minimum, the OTDA classes were deprecated. If you are happy with the cudamat
+POT 0.5 also comes with a rewriting of ot.gpu using the cupy framework instead of
+the unmaintained cudamat. Note that while we tried to keed changes to the
+minimum, the OTDA classes were deprecated. If you are happy with the cudamat
implementation, we recommend you stay with stable release 0.4 for now.
-The code quality has also improved with 92% code coverage in tests that is now
-printed to the log in the Travis builds. The documentation has also been
+The code quality has also improved with 92% code coverage in tests that is now
+printed to the log in the Travis builds. The documentation has also been
greatly improved with new modules and examples/notebooks.
This new release is so full of new stuff and corrections thanks to the old
@@ -466,24 +466,24 @@ and new POT contributors (you can see the list in the [readme](https://github.co
* Stochastic OT in the dual and semi-dual (PR #52 and PR #62)
* Free support barycenters (PR #56)
* Speed-up Sinkhorn function (PR #57 and PR #58)
-* Add convolutional Wassersein barycenters for 2D images (PR #64)
+* Add convolutional Wassersein barycenters for 2D images (PR #64)
* Add Greedy Sinkhorn variant (Greenkhorn) (PR #66)
* Big ot.gpu update with cupy implementation (instead of un-maintained cudamat) (PR #67)
#### Deprecation
-Deprecated OTDA Classes were removed from ot.da and ot.gpu for version 0.5
-(PR #48 and PR #67). The deprecation message has been for a year here since
+Deprecated OTDA Classes were removed from ot.da and ot.gpu for version 0.5
+(PR #48 and PR #67). The deprecation message has been for a year here since
0.4 and it is time to pull the plug.
#### Closed issues
* Issue #35 : remove import plot from ot/__init__.py (See PR #41)
* Issue #43 : Unusable parameter log for EMDTransport (See PR #44)
-* Issue #55 : UnicodeDecodeError: 'ascii' while installing with pip
+* Issue #55 : UnicodeDecodeError: 'ascii' while installing with pip
-## 0.4
+## 0.4
*15 Sep 2017*
This release contains a lot of contribution from new contributors.
@@ -493,14 +493,14 @@ This release contains a lot of contribution from new contributors.
* Automatic notebooks and doc update (PR #27)
* Add gromov Wasserstein solver and Gromov Barycenters (PR #23)
-* emd and emd2 can now return dual variables and have max_iter (PR #29 and PR #25)
+* emd and emd2 can now return dual variables and have max_iter (PR #29 and PR #25)
* New domain adaptation classes compatible with scikit-learn (PR #22)
* Proper tests with pytest on travis (PR #19)
* PEP 8 tests (PR #13)
#### Closed issues
-* emd convergence problem du to fixed max iterations (#24)
+* emd convergence problem du to fixed max iterations (#24)
* Semi supervised DA error (#26)
## 0.3.1
@@ -508,7 +508,7 @@ This release contains a lot of contribution from new contributors.
* Correct bug in emd on windows
-## 0.3
+## 0.3
*7 Jul 2017*
* emd* and sinkhorn* are now performed in parallel for multiple target distributions
@@ -521,7 +521,7 @@ This release contains a lot of contribution from new contributors.
* GPU implementations for sinkhorn and group lasso regularization
-## V0.2
+## V0.2
*7 Apr 2017*
* New dimensionality reduction method (WDA)
@@ -529,7 +529,7 @@ This release contains a lot of contribution from new contributors.
-## 0.1.11
+## 0.1.11
*5 Jan 2017*
* Add sphinx gallery for better documentation
@@ -537,7 +537,7 @@ This release contains a lot of contribution from new contributors.
* Add simple tic() toc() functions for timing
-## 0.1.10
+## 0.1.10
*7 Nov 2016*
* numerical stabilization for sinkhorn (log domain and epsilon scaling)
diff --git a/ot/bregman.py b/ot/bregman.py
index c33c92c..192a9e2 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -24,9 +24,8 @@ from ot.utils import unif, dist, list_to_array
from .backend import get_backend
-def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, warn=True,
- **kwargs):
+def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9,
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -101,6 +100,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -156,34 +158,33 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn,
+ warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn,
+ warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn)
+ warn=warn, warmstart=warmstart)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- return sinkhorn_epsilon_scaling(a, b, M, reg,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the loss
@@ -260,6 +261,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -324,17 +328,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
@@ -348,25 +352,24 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
- verbose=False, log=False, warn=True,
- **kwargs):
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -415,6 +418,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -474,12 +480,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
# we assume that no distances are null except those of the diagonal of
# distances
- if n_hists:
- u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
- v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
+ if warmstart is None:
+ if n_hists:
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
+ else:
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
else:
- u = nx.ones(dim_a, type_as=M) / dim_a
- v = nx.ones(dim_b, type_as=M) / dim_b
+ u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
K = nx.exp(M / (-reg))
@@ -547,7 +556,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
- log=False, warn=True, **kwargs):
+ log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem in log space
and return the OT matrix
@@ -596,6 +605,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -656,6 +668,10 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
else:
n_hists = 0
+ # in case of multiple historgrams
+ if n_hists > 1 and warmstart is None:
+ warmstart = [None] * n_hists
+
if n_hists: # we do not want to use tensors sor we do a loop
lst_loss = []
@@ -663,8 +679,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
lst_v = []
for k in range(n_hists):
- res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr,
+ verbose=verbose, log=log, warmstart=warmstart[k], **kwargs)
if log:
lst_loss.append(nx.sum(M * res[0]))
@@ -691,9 +707,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
# we assume that no distances are null except those of the diagonal of
# distances
-
- u = nx.zeros(dim_a, type_as=M)
- v = nx.zeros(dim_b, type_as=M)
+ if warmstart is None:
+ u = nx.zeros(dim_a, type_as=M)
+ v = nx.zeros(dim_b, type_as=M)
+ else:
+ u, v = warmstart
def get_logT(u, v):
if n_hists:
@@ -747,7 +765,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
- log=False, warn=True):
+ log=False, warn=True, warmstart=None):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -795,6 +813,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -853,8 +874,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
K = nx.exp(-M / reg)
- u = nx.full((dim_a,), 1. / dim_a, type_as=K)
- v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ if warmstart is None:
+ u = nx.full((dim_a,), 1. / dim_a, type_as=K)
+ v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ else:
+ u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
G = u[:, None] * K * v[None, :]
viol = nx.sum(G, axis=1) - a
@@ -1074,7 +1098,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# remove numerical problems and store them in K
if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
- alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
+ alpha, beta = alpha + reg * \
+ nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
else:
alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
@@ -1298,13 +1323,15 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
- err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2
+ err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \
+ nx.norm(nx.sum(transp, axis=1) - a) ** 2
if log:
log['err'].append(err)
if verbose:
if ii % (print_period * 10) == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err <= stopThr and ii > numItermin:
@@ -1648,8 +1675,10 @@ def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_ini
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
- T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs)
- T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i)
+ T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg,
+ numItermax=numInnerItermax, **kwargs)
+ T_sum = T_sum + weight_i * 1. / \
+ b[:, None] * nx.dot(T_i, measure_locations_i)
displacement_square_norm = nx.sum((T_sum - X) ** 2)
if log:
@@ -1658,7 +1687,8 @@ def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_ini
X = T_sum
if verbose:
- print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
+ print('iteration %d, displacement_square_norm=%f\n',
+ iter_count, displacement_square_norm)
iter_count += 1
@@ -2213,7 +2243,8 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
@@ -2291,7 +2322,8 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
@@ -2450,7 +2482,8 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
# debiased Sinkhorn does not converge monotonically
@@ -2530,7 +2563,8 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr and ii > 20:
break
@@ -2858,7 +2892,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
- log=False, warn=True, **kwargs):
+ log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -2911,6 +2945,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
@@ -2961,14 +2998,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
dict_log = {"err": []}
log_a, log_b = nx.log(a), nx.log(b)
- f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+ if warmstart is None:
+ f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+ else:
+ f, g = warmstart
if isinstance(batchSize, int):
bs, bt = batchSize, batchSize
elif isinstance(batchSize, tuple) and len(batchSize) == 2:
bs, bt = batchSize[0], batchSize[1]
else:
- raise ValueError("Batch size must be in integer or a tuple of two integers")
+ raise ValueError(
+ "Batch size must be in integer or a tuple of two integers")
range_s, range_t = range(0, ns, bs), range(0, nt, bt)
@@ -3006,7 +3047,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
M = nx.from_numpy(M, type_as=a)
m1_cols.append(
- nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1)
+ nx.sum(nx.exp(f[i:i + bs, None] +
+ g[None, :] - M / reg), axis=1)
)
m1 = nx.concatenate(m1_cols, axis=0)
err = nx.sum(nx.abs(m1 - a))
@@ -3014,7 +3056,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
dict_log["err"].append(err)
if verbose and (i_ot + 1) % 100 == 0:
- print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
+ print("Error in marginal at iteration {} = {}".format(
+ i_ot + 1, err))
if err <= stopThr:
break
@@ -3034,17 +3077,17 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=True, **kwargs)
+ verbose=verbose, log=True, warmstart=warmstart, **kwargs)
return pi, log
else:
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=False, **kwargs)
+ verbose=verbose, log=False, warmstart=warmstart, **kwargs)
return pi
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9, isLazy=False,
- batchSize=100, verbose=False, log=False, warn=True, **kwargs):
+ numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100,
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -3101,7 +3144,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
-
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -3157,13 +3202,16 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
isLazy=isLazy,
batchSize=batchSize,
verbose=verbose, log=log,
- warn=warn)
+ warn=warn,
+ warmstart=warmstart)
else:
f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
- numIterMax=numIterMax, stopThr=stopThr,
+ numIterMax=numIterMax,
+ stopThr=stopThr,
isLazy=isLazy, batchSize=batchSize,
verbose=verbose, log=log,
- warn=warn)
+ warn=warn,
+ warmstart=warmstart)
bs = batchSize if isinstance(batchSize, int) else batchSize[0]
range_s = range(0, ns, bs)
@@ -3190,19 +3238,18 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
if log:
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn, **kwargs)
+ warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss, log
else:
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn, **kwargs)
+ warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, warn=True,
- **kwargs):
+ numIterMax=10000, stopThr=1e-9, verbose=False,
+ log=False, warn=True, warmstart=None, **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data
@@ -3279,6 +3326,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -3308,24 +3358,31 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
X_s, X_t = list_to_array(X_s, X_t)
nx = get_backend(X_s, X_t)
+ if warmstart is None:
+ warmstart_a, warmstart_b = None, None
+ else:
+ u, v = warmstart
+ warmstart_a = (u, u)
+ warmstart_b = (v, v)
if log:
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
- numIterMax=numIterMax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart, **kwargs)
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
- numIterMax=numIterMax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_a, **kwargs)
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
- numIterMax=numIterMax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_b, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
+ (sinkhorn_loss_a + sinkhorn_loss_b)
log = {}
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
@@ -3340,20 +3397,21 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
else:
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart, **kwargs)
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_a, **kwargs)
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_b, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
+ (sinkhorn_loss_a + sinkhorn_loss_b)
return nx.maximum(0, sinkhorn_div)
@@ -3521,7 +3579,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
else:
aK_sort = nx.from_numpy(
- bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1],
+ bottleneck.partition(nx.to_numpy(
+ K_sum_cols), ns_budget - 1)[ns_budget - 1],
type_as=M
)
epsilon_u_square = a[0] / aK_sort
@@ -3531,7 +3590,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
else:
bK_sort = nx.from_numpy(
- bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1],
+ bottleneck.partition(nx.to_numpy(
+ K_sum_rows), nt_budget - 1)[nt_budget - 1],
type_as=M
)
epsilon_v_square = b[0] / bK_sort
diff --git a/test/test_bregman.py b/test/test_bregman.py
index ce15642..f01bb14 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -59,10 +59,12 @@ def test_convergence_warning(method):
with pytest.warns(UserWarning):
ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1)
with pytest.warns(UserWarning):
- ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True)
+ ot.sinkhorn2(a1, a2, M, 1, method=method,
+ stopThr=0, numItermax=1, warn=True)
with warnings.catch_warnings():
warnings.simplefilter("error")
- ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False)
+ ot.sinkhorn2(a1, a2, M, 1, method=method,
+ stopThr=0, numItermax=1, warn=False)
def test_not_implemented_method():
@@ -266,12 +268,16 @@ def test_sinkhorn_variants(nx):
ub, M_nx = nx.from_numpy(u, M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
Ges = nx.to_numpy(ot.sinkhorn(
ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
- G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
+ G_green = nx.to_numpy(ot.sinkhorn(
+ ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -371,9 +377,12 @@ def test_sinkhorn_variants_multi_b(nx):
ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(
+ ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -399,9 +408,12 @@ def test_sinkhorn2_variants_multi_b(nx):
ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(
+ ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -419,12 +431,16 @@ def test_sinkhorn_variants_log():
M = ot.dist(x, x)
- G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
- Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
- Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+ G0, log0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn',
+ stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
+ Gs, logs = ot.sinkhorn(
+ u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,)
- G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
+ G_green, loggreen = ot.sinkhorn(
+ u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
@@ -446,7 +462,8 @@ def test_sinkhorn_variants_log_multib(verbose, warn):
M = ot.dist(x, x)
- G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
+ G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn',
+ stopThr=1e-10, log=True)
Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True,
verbose=verbose, warn=warn)
Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True,
@@ -485,8 +502,10 @@ def test_barycenter(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
else:
# wasserstein
- bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn)
- bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass_np = ot.bregman.barycenter(
+ A, M, reg, weights, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(
+ A_nx, M_nx, reg, weights_nx, method=method, log=True)
bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass))
@@ -514,7 +533,8 @@ def test_free_support_sinkhorn_barycenter():
# Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization
# term to 1, but this should be, in general, fine-tuned to the problem.
- X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1)
+ X = ot.bregman.free_support_sinkhorn_barycenter(
+ measures_locations, measures_weights, X_init, reg=1)
# Verifies if calculated barycenter matches ground-truth
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
@@ -545,8 +565,10 @@ def test_barycenter_assymetric_cost(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, method=method)
else:
# wasserstein
- bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn)
- bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True)
+ bary_wass_np = ot.bregman.barycenter(
+ A, M, reg, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(
+ A_nx, M_nx, reg, method=method, log=True)
bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass))
@@ -581,17 +603,20 @@ def test_barycenter_debiased(nx, method, verbose, warn):
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
- ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
+ ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, weights, method=method)
else:
bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method,
verbose=verbose, warn=warn)
- bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass, _ = ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, weights_nx, method=method, log=True)
bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5)
- ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False)
+ ot.bregman.barycenter_debiased(
+ A_nx, M_nx, reg, log=True, verbose=False)
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
@@ -616,7 +641,8 @@ def test_convergence_warning_barycenters(method):
weights = np.array([1 - alpha, alpha])
reg = 0.1
with pytest.warns(UserWarning):
- ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1)
+ ot.bregman.barycenter_debiased(
+ A, M, reg, weights, method=method, numItermax=1)
with pytest.warns(UserWarning):
ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1)
with pytest.warns(UserWarning):
@@ -648,7 +674,8 @@ def test_barycenter_stabilization(nx):
# wasserstein
reg = 1e-2
- bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
+ bar_np = ot.bregman.barycenter(
+ A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
bar_stable = nx.to_numpy(ot.bregman.barycenter(
A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized",
stopThr=1e-8, verbose=True
@@ -683,8 +710,10 @@ def test_wasserstein_bary_2d(nx, method):
with pytest.raises(NotImplementedError):
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
else:
- bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True)
- bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
+ A, reg, method=method, verbose=True, log=True)
+ bary_wass = nx.to_numpy(
+ ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
@@ -713,10 +742,13 @@ def test_wasserstein_bary_2d_debiased(nx, method):
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
with pytest.raises(NotImplementedError):
- ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+ ot.bregman.convolutional_barycenter2d_debiased(
+ A_nx, reg, method=method)
else:
- bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True)
- bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
+ bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
+ A, reg, method=method, verbose=True, log=True)
+ bary_wass = nx.to_numpy(
+ ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
@@ -750,7 +782,8 @@ def test_unmix(nx):
# wasserstein
reg = 1e-3
um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01)
- um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
+ um = nx.to_numpy(ot.bregman.unmix(
+ ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
@@ -781,10 +814,12 @@ def test_empirical_sinkhorn(nx):
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, metric='euclidean'))
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
+ loss_emp_sinkhorn = nx.to_numpy(
+ ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
# check constraints
@@ -817,23 +852,27 @@ def test_lazy_empirical_sinkhorn(nx):
ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m)
- f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
+ f, g = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
- f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ f, g, log_es = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
- f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = ot.bregman.empirical_sinkhorn(
+ X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(
+ X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
@@ -865,22 +904,27 @@ def test_empirical_sinkhorn_divergence(nx):
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t)
+ ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(
+ a, b, X_s, X_t, M, M_s, M_t)
- emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
+ emp_sinkhorn_div = nx.to_numpy(
+ ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
sinkhorn_div = nx.to_numpy(
ot.sinkhorn2(ab, bb, M_nx, 1)
- 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1)
- 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1)
)
- emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
+ emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(
+ X_s, X_t, 1, a=a, b=b)
# check constraints
- np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
+ np.testing.assert_allclose(
+ emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
- ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
+ ot.bregman.empirical_sinkhorn_divergence(
+ X_sb, X_tb, 1, a=ab, b=bb, log=True)
@pytest.mark.skipif(not torch, reason="No torch available")
@@ -902,7 +946,8 @@ def test_empirical_sinkhorn_divergence_gradient():
X_sb.requires_grad = True
X_tb.requires_grad = True
- emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)
+ emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(
+ X_sb, X_tb, 1, a=ab, b=bb)
emp_sinkhorn_div.backward()
@@ -931,7 +976,8 @@ def test_stabilized_vs_sinkhorn_multidim(nx):
ab, bb, M_nx = nx.from_numpy(a, b, M)
- G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
+ G_np, _ = ot.bregman.sinkhorn(
+ a, b, M, reg=epsilon, method="sinkhorn", log=True)
G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
method="sinkhorn_stabilized",
log=True)
@@ -996,7 +1042,8 @@ def test_screenkhorn(nx):
# sinkhorn
G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1))
# screenkhorn
- G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(
+ ab, bb, M_nx, 1e-1, uniform=True, verbose=True))
# check marginals
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
@@ -1013,3 +1060,93 @@ def test_convolutional_barycenter_non_square(nx):
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(b, b_np)
+
+
+def test_sinkhorn_warmstart():
+ m, n = 10, 20
+ a = ot.unif(m)
+ b = ot.unif(n)
+
+ Xs = np.arange(m) * 1.0
+ Xt = np.arange(n) * 1.0
+ M = ot.dist(Xs.reshape(-1, 1), Xt.reshape(-1, 1))
+
+ # Generate warmstart from dual vectors of unregularized OT
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ pi_unif, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn", log=True, warmstart=None)
+ # Optimal plan with warmstart generated from unregularized OT
+ pi_sh, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn", log=True, warmstart=warmstart)
+ pi_sh_log, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_log", log=True, warmstart=warmstart)
+ pi_sh_stab, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_stabilized", log=True, warmstart=warmstart)
+ pi_sh_sc, _ = ot.bregman.sinkhorn(
+ a, b, M, reg, method="sinkhorn_epsilon_scaling", log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(pi_unif, pi_sh, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_log, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_stab, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_sh_sc, atol=1e-05)
+
+
+def test_empirical_sinkhorn_warmstart():
+ m, n = 10, 20
+ Xs = np.arange(m).reshape(-1, 1) * 1.0
+ Xt = np.arange(n).reshape(-1, 1) * 1.0
+ M = ot.dist(Xs, Xt)
+
+ # Generate warmstart from dual vectors of unregularized OT
+ a = ot.unif(m)
+ b = ot.unif(n)
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ f, g, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None)
+ pi_unif = np.exp(f[:, None] + g[None, :] - M / reg)
+ # Optimal plan with warmstart generated from unregularized OT
+ f, g, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart)
+ pi_ws_lazy = np.exp(f[:, None] + g[None, :] - M / reg)
+ pi_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(pi_unif, pi_ws_lazy, atol=1e-05)
+ np.testing.assert_allclose(pi_unif, pi_ws_not_lazy, atol=1e-05)
+
+
+def test_empirical_sinkhorn_divergence_warmstart():
+ m, n = 10, 20
+ Xs = np.arange(m).reshape(-1, 1) * 1.0
+ Xt = np.arange(n).reshape(-1, 1) * 1.0
+ M = ot.dist(Xs, Xt)
+
+ # Generate warmstart from dual vectors of unregularized OT
+ a = ot.unif(m)
+ b = ot.unif(n)
+ _, log = ot.lp.emd(a, b, M, log=True)
+ warmstart = (log["u"], log["v"])
+
+ reg = 1
+
+ # Optimal plan with uniform warmstart
+ sd_unif, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=None)
+ # Optimal plan with warmstart generated from unregularized OT
+ sd_ws_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=True, log=True, warmstart=warmstart)
+ sd_ws_not_lazy, _ = ot.bregman.empirical_sinkhorn_divergence(
+ X_s=Xs, X_t=Xt, reg=reg, isLazy=False, log=True, warmstart=warmstart)
+
+ np.testing.assert_allclose(sd_unif, sd_ws_lazy, atol=1e-05)
+ np.testing.assert_allclose(sd_unif, sd_ws_not_lazy, atol=1e-05)