From 24a7a0439e631e90ff84ce84d0a78bc22846cf71 Mon Sep 17 00:00:00 2001
From: Panayiotis Panayiotou
Date: Mon, 24 Aug 2020 15:40:05 +0300
Subject: Check if alpha is not None when restricting it to be at most 1 (#199)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* Check if alpha is not None when restricting it to be at most 1
* Write check more clearly
* Add no regression test for line search armijo returning None for alpha
Co-authored-by: Rémi Flamary
---
test/test_optim.py | 10 ++++++++++
1 file changed, 10 insertions(+)
(limited to 'test')
diff --git a/test/test_optim.py b/test/test_optim.py
index 87b0268..48de38a 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -104,3 +104,13 @@ def test_solve_1d_linesearch_quad_funct():
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
+
+
+def test_line_search_armijo():
+ xk = np.array([[0.25, 0.25], [0.25, 0.25]])
+ pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
+ gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
+ old_fval = -123
+ # Should not throw an exception and return None for alpha
+ alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
+ assert alpha is None
--
cgit v1.2.3
From 7adc1b1aa73c55dc07983ff08dcb23fd71e9e8b6 Mon Sep 17 00:00:00 2001
From: Rémi Flamary
Date: Thu, 22 Oct 2020 10:16:40 +0200
Subject: [MRG] Cleanup minimal build and add separate build for pep8 (#210)
* cleanup requiorement minimal
* add pep8 build
* cleanup sklearn
* skip test if no sklearn
* debug build yaml
* comment error out in test (test sklearn)
* maybe small stuff for better robustness : copy the sub-array
* bump verison minimal build
* update version strict requireent
* update version strict requirement last change
---
.github/requirements_strict.txt | 9 +++------
.github/workflows/build_tests.yml | 37 +++++++++++++++++++++++++------------
.gitignore | 3 +++
Makefile | 4 ++--
ot/lp/__init__.py | 2 +-
test/test_da.py | 8 ++++++++
6 files changed, 42 insertions(+), 21 deletions(-)
(limited to 'test')
diff --git a/.github/requirements_strict.txt b/.github/requirements_strict.txt
index d7539c5..9a1ada4 100644
--- a/.github/requirements_strict.txt
+++ b/.github/requirements_strict.txt
@@ -1,7 +1,4 @@
-numpy==1.16.*
-scipy==1.0.*
-cython==0.23.*
-matplotlib
-cvxopt
-scikit-learn
+numpy
+scipy>=1.3
+cython
pytest
diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index 41b08b3..fa814ba 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -30,14 +30,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- pip install flake8 pytest "pytest-cov<2.6" codecov
- pip install -U "sklearn"
- - name: Lint with flake8
- run: |
- # stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
- # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics
+ pip install pytest "pytest-cov<2.6" codecov
- name: Install POT
run: |
pip install -e .
@@ -48,6 +41,29 @@ jobs:
run: |
codecov
+ pep8:
+ runs-on: ubuntu-latest
+ strategy:
+ max-parallel: 4
+ matrix:
+ python-version: [3.8]
+
+ steps:
+ - uses: actions/checkout@v1
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v1
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install flake8
+ - name: Lint with flake8
+ run: |
+ # stop the build if there are Python syntax errors or undefined names
+ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
+ # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
+ flake8 examples/ ot/ test/ --count --max-line-length=127 --statistics
linux-minimal-deps:
@@ -55,7 +71,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
- python-version: [3.6]
+ python-version: [3.8]
steps:
- uses: actions/checkout@v1
@@ -68,7 +84,6 @@ jobs:
python -m pip install --upgrade pip
pip install -r .github/requirements_strict.txt
pip install pytest
- pip install -U "sklearn"
- name: Install POT
run: |
pip install -e .
@@ -95,7 +110,6 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest "pytest-cov<2.6"
- pip install -U "sklearn"
- name: Install POT
run: |
pip install -e .
@@ -122,7 +136,6 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest "pytest-cov<2.6"
- pip install -U "sklearn"
- name: Install POT
run: |
pip install -e .
diff --git a/.gitignore b/.gitignore
index a2ace7c..b44ea43 100644
--- a/.gitignore
+++ b/.gitignore
@@ -40,6 +40,9 @@ var/
*.manifest
*.spec
+# env
+pythonenv3.8/
+
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
diff --git a/Makefile b/Makefile
index 70cdbdd..32332b4 100644
--- a/Makefile
+++ b/Makefile
@@ -45,10 +45,10 @@ pep8 :
flake8 examples/ ot/ test/
test : FORCE pep8
- $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot --cov-report html:cov_html
+ $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/
pytest : FORCE
- $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/ --cov=ot
+ $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/
release :
twine upload dist/*
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 2a1b082..f08e020 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -426,7 +426,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
nb = b.shape[1]
if processes > 1:
- res = parmap(f, [b[:, i] for i in range(nb)], processes)
+ res = parmap(f, [b[:, i].copy() for i in range(nb)], processes)
else:
res = list(map(f, [b[:, i].copy() for i in range(nb)]))
diff --git a/test/test_da.py b/test/test_da.py
index 3b28119..52c6a48 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -6,11 +6,18 @@
import numpy as np
from numpy.testing import assert_allclose, assert_equal
+import pytest
import ot
from ot.datasets import make_data_classif
from ot.utils import unif
+try: # test if cudamat installed
+ import sklearn # noqa: F401
+ nosklearn = False
+except ImportError:
+ nosklearn = True
+
def test_sinkhorn_lpl1_transport_class():
"""test_sinkhorn_transport
@@ -691,6 +698,7 @@ def test_jcpot_barycenter():
np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3)
+@pytest.mark.skipif(nosklearn, reason="No sklearn available")
def test_emd_laplace_class():
"""test_emd_laplace_transport
"""
--
cgit v1.2.3
From 78b44af2434f494c8f9e4c8c91003fbc0e1d4415 Mon Sep 17 00:00:00 2001
From: AdrienCorenflos
Date: Thu, 22 Oct 2020 09:28:53 +0100
Subject: [MRG] Sliced wasserstein (#203)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* example for log treatment in bregman.py
* Improve doc
* Revert "example for log treatment in bregman.py"
This reverts commit 9f51c14e
* Add comments by Flamary
* Delete repetitive description
* Added raw string to avoid pbs with backslashes
* Implements sliced wasserstein
* Changed formatting of string for py3.5 support
* Docstest, expected 0.0 and not 0.
* Adressed comments by @rflamary
* No 3d plot here
* add sliced to the docs
* Incorporate comments by @rflamary
* add link to pdf
Co-authored-by: Rémi Flamary
---
README.md | 4 +
docs/source/all.rst | 1 +
examples/sliced-wasserstein/README.txt | 4 +
examples/sliced-wasserstein/plot_variance.py | 84 ++++++++++++++++
ot/__init__.py | 3 +-
ot/sliced.py | 144 +++++++++++++++++++++++++++
test/test_sliced.py | 85 ++++++++++++++++
7 files changed, 324 insertions(+), 1 deletion(-)
create mode 100644 examples/sliced-wasserstein/README.txt
create mode 100644 examples/sliced-wasserstein/plot_variance.py
create mode 100644 ot/sliced.py
create mode 100644 test/test_sliced.py
(limited to 'test')
diff --git a/README.md b/README.md
index e3598f1..6fe528a 100644
--- a/README.md
+++ b/README.md
@@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples):
* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
+* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32].
POT provides the following Machine Learning related solvers:
@@ -180,6 +181,7 @@ The contributors to this library are
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
+* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
@@ -263,3 +265,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276.
[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+
+[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
diff --git a/docs/source/all.rst b/docs/source/all.rst
index d7b878f..f1f7075 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -27,6 +27,7 @@ API and modules
stochastic
unbalanced
partial
+ sliced
.. autosummary::
:toctree: ../modules/generated/
diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt
new file mode 100644
index 0000000..a575345
--- /dev/null
+++ b/examples/sliced-wasserstein/README.txt
@@ -0,0 +1,4 @@
+
+
+Sliced Wasserstein Distance
+---------------------------
\ No newline at end of file
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
new file mode 100644
index 0000000..f3deeff
--- /dev/null
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+2D Sliced Wasserstein Distance
+==============================
+
+This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31].
+
+[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+"""
+
+# Author: Adrien Corenflos
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import numpy as np
+
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+# %% parameters and data generation
+
+n = 500 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+##############################################################################
+# Plot data
+# ---------
+
+# %% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+###################################################################################
+# Compute Sliced Wasserstein distance for different seeds and number of projections
+# -----------
+
+n_seed = 50
+n_projections_arr = np.logspace(0, 3, 25, dtype=int)
+res = np.empty((n_seed, 25))
+
+# %% Compute statistics
+for seed in range(n_seed):
+ for i, n_projections in enumerate(n_projections_arr):
+ res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed)
+
+res_mean = np.mean(res, axis=0)
+res_std = np.std(res, axis=0)
+
+###################################################################################
+# Plot Sliced Wasserstein Distance
+# -----------
+
+pl.figure(2)
+pl.plot(n_projections_arr, res_mean, label="SWD")
+pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)
+
+pl.legend()
+pl.xscale('log')
+
+pl.xlabel("Number of projections")
+pl.ylabel("Distance")
+pl.title('Sliced Wasserstein Distance with 95% confidence inverval')
+
+pl.show()
diff --git a/ot/__init__.py b/ot/__init__.py
index 0e6e2e2..ec3ede2 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -39,6 +39,7 @@ from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
from .bregman import sinkhorn, sinkhorn2, barycenter
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
from .da import sinkhorn_lpl1_mm
+from .sliced import sliced_wasserstein_distance
# utils functions
from .utils import dist, unif, tic, toc, toq
@@ -50,4 +51,4 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets'
'emd_1d', 'emd2_1d', 'wasserstein_1d',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
- 'sinkhorn_unbalanced2']
+ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance']
diff --git a/ot/sliced.py b/ot/sliced.py
new file mode 100644
index 0000000..4792576
--- /dev/null
+++ b/ot/sliced.py
@@ -0,0 +1,144 @@
+"""
+Sliced Wasserstein Distance.
+
+"""
+
+# Author: Adrien Corenflos
+#
+# License: MIT License
+
+
+import numpy as np
+
+
+def get_random_projections(n_projections, d, seed=None):
+ r"""
+ Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})`
+
+ Parameters
+ ----------
+ n_projections : int
+ number of samples requested
+ d : int
+ dimension of the space
+ seed: int or RandomState, optional
+ Seed used for numpy random number generator
+
+ Returns
+ -------
+ out: ndarray, shape (n_projections, d)
+ The uniform unit vectors on the sphere
+
+ Examples
+ --------
+ >>> n_projections = 100
+ >>> d = 5
+ >>> projs = get_random_projections(n_projections, d)
+ >>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE
+ True
+
+ """
+
+ if not isinstance(seed, np.random.RandomState):
+ random_state = np.random.RandomState(seed)
+ else:
+ random_state = seed
+
+ projections = random_state.normal(0., 1., [n_projections, d])
+ norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True)
+ projections = projections / norm
+ return projections
+
+
+def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False):
+ r"""
+ Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance
+
+ .. math::
+ \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}}
+
+ where :
+
+ - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle`
+
+
+ Parameters
+ ----------
+ X_s : ndarray, shape (n_samples_a, dim)
+ samples in the source domain
+ X_t : ndarray, shape (n_samples_b, dim)
+ samples in the target domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,), optional
+ samples weights in the target domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ seed: int or RandomState or None, optional
+ Seed used for numpy random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Sliced Wasserstein Cost
+ log : dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_samples_a = 20
+ >>> reg = 0.1
+ >>> X = np.random.normal(0., 1., (n_samples_a, 5))
+ >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
+ 0.0
+
+ References
+ ----------
+
+ .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+ """
+ from .lp import emd2_1d
+
+ X_s = np.asanyarray(X_s)
+ X_t = np.asanyarray(X_t)
+
+ n = X_s.shape[0]
+ m = X_t.shape[0]
+
+ if X_s.shape[1] != X_t.shape[1]:
+ raise ValueError(
+ "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
+ X_t.shape[1]))
+
+ if a is None:
+ a = np.full(n, 1 / n)
+ if b is None:
+ b = np.full(m, 1 / m)
+
+ d = X_s.shape[1]
+
+ projections = get_random_projections(n_projections, d, seed)
+
+ X_s_projections = np.dot(projections, X_s.T)
+ X_t_projections = np.dot(projections, X_t.T)
+
+ if log:
+ projected_emd = np.empty(n_projections)
+ else:
+ projected_emd = None
+
+ res = 0.
+
+ for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)):
+ emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False)
+ if projected_emd is not None:
+ projected_emd[i] = emd
+ res += emd
+
+ res = (res / n_projections) ** 0.5
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
diff --git a/test/test_sliced.py b/test/test_sliced.py
new file mode 100644
index 0000000..a07d975
--- /dev/null
+++ b/test/test_sliced.py
@@ -0,0 +1,85 @@
+"""Tests for module sliced"""
+
+# Author: Adrien Corenflos
+#
+# License: MIT License
+
+import numpy as np
+import pytest
+
+import ot
+from ot.sliced import get_random_projections
+
+
+def test_get_random_projections():
+ rng = np.random.RandomState(0)
+ projections = get_random_projections(1000, 50, rng)
+ np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.)
+
+
+def test_sliced_same_dist():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+
+ res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng)
+ np.testing.assert_almost_equal(res, 0.)
+
+
+def test_sliced_bad_shapes():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ y = rng.randn(n, 4)
+ u = ot.utils.unif(n)
+
+ with pytest.raises(ValueError):
+ _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
+
+
+def test_sliced_log():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 4)
+ y = rng.randn(n, 4)
+ u = ot.utils.unif(n)
+
+ res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True)
+ assert len(log) == 2
+ projections = log["projections"]
+ projected_emds = log["projected_emds"]
+
+ assert len(projections) == len(projected_emds) == 10
+ for emd in projected_emds:
+ assert emd > 0
+
+
+def test_sliced_different_dists():
+ n = 100
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 2)
+ u = ot.utils.unif(n)
+ y = rng.randn(n, 2)
+
+ res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
+ assert res > 0.
+
+
+def test_1d_sliced_equals_emd():
+ n = 100
+ m = 120
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n, 1)
+ a = rng.uniform(0, 1, n)
+ a /= a.sum()
+ y = rng.randn(m, 1)
+ u = ot.utils.unif(m)
+ res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42)
+ expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u)
+ np.testing.assert_almost_equal(res ** 2, expected)
--
cgit v1.2.3
From 93785eba11b59d544f1edde6661e93ee587148ee Mon Sep 17 00:00:00 2001
From: Laetitia Chapel
Date: Thu, 22 Oct 2020 10:58:31 +0200
Subject: [MRG] Fix bugs for partial OT (#215)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* bugfix
* update refs partial OT
* fixes small typos in plot_partial_wass_and_gromov
* fix small bugs in partial.py
* update README
* pep8 bugfix
* modif doctest
* fix bugtests
* update on test_partial and test on the numerical precision on ot/partial
* resolve merge pb
Co-authored-by: Rémi Flamary
---
README.md | 2 +-
.../plot_partial_wass_and_gromov.py | 23 ++++---
ot/partial.py | 71 +++++++++++++---------
test/test_partial.py | 6 +-
4 files changed, 60 insertions(+), 42 deletions(-)
(limited to 'test')
diff --git a/README.md b/README.md
index 6fe528a..238faed 100644
--- a/README.md
+++ b/README.md
@@ -262,7 +262,7 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[28] Caffarelli, L. A., McCann, R. J. (2010). [Free boundaries in optimal transport and Monge-Ampere obstacle problems](http://www.math.toronto.edu/~mccann/papers/annals2010.pdf), Annals of mathematics, 673-730.
-[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276.
+[29] Chapel, L., Alaya, M., Gasso, G. (2020). [Partial Optimal Transport with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), Advances in Neural Information Processing Systems (NeurIPS), 2020.
[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py
index 0c5cbf9..ac4194c 100755
--- a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py
+++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py
@@ -4,7 +4,7 @@
Partial Wasserstein and Gromov-Wasserstein example
==================================================
-This example is designed to show how to use the Partial (Gromov-)Wassertsein
+This example is designed to show how to use the Partial (Gromov-)Wasserstein
distance computation in POT.
"""
@@ -123,11 +123,12 @@ C1 = sp.spatial.distance.cdist(xs, xs)
C2 = sp.spatial.distance.cdist(xt, xt)
# transport 100% of the mass
-print('-----m = 1')
+print('------m = 1')
m = 1
res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
- m=m, log=True)
+ m=m, log=True,
+ verbose=True)
print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist']))
@@ -136,18 +137,20 @@ pl.figure(1, (10, 5))
pl.title("mass to be transported m = 1")
pl.subplot(1, 2, 1)
pl.imshow(res0, cmap='jet')
-pl.title('Wasserstein')
+pl.title('Gromov-Wasserstein')
pl.subplot(1, 2, 2)
pl.imshow(res, cmap='jet')
-pl.title('Entropic Wasserstein')
+pl.title('Entropic Gromov-Wasserstein')
pl.show()
# transport 2/3 of the mass
-print('-----m = 2/3')
+print('------m = 2/3')
m = 2 / 3
-res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
+res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True,
+ verbose=True)
res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
- m=m, log=True)
+ m=m, log=True,
+ verbose=True)
print('Partial Wasserstein distance (m = 2/3): ' +
str(log0['partial_gw_dist']))
@@ -158,8 +161,8 @@ pl.figure(1, (10, 5))
pl.title("mass to be transported m = 2/3")
pl.subplot(1, 2, 1)
pl.imshow(res0, cmap='jet')
-pl.title('Partial Wasserstein')
+pl.title('Partial Gromov-Wasserstein')
pl.subplot(1, 2, 2)
pl.imshow(res, cmap='jet')
-pl.title('Entropic partial Wasserstein')
+pl.title('Entropic partial Gromov-Wasserstein')
pl.show()
diff --git a/ot/partial.py b/ot/partial.py
index eb707d8..814d779 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -230,9 +230,9 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
.. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in
optimal transport and Monge-Ampere obstacle problems. Annals of
mathematics, 673-730.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
See Also
--------
@@ -254,7 +254,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
M_extended = np.zeros((len(a_extended), len(b_extended)))
- M_extended[-1, -1] = np.max(M) * 1e5
+ M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5
M_extended[:len(a), :len(b)] = M
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
@@ -344,14 +344,13 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
.. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in
optimal transport and Monge-Ampere obstacle problems. Annals of
mathematics, 673-730.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True,
**kwargs)
-
log_w['T'] = partial_gw
if log:
@@ -501,14 +500,14 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
>>> np.round(partial_gromov_wasserstein(C1, C2, a, b, m=0.25),2)
array([[0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. ],
- [0. , 0. , 0. , 0. ],
- [0. , 0. , 0. , 0.25]])
+ [0. , 0. , 0.25, 0. ],
+ [0. , 0. , 0. , 0. ]])
References
----------
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
@@ -530,20 +529,18 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
cpt = 0
err = 1
- eps = 1e-20
+
if log:
log = {'err': []}
while (err > tol and cpt < numItermax):
- Gprev = G0
+ Gprev = np.copy(G0)
M = gwgrad_partial(C1, C2, G0)
- M[M < eps] = np.quantile(M, thres)
-
M_emd = np.zeros(dim_G_extended)
M_emd[:len(p), :len(q)] = M
- M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5
+ M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2
M_emd = np.asarray(M_emd, dtype=np.float64)
Gc, logemd = emd(p_extended, q_extended, M_emd, log=True, **kwargs)
@@ -565,6 +562,22 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
print('{:5d}|{:8e}|{:8e}'.format(cpt, err,
gwloss_partial(C1, C2, G0)))
+ deltaG = G0 - Gprev
+ a = gwloss_partial(C1, C2, deltaG)
+ b = 2 * np.sum(M * deltaG)
+ if b > 0: # due to numerical precision
+ gamma = 0
+ cpt = numItermax
+ elif a > 0:
+ gamma = min(1, np.divide(-b, 2.0 * a))
+ else:
+ if (a + b) < 0:
+ gamma = 1
+ else:
+ gamma = 0
+ cpt = numItermax
+
+ G0 = Gprev + gamma * deltaG
cpt += 1
if log:
@@ -665,9 +678,9 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
References
----------
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
@@ -887,12 +900,12 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
>>> y = np.array([3,2,98,199]).reshape((-1,1))
>>> C1 = sp.spatial.distance.cdist(x, x)
>>> C2 = sp.spatial.distance.cdist(y, y)
- >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b,50), 2)
+ >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50), 2)
array([[0.12, 0.13, 0. , 0. ],
[0.13, 0.12, 0. , 0. ],
[0. , 0. , 0.25, 0. ],
[0. , 0. , 0. , 0.25]])
- >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50, m=0.25), 2)
+ >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50,0.25), 2)
array([[0.02, 0.03, 0. , 0.03],
[0.03, 0.03, 0. , 0.03],
[0. , 0. , 0.03, 0. ],
@@ -910,9 +923,9 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
See Also
--------
@@ -1044,9 +1057,9 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
partial_gw, log_gw = entropic_partial_gromov_wasserstein(C1, C2, p, q, reg,
diff --git a/test/test_partial.py b/test/test_partial.py
index 510e081..121f345 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -51,10 +51,12 @@ def test_raise_errors():
ot.partial.partial_gromov_wasserstein(M, M, p, q, m=-1, log=True)
with pytest.raises(ValueError):
- ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, log=True)
+ ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2,
+ log=True)
with pytest.raises(ValueError):
- ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, log=True)
+ ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1,
+ log=True)
def test_partial_wasserstein_lagrange():
--
cgit v1.2.3
From 2e97be778d2d72d7a66b3721ee697399522538ba Mon Sep 17 00:00:00 2001
From: Rémi Flamary
Date: Thu, 8 Apr 2021 11:09:50 +0200
Subject: [MRG] ADD JMLR paper to the readme and documentation (#231)
* add JMLR reefrence to eradme and doc
---
README.md | 20 ++++++++++++--------
docs/source/readme.rst | 50 +++++++++++++++++++++++++++++++++-----------------
test/test_ot.py | 6 +++---
3 files changed, 48 insertions(+), 28 deletions(-)
(limited to 'test')
diff --git a/README.md b/README.md
index 238faed..7321aff 100644
--- a/README.md
+++ b/README.md
@@ -50,19 +50,23 @@ Some other examples are available in the [documentation](https://pythonot.githu
#### Using and citing the toolbox
If you use this toolbox in your research and find it useful, please cite POT
-using the following reference:
+using the following reference from our [JMLR paper](https://jmlr.org/papers/v22/20-451.html):
```
-Rémi Flamary and Nicolas Courty, POT Python Optimal Transport library,
-Website: https://pythonot.github.io/, 2017
+Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021.
+Website: https://pythonot.github.io/
```
In Bibtex format:
```
-@misc{flamary2017pot,
-title={POT Python Optimal Transport library},
-author={Flamary, R{'e}mi and Courty, Nicolas},
-url={https://pythonot.github.io/},
-year={2017}
+@article{flamary2021pot,
+ author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
+ title = {POT: Python Optimal Transport},
+ journal = {Journal of Machine Learning Research},
+ year = {2021},
+ volume = {22},
+ number = {78},
+ pages = {1-8},
+ url = {http://jmlr.org/papers/v22/20-451.html}
}
```
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
index b8cb48c..f35f01b 100644
--- a/docs/source/readme.rst
+++ b/docs/source/readme.rst
@@ -66,6 +66,9 @@ POT provides the following generic OT solvers (links to examples):
- `Partial Wasserstein and
Gromov-Wasserstein `__
(exact [29] and entropic [3] formulations).
+- `Sliced
+ Wasserstein `__
+ [31, 32].
POT provides the following Machine Learning related solvers:
@@ -96,22 +99,27 @@ Using and citing the toolbox
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If you use this toolbox in your research and find it useful, please cite
-POT using the following reference:
+POT using the following reference from our `JMLR
+paper `__:
::
- Rémi Flamary and Nicolas Courty, POT Python Optimal Transport library,
- Website: https://pythonot.github.io/, 2017
+ Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021.
+ Website: https://pythonot.github.io/
In Bibtex format:
::
- @misc{flamary2017pot,
- title={POT Python Optimal Transport library},
- author={Flamary, R{'e}mi and Courty, Nicolas},
- url={https://pythonot.github.io/},
- year={2017}
+ @article{flamary2021pot,
+ author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
+ title = {POT: Python Optimal Transport},
+ journal = {Journal of Machine Learning Research},
+ year = {2021},
+ volume = {22},
+ number = {78},
+ pages = {1-8},
+ url = {http://jmlr.org/papers/v22/20-451.html}
}
Installation
@@ -269,6 +277,8 @@ The contributors to this library are
- `Romain Tavenard `__ (1d Wasserstein)
- `Mokhtar Z. Alaya `__ (Screenkhorn)
- `Ievgen Redko `__ (Laplacian DA, JCPOT)
+- `Adrien Corenflos `__ (Sliced
+ Wasserstein Distance)
This toolbox benefit a lot from open source research and we would like
to thank the following persons for providing some code (in various
@@ -285,20 +295,21 @@ Contributions and code of conduct
---------------------------------
Every contribution is welcome and should respect the `contribution
-guidelines `__. Each member of the project is expected
-to follow the `code of conduct `__.
+guidelines <.github/CONTRIBUTING.md>`__. Each member of the project is
+expected to follow the `code of conduct <.github/CODE_OF_CONDUCT.md>`__.
Support
-------
You can ask questions and join the development discussion:
-- On the `POT Slack channel `__
+- On the POT `slack channel `__
+- On the POT `gitter channel `__
- On the POT `mailing
list `__
You can also post bug reports and feature requests in Github issues.
-Make sure to read our `guidelines `__ first.
+Make sure to read our `guidelines <.github/CONTRIBUTING.md>`__ first.
References
----------
@@ -439,10 +450,10 @@ optimal transport and Monge-Ampere obstacle
problems `__,
Annals of mathematics, 673-730.
-[29] Chapel, L., Alaya, M., Gasso, G. (2019). `Partial
-Gromov-Wasserstein with Applications on Positive-Unlabeled
-Learning `__, arXiv preprint
-arXiv:2002.08276.
+[29] Chapel, L., Alaya, M., Gasso, G. (2020). `Partial Optimal Transport
+with Applications on Positive-Unlabeled
+Learning `__, Advances in Neural
+Information Processing Systems (NeurIPS), 2020.
[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). `Optimal
transport with Laplacian regularization: Applications to domain
@@ -450,11 +461,16 @@ adaptation and shape
matching `__,
NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+[31] Bonneel, Nicolas, et al. `Sliced and radon wasserstein barycenters
+of
+measures `__,
+Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
:target: https://badge.fury.io/py/POT
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
:target: https://anaconda.org/conda-forge/pot
-.. |Build Status| image:: https://github.com/PythonOT/POT/workflows/build/badge.svg
+.. |Build Status| image:: https://github.com/PythonOT/POT/workflows/build/badge.svg?branch=master&event=push
:target: https://github.com/PythonOT/POT/actions
.. |Codecov Status| image:: https://codecov.io/gh/PythonOT/POT/branch/master/graph/badge.svg
:target: https://codecov.io/gh/PythonOT/POT
diff --git a/test/test_ot.py b/test/test_ot.py
index b7306f6..f45e4c9 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -291,17 +291,17 @@ def test_warnings():
print('Computing {} EMD '.format(1))
ot.emd(a, b, M, numItermax=1)
assert "numItermax" in str(w[-1].message)
- assert len(w) == 1
+ #assert len(w) == 1
a[0] = 100
print('Computing {} EMD '.format(2))
ot.emd(a, b, M)
assert "infeasible" in str(w[-1].message)
- assert len(w) == 2
+ #assert len(w) == 2
a[0] = -1
print('Computing {} EMD '.format(2))
ot.emd(a, b, M)
assert "infeasible" in str(w[-1].message)
- assert len(w) == 3
+ #assert len(w) == 3
def test_dual_variables():
--
cgit v1.2.3
From 2a3f2241951ea9cc044b4fba8a382b6ae9630513 Mon Sep 17 00:00:00 2001
From: AdrienCorenflos
Date: Mon, 19 Apr 2021 14:57:51 +0300
Subject: BUG/DOC FIX - Sinkhorn divergence used the wrong weights, and
sinkhorn2 didn't support epsilon_scaling method. (#235)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* FIX:
1. Documentation of loss specific functions
2. Sinkhorn divergence weights handling
3. Sinkhorn2 does not support epsilon scaling, so I removed it (it *should* arguably support it, but this would require a refactoring of the sinkhorn iterates pretty much everywhere, maybe should be done in torch first?)
* Had some PEP8 issues
Co-authored-by: Rémi Flamary
---
ot/bregman.py | 53 +++++++++++++++++++++++++---------------------------
test/test_bregman.py | 13 +++++++------
2 files changed, 32 insertions(+), 34 deletions(-)
(limited to 'test')
diff --git a/ot/bregman.py b/ot/bregman.py
index dcd35e1..559db14 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -14,11 +14,13 @@ Bregman projections solvers for entropic regularized OT
#
# License: MIT License
-import numpy as np
import warnings
-from .utils import unif, dist
+
+import numpy as np
from scipy.optimize import fmin_l_bfgs_b
+from ot.utils import unif, dist
+
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
stopThr=1e-9, verbose=False, log=False, **kwargs):
@@ -179,8 +181,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -207,7 +208,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
Returns
-------
- W : (n_hists) ndarray or float
+ W : (n_hists) ndarray
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -244,12 +245,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
ot.bregman.greenkhorn : Greenkhorn [21]
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
"""
b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
b = b[:, None]
+
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
@@ -258,10 +259,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
**kwargs)
- elif method.lower() == 'sinkhorn_epsilon_scaling':
- return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
@@ -745,8 +742,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# remove numerical problems and store them in K
if np.abs(u).max() > tau or np.abs(v).max() > tau:
if n_hists:
- alpha, beta = alpha + reg * \
- np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
+ alpha, beta = alpha + reg * np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
else:
alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
if n_hists:
@@ -1747,7 +1743,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
>>> reg = 0.1
>>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
>>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
- >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
+ >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
array([[4.99977301e-01, 2.26989344e-05],
[2.26989344e-05, 4.99977301e-01]])
@@ -1825,8 +1821,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
- Regularized optimal transportation matrix for the given parameters
+ W : (n_hists) ndarray or float
+ Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1838,8 +1834,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
>>> reg = 0.1
>>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
>>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
- >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
- array([4.53978687e-05])
+ >>> b = np.full((n_samples_b, 3), 1/n_samples_b)
+ >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False)
+ array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05])
References
@@ -1935,8 +1932,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
- Regularized optimal transportation matrix for the given parameters
+ W : (1,) ndarray
+ Optimal transportation symmetrized loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1959,13 +1956,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
stopThr=1e-9, verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax,
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax,
stopThr=1e-9, verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax,
stopThr=1e-9, verbose=verbose, log=log, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
log = {}
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
@@ -1981,13 +1978,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
verbose=verbose, log=log, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
return max(0, sinkhorn_div)
@@ -2212,11 +2209,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
# box constraints in L-BFGS-B (see Proposition 1 in [26])
bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
- ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
+ ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
bounds_v = [(
- max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
- epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
+ max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
+ epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
# pre-calculated constants for the objective
vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6aa4e08..331acd3 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -6,9 +6,10 @@
# License: MIT License
import numpy as np
-import ot
import pytest
+import ot
+
def test_sinkhorn():
# test sinkhorn
@@ -257,7 +258,8 @@ def test_empirical_sinkhorn():
def test_empirical_sinkhorn_divergence():
# Test sinkhorn divergence
n = 10
- a = ot.unif(n)
+ a = np.linspace(1, n, n)
+ a /= a.sum()
b = ot.unif(n)
X_s = np.reshape(np.arange(n), (n, 1))
X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
@@ -265,16 +267,15 @@ def test_empirical_sinkhorn_divergence():
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1)
+ emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
- emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, log=True)
+ emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b, log=True)
sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
-
- # check constratints
+ # check constraints
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
np.testing.assert_allclose(
--
cgit v1.2.3
From cd3ce6140d7a2dbe2bcf05927a8dd8289f4ce9e2 Mon Sep 17 00:00:00 2001
From: Rémi Flamary
Date: Mon, 19 Apr 2021 15:03:57 +0200
Subject: [MRG] Cleanup test warnings (#242)
* remove warnings in tests from docstrings
* working tets for bregman implemneted methods
* pep8
---
ot/da.py | 12 ++++++------
ot/dr.py | 2 +-
ot/gpu/bregman.py | 2 +-
ot/gromov.py | 20 ++++++++++----------
ot/lp/cvx.py | 3 +--
ot/optim.py | 4 ++--
test/test_bregman.py | 3 ++-
7 files changed, 23 insertions(+), 23 deletions(-)
(limited to 'test')
diff --git a/ot/da.py b/ot/da.py
index f1e4769..cdc747c 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -26,7 +26,7 @@ from .optim import gcg
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
log=False):
- """
+ r"""
Solve the entropic regularization optimal transport problem with nonconvex
group lasso regularization
@@ -137,7 +137,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
log=False):
- """
+ r"""
Solve the entropic regularization optimal transport problem with group
lasso regularization
@@ -245,7 +245,7 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
verbose2=False, numItermax=100, numInnerItermax=10,
stopInnerThr=1e-6, stopThr=1e-5, log=False,
**kwargs):
- """Joint OT and linear mapping estimation as proposed in [8]
+ r"""Joint OT and linear mapping estimation as proposed in [8]
The function solves the following optimization problem:
@@ -434,7 +434,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
numItermax=100, numInnerItermax=10,
stopInnerThr=1e-6, stopThr=1e-5, log=False,
**kwargs):
- """Joint OT and nonlinear mapping estimation with kernels as proposed in [8]
+ r"""Joint OT and nonlinear mapping estimation with kernels as proposed in [8]
The function solves the following optimization problem:
@@ -645,7 +645,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
wt=None, bias=True, log=False):
- """ return OT linear operator between samples
+ r""" return OT linear operator between samples
The function estimates the optimal linear operator that aligns the two
empirical distributions. This is equivalent to estimating the closed
@@ -1228,7 +1228,7 @@ class BaseTransport(BaseEstimator):
class LinearTransport(BaseTransport):
- """ OT linear operator between empirical distributions
+ r""" OT linear operator between empirical distributions
The function estimates the optimal linear operator that aligns the two
empirical distributions. This is equivalent to estimating the closed
diff --git a/ot/dr.py b/ot/dr.py
index 11d2e10..b7a1af0 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -109,7 +109,7 @@ def fda(X, y, p=2, reg=1e-16):
def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
- """
+ r"""
Wasserstein Discriminant Analysis [11]_
The function solves the following optimization problem:
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 2e2df83..82f34f3 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -15,7 +15,7 @@ from . import utils
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
verbose=False, log=False, to_numpy=True, **kwargs):
- """
+ r"""
Solve the entropic regularization optimal transport on GPU
If the input matrix are in numpy format, they will be uploaded to the
diff --git a/ot/gromov.py b/ot/gromov.py
index 4427a96..8f457e9 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -19,7 +19,7 @@ from .optim import cg
def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
- """Return loss matrices and tensors for Gromov-Wasserstein fast computation
+ r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation
Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
function as the loss function of Gromow-Wasserstein discrepancy.
@@ -109,7 +109,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
def tensor_product(constC, hC1, hC2, T):
- """Return the tensor for Gromov-Wasserstein fast computation
+ r"""Return the tensor for Gromov-Wasserstein fast computation
The tensor is computed as described in Proposition 1 Eq. (6) in [12].
@@ -262,7 +262,7 @@ def update_kl_loss(p, lambdas, T, Cs):
def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
- """
+ r"""
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
The function solves the following optimization problem:
@@ -343,7 +343,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
- """
+ r"""
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
The function solves the following optimization problem:
@@ -420,7 +420,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
- """
+ r"""
Computes the FGW transport between two graphs see [24]
.. math::
@@ -496,7 +496,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
- """
+ r"""
Computes the FGW distance between two graphs see [24]
.. math::
@@ -574,7 +574,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False):
- """
+ r"""
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
(C1,p) and (C2,q)
@@ -681,7 +681,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False):
- """
+ r"""
Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices
(C1,p) and (C2,q)
@@ -747,7 +747,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
- """
+ r"""
Returns the gromov-wasserstein barycenters of S measured similarity matrices
(Cs)_{s=1}^{s=S}
@@ -857,7 +857,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
- """
+ r"""
Returns the gromov-wasserstein barycenters of S measured similarity matrices
(Cs)_{s=1}^{s=S}
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 8e763be..869d450 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -27,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):
def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
- """Compute the Wasserstein barycenter of distributions A
+ r"""Compute the Wasserstein barycenter of distributions A
The function solves the following optimization problem [16]:
@@ -76,7 +76,6 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
.. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
-
"""
if weights is None:
diff --git a/ot/optim.py b/ot/optim.py
index 1902907..abe9e6a 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -139,7 +139,7 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
- """
+ r"""
Solve the general regularized OT problem with conditional gradient
The function solves the following optimization problem:
@@ -278,7 +278,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
- """
+ r"""
Solve the general regularized OT problem with the generalized conditional gradient
The function solves the following optimization problem:
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 331acd3..1ebd21f 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -321,8 +321,9 @@ def test_implemented_methods():
# make dists unbalanced
b = ot.utils.unif(n)
A = rng.rand(n, 2)
+ A /= A.sum(0, keepdims=True)
M = ot.dist(x, x)
- epsilon = 1.
+ epsilon = 1.0
for method in IMPLEMENTED_METHODS:
ot.bregman.sinkhorn(a, b, M, epsilon, method=method)
--
cgit v1.2.3
From 184f8f4f7ac78f1dd7f653496d2753211a4e3426 Mon Sep 17 00:00:00 2001
From: Rémi Flamary
Date: Tue, 1 Jun 2021 10:10:54 +0200
Subject: [MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends
* stat sets on functions
* proper import
* install recent torch on windows
* install recent torch on windows
* now testing all functions in backedn
* add jax backedn
* clenaup windowds
* proper convert for jax backedn
* pep8
* try again windows tests
* test jax conversion
* try proper widows tests
* emd fuction ses backedn
* better test partial OT
* proper tests to_numpy and teplate Backend
* pep8
* pep8 x2
* feaking sinkhorn works with torch
* sinkhorn2 compatible
* working ot.emd2
* important detach
* it should work
* jax autodiff emd
* pep8
* no tast same for jax
* new independat tests per backedn
* freaking pep8
* add tests for gradients
* deprecate ot.gpu
* worging dist function
* working dist
* dist done in backedn
* not in
* remove indexing
* change accuacy for jax
* first pull backend
* projection simplex
* projection simplex
* projection simplex
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8
* add backedn discusion to quickstart guide
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8 + better doc
* proper links
* corect doctest
* big debug documentation
* doctest again
* doctest again bis
* doctest again ter (last one or i kill myself)
* backend test + doc proj simplex
* correction test_utils
* correction test_utils
* correction cumsum
* correction flip
* correction flip v2
* more debug
* more debug
* more debug + pep8
* pep8
* argh
* proj_simplex
* backedn works for sort
* proj simplex
* jax sucks
* update doc
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort
* Update ot/utils.py
Co-authored-by: Alexandre Gramfort
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort
* Update ot/lp/__init__.py
Co-authored-by: Alexandre Gramfort
* begin comment alex
* comment alex part 2
* optimize test gromov
* proj_simplex on vectors
* add awesome gradient decsnt example on the weights
* pep98 of course
* proof read example by alex
* pep8 again
* encoding oos in translation
* correct legend
Co-authored-by: Nicolas Courty
Co-authored-by: Alexandre Gramfort
---
.github/requirements_test_windows.txt | 10 +
.github/workflows/build_tests.yml | 9 +-
README.md | 8 +-
docs/source/quickstart.rst | 68 +++-
docs/source/readme.rst | 70 ++--
examples/README.txt | 2 +-
examples/backends/README.txt | 4 +
examples/backends/plot_unmix_optim_torch.py | 161 +++++++++
ot/__init__.py | 1 +
ot/backend.py | 536 ++++++++++++++++++++++++++++
ot/bregman.py | 141 ++++----
ot/gpu/__init__.py | 4 +-
ot/lp/__init__.py | 137 ++++---
ot/utils.py | 128 +++++--
requirements.txt | 3 +
test/test_backend.py | 364 +++++++++++++++++++
test/test_bregman.py | 74 ++++
test/test_gromov.py | 10 +-
test/test_ot.py | 91 ++++-
test/test_partial.py | 4 +-
test/test_utils.py | 76 +++-
21 files changed, 1692 insertions(+), 209 deletions(-)
create mode 100644 .github/requirements_test_windows.txt
create mode 100644 examples/backends/README.txt
create mode 100644 examples/backends/plot_unmix_optim_torch.py
create mode 100644 ot/backend.py
create mode 100644 test/test_backend.py
(limited to 'test')
diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt
new file mode 100644
index 0000000..331dd57
--- /dev/null
+++ b/.github/requirements_test_windows.txt
@@ -0,0 +1,10 @@
+numpy
+scipy>=1.3
+cython
+matplotlib
+autograd
+pymanopt==0.2.4; python_version <'3'
+pymanopt; python_version >= '3'
+cvxopt
+scikit-learn
+pytest
\ No newline at end of file
diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index 2fc6770..92a07b5 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -40,7 +40,7 @@ jobs:
pip install -e .
- name: Run tests
run: |
- python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
+ python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
- name: Upload codecov
run: |
codecov
@@ -142,11 +142,12 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install -r requirements.txt
- pip install pytest "pytest-cov<2.6"
+ python -m pip install -r .github/requirements_test_windows.txt
+ python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ python -m pip install pytest "pytest-cov<2.6"
- name: Install POT
run: |
- pip install -e .
+ python -m pip install -e .
- name: Run tests
run: |
python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
diff --git a/README.md b/README.md
index f5d18c1..e5e16e0 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,7 @@ POT provides the following generic OT solvers (links to examples):
* [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] .
* [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7].
-* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html) with optional GPU implementation (requires cupy).
+* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html).
* Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4].
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
@@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples):
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32].
+* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays.
POT provides the following Machine Learning related solvers:
@@ -77,8 +78,7 @@ The library has been tested on Linux, MacOSX and Windows. It requires a C++ comp
- Numpy (>=1.16)
- Scipy (>=1.0)
-- Cython (>=0.23)
-- Matplotlib (>=1.5)
+- Cython (>=0.23) (build only, not necessary when installing wheels from pip or conda)
#### Pip installation
@@ -129,7 +129,7 @@ Some sub-modules require additional dependences which are discussed below
pip install pymanopt autograd
```
-* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU.
+* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. Note that this module is deprecated since version 0.8 and will be deleted in the future. GPU is now handled automatically through the backends and several solver already can run on GPU using the Pytorch backend.
## Examples
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index cf5d6aa..fd046a1 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -15,6 +15,12 @@ are also available as notebooks on the POT Github.
in ML applications we refer the reader to the following `OTML tutorial
`_.
+.. note::
+
+ Since version 0.8, POT provides a backend to automatically solve some OT
+ problems independently from the toolbox used by the user (numpy/torch/jax).
+ We provide a discussion about which functions are compatible in section
+ `Backend section <#solving-ot-with-multiple-backends>`_ .
Why Optimal Transport ?
@@ -158,7 +164,6 @@ Wasserstein but has better computational and
`statistical properties `_.
-
Optimal transport and Wasserstein distance
------------------------------------------
@@ -922,6 +927,13 @@ The implementations of FGW and FGW barycenter is provided in functions
GPU acceleration
^^^^^^^^^^^^^^^^
+.. warning::
+
+ The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and
+ should not be used. The GPU implementation (in Pytorch for instance) can be
+ used with the novel backends using the compatible functions from POT.
+
+
We provide several implementation of our OT solvers in :any:`ot.gpu`. Those
implementations use the :code:`cupy` toolbox that obviously need to be installed.
@@ -950,6 +962,60 @@ explicitly.
use it you have to specifically import it with :code:`import ot.gpu` .
+Solving OT with Multiple backends
+---------------------------------
+
+.. _backends_section:
+
+Since version 0.8, POT provides a backend that allows to code solvers
+independently from the type of the input arrays. The idea is to provide the user
+with a package that works seamlessly and returns a solution for instance as a
+Pytorch tensors when the function has Pytorch tensors as input.
+
+
+How it works
+^^^^^^^^^^^^
+
+The aim of the backend is to use the same function independently of the type of
+the input arrays.
+
+For instance when executing the following code
+
+.. code:: python
+
+ # a and b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T = ot.emd(a, b, M) # exact linear program
+ w = ot.emd2(a, b, M) # Wasserstein computation
+
+the functions :any:`ot.emd` and :any:`ot.emd2` can take inputs of the type
+:any:`numpy.array`, :any:`torch.tensor` or :any:`jax.numpy.array`. The output of
+the function will be the same type as the inputs and on the same device. When
+possible all computations are done on the same device and also when possible the
+output will be differentiable with respect to the input of the function.
+
+
+
+List of compatible Backends
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+- `Numpy `_ (all functions and solvers)
+- `Pytorch `_ (all outputs differentiable w.r.t. inputs)
+- `Jax `_ (Some functions are differentiable some require a wrapper)
+
+List of compatible functions
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+This list will get longer for new releases and will hopefully disappear when POT
+become fully implemented with the backend.
+
+- :any:`ot.emd`
+- :any:`ot.emd2`
+- :any:`ot.sinkhorn`
+- :any:`ot.sinkhorn2`
+- :any:`ot.dist`
+
+
FAQ
---
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
index 3b594c2..82d3e6c 100644
--- a/docs/source/readme.rst
+++ b/docs/source/readme.rst
@@ -26,8 +26,7 @@ POT provides the following generic OT solvers (links to examples):
Algorithm `__
[2] , stabilized version [9] [10], greedy Sinkhorn [22] and
`Screening Sinkhorn
- [26] `__
- with optional GPU implementation (requires cupy).
+ [26] `__.
- Bregman projections for `Wasserstein
barycenter `__
[3], `convolutional
@@ -69,6 +68,11 @@ POT provides the following generic OT solvers (links to examples):
- `Sliced
Wasserstein `__
[31, 32].
+- `Several
+ backends `__
+ for easy use of POT with
+ `Pytorch `__/`jax `__/`Numpy `__
+ arrays.
POT provides the following Machine Learning related solvers:
@@ -104,12 +108,14 @@ paper `__:
::
- Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021.
+ Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer,
+ POT Python Optimal Transport library,
+ Journal of Machine Learning Research, 22(78):1−8, 2021.
Website: https://pythonot.github.io/
In Bibtex format:
-::
+.. code:: bibtex
@article{flamary2021pot,
author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
@@ -131,8 +137,8 @@ following Python modules:
- Numpy (>=1.16)
- Scipy (>=1.0)
-- Cython (>=0.23)
-- Matplotlib (>=1.5)
+- Cython (>=0.23) (build only, not necessary when installing wheels
+ from pip or conda)
Pip installation
^^^^^^^^^^^^^^^^
@@ -140,19 +146,19 @@ Pip installation
Note that due to a limitation of pip, ``cython`` and ``numpy`` need to
be installed prior to installing POT. This can be done easily with
-::
+.. code:: console
pip install numpy cython
You can install the toolbox through PyPI with:
-::
+.. code:: console
pip install POT
or get the very latest version by running:
-::
+.. code:: console
pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root)
@@ -163,7 +169,7 @@ If you use the Anaconda python distribution, POT is available in
`conda-forge `__. To install it and the
required dependencies:
-::
+.. code:: console
conda install -c conda-forge pot
@@ -188,15 +194,17 @@ below
- **ot.dr** (Wasserstein dimensionality reduction) depends on autograd
and pymanopt that can be installed with:
- ::
+.. code:: shell
- pip install pymanopt autograd
+ pip install pymanopt autograd
- **ot.gpu** (GPU accelerated OT) depends on cupy that have to be
installed following instructions on `this
page `__.
-
-obviously you need CUDA installed and a compatible GPU.
+ Obviously you will need CUDA installed and a compatible GPU. Note
+ that this module is deprecated since version 0.8 and will be deleted
+ in the future. GPU is now handled automatically through the backends
+ and several solver already can run on GPU using the Pytorch backend.
Examples
--------
@@ -206,36 +214,36 @@ Short examples
- Import the toolbox
- .. code:: python
+.. code:: python
- import ot
+ import ot
- Compute Wasserstein distances
- .. code:: python
+.. code:: python
- # a,b are 1D histograms (sum to 1 and positive)
- # M is the ground cost matrix
- Wd=ot.emd2(a,b,M) # exact linear program
- Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT
- # if b is a matrix compute all distances to a and return a vector
+ # a and b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ Wd = ot.emd2(a, b, M) # exact linear program
+ Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT
+ # if b is a matrix compute all distances to a and return a vector
- Compute OT matrix
- .. code:: python
+.. code:: python
- # a,b are 1D histograms (sum to 1 and positive)
- # M is the ground cost matrix
- T=ot.emd(a,b,M) # exact linear program
- T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
+ # a and b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T = ot.emd(a, b, M) # exact linear program
+ T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT
- Compute Wasserstein barycenter
- .. code:: python
+.. code:: python
- # A is a n*d matrix containing d 1D histograms
- # M is the ground cost matrix
- ba=ot.barycenter(A,M,reg) # reg is regularization parameter
+ # A is a n*d matrix containing d 1D histograms
+ # M is the ground cost matrix
+ ba = ot.barycenter(A, M, reg) # reg is regularization parameter
Examples and Notebooks
~~~~~~~~~~~~~~~~~~~~~~
diff --git a/examples/README.txt b/examples/README.txt
index 69a9f84..b48487f 100644
--- a/examples/README.txt
+++ b/examples/README.txt
@@ -1,7 +1,7 @@
Examples gallery
================
-This is a gallery of all the POT example files.
+This is a gallery of all the POT example files.
OT and regularized OT
diff --git a/examples/backends/README.txt b/examples/backends/README.txt
new file mode 100644
index 0000000..3ee0e27
--- /dev/null
+++ b/examples/backends/README.txt
@@ -0,0 +1,4 @@
+
+
+POT backend examples
+--------------------
\ No newline at end of file
diff --git a/examples/backends/plot_unmix_optim_torch.py b/examples/backends/plot_unmix_optim_torch.py
new file mode 100644
index 0000000..9ae66e9
--- /dev/null
+++ b/examples/backends/plot_unmix_optim_torch.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+r"""
+=================================
+Wasserstein unmixing with PyTorch
+=================================
+
+In this example we estimate mixing parameters from distributions that minimize
+the Wasserstein distance. In other words we suppose that a target
+distribution :math:`\mu^t` can be expressed as a weighted sum of source
+distributions :math:`\mu^s_k` with the following model:
+
+.. math::
+ \mu^t = \sum_{k=1}^K w_k\mu^s_k
+
+where :math:`\mathbf{w}` is a vector of size :math:`K` and belongs in the
+distribution simplex :math:`\Delta_K`.
+
+In order to estimate this weight vector we propose to optimize the Wasserstein
+distance between the model and the observed :math:`\mu^t` with respect to
+the vector. This leads to the following optimization problem:
+
+.. math::
+ \min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right)
+
+This minimization is done in this example with a simple projected gradient
+descent in PyTorch. We use the automatic backend of POT that allows us to
+compute the Wasserstein distance with :any:`ot.emd2` with
+differentiable losses.
+
+"""
+
+# Author: Remi Flamary
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import torch
+
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% Data
+
+nt = 100
+nt1 = 10 #
+
+ns1 = 50
+ns = 2 * ns1
+
+rng = np.random.RandomState(2)
+
+xt = rng.randn(nt, 2) * 0.2
+xt[:nt1, 0] += 1
+xt[nt1:, 1] += 1
+
+
+xs1 = rng.randn(ns1, 2) * 0.2
+xs1[:, 0] += 1
+xs2 = rng.randn(ns1, 2) * 0.2
+xs2[:, 1] += 1
+
+xs = np.concatenate((xs1, xs2))
+
+# Sample reweighting matrix H
+H = np.zeros((ns, 2))
+H[:ns1, 0] = 1 / ns1
+H[ns1:, 1] = 1 / ns1
+# each columns sums to 1 and has weights only for samples form the
+# corresponding source distribution
+
+M = ot.dist(xs, xt)
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.figure(1)
+pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5)
+pl.scatter(xs1[:, 0], xs1[:, 1], label='Source $\mu^s_1$', alpha=0.5)
+pl.scatter(xs2[:, 0], xs2[:, 1], label='Source $\mu^s_2$', alpha=0.5)
+pl.title('Sources and Target distributions')
+pl.legend()
+
+
+##############################################################################
+# Optimization of the model wrt the Wasserstein distance
+# ------------------------------------------------------
+
+
+#%% Weights optimization with gradient descent
+
+# convert numpy arrays to torch tensors
+H2 = torch.tensor(H)
+M2 = torch.tensor(M)
+
+# weights for the source distributions
+w = torch.tensor(ot.unif(2), requires_grad=True)
+
+# uniform weights for target
+b = torch.tensor(ot.unif(nt))
+
+lr = 2e-3 # learning rate
+niter = 500 # number of iterations
+losses = [] # loss along the iterations
+
+# loss for the minimal Wasserstein estimator
+
+
+def get_loss(w):
+ a = torch.mv(H2, w) # distribution reweighting
+ return ot.emd2(a, b, M2) # squared Wasserstein 2
+
+
+for i in range(niter):
+
+ loss = get_loss(w)
+ losses.append(float(loss))
+
+ loss.backward()
+
+ with torch.no_grad():
+ w -= lr * w.grad # gradient step
+ w[:] = ot.utils.proj_simplex(w) # projection on the simplex
+
+ w.grad.zero_()
+
+
+##############################################################################
+# Estimated weights and convergence of the objective
+# ---------------------------------------------------
+
+we = w.detach().numpy()
+print('Estimated mixture:', we)
+
+pl.figure(2)
+pl.semilogy(losses)
+pl.grid()
+pl.title('Wasserstein distance')
+pl.xlabel("Iterations")
+
+##############################################################################
+# Ploting the reweighted source distribution
+# ------------------------------------------
+
+pl.figure(3)
+
+# compute source weights
+ws = H.dot(we)
+
+pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5)
+pl.scatter(xs[:, 0], xs[:, 1], color='C3', s=ws * 20 * ns, label='Weighted sources $\sum_{k} w_k\mu^s_k$', alpha=0.5)
+pl.title('Target and reweighted source distributions')
+pl.legend()
diff --git a/ot/__init__.py b/ot/__init__.py
index 5a8a415..3b072c6 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -33,6 +33,7 @@ from . import smooth
from . import stochastic
from . import unbalanced
from . import partial
+from . import backend
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
diff --git a/ot/backend.py b/ot/backend.py
new file mode 100644
index 0000000..d68f5cf
--- /dev/null
+++ b/ot/backend.py
@@ -0,0 +1,536 @@
+# -*- coding: utf-8 -*-
+"""
+Multi-lib backend for POT
+"""
+
+# Author: Remi Flamary
+# Nicolas Courty
+#
+# License: MIT License
+
+import numpy as np
+
+try:
+ import torch
+ torch_type = torch.Tensor
+except ImportError:
+ torch = False
+ torch_type = float
+
+try:
+ import jax
+ import jax.numpy as jnp
+ jax_type = jax.numpy.ndarray
+except ImportError:
+ jax = False
+ jax_type = float
+
+str_type_error = "All array should be from the same type/backend. Current types are : {}"
+
+
+def get_backend_list():
+ """ returns the list of available backends)"""
+ lst = [NumpyBackend(), ]
+
+ if torch:
+ lst.append(TorchBackend())
+
+ if jax:
+ lst.append(JaxBackend())
+
+ return lst
+
+
+def get_backend(*args):
+ """returns the proper backend for a list of input arrays
+
+ Also raises TypeError if all arrays are not from the same backend
+ """
+ # check that some arrays given
+ if not len(args) > 0:
+ raise ValueError(" The function takes at least one parameter")
+ # check all same type
+
+ if isinstance(args[0], np.ndarray):
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
+ return NumpyBackend()
+ elif torch and isinstance(args[0], torch_type):
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
+ return TorchBackend()
+ elif isinstance(args[0], jax_type):
+ return JaxBackend()
+ else:
+ raise ValueError("Unknown type of non implemented backend.")
+
+
+def to_numpy(*args):
+ """returns numpy arrays from any compatible backend"""
+
+ if len(args) == 1:
+ return get_backend(args[0]).to_numpy(args[0])
+ else:
+ return [get_backend(a).to_numpy(a) for a in args]
+
+
+class Backend():
+
+ __name__ = None
+ __type__ = None
+
+ def __str__(self):
+ return self.__name__
+
+ # convert to numpy
+ def to_numpy(self, a):
+ raise NotImplementedError()
+
+ # convert from numpy
+ def from_numpy(self, a, type_as=None):
+ raise NotImplementedError()
+
+ def set_gradients(self, val, inputs, grads):
+ """ define the gradients for the value val wrt the inputs """
+ raise NotImplementedError()
+
+ def zeros(self, shape, type_as=None):
+ raise NotImplementedError()
+
+ def ones(self, shape, type_as=None):
+ raise NotImplementedError()
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ raise NotImplementedError()
+
+ def full(self, shape, fill_value, type_as=None):
+ raise NotImplementedError()
+
+ def eye(self, N, M=None, type_as=None):
+ raise NotImplementedError()
+
+ def sum(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def cumsum(self, a, axis=None):
+ raise NotImplementedError()
+
+ def max(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def min(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def maximum(self, a, b):
+ raise NotImplementedError()
+
+ def minimum(self, a, b):
+ raise NotImplementedError()
+
+ def dot(self, a, b):
+ raise NotImplementedError()
+
+ def abs(self, a):
+ raise NotImplementedError()
+
+ def exp(self, a):
+ raise NotImplementedError()
+
+ def log(self, a):
+ raise NotImplementedError()
+
+ def sqrt(self, a):
+ raise NotImplementedError()
+
+ def norm(self, a):
+ raise NotImplementedError()
+
+ def any(self, a):
+ raise NotImplementedError()
+
+ def isnan(self, a):
+ raise NotImplementedError()
+
+ def isinf(self, a):
+ raise NotImplementedError()
+
+ def einsum(self, subscripts, *operands):
+ raise NotImplementedError()
+
+ def sort(self, a, axis=-1):
+ raise NotImplementedError()
+
+ def argsort(self, a, axis=None):
+ raise NotImplementedError()
+
+ def flip(self, a, axis=None):
+ raise NotImplementedError()
+
+
+class NumpyBackend(Backend):
+
+ __name__ = 'numpy'
+ __type__ = np.ndarray
+
+ def to_numpy(self, a):
+ return a
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return a
+ elif isinstance(a, float):
+ return a
+ else:
+ return a.astype(type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # no gradients for numpy
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return np.zeros(shape)
+ else:
+ return np.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return np.ones(shape)
+ else:
+ return np.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return np.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return np.full(shape, fill_value)
+ else:
+ return np.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return np.eye(N, M)
+ else:
+ return np.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return np.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return np.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return np.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return np.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return np.maximum(a, b)
+
+ def minimum(self, a, b):
+ return np.minimum(a, b)
+
+ def dot(self, a, b):
+ return np.dot(a, b)
+
+ def abs(self, a):
+ return np.abs(a)
+
+ def exp(self, a):
+ return np.exp(a)
+
+ def log(self, a):
+ return np.log(a)
+
+ def sqrt(self, a):
+ return np.sqrt(a)
+
+ def norm(self, a):
+ return np.sqrt(np.sum(np.square(a)))
+
+ def any(self, a):
+ return np.any(a)
+
+ def isnan(self, a):
+ return np.isnan(a)
+
+ def isinf(self, a):
+ return np.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return np.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return np.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return np.argsort(a, axis)
+
+ def flip(self, a, axis=None):
+ return np.flip(a, axis)
+
+
+class JaxBackend(Backend):
+
+ __name__ = 'jax'
+ __type__ = jax_type
+
+ def to_numpy(self, a):
+ return np.array(a)
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return jnp.array(a)
+ else:
+ return jnp.array(a).astype(type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # no gradients for jax because it is functional
+
+ # does not work
+ # from jax import custom_jvp
+ # @custom_jvp
+ # def f(*inputs):
+ # return val
+ # f.defjvps(*grads)
+ # return f(*inputs)
+
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.zeros(shape)
+ else:
+ return jnp.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.ones(shape)
+ else:
+ return jnp.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return jnp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return jnp.full(shape, fill_value)
+ else:
+ return jnp.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return jnp.eye(N, M)
+ else:
+ return jnp.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return jnp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return jnp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return jnp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return jnp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return jnp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return jnp.minimum(a, b)
+
+ def dot(self, a, b):
+ return jnp.dot(a, b)
+
+ def abs(self, a):
+ return jnp.abs(a)
+
+ def exp(self, a):
+ return jnp.exp(a)
+
+ def log(self, a):
+ return jnp.log(a)
+
+ def sqrt(self, a):
+ return jnp.sqrt(a)
+
+ def norm(self, a):
+ return jnp.sqrt(jnp.sum(jnp.square(a)))
+
+ def any(self, a):
+ return jnp.any(a)
+
+ def isnan(self, a):
+ return jnp.isnan(a)
+
+ def isinf(self, a):
+ return jnp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return jnp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return jnp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return jnp.argsort(a, axis)
+
+ def flip(self, a, axis=None):
+ return jnp.flip(a, axis)
+
+
+class TorchBackend(Backend):
+
+ __name__ = 'torch'
+ __type__ = torch_type
+
+ def to_numpy(self, a):
+ return a.cpu().detach().numpy()
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return torch.from_numpy(a)
+ else:
+ return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)
+
+ def set_gradients(self, val, inputs, grads):
+ from torch.autograd import Function
+
+ # define a function that takes inputs and return val
+ class ValFunction(Function):
+ @staticmethod
+ def forward(ctx, *inputs):
+ return val
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # the gradients are grad
+ return grads
+
+ return ValFunction.apply(*inputs)
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return torch.zeros(shape)
+ else:
+ return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return torch.ones(shape)
+ else:
+ return torch.ones(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ if type_as is None:
+ return torch.arange(start, stop, step)
+ else:
+ return torch.arange(start, stop, step, device=type_as.device)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return torch.full(shape, fill_value)
+ else:
+ return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device)
+
+ def eye(self, N, M=None, type_as=None):
+ if M is None:
+ M = N
+ if type_as is None:
+ return torch.eye(N, m=M)
+ else:
+ return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device)
+
+ def sum(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.sum(a)
+ else:
+ return torch.sum(a, axis, keepdim=keepdims)
+
+ def cumsum(self, a, axis=None):
+ if axis is None:
+ return torch.cumsum(a.flatten(), 0)
+ else:
+ return torch.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.max(a)
+ else:
+ return torch.max(a, axis, keepdim=keepdims)[0]
+
+ def min(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.min(a)
+ else:
+ return torch.min(a, axis, keepdim=keepdims)[0]
+
+ def maximum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ return torch.maximum(a, b)
+
+ def minimum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ return torch.minimum(a, b)
+
+ def dot(self, a, b):
+ if len(a.shape) == len(b.shape) == 1:
+ return torch.dot(a, b)
+ elif len(a.shape) == 2 and len(b.shape) == 1:
+ return torch.mv(a, b)
+ else:
+ return torch.mm(a, b)
+
+ def abs(self, a):
+ return torch.abs(a)
+
+ def exp(self, a):
+ return torch.exp(a)
+
+ def log(self, a):
+ return torch.log(a)
+
+ def sqrt(self, a):
+ return torch.sqrt(a)
+
+ def norm(self, a):
+ return torch.sqrt(torch.sum(torch.square(a)))
+
+ def any(self, a):
+ return torch.any(a)
+
+ def isnan(self, a):
+ return torch.isnan(a)
+
+ def isinf(self, a):
+ return torch.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return torch.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ sorted0, indices = torch.sort(a, dim=axis)
+ return sorted0
+
+ def argsort(self, a, axis=-1):
+ sorted, indices = torch.sort(a, dim=axis)
+ return indices
+
+ def flip(self, a, axis=None):
+ if axis is None:
+ return torch.flip(a, tuple(i for i in range(len(a.shape))))
+ if isinstance(axis, int):
+ return torch.flip(a, (axis,))
+ else:
+ return torch.flip(a, dims=axis)
diff --git a/ot/bregman.py b/ot/bregman.py
index 559db14..b10effd 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -19,7 +19,8 @@ import warnings
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
-from ot.utils import unif, dist
+from ot.utils import unif, dist, list_to_array
+from .backend import get_backend
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -43,17 +44,36 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in [2]_
+
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
+ fast approximation of the Sinkhorn problem.
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -69,25 +89,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True
- **Choosing a Sinkhorn solver**
-
- By default and when using a regularization parameter that is not too small
- the default sinkhorn solver should be enough. If you need to use a small
- regularization to get sharper OT matrices, you should use the
- :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
- errors. This last solver can be very slow in practice and might not even
- converge to a reasonable OT matrix in a finite time. This is why
- :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value
- of the regularization (and using warm start) sometimes leads to better
- solutions. Note that the greedy version of the sinkhorn
- :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
- version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
- fast approximation of the Sinkhorn problem.
-
-
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -166,17 +170,35 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (histograms, both sum to 1)
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
+ fast approximation of the Sinkhorn problem.
+
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -191,28 +213,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True
- **Choosing a Sinkhorn solver**
-
- By default and when using a regularization parameter that is not too small
- the default sinkhorn solver should be enough. If you need to use a small
- regularization to get sharper OT matrices, you should use the
- :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
- errors. This last solver can be very slow in practice and might not even
- converge to a reasonable OT matrix in a finite time. This is why
- :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value
- of the regularization (and using warm start) sometimes leads to better
- solutions. Note that the greedy version of the sinkhorn
- :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
- version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
- fast approximation of the Sinkhorn problem.
-
Returns
-------
- W : (n_hists) ndarray
+ W : (n_hists) float/array-like
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
+
Examples
--------
@@ -247,7 +255,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
"""
- b = np.asarray(b, dtype=np.float64)
+
+ b = list_to_array(b)
if len(b.shape) < 2:
b = b[:, None]
@@ -339,14 +348,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
# init data
dim_a = len(a)
@@ -363,21 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
- # print(reg)
-
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
-
- # print(np.min(K))
- tmp2 = np.empty(b.shape, dtype=M.dtype)
+ K = nx.exp(M / (-reg))
Kp = (1 / a).reshape(-1, 1) * K
cpt = 0
@@ -386,13 +387,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
uprev = u
vprev = v
- KtransposeU = np.dot(K.T, u)
- v = np.divide(b, KtransposeU)
- u = 1. / np.dot(Kp, v)
+ KtransposeU = nx.dot(K.T, u)
+ v = b / KtransposeU
+ u = 1. / nx.dot(Kp, v)
- if (np.any(KtransposeU == 0)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(KtransposeU == 0)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)
@@ -403,11 +404,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
- np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
+ tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
- np.einsum('i,ij,j->j', u, K, v, out=tmp2)
- err = np.linalg.norm(tmp2 - b) # violation of marginal
+ tmp2 = nx.einsum('i,ij,j->j', u, K, v)
+ err = nx.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
@@ -422,7 +423,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['v'] = v
if n_hists: # return only loss
- res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index 7478fb9..e939610 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -25,6 +25,8 @@ result of the function with parameter ``to_numpy=False``.
#
# License: MIT License
+import warnings
+
from . import bregman
from . import da
from .bregman import sinkhorn
@@ -34,7 +36,7 @@ from . import utils
from .utils import dist, to_gpu, to_np
-
+warnings.warn('This module will be deprecated in the next minor release of POT', category=DeprecationWarning)
__all__ = ["utils", "dist", "sinkhorn",
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index d5c3a5e..c8c9da6 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -18,8 +18,9 @@ from . import cvx
from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from ..utils import dist
+from ..utils import dist, list_to_array
from ..utils import parmap
+from ..backend import get_backend
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -176,8 +177,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
r"""Solves the Earth Movers distance problem and returns the OT matrix
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ .. math:: \gamma = arg\min_\gamma <\gamma,M>_F
s.t. \gamma 1 = a
@@ -189,37 +189,41 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
- M is the metric cost matrix
- a and b are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. warning:: Note that the M matrix in numpy needs to be a C-order
+ numpy.array in float64 format. It will be converted if not in this
+ format
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Uses the algorithm proposed in [1]_
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
- Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
- numItermax : int, optional (default=100000)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list)
+ M : (ns,nt) array-like, float
+ Loss matrix (c-order array in numpy with type float64)
+ numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
- algorithm if it has not converged.
- log: bool, optional (default=False)
- If True, returns a dictionary containing the cost and dual
- variables. Otherwise returns only the optimal transportation matrix.
+ algorithm if it has not converged.
+ log: bool, optional (default=False)
+ If True, returns a dictionary containing the cost and dual variables.
+ Otherwise returns only the optimal transportation matrix.
center_dual: boolean, optional (default=True)
- If True, centers the dual potential using function
+ If True, centers the dual potential using function
:ref:`center_ot_dual`.
Returns
-------
- gamma: (ns x nt) numpy.ndarray
- Optimal transportation matrix for the given parameters
- log: dict
- If input log is true, a dictionary containing the cost and dual
- variables and exit status
+ gamma: array-like, shape (ns, nt)
+ Optimal transportation matrix for the given
+ parameters
+ log: dict, optional
+ If input log is true, a dictionary containing the
+ cost and dual variables and exit status
Examples
@@ -232,26 +236,37 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
- >>> ot.emd(a,b,M)
+ >>> ot.emd(a, b, M)
array([[0.5, 0. ],
[0. , 0.5]])
References
----------
- .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
- (2011, December). Displacement interpolation using Lagrangian mass
- transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
- 158). ACM.
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
+ December). Displacement interpolation using Lagrangian mass transport.
+ In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
See Also
--------
- ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT"""
-
+ ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
+ regularized OT"""
+
+ # convert to numpy if list
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
+
+ # ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64, order='C')
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -262,6 +277,11 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0), err_msg='a and b vector must have the same sum')
+ b=b*a.sum()/b.sum()
+
asel = a != 0
bsel = b != 0
@@ -277,12 +297,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
if log:
log = {}
log['cost'] = cost
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
- return G, log
- return G
+ return nx.from_numpy(G, type_as=M0), log
+ return nx.from_numpy(G, type_as=M0)
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
@@ -303,20 +323,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- M is the metric cost matrix
- a and b are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Uses the algorithm proposed in [1]_
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float64
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
+ b : (nt,) array-like, float64
Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
+ M : (ns,nt) array-like, float64
+ Loss matrix (for numpy c-order array with type float64)
processes : int, optional (default=nb cpu)
Nb of processes used for multiple emd computation (not used on windows)
numItermax : int, optional (default=100000)
@@ -333,9 +352,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
Returns
-------
- W: float
+ W: float, array-like
Optimal transportation loss for the given parameters
- log: dictnp
+ log: dict
If input log is true, a dictionary containing dual
variables and exit status
@@ -367,12 +386,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
ot.bregman.sinkhorn : Entropic regularized OT
ot.optim.cg : General regularized OT"""
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
+
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64, order= 'C')
# problem with pikling Forks
- if sys.platform.endswith('win32'):
+ if sys.platform.endswith('win32') or not nx.__name__ == 'numpy':
processes = 1
# if empty array given then use uniform distributions
@@ -400,12 +429,15 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
result_code_string = check_result(result_code)
log = {}
+ G = nx.from_numpy(G, type_as=M0)
if return_matrix:
log['G'] = G
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0,b0, M0), (log['u'], log['v'], G))
return [cost, log]
else:
def f(b):
@@ -418,6 +450,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
+ G = nx.from_numpy(G, type_as=M0)
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0,b0, M0), (nx.from_numpy(u, type_as=a0),
+ nx.from_numpy(v, type_as=b0),G))
+
check_result(result_code)
return cost
@@ -637,6 +674,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
if b.ndim == 0 or len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum')
+ b=b*a.sum()/b.sum()
+
x_a_1d = x_a.reshape((-1,))
x_b_1d = x_b.reshape((-1,))
perm_a = np.argsort(x_a_1d)
diff --git a/ot/utils.py b/ot/utils.py
index 544c569..4dac0c5 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -16,6 +16,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
+from .backend import get_backend
__time_tic_toc = time.time()
@@ -41,8 +42,11 @@ def toq():
def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
"""Compute kernel matrix"""
+
+ nx = get_backend(x1, x2)
+
if method.lower() in ['gaussian', 'gauss', 'rbf']:
- K = np.exp(-dist(x1, x2) / (2 * sigma**2))
+ K = nx.exp(-dist(x1, x2) / (2 * sigma**2))
return K
@@ -52,6 +56,66 @@ def laplacian(x):
return L
+def list_to_array(*lst):
+ """ Convert a list if in numpy format """
+ if len(lst) > 1:
+ return [np.array(a) if isinstance(a, list) else a for a in lst]
+ else:
+ return np.array(lst[0]) if isinstance(lst[0], list) else lst[0]
+
+
+def proj_simplex(v, z=1):
+ r""" compute the closest point (orthogonal projection) on the
+ generalized (n-1)-simplex of a vector v wrt. to the Euclidean
+ distance, thus solving:
+ .. math::
+ \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2
+
+ s.t. \gamma^T 1= z
+
+ \gamma\geq 0
+
+ If v is a 2d array, compute all the projections wrt. axis 0
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
+ Parameters
+ ----------
+ v : {array-like}, shape (n, d)
+ z : int, optional
+ 'size' of the simplex (each vectors sum to z, 1 by default)
+
+ Returns
+ -------
+ h : ndarray, shape (n,d)
+ Array of projections on the simplex
+ """
+ nx = get_backend(v)
+ n = v.shape[0]
+ if v.ndim == 1:
+ d1 = 1
+ v = v[:, None]
+ else:
+ d1 = 0
+ d = v.shape[1]
+
+ # sort u in ascending order
+ u = nx.sort(v, axis=0)
+ # take the descending order
+ u = nx.flip(u, 0)
+ cssv = nx.cumsum(u, axis=0) - z
+ ind = nx.arange(n, type_as=v)[:, None] + 1
+ cond = u - cssv / ind > 0
+ rho = nx.sum(cond, 0)
+ theta = cssv[rho - 1, nx.arange(d)] / rho
+ w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v))
+ if d1:
+ return w[:, 0]
+ else:
+ return w
+
+
def unif(n):
""" return a uniform histogram of length n (simplex)
@@ -84,52 +148,68 @@ def euclidean_distances(X, Y, squared=False):
"""
Considering the rows of X (and Y=X) as vectors, compute the
distance matrix between each pair of vectors.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
Parameters
----------
X : {array-like}, shape (n_samples_1, n_features)
Y : {array-like}, shape (n_samples_2, n_features)
squared : boolean, optional
Return squared Euclidean distances.
+
Returns
-------
distances : {array}, shape (n_samples_1, n_samples_2)
"""
- XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
- YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
- distances = np.dot(X, Y.T)
- distances *= -2
- distances += XX
- distances += YY
- np.maximum(distances, 0, out=distances)
+
+ nx = get_backend(X, Y)
+
+ a2 = nx.einsum('ij,ij->i', X, X)
+ b2 = nx.einsum('ij,ij->i', Y, Y)
+
+ c = -2 * nx.dot(X, Y.T)
+ c += a2[:, None]
+ c += b2[None, :]
+
+ c = nx.maximum(c, 0)
+
+ if not squared:
+ c = nx.sqrt(c)
+
if X is Y:
- # Ensure that distances between vectors and themselves are set to 0.0.
- # This may not be the case due to floating point rounding errors.
- distances.flat[::distances.shape[0] + 1] = 0.0
- return distances if squared else np.sqrt(distances, out=distances)
+ c = c * (1 - nx.eye(X.shape[0], type_as=c))
+
+ return c
def dist(x1, x2=None, metric='sqeuclidean'):
- """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
+ """Compute distance between samples in x1 and x2
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Parameters
----------
- x1 : ndarray, shape (n1,d)
+ x1 : array-like, shape (n1,d)
matrix with n1 samples of size d
- x2 : array, shape (n2,d), optional
+ x2 : array-like, shape (n2,d), optional
matrix with n2 samples of size d (if None then x2=x1)
metric : str | callable, optional
- Name of the metric to be computed (full list in the doc of scipy), If a string,
- the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
- 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
- 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
+ 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also
+ accepts from the scipy.spatial.distance.cdist function : 'braycurtis',
+ 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
+ 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
+ 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
Returns
-------
- M : np.array (n1,n2)
+ M : array-like, shape (n1, n2)
distance matrix computed with given metric
"""
@@ -137,7 +217,13 @@ def dist(x1, x2=None, metric='sqeuclidean'):
x2 = x1
if metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True)
- return cdist(x1, x2, metric=metric)
+ elif metric == "euclidean":
+ return euclidean_distances(x1, x2, squared=False)
+ else:
+ if not get_backend(x1, x2).__name__ == 'numpy':
+ raise NotImplementedError()
+ else:
+ return cdist(x1, x2, metric=metric)
def dist0(n, method='lin_square'):
diff --git a/requirements.txt b/requirements.txt
index 331dd57..4353247 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,4 +7,7 @@ pymanopt==0.2.4; python_version <'3'
pymanopt; python_version >= '3'
cvxopt
scikit-learn
+torch
+jax
+jaxlib
pytest
\ No newline at end of file
diff --git a/test/test_backend.py b/test/test_backend.py
new file mode 100644
index 0000000..bc5b00c
--- /dev/null
+++ b/test/test_backend.py
@@ -0,0 +1,364 @@
+"""Tests for backend module """
+
+# Author: Remi Flamary
+#
+# License: MIT License
+
+import ot
+import ot.backend
+from ot.backend import torch, jax
+
+import pytest
+
+import numpy as np
+from numpy.testing import assert_array_almost_equal_nulp
+
+from ot.backend import get_backend, get_backend_list, to_numpy
+
+
+backend_list = get_backend_list()
+
+
+def test_get_backend_list():
+
+ lst = get_backend_list()
+
+ assert len(lst) > 0
+ assert isinstance(lst[0], ot.backend.NumpyBackend)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_to_numpy(nx):
+
+ v = nx.zeros(10)
+ M = nx.ones((10, 10))
+
+ v2 = to_numpy(v)
+ assert isinstance(v2, np.ndarray)
+
+ v2, M2 = to_numpy(v, M)
+ assert isinstance(M2, np.ndarray)
+
+
+def test_get_backend():
+
+ A = np.zeros((3, 2))
+ B = np.zeros((3, 1))
+
+ nx = get_backend(A)
+ assert nx.__name__ == 'numpy'
+
+ nx = get_backend(A, B)
+ assert nx.__name__ == 'numpy'
+
+ # error if no parameters
+ with pytest.raises(ValueError):
+ get_backend()
+
+ # error if unknown types
+ with pytest.raises(ValueError):
+ get_backend(1, 2.0)
+
+ # test torch
+ if torch:
+
+ A2 = torch.from_numpy(A)
+ B2 = torch.from_numpy(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'torch'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'torch'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
+ if jax:
+
+ A2 = jax.numpy.array(A)
+ B2 = jax.numpy.array(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'jax'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'jax'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_convert_between_backends(nx):
+
+ A = np.zeros((3, 2))
+ B = np.zeros((3, 1))
+
+ A2 = nx.from_numpy(A)
+ B2 = nx.from_numpy(B)
+
+ assert isinstance(A2, nx.__type__)
+ assert isinstance(B2, nx.__type__)
+
+ nx2 = get_backend(A2, B2)
+
+ assert nx2.__name__ == nx.__name__
+
+ assert_array_almost_equal_nulp(nx.to_numpy(A2), A)
+ assert_array_almost_equal_nulp(nx.to_numpy(B2), B)
+
+
+def test_empty_backend():
+
+ rnd = np.random.RandomState(0)
+ M = rnd.randn(10, 3)
+ v = rnd.randn(3)
+
+ nx = ot.backend.Backend()
+
+ with pytest.raises(NotImplementedError):
+ nx.from_numpy(M)
+ with pytest.raises(NotImplementedError):
+ nx.to_numpy(M)
+ with pytest.raises(NotImplementedError):
+ nx.set_gradients(0, 0, 0)
+ with pytest.raises(NotImplementedError):
+ nx.zeros((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.ones((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.arange(10, 1, 2)
+ with pytest.raises(NotImplementedError):
+ nx.full((10, 3), 3.14)
+ with pytest.raises(NotImplementedError):
+ nx.eye((10, 3))
+ with pytest.raises(NotImplementedError):
+ nx.sum(M)
+ with pytest.raises(NotImplementedError):
+ nx.cumsum(M)
+ with pytest.raises(NotImplementedError):
+ nx.max(M)
+ with pytest.raises(NotImplementedError):
+ nx.min(M)
+ with pytest.raises(NotImplementedError):
+ nx.maximum(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.minimum(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.abs(M)
+ with pytest.raises(NotImplementedError):
+ nx.log(M)
+ with pytest.raises(NotImplementedError):
+ nx.exp(M)
+ with pytest.raises(NotImplementedError):
+ nx.sqrt(M)
+ with pytest.raises(NotImplementedError):
+ nx.dot(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.norm(M)
+ with pytest.raises(NotImplementedError):
+ nx.exp(M)
+ with pytest.raises(NotImplementedError):
+ nx.any(M)
+ with pytest.raises(NotImplementedError):
+ nx.isnan(M)
+ with pytest.raises(NotImplementedError):
+ nx.isinf(M)
+ with pytest.raises(NotImplementedError):
+ nx.einsum('ij->i', M)
+ with pytest.raises(NotImplementedError):
+ nx.sort(M)
+ with pytest.raises(NotImplementedError):
+ nx.argsort(M)
+ with pytest.raises(NotImplementedError):
+ nx.flip(M)
+
+
+@pytest.mark.parametrize('backend', backend_list)
+def test_func_backends(backend):
+
+ rnd = np.random.RandomState(0)
+ M = rnd.randn(10, 3)
+ v = rnd.randn(3)
+ val = np.array([1.0])
+
+ lst_tot = []
+
+ for nx in [ot.backend.NumpyBackend(), backend]:
+
+ print('Backend: ', nx.__name__)
+
+ lst_b = []
+ lst_name = []
+
+ Mb = nx.from_numpy(M)
+ vb = nx.from_numpy(v)
+ val = nx.from_numpy(val)
+
+ A = nx.set_gradients(val, v, v)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('set_gradients')
+
+ A = nx.zeros((10, 3))
+ A = nx.zeros((10, 3), type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('zeros')
+
+ A = nx.ones((10, 3))
+ A = nx.ones((10, 3), type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('ones')
+
+ A = nx.arange(10, 1, 2)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('arange')
+
+ A = nx.full((10, 3), 3.14)
+ A = nx.full((10, 3), 3.14, type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('full')
+
+ A = nx.eye(10, 3)
+ A = nx.eye(10, 3, type_as=Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('eye')
+
+ A = nx.sum(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sum')
+
+ A = nx.sum(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sum(axis)')
+
+ A = nx.cumsum(Mb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('cumsum(axis)')
+
+ A = nx.max(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('max')
+
+ A = nx.max(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('max(axis)')
+
+ A = nx.min(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('min')
+
+ A = nx.min(Mb, axis=1, keepdims=True)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('min(axis)')
+
+ A = nx.maximum(vb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('maximum')
+
+ A = nx.minimum(vb, 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('minimum')
+
+ A = nx.abs(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('abs')
+
+ A = nx.log(A)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('log')
+
+ A = nx.exp(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('exp')
+
+ A = nx.sqrt(nx.abs(Mb))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sqrt')
+
+ A = nx.dot(vb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(v,v)')
+
+ A = nx.dot(Mb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(M,v)')
+
+ A = nx.dot(Mb, Mb.T)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('dot(M,M)')
+
+ A = nx.norm(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('norm')
+
+ A = nx.any(vb > 0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('any')
+
+ A = nx.isnan(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('isnan')
+
+ A = nx.isinf(vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('isinf')
+
+ A = nx.einsum('ij->i', Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('einsum(ij->i)')
+
+ A = nx.einsum('ij,j->i', Mb, vb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('nx.einsum(ij,j->i)')
+
+ A = nx.einsum('ij->i', Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('nx.einsum(ij->i)')
+
+ A = nx.sort(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('sort')
+
+ A = nx.argsort(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argsort')
+
+ A = nx.flip(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('flip')
+
+ lst_tot.append(lst_b)
+
+ lst_np = lst_tot[0]
+ lst_b = lst_tot[1]
+
+ for a1, a2, name in zip(lst_np, lst_b, lst_name):
+ if not np.allclose(a1, a2):
+ print('Assert fail on: ', name)
+ assert np.allclose(a1, a2, atol=1e-7)
+
+
+def test_gradients_backends():
+
+ rnd = np.random.RandomState(0)
+ v = rnd.randn(10)
+ c = rnd.randn(1)
+
+ if torch:
+
+ nx = ot.backend.TorchBackend()
+
+ v2 = torch.tensor(v, requires_grad=True)
+ c2 = torch.tensor(c, requires_grad=True)
+
+ val = c2 * torch.sum(v2 * v2)
+
+ val2 = nx.set_gradients(val, (v2, c2), (v2, c2))
+
+ val2.backward()
+
+ assert torch.equal(v2.grad, v2)
+ assert torch.equal(c2.grad, c2)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 1ebd21f..7c5162a 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -9,6 +9,10 @@ import numpy as np
import pytest
import ot
+from ot.backend import get_backend_list
+from ot.backend import torch
+
+backend_list = get_backend_list()
def test_sinkhorn():
@@ -30,6 +34,76 @@ def test_sinkhorn():
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+@pytest.mark.parametrize('nx', backend_list)
+def test_sinkhorn_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn(ab, ab, Mb, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_sinkhorn2_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.sinkhorn(a, a, M, 1)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ Gb = ot.sinkhorn2(ab, ab, Mb, 1)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+def test_sinkhorn2_gradients():
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ if torch:
+
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.sinkhorn2(a1, b1, M1, 1)
+
+ val.backward()
+
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
+
+
def test_sinkhorn_empty():
# test sinkhorn
n = 100
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 43da9fc..81138ca 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -181,7 +181,7 @@ def test_fgw():
M = ot.dist(ys, yt)
M /= M.max()
- G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
# check constratints
np.testing.assert_allclose(
@@ -242,9 +242,9 @@ def test_fgw_barycenter():
init_X = np.random.randn(n_samples, ys.shape[1])
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3, log=True)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_ot.py b/test/test_ot.py
index f45e4c9..3e953dc 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -12,9 +12,12 @@ from scipy.stats import wasserstein_distance
import ot
from ot.datasets import make_1D_gauss as gauss
+from ot.backend import get_backend_list, torch
+backend_list = get_backend_list()
-def test_emd_dimension_mismatch():
+
+def test_emd_dimension_and_mass_mismatch():
# test emd and emd2 for dimension mismatch
n_samples = 100
n_features = 2
@@ -29,6 +32,80 @@ def test_emd_dimension_mismatch():
np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)
+ b = a.copy()
+ a[0] = 100
+ np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_emd_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ G = ot.emd(a, a, M)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ Gb = ot.emd(ab, ab, Mb)
+
+ np.allclose(G, nx.to_numpy(Gb))
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_emd2_backends(nx):
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ val = ot.emd2(a, a, M)
+
+ ab = nx.from_numpy(a)
+ Mb = nx.from_numpy(M)
+
+ valb = ot.emd2(ab, ab, Mb)
+
+ np.allclose(val, nx.to_numpy(valb))
+
+
+def test_emd2_gradients():
+ n_samples = 100
+ n_features = 2
+ rng = np.random.RandomState(0)
+
+ x = rng.randn(n_samples, n_features)
+ y = rng.randn(n_samples, n_features)
+ a = ot.utils.unif(n_samples)
+
+ M = ot.dist(x, y)
+
+ if torch:
+
+ a1 = torch.tensor(a, requires_grad=True)
+ b1 = torch.tensor(a, requires_grad=True)
+ M1 = torch.tensor(M, requires_grad=True)
+
+ val = ot.emd2(a1, b1, M1)
+
+ val.backward()
+
+ assert a1.shape == a1.grad.shape
+ assert b1.shape == b1.grad.shape
+ assert M1.shape == M1.grad.shape
+
def test_emd_emd2():
# test emd and emd2 for simple identity
@@ -83,7 +160,7 @@ def test_emd_1d_emd2_1d():
np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
# check G is similar
- np.testing.assert_allclose(G, G_1d)
+ np.testing.assert_allclose(G, G_1d, atol=1e-15)
# check AssertionError is raised if called on non 1d arrays
u = np.random.randn(n, 2)
@@ -292,16 +369,6 @@ def test_warnings():
ot.emd(a, b, M, numItermax=1)
assert "numItermax" in str(w[-1].message)
#assert len(w) == 1
- a[0] = 100
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- #assert len(w) == 2
- a[0] = -1
- print('Computing {} EMD '.format(2))
- ot.emd(a, b, M)
- assert "infeasible" in str(w[-1].message)
- #assert len(w) == 3
def test_dual_variables():
diff --git a/test/test_partial.py b/test/test_partial.py
index 121f345..3571e2a 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -129,9 +129,9 @@ def test_partial_wasserstein():
# check constratints
np.testing.assert_equal(
- G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
+ G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
- G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein
+ G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein
np.testing.assert_allclose(
np.sum(G), m, atol=1e-04)
diff --git a/test/test_utils.py b/test/test_utils.py
index db9cda6..76b1faa 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -4,11 +4,47 @@
#
# License: MIT License
-
+import pytest
import ot
import numpy as np
import sys
+from ot.backend import get_backend_list
+
+backend_list = get_backend_list()
+
+
+@pytest.mark.parametrize('nx', backend_list)
+def test_proj_simplex(nx):
+ n = 10
+ rng = np.random.RandomState(0)
+
+ # test on matrix when projection is done on axis 0
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # all projections should sum to 3
+ proj = ot.utils.proj_simplex(x1, 3)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = 3 * np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
+ # tets on vector
+ x = rng.randn(n)
+ x1 = nx.from_numpy(x)
+
+ # all projections should sum to 1
+ proj = ot.utils.proj_simplex(x1)
+ l1 = np.sum(nx.to_numpy(proj), axis=0)
+ l2 = np.ones(2)
+ np.testing.assert_allclose(l1, l2, atol=1e-5)
+
def test_parmap():
@@ -45,8 +81,8 @@ def test_tic_toc():
def test_kernel():
n = 100
-
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
K = ot.utils.kernel(x, x)
@@ -67,7 +103,8 @@ def test_dist():
n = 100
- x = np.random.randn(n, 2)
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
D = np.zeros((n, n))
for i in range(n):
@@ -78,8 +115,27 @@ def test_dist():
D3 = ot.dist(x)
# dist shoul return squared euclidean
- np.testing.assert_allclose(D, D2)
- np.testing.assert_allclose(D, D3)
+ np.testing.assert_allclose(D, D2, atol=1e-14)
+ np.testing.assert_allclose(D, D3, atol=1e-14)
+
+
+@ pytest.mark.parametrize('nx', backend_list)
+def test_dist_backends(nx):
+
+ n = 100
+ rng = np.random.RandomState(0)
+ x = rng.randn(n, 2)
+ x1 = nx.from_numpy(x)
+
+ lst_metric = ['euclidean', 'sqeuclidean']
+
+ for metric in lst_metric:
+
+ D = ot.dist(x, x, metric=metric)
+ D1 = ot.dist(x1, x1, metric=metric)
+
+ # low atol because jax forces float32
+ np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5)
def test_dist0():
@@ -95,9 +151,11 @@ def test_dots():
n1, n2, n3, n4 = 100, 50, 200, 100
- A = np.random.randn(n1, n2)
- B = np.random.randn(n2, n3)
- C = np.random.randn(n3, n4)
+ rng = np.random.RandomState(0)
+
+ A = rng.randn(n1, n2)
+ B = rng.randn(n2, n3)
+ C = rng.randn(n3, n4)
X1 = ot.utils.dots(A, B, C)
--
cgit v1.2.3
From 2dbeeda9308029a8e8db56bed07d48f4d5718efb Mon Sep 17 00:00:00 2001
From: Huy Tran
Date: Mon, 14 Jun 2021 13:06:40 +0200
Subject: [MRG] Batch/Lazy Log Sinkhorn Knopp on samples (#259)
* Add batch implementation of Sinkhorn
* Reformat to pep8 and modify parameter
* Fix error in batch size
* Code review and add test
* Fix accidental typo in test_empirical_sinkhorn
* Remove whitespace
* Edit config.yml
---
.circleci/config.yml | 1 +
ot/bregman.py | 134 +++++++++++++++++++++++++++++++++++++++++++--------
test/test_bregman.py | 44 +++++++++++++++++
3 files changed, 158 insertions(+), 21 deletions(-)
(limited to 'test')
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 29c9a07..e4c71dd 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -73,6 +73,7 @@ jobs:
command: |
cd docs;
make html;
+ no_output_timeout: 30m
# Save the outputs
- store_artifacts:
diff --git a/ot/bregman.py b/ot/bregman.py
index b10effd..105b38b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -11,6 +11,7 @@ Bregman projections solvers for entropic regularized OT
# Mokhtar Z. Alaya
# Alexander Tong
# Ievgen Redko
+# Quang Huy Tran
#
# License: MIT License
@@ -18,6 +19,7 @@ import warnings
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
+from scipy.special import logsumexp
from ot.utils import unif, dist, list_to_array
from .backend import get_backend
@@ -1684,7 +1686,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9, verbose=False,
+ numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
@@ -1723,6 +1725,12 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
+ If False, calculate full cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batcheses used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -1758,24 +1766,78 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
-
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = unif(ns)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = unif(nt)
+
+ if isLazy:
+ if log:
+ dict_log = {"err": []}
- M = dist(X_s, X_t, metric=metric)
+ log_a, log_b = np.log(a), np.log(b)
+ f, g = np.zeros(ns), np.zeros(nt)
+
+ if isinstance(batchSize, int):
+ bs, bt = batchSize, batchSize
+ elif isinstance(batchSize, tuple) and len(batchSize) == 2:
+ bs, bt = batchSize[0], batchSize[1]
+ else:
+ raise ValueError("Batch size must be in integer or a tuple of two integers")
+
+ range_s, range_t = range(0, ns, bs), range(0, nt, bt)
+
+ lse_f = np.zeros(ns)
+ lse_g = np.zeros(nt)
+
+ for i_ot in range(numIterMax):
+
+ for i in range_s:
+ M = dist(X_s[i:i + bs, :], X_t, metric=metric)
+ lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1)
+ f = log_a - lse_f
+
+ for j in range_t:
+ M = dist(X_s, X_t[j:j + bt, :], metric=metric)
+ lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0)
+ g = log_b - lse_g
+
+ if (i_ot + 1) % 10 == 0:
+ m1 = np.zeros_like(a)
+ for i in range_s:
+ M = dist(X_s[i:i + bs, :], X_t, metric=metric)
+ m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1)
+ err = np.abs(m1 - a).sum()
+ if log:
+ dict_log["err"].append(err)
+
+ if verbose and (i_ot + 1) % 100 == 0:
+ print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
+
+ if err <= stopThr:
+ break
+
+ if log:
+ dict_log["u"] = f
+ dict_log["v"] = g
+ return (f, g, dict_log)
+ else:
+ return (f, g)
- if log:
- pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
- return pi, log
else:
- pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
- return pi
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
+ return pi, log
+ else:
+ pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
+ return pi
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, **kwargs):
+ isLazy=False, batchSize=100, verbose=False, log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -1814,6 +1876,12 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
+ If False, calculate full cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batcheses used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -1850,21 +1918,45 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = unif(ns)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = unif(nt)
- M = dist(X_s, X_t, metric=metric)
+ if isLazy:
+ if log:
+ f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
+ isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)
+ else:
+ f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
+ isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)
+
+ bs = batchSize if isinstance(batchSize, int) else batchSize[0]
+ range_s = range(0, ns, bs)
+
+ loss = 0
+ for i in range_s:
+ M_block = dist(X_s[i:i + bs, :], X_t, metric=metric)
+ pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
+ loss += np.sum(M_block * pi_block)
+
+ if log:
+ return loss, dict_log
+ else:
+ return loss
- if log:
- sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss, log
else:
- sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss
+ M = dist(X_s, X_t, metric=metric)
+
+ if log:
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ return sinkhorn_loss, log
+ else:
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ return sinkhorn_loss
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 7c5162a..9665229 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -2,6 +2,7 @@
# Author: Remi Flamary
# Kilian Fatras
+# Quang Huy Tran
#
# License: MIT License
@@ -329,6 +330,49 @@ def test_empirical_sinkhorn():
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+def test_lazy_empirical_sinkhorn():
+ # test sinkhorn
+ n = 100
+ a = ot.unif(n)
+ b = ot.unif(n)
+ numIterMax = 1000
+
+ X_s = np.reshape(np.arange(n), (n, 1))
+ X_t = np.reshape(np.arange(0, n), (n, 1))
+ M = ot.dist(X_s, X_t)
+ M_m = ot.dist(X_s, X_t, metric='minkowski')
+
+ f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True)
+ G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
+ sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+
+ f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
+ sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+
+ f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
+ sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(
+ sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
+ np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
+
+
def test_empirical_sinkhorn_divergence():
# Test sinkhorn divergence
n = 10
--
cgit v1.2.3
From 8ef3341a472909f223ec0f678f11f136f55c1406 Mon Sep 17 00:00:00 2001
From: Rémi Flamary
Date: Thu, 17 Jun 2021 11:46:37 +0200
Subject: [MRG] Speedup tests (#262)
* speedup tests
* add color to tests and timings
* add test unbalanced
* stupid missing -
---
.github/workflows/build_tests.yml | 8 ++++----
Makefile | 4 ++--
test/test_bregman.py | 7 ++++---
test/test_da.py | 8 ++++----
test/test_gromov.py | 15 +++++++++------
test/test_optim.py | 6 +++---
test/test_stochastic.py | 40 +++++++++++++++++++--------------------
test/test_unbalanced.py | 33 ++++++++++++++++++++++++++++++--
8 files changed, 77 insertions(+), 44 deletions(-)
(limited to 'test')
diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index 92a07b5..fd0ade6 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -40,7 +40,7 @@ jobs:
pip install -e .
- name: Run tests
run: |
- python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
+ python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
- name: Upload codecov
run: |
codecov
@@ -95,7 +95,7 @@ jobs:
pip install -e .
- name: Run tests
run: |
- python -m pytest -v test/ ot/ --ignore ot/gpu/
+ python -m pytest --durations=20 -v test/ ot/ --ignore ot/gpu/ --color=yes
macos:
@@ -122,7 +122,7 @@ jobs:
pip install -e .
- name: Run tests
run: |
- python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
+ python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
windows:
@@ -150,4 +150,4 @@ jobs:
python -m pip install -e .
- name: Run tests
run: |
- python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
+ python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
diff --git a/Makefile b/Makefile
index 32332b4..315218d 100644
--- a/Makefile
+++ b/Makefile
@@ -45,10 +45,10 @@ pep8 :
flake8 examples/ ot/ test/
test : FORCE pep8
- $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/
+ $(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/
pytest : FORCE
- $(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/
+ $(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/
release :
twine upload dist/*
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 9665229..88166a5 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -293,7 +293,7 @@ def test_unmix():
def test_empirical_sinkhorn():
# test sinkhorn
- n = 100
+ n = 10
a = ot.unif(n)
b = ot.unif(n)
@@ -332,7 +332,7 @@ def test_empirical_sinkhorn():
def test_lazy_empirical_sinkhorn():
# test sinkhorn
- n = 100
+ n = 10
a = ot.unif(n)
b = ot.unif(n)
numIterMax = 1000
@@ -342,7 +342,7 @@ def test_lazy_empirical_sinkhorn():
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='minkowski')
- f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True)
+ f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
@@ -458,6 +458,7 @@ def test_implemented_methods():
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+@pytest.mark.filterwarnings("ignore:Bottleneck")
def test_screenkhorn():
# test screenkhorn
rng = np.random.RandomState(0)
diff --git a/test/test_da.py b/test/test_da.py
index 52c6a48..44bb2e9 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -106,8 +106,8 @@ def test_sinkhorn_l1l2_transport_class():
"""test_sinkhorn_transport
"""
- ns = 150
- nt = 200
+ ns = 50
+ nt = 100
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
@@ -448,8 +448,8 @@ def test_mapping_transport_class():
"""test_mapping_transport
"""
- ns = 60
- nt = 120
+ ns = 20
+ nt = 30
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 81138ca..56414a8 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -9,6 +9,8 @@
import numpy as np
import ot
+import pytest
+
def test_gromov():
n_samples = 50 # nb samples
@@ -128,9 +130,10 @@ def test_gromov_barycenter():
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+@pytest.mark.filterwarnings("ignore:divide")
def test_gromov_entropic_barycenter():
- ns = 50
- nt = 60
+ ns = 20
+ nt = 30
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
@@ -138,19 +141,19 @@ def test_gromov_entropic_barycenter():
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
- n_samples = 3
+ n_samples = 2
Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
[ot.unif(ns), ot.unif(nt)
], ot.unif(n_samples), [.5, .5],
- 'square_loss', 2e-3,
- max_iter=100, tol=1e-3,
+ 'square_loss', 1e-3,
+ max_iter=50, tol=1e-5,
verbose=True)
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
[ot.unif(ns), ot.unif(nt)
], ot.unif(n_samples), [.5, .5],
- 'kl_loss', 2e-3,
+ 'kl_loss', 1e-3,
max_iter=100, tol=1e-3)
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
diff --git a/test/test_optim.py b/test/test_optim.py
index 48de38a..fd194c2 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -37,8 +37,8 @@ def test_conditional_gradient():
np.testing.assert_allclose(b, G.sum(0))
-def test_conditional_gradient2():
- n = 1000 # nb samples
+def test_conditional_gradient_itermax():
+ n = 100 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
@@ -63,7 +63,7 @@ def test_conditional_gradient2():
reg = 1e-1
- G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000,
+ G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000,
verbose=True, log=True)
np.testing.assert_allclose(a, G.sum(1))
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 155622c..98e93ec 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -30,7 +30,7 @@ import ot
def test_stochastic_sag():
# test sag
- n = 15
+ n = 10
reg = 1
numItermax = 30000
rng = np.random.RandomState(0)
@@ -45,9 +45,9 @@ def test_stochastic_sag():
# check constratints
np.testing.assert_allclose(
- u, G.sum(1), atol=1e-04) # cf convergence sag
+ u, G.sum(1), atol=1e-03) # cf convergence sag
np.testing.assert_allclose(
- u, G.sum(0), atol=1e-04) # cf convergence sag
+ u, G.sum(0), atol=1e-03) # cf convergence sag
#############################################################################
@@ -60,9 +60,9 @@ def test_stochastic_sag():
def test_stochastic_asgd():
# test asgd
- n = 15
+ n = 10
reg = 1
- numItermax = 100000
+ numItermax = 10000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -75,9 +75,9 @@ def test_stochastic_asgd():
# check constratints
np.testing.assert_allclose(
- u, G.sum(1), atol=1e-03) # cf convergence asgd
+ u, G.sum(1), atol=1e-02) # cf convergence asgd
np.testing.assert_allclose(
- u, G.sum(0), atol=1e-03) # cf convergence asgd
+ u, G.sum(0), atol=1e-02) # cf convergence asgd
#############################################################################
@@ -90,9 +90,9 @@ def test_stochastic_asgd():
def test_sag_asgd_sinkhorn():
# test all algorithms
- n = 15
+ n = 10
reg = 1
- nb_iter = 100000
+ nb_iter = 10000
rng = np.random.RandomState(0)
x = rng.randn(n, 2)
@@ -107,17 +107,17 @@ def test_sag_asgd_sinkhorn():
# check constratints
np.testing.assert_allclose(
- G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
- G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-02)
np.testing.assert_allclose(
- G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
- G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-02)
np.testing.assert_allclose(
- G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag
+ G_sag, G_sinkhorn, atol=1e-02) # cf convergence sag
np.testing.assert_allclose(
- G_asgd, G_sinkhorn, atol=1e-03) # cf convergence asgd
+ G_asgd, G_sinkhorn, atol=1e-02) # cf convergence asgd
#############################################################################
@@ -136,7 +136,7 @@ def test_stochastic_dual_sgd():
# test sgd
n = 10
reg = 1
- numItermax = 15000
+ numItermax = 5000
batch_size = 10
rng = np.random.RandomState(0)
@@ -167,7 +167,7 @@ def test_dual_sgd_sinkhorn():
# test all dual algorithms
n = 10
reg = 1
- nb_iter = 15000
+ nb_iter = 5000
batch_size = 10
rng = np.random.RandomState(0)
@@ -183,11 +183,11 @@ def test_dual_sgd_sinkhorn():
# check constratints
np.testing.assert_allclose(
- G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
+ G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
- G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
+ G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-02)
np.testing.assert_allclose(
- G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
+ G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd
# Test gaussian
n = 30
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index dfeaad9..e8349d1 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -115,7 +115,8 @@ def test_stabilized_vs_sinkhorn():
G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
method="sinkhorn_stabilized",
reg_m=reg_m,
- log=True)
+ log=True,
+ verbose=True)
G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
method="sinkhorn", log=True)
@@ -138,7 +139,7 @@ def test_unbalanced_barycenter(method):
reg_m = 1.
q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method=method, log=True)
+ method=method, log=True, verbose=True)
# check fixed point equations
fi = reg_m / (reg_m + epsilon)
logA = np.log(A + 1e-16)
@@ -173,6 +174,7 @@ def test_barycenter_stabilized_vs_sinkhorn():
reg_m=reg_m, log=True,
tau=100,
method="sinkhorn_stabilized",
+ verbose=True
)
q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
method="sinkhorn",
@@ -182,6 +184,33 @@ def test_barycenter_stabilized_vs_sinkhorn():
q, qstable, atol=1e-05)
+def test_wrong_method():
+
+ n = 10
+ rng = np.random.RandomState(42)
+
+ x = rng.randn(n, 2)
+ a = ot.utils.unif(n)
+
+ # make dists unbalanced
+ b = ot.utils.unif(n) * 1.5
+
+ M = ot.dist(x, x)
+ epsilon = 1.
+ reg_m = 1.
+
+ with pytest.raises(ValueError):
+ ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
+ reg_m=reg_m,
+ method='badmethod',
+ log=True,
+ verbose=True)
+ with pytest.raises(ValueError):
+ ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
+ method='badmethod',
+ verbose=True)
+
+
def test_implemented_methods():
IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
--
cgit v1.2.3
From 96bf1a46e74d6985419e14222afb0b9241a7bb36 Mon Sep 17 00:00:00 2001
From: Minhui Huang <32522773+mhhuang95@users.noreply.github.com>
Date: Mon, 6 Sep 2021 08:06:50 -0700
Subject: [MRG] Projection Robust Wasserstein (#267)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* ot.dr: PRW code; text.text_dr: PRW test code.
* ot.dr: PRW code; test.test_dr: PRW test code.
* fix errors: pep8(3.8)
* fix errors: pep8(3.8)
* modified readme; prw code review
* fix pep error
* edit comment
* modified math comment
Co-authored-by: Rémi Flamary
---
README.md | 3 ++
ot/dr.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++
test/test_dr.py | 37 ++++++++++++++++++
3 files changed, 154 insertions(+)
(limited to 'test')
diff --git a/README.md b/README.md
index 20e0606..6a2cf15 100644
--- a/README.md
+++ b/README.md
@@ -198,6 +198,7 @@ The contributors to this library are
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
+* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
@@ -283,3 +284,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+
+[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML).
diff --git a/ot/dr.py b/ot/dr.py
index b7a1af0..64588cf 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -10,6 +10,7 @@ Dimension reduction with OT
"""
# Author: Remi Flamary
+# Minhui Huang
#
# License: MIT License
@@ -198,3 +199,116 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
return (X - mx.reshape((1, -1))).dot(Popt)
return Popt, proj
+
+
+def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
+ r"""
+ Projection Robust Wasserstein Distance [32]
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi)
+
+ - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold
+ - :math:`H(\pi)` is entropy regularizer
+ - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, d)
+ Samples from measure \mu
+ Y : ndarray, shape (n, d)
+ Samples from measure \nu
+ a : ndarray, shape (n, )
+ weights for measure \mu
+ b : ndarray, shape (n, )
+ weights for measure \nu
+ tau : float
+ stepsize for Riemannian Gradient Descent
+ U0 : ndarray, shape (d, p)
+ Initial starting point for projection.
+ reg : float, optional
+ Regularization term >0 (entropic regularization)
+ k : int
+ Subspace dimension
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : int, optional
+ Print information along iterations.
+
+ Returns
+ -------
+ pi : ndarray, shape (n, n)
+ Optimal transportation matrix for the given parameters
+ U : ndarray, shape (d, k)
+ Projection operator.
+
+ References
+ ----------
+ .. [32] Huang, M. , Ma S. & Lai L. (2021).
+ A Riemannian Block Coordinate Descent Method for Computing
+ the Projection Robust Wasserstein Distance, ICML.
+ """ # noqa
+
+ # initialization
+ n, d = X.shape
+ m, d = Y.shape
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ u = np.ones(n) / n
+ v = np.ones(m) / m
+ ones = np.ones((n, m))
+
+ assert d > k
+
+ if U0 is None:
+ U = np.random.randn(d, k)
+ U, _ = np.linalg.qr(U)
+ else:
+ U = U0
+
+ def Vpi(X, Y, a, b, pi):
+ # Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }.
+ A = X.T.dot(pi).dot(Y)
+ return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T
+
+ err = 1
+ iter = 0
+
+ while err > stopThr and iter < maxiter:
+
+ # Projected cost matrix
+ UUT = U.dot(U.T)
+ M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot(
+ np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T))
+
+ A = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=A)
+ np.exp(A, out=A)
+
+ # Sinkhorn update
+ Ap = (1 / a).reshape(-1, 1) * A
+ AtransposeU = np.dot(A.T, u)
+ v = np.divide(b, AtransposeU)
+ u = 1. / np.dot(Ap, v)
+ pi = u.reshape((-1, 1)) * A * v.reshape((1, -1))
+
+ V = Vpi(X, Y, a, b, pi)
+
+ # Riemannian gradient descent
+ G = 2 / reg * V.dot(U)
+ GTU = G.T.dot(U)
+ xi = G - U.dot(GTU + GTU.T) / 2 # Riemannian gradient
+ U, _ = np.linalg.qr(U + tau * xi) # Retraction by QR decomposition
+
+ grad_norm = np.linalg.norm(xi)
+ err = max(reg * grad_norm, np.linalg.norm(np.sum(pi, 0) - b, 1))
+
+ f_val = np.trace(U.T.dot(V.dot(U)))
+ if verbose:
+ print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val)
+
+ iter = iter + 1
+
+ return pi, U
diff --git a/test/test_dr.py b/test/test_dr.py
index c5df287..fa75a18 100644
--- a/test/test_dr.py
+++ b/test/test_dr.py
@@ -1,6 +1,7 @@
"""Tests for module dr on Dimensionality Reduction """
# Author: Remi Flamary
+# Minhui Huang
#
# License: MIT License
@@ -57,3 +58,39 @@ def test_wda():
projwda(xs)
np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))
+
+
+@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
+def test_prw():
+ d = 100 # Dimension
+ n = 100 # Number samples
+ k = 3 # Subspace dimension
+ dim = 3
+
+ def fragmented_hypercube(n, d, dim):
+ assert dim <= d
+ assert dim >= 1
+ assert dim == int(dim)
+
+ a = (1. / n) * np.ones(n)
+ b = (1. / n) * np.ones(n)
+
+ # First measure : uniform on the hypercube
+ X = np.random.uniform(-1, 1, size=(n, d))
+
+ # Second measure : fragmentation
+ tmp_y = np.random.uniform(-1, 1, size=(n, d))
+ Y = tmp_y + 2 * np.sign(tmp_y) * np.array(dim * [1] + (d - dim) * [0])
+ return a, b, X, Y
+
+ a, b, X, Y = fragmented_hypercube(n, d, dim)
+
+ tau = 0.002
+ reg = 0.2
+
+ pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, reg=reg, k=k, maxiter=1000, verbose=1)
+
+ U0 = np.random.randn(d, k)
+ U0, _ = np.linalg.qr(U0)
+
+ pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, U0=U0, reg=reg, k=k, maxiter=1000, verbose=1)
--
cgit v1.2.3
From e0ba31ce39a7d9e65e50ea970a574b3db54e4207 Mon Sep 17 00:00:00 2001
From: Tanguy
Date: Fri, 17 Sep 2021 18:36:33 +0200
Subject: [MRG] Implementation of two news algorithms: SaGroW and PoGroW.
(#275)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* Add two new algorithms to solve Gromov Wasserstein: Sampled Gromov Wasserstein and Pointwise Gromov Wasserstein.
* Correct some lines in SaGroW and PoGroW to follow pep8 guide.
* Change nb_samples name. Use rdm state. Change symmetric check.
* Change names of len(p) and len(q) in SaGroW and PoGroW.
* Re-add some deleted lines in the comments of gromov.py
Co-authored-by: Rémi Flamary
---
README.md | 4 +
examples/gromov/plot_gromov.py | 34 ++++
ot/gromov.py | 376 +++++++++++++++++++++++++++++++++++++++++
test/test_gromov.py | 88 +++++++++-
4 files changed, 496 insertions(+), 6 deletions(-)
(limited to 'test')
diff --git a/README.md b/README.md
index 6a2cf15..266d847 100644
--- a/README.md
+++ b/README.md
@@ -28,6 +28,7 @@ POT provides the following generic OT solvers (links to examples):
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12])
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
* [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
+* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20].
* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
@@ -198,6 +199,7 @@ The contributors to this library are
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
+* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
@@ -286,3 +288,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML).
+
+[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021
diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py
index deb2f86..5a362cf 100644
--- a/examples/gromov/plot_gromov.py
+++ b/examples/gromov/plot_gromov.py
@@ -104,3 +104,37 @@ pl.imshow(gw, cmap='jet')
pl.title('Entropic Gromov Wasserstein')
pl.show()
+
+#############################################################################
+#
+# Compute GW with a scalable stochastic method with any loss function
+# ----------------------------------------------------------------------
+
+
+def loss(x, y):
+ return np.abs(x - y)
+
+
+pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100,
+ log=True)
+
+sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100,
+ log=True)
+
+print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated']))
+print('Variance estimated: ' + str(plog['gw_dist_std']))
+print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated']))
+print('Variance estimated: ' + str(slog['gw_dist_std']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(pgw.toarray(), cmap='jet')
+pl.title('Pointwise Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
+pl.imshow(sgw, cmap='jet')
+pl.title('Sampled Gromov Wasserstein')
+
+pl.show()
diff --git a/ot/gromov.py b/ot/gromov.py
index 8f457e9..a27217a 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -16,6 +16,10 @@ import numpy as np
from .bregman import sinkhorn
from .utils import dist, UndefinedParameter
from .optim import cg
+from .lp import emd_1d, emd
+from .utils import check_random_state
+
+from scipy.sparse import issparse
def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
@@ -572,6 +576,378 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
return log['fgw_dist']
+def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
+ nb_samples_p=None, nb_samples_q=None, std=True, random_state=None):
+ r"""
+ Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q)
+ with a fixed transport plan T.
+
+ The function gives an unbiased approximation of the following equation:
+
+ .. math::
+ GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+ Where :
+
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - L : Loss function to account for the misfit between the similarity matrices
+ - T : Matrix with marginal p and q
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R}
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ T : csr or ndarray, shape (ns, nt)
+ Transport plan matrix, either a sparse csr matrix or
+ nb_samples_p : int, optional
+ nb_samples_p is the number of samples (without replacement) along the first dimension of T.
+ nb_samples_q : int, optional
+ nb_samples_q is the number of samples along the second dimension of T, for each sample along the first.
+ std : bool, optional
+ Standard deviation associated with the prediction of the gromov-wasserstein cost.
+ random_state : int or RandomState instance, optional
+ Fix the seed for to allow reproducibility
+
+ Returns
+ -------
+ : float
+ Gromov-wasserstein cost
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ generator = check_random_state(random_state)
+
+ len_p = len(p)
+ len_q = len(q)
+
+ # It is always better to sample from the biggest distribution first.
+ if len_p < len_q:
+ p, q = q, p
+ len_p, len_q = len_q, len_p
+ C1, C2 = C2, C1
+ T = T.T
+
+ if nb_samples_p is None:
+ if issparse(T):
+ # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced
+ nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p)
+ else:
+ nb_samples_p = len_p
+ else:
+ # The number of sample along the first dimension is without replacement.
+ nb_samples_p = min(nb_samples_p, len_p)
+ if nb_samples_q is None:
+ nb_samples_q = 1
+ if std:
+ nb_samples_q = max(2, nb_samples_q)
+
+ index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+ index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+ list_value_sample = np.zeros((nb_samples_p, nb_samples_p, nb_samples_q))
+
+ index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+ index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+
+ for i in range(nb_samples_p):
+ if issparse(T):
+ T_indexi = T[index_i[i], :].toarray()[0]
+ T_indexj = T[index_j[i], :].toarray()[0]
+ else:
+ T_indexi = T[index_i[i], :]
+ T_indexj = T[index_j[i], :]
+ # For each of the row sampled, the column is sampled.
+ index_k[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexi / T_indexi.sum(), replace=True)
+ index_l[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexj / T_indexj.sum(), replace=True)
+
+ for n in range(nb_samples_q):
+ list_value_sample[:, :, n] = loss_fun(C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])])
+
+ if std:
+ std_value = np.sum(np.std(list_value_sample, axis=2) ** 2) ** 0.5
+ return np.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p)
+ else:
+ return np.mean(list_value_sample)
+
+
+def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe.
+ This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+ s.t. T 1 = p
+
+ T^T 1= q
+
+ T\geq 0
+
+ Where :
+
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - p : distribution in the source space
+ - q : distribution in the target space
+ - L : loss function to account for the misfit between the similarity matrices
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R}
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ alpha : float
+ Step of the Frank-Wolfe algorithm, should be between 0 and 1
+ max_iter : int, optional
+ Max number of iterations
+ threshold_plan : float, optional
+ Deleting very small value in the transport plan. If above zero, it violate the marginal constraints.
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for to allow reproducibility
+
+ Returns
+ -------
+ T : ndarray, shape (ns, nt)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1 = np.asarray(C1, dtype=np.float64)
+ C2 = np.asarray(C2, dtype=np.float64)
+ p = np.asarray(p, dtype=np.float64)
+ q = np.asarray(q, dtype=np.float64)
+ len_p = len(p)
+ len_q = len(q)
+
+ generator = check_random_state(random_state)
+
+ index = np.zeros(2, dtype=int)
+
+ # Initialize with default marginal
+ index[0] = generator.choice(len_p, size=1, p=p)
+ index[1] = generator.choice(len_q, size=1, p=q)
+ T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr()
+
+ best_gw_dist_estimated = np.inf
+ for cpt in range(max_iter):
+ index[0] = generator.choice(len_p, size=1, p=p)
+ T_index0 = T[index[0], :].toarray()[0]
+ index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum())
+
+ if alpha == 1:
+ T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr()
+ else:
+ new_T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr()
+ T = (1 - alpha) * T + alpha * new_T
+ # To limit the number of non 0, the values bellow the threshold are set to 0.
+ T.data[T.data < threshold_plan] = 0
+ T.eliminate_zeros()
+
+ if cpt % 10 == 0 or cpt == (max_iter - 1):
+ gw_dist_estimated = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, std=False, random_state=generator)
+
+ if gw_dist_estimated < best_gw_dist_estimated:
+ best_gw_dist_estimated = gw_dist_estimated
+ best_T = T.copy()
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated))
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=best_T,
+ random_state=generator)
+ return best_T, log
+ return best_T
+
+
+def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False,
+ random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe.
+ This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+
+ s.t. T 1 = p
+
+ T^T 1= q
+
+ T\geq 0
+
+ Where :
+
+ - C1 : Metric cost matrix in the source space
+ - C2 : Metric cost matrix in the target space
+ - p : distribution in the source space
+ - q : distribution in the target space
+ - L : loss function to account for the misfit between the similarity matrices
+
+ Parameters
+ ----------
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric costfr matrix in the target space
+ p : ndarray, shape (ns,)
+ Distribution in the source space
+ q : ndarray, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R}
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ nb_samples_grad : int
+ Number of samples to approximate the gradient
+ epsilon : float
+ Weight of the Kullback-Leiber regularization
+ max_iter : int, optional
+ Max number of iterations
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for to allow reproducibility
+
+ Returns
+ -------
+ T : ndarray, shape (ns, nt)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1 = np.asarray(C1, dtype=np.float64)
+ C2 = np.asarray(C2, dtype=np.float64)
+ p = np.asarray(p, dtype=np.float64)
+ q = np.asarray(q, dtype=np.float64)
+ len_p = len(p)
+ len_q = len(q)
+
+ generator = check_random_state(random_state)
+
+ # The most natural way to define nb_sample is with a simple integer.
+ if isinstance(nb_samples_grad, int):
+ if nb_samples_grad > len_p:
+ # As the sampling along the first dimension is done without replacement, the rest is reported to the second
+ # dimension.
+ nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad
+ T = np.outer(p, q)
+ # continue_loop allows to stop the loop if there is several successive small modification of T.
+ continue_loop = 0
+
+ # The gradient of GW is more complex if the two matrices are not symmetric.
+ C_are_symmetric = np.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and np.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
+
+ for cpt in range(max_iter):
+ index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False)
+ Lik = 0
+ for i, index0_i in enumerate(index0):
+ index1 = generator.choice(len_q,
+ size=nb_samples_grad_q,
+ p=T[index0_i, :] / T[index0_i, :].sum(),
+ replace=False)
+ # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
+ if (not C_are_symmetric) and generator.rand(1) > 0.5:
+ Lik += np.mean(loss_fun(np.expand_dims(C1[:, np.repeat(index0[i], nb_samples_grad_q)], 1),
+ np.expand_dims(C2[:, index1], 0)),
+ axis=2)
+ else:
+ Lik += np.mean(loss_fun(np.expand_dims(C1[np.repeat(index0[i], nb_samples_grad_q), :], 2),
+ np.expand_dims(C2[index1, :], 1)),
+ axis=0)
+
+ max_Lik = np.max(Lik)
+ if max_Lik == 0:
+ continue
+ # This division by the max is here to facilitate the choice of epsilon.
+ Lik /= max_Lik
+
+ if epsilon > 0:
+ # Set to infinity all the numbers bellow exp(-200) to avoid log of 0.
+ log_T = np.log(np.clip(T, np.exp(-200), 1))
+ log_T[log_T == -200] = -np.inf
+ Lik = Lik - epsilon * log_T
+
+ try:
+ new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon)
+ except (RuntimeWarning, UserWarning):
+ print("Warning catched in Sinkhorn: Return last stable T")
+ break
+ else:
+ new_T = emd(a=p, b=q, M=Lik)
+
+ change_T = ((T - new_T) ** 2).mean()
+ if change_T <= 10e-20:
+ continue_loop += 1
+ if continue_loop > 100: # Number max of low modifications of T
+ T = new_T.copy()
+ break
+ else:
+ continue_loop = 0
+
+ if verbose and cpt % 10 == 0:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, change_T))
+ T = new_T.copy()
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, random_state=generator)
+ return T, log
+ return T
+
+
def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False):
r"""
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 56414a8..19d61b1 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -33,7 +33,7 @@ def test_gromov():
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
@@ -54,7 +54,7 @@ def test_gromov():
np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
@@ -83,7 +83,7 @@ def test_entropic_gromov():
G = ot.gromov.entropic_gromov_wasserstein(
C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
@@ -96,13 +96,89 @@ def test_entropic_gromov():
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+def test_pointwise_gromov():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ def loss(x, y):
+ return np.abs(x - y)
+
+ G, log = ot.gromov.pointwise_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42)
+
+ # check constraints
+ np.testing.assert_allclose(
+ p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov
+
+ assert log['gw_dist_estimated'] == 0.0
+ assert log['gw_dist_std'] == 0.0
+
+ G, log = ot.gromov.pointwise_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+
+ assert log['gw_dist_estimated'] == 0.10342276348494964
+ assert log['gw_dist_std'] == 0.0015952535464736394
+
+
+def test_sampled_gromov():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ def loss(x, y):
+ return np.abs(x - y)
+
+ G, log = ot.gromov.sampled_gromov_wasserstein(
+ C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+
+ # check constraints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+ assert log['gw_dist_estimated'] == 0.05679474884977278
+ assert log['gw_dist_std'] == 0.0005986592106971995
+
+
def test_gromov_barycenter():
ns = 50
nt = 60
@@ -186,7 +262,7 @@ def test_fgw():
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence fgw
np.testing.assert_allclose(
@@ -203,7 +279,7 @@ def test_fgw():
np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
--
cgit v1.2.3
From 7dde9e8e4b6aae756e103d49198caaa4f24150e3 Mon Sep 17 00:00:00 2001
From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>
Date: Tue, 28 Sep 2021 16:34:28 +0200
Subject: [MRG] Regularized OT (optim.cg) bug solve (#286)
* Line search stops when derphi is 0 instead of bugging out like in some instances
* pep8 compliance
* Tests
---
ot/optim.py | 10 ++++++----
test/test_da.py | 8 ++++++++
test/test_optim.py | 25 +++++++++++++++++++++++++
3 files changed, 39 insertions(+), 4 deletions(-)
(limited to 'test')
diff --git a/ot/optim.py b/ot/optim.py
index abe9e6a..0359343 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -178,9 +178,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
numItermaxEmd : int, optional
Max number of iterations for emd
stopThr : float, optional
- Stop threshol on the relative variation (>0)
+ Stop threshold on the relative variation (>0)
stopThr2 : float, optional
- Stop threshol on the absolute variation (>0)
+ Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -249,6 +249,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
# line search
alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
+ if alpha is None:
+ alpha = 0.0
G = G + alpha * deltaG
@@ -320,9 +322,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
numInnerItermax : int, optional
Max number of iterations of Sinkhorn
stopThr : float, optional
- Stop threshol on the relative variation (>0)
+ Stop threshold on the relative variation (>0)
stopThr2 : float, optional
- Stop threshol on the absolute variation (>0)
+ Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
diff --git a/test/test_da.py b/test/test_da.py
index 44bb2e9..9f2bb50 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -565,6 +565,14 @@ def test_mapping_transport_class():
otda.fit(Xs=Xs, Xt=Xt)
assert len(otda.log_.keys()) != 0
+ # check that it does not crash when derphi is very close to 0
+ np.random.seed(39)
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+ otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
+ otda.fit(Xs=Xs, Xt=Xt)
+ np.random.seed(None)
+
def test_linear_mapping():
ns = 150
diff --git a/test/test_optim.py b/test/test_optim.py
index fd194c2..94995d5 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -114,3 +114,28 @@ def test_line_search_armijo():
# Should not throw an exception and return None for alpha
alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
assert alpha is None
+
+ # check line search armijo
+ def f(x):
+ return np.sum((x - 5.0) ** 2)
+
+ def grad(x):
+ return 2 * (x - 5.0)
+
+ xk = np.array([[[-5.0, -5.0]]])
+ pk = np.array([[[100.0, 100.0]]])
+ gfk = grad(xk)
+ old_fval = f(xk)
+
+ # chech the case where the optimum is on the direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
+ np.testing.assert_allclose(alpha, 0.1)
+
+ # check the case where the direction is not far enough
+ pk = np.array([[[3.0, 3.0]]])
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
+ np.testing.assert_allclose(alpha, 1.0)
+
+ # check the case where the checking the wrong direction
+ alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
+ assert alpha <= 0
--
cgit v1.2.3
From 1c7e7ce2da8bb362c184fb6eae71fe7e36356494 Mon Sep 17 00:00:00 2001
From: kguerda-idris <84066930+kguerda-idris@users.noreply.github.com>
Date: Wed, 29 Sep 2021 15:29:31 +0200
Subject: [MRG] OpenMP support (#260)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* Added : OpenMP support
Restored : Epsilon and Debug mode
Replaced : parmap => multiprocessing is now replace by multithreading
* Commit clean up
* Number of CPUs correctly calculated on SLURM clusters
* Corrected number of processes for cluster slurm
* Mistake corrected
* parmap is now deprecated
* Now a different solver is used depending on the requested number of threads
* Tiny mistake corrected
* Folders are now in the ot library instead of at the root
* Helpers is now correctly placed
* Attempt to make compilation work smoothly
* OS compatible path
* NumThreads now defaults to 1
* Better flags
* Mistake corrected in case of OpenMP unavailability
* Revert OpenMP flags modification, which do not compile on Windows
* Test helper functions
* Helpers comments
* Documentation update
* File title corrected
* Warning no longer using print
* Last attempt for macos compilation
* pls work
* atempt
* solving a type error
* TypeError OpenMP
* Compilation finally working on Windows
* Bug solve, number of threads now correctly selected
* 64 bits solver to avoid overflows for bigger problems
* 64 bits EMD corrected
Co-authored-by: kguerda-idris
Co-authored-by: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>
Co-authored-by: ncassereau
Co-authored-by: Rémi Flamary
---
ot/helpers/openmp_helpers.py | 85 ++
ot/helpers/pre_build_helpers.py | 87 ++
ot/lp/EMD.h | 5 +-
ot/lp/EMD_wrapper.cpp | 124 ++-
ot/lp/__init__.py | 71 +-
ot/lp/emd_wrap.pyx | 9 +-
ot/lp/full_bipartitegraph.h | 27 +-
ot/lp/full_bipartitegraph_omp.h | 234 +++++
ot/lp/network_simplex_simple.h | 210 ++---
ot/lp/network_simplex_simple_omp.h | 1699 ++++++++++++++++++++++++++++++++++++
ot/utils.py | 38 +-
setup.py | 12 +-
test/test_helpers.py | 26 +
13 files changed, 2442 insertions(+), 185 deletions(-)
create mode 100644 ot/helpers/openmp_helpers.py
create mode 100644 ot/helpers/pre_build_helpers.py
create mode 100644 ot/lp/full_bipartitegraph_omp.h
create mode 100644 ot/lp/network_simplex_simple_omp.h
create mode 100644 test/test_helpers.py
(limited to 'test')
diff --git a/ot/helpers/openmp_helpers.py b/ot/helpers/openmp_helpers.py
new file mode 100644
index 0000000..a6ad38b
--- /dev/null
+++ b/ot/helpers/openmp_helpers.py
@@ -0,0 +1,85 @@
+"""Helpers for OpenMP support during the build."""
+
+# This code is adapted for a large part from the astropy openmp helpers, which
+# can be found at: https://github.com/astropy/extension-helpers/blob/master/extension_helpers/_openmp_helpers.py # noqa
+
+
+import os
+import sys
+import textwrap
+import subprocess
+
+from distutils.errors import CompileError, LinkError
+
+from pre_build_helpers import compile_test_program
+
+
+def get_openmp_flag(compiler):
+ """Get openmp flags for a given compiler"""
+
+ if hasattr(compiler, 'compiler'):
+ compiler = compiler.compiler[0]
+ else:
+ compiler = compiler.__class__.__name__
+
+ if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler):
+ omp_flag = ['/Qopenmp']
+ elif sys.platform == "win32":
+ omp_flag = ['/openmp']
+ elif sys.platform in ("darwin", "linux") and "icc" in compiler:
+ omp_flag = ['-qopenmp']
+ elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''):
+ omp_flag = []
+ else:
+ # Default flag for GCC and clang:
+ omp_flag = ['-fopenmp']
+ if sys.platform.startswith("darwin"):
+ omp_flag += ["-Xpreprocessor", "-lomp"]
+ return omp_flag
+
+
+def check_openmp_support():
+ """Check whether OpenMP test code can be compiled and run"""
+
+ code = textwrap.dedent(
+ """\
+ #include
+ #include
+ int main(void) {
+ #pragma omp parallel
+ printf("nthreads=%d\\n", omp_get_num_threads());
+ return 0;
+ }
+ """)
+
+ extra_preargs = os.getenv('LDFLAGS', None)
+ if extra_preargs is not None:
+ extra_preargs = extra_preargs.strip().split(" ")
+ extra_preargs = [
+ flag for flag in extra_preargs
+ if flag.startswith(('-L', '-Wl,-rpath', '-l'))]
+
+ extra_postargs = get_openmp_flag
+
+ try:
+ output, compile_flags = compile_test_program(
+ code,
+ extra_preargs=extra_preargs,
+ extra_postargs=extra_postargs
+ )
+
+ if output and 'nthreads=' in output[0]:
+ nthreads = int(output[0].strip().split('=')[1])
+ openmp_supported = len(output) == nthreads
+ elif "PYTHON_CROSSENV" in os.environ:
+ # Since we can't run the test program when cross-compiling
+ # assume that openmp is supported if the program can be
+ # compiled.
+ openmp_supported = True
+ else:
+ openmp_supported = False
+
+ except (CompileError, LinkError, subprocess.CalledProcessError):
+ openmp_supported = False
+ compile_flags = []
+ return openmp_supported, compile_flags
diff --git a/ot/helpers/pre_build_helpers.py b/ot/helpers/pre_build_helpers.py
new file mode 100644
index 0000000..93ecd6a
--- /dev/null
+++ b/ot/helpers/pre_build_helpers.py
@@ -0,0 +1,87 @@
+"""Helpers to check build environment before actual build of POT"""
+
+import os
+import sys
+import glob
+import tempfile
+import setuptools # noqa
+import subprocess
+
+from distutils.dist import Distribution
+from distutils.sysconfig import customize_compiler
+from numpy.distutils.ccompiler import new_compiler
+from numpy.distutils.command.config_compiler import config_cc
+
+
+def _get_compiler():
+ """Get a compiler equivalent to the one that will be used to build POT
+ Handles compiler specified as follows:
+ - python setup.py build_ext --compiler=
+ - CC= python setup.py build_ext
+ """
+ dist = Distribution({'script_name': os.path.basename(sys.argv[0]),
+ 'script_args': sys.argv[1:],
+ 'cmdclass': {'config_cc': config_cc}})
+
+ cmd_opts = dist.command_options.get('build_ext')
+ if cmd_opts is not None and 'compiler' in cmd_opts:
+ compiler = cmd_opts['compiler'][1]
+ else:
+ compiler = None
+
+ ccompiler = new_compiler(compiler=compiler)
+ customize_compiler(ccompiler)
+
+ return ccompiler
+
+
+def compile_test_program(code, extra_preargs=[], extra_postargs=[]):
+ """Check that some C code can be compiled and run"""
+ ccompiler = _get_compiler()
+
+ # extra_(pre/post)args can be a callable to make it possible to get its
+ # value from the compiler
+ if callable(extra_preargs):
+ extra_preargs = extra_preargs(ccompiler)
+ if callable(extra_postargs):
+ extra_postargs = extra_postargs(ccompiler)
+
+ start_dir = os.path.abspath('.')
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ try:
+ os.chdir(tmp_dir)
+
+ # Write test program
+ with open('test_program.c', 'w') as f:
+ f.write(code)
+
+ os.mkdir('objects')
+
+ # Compile, test program
+ ccompiler.compile(['test_program.c'], output_dir='objects',
+ extra_postargs=extra_postargs)
+
+ # Link test program
+ objects = glob.glob(
+ os.path.join('objects', '*' + ccompiler.obj_extension))
+ ccompiler.link_executable(objects, 'test_program',
+ extra_preargs=extra_preargs,
+ extra_postargs=extra_postargs)
+
+ if "PYTHON_CROSSENV" not in os.environ:
+ # Run test program if not cross compiling
+ # will raise a CalledProcessError if return code was non-zero
+ output = subprocess.check_output('./test_program')
+ output = output.decode(
+ sys.stdout.encoding or 'utf-8').splitlines()
+ else:
+ # Return an empty output if we are cross compiling
+ # as we cannot run the test_program
+ output = []
+ except Exception:
+ raise
+ finally:
+ os.chdir(start_dir)
+
+ return output, extra_postargs
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index c0fe7a3..8a1f9ac 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -18,19 +18,18 @@
#include
#include
-#include "network_simplex_simple.h"
-using namespace lemon;
typedef unsigned int node_id_type;
enum ProblemType {
INFEASIBLE,
OPTIMAL,
UNBOUNDED,
- MAX_ITER_REACHED
+ MAX_ITER_REACHED
};
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
+int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index bc873ed..2bdc172 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -12,16 +12,22 @@
*
*/
+
+#include "network_simplex_simple.h"
+#include "network_simplex_simple_omp.h"
#include "EMD.h"
+#include
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, int maxIter) {
- // beware M and C anre strored in row major C style!!!
- int n, m, i, cur;
+ // beware M and C are stored in row major C style!!!
+
+ using namespace lemon;
+ int n, m, cur;
typedef FullBipartiteDigraph Digraph;
- DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
+ DIGRAPH_TYPEDEFS(Digraph);
// Get the number of non zero coordinates for r and c
n=0;
@@ -48,7 +54,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
std::vector indI(n), indJ(m);
std::vector weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple net(di, true, n+m, n*m, maxIter);
+ NetworkSimplexSimple net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
// Set supply and demand, don't account for 0 values (faster)
@@ -76,10 +82,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
net.supplyMap(&weights1[0], n, &weights2[0], m);
// Set the cost of each edge
+ int64_t idarc = 0;
for (int i=0; i0) {
+ n++;
+ }else if(val<0){
+ return INFEASIBLE;
+ }
+ }
+ m=0;
+ for (int i=0; i0) {
+ m++;
+ }else if(val<0){
+ return INFEASIBLE;
+ }
+ }
+
+ // Define the graph
+
+ std::vector indI(n), indJ(m);
+ std::vector weights1(n), weights2(m);
+ Digraph di(n, m);
+ NetworkSimplexSimple net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
+
+ // Set supply and demand, don't account for 0 values (faster)
+
+ cur=0;
+ for (int i=0; i0) {
+ weights1[ cur ] = val;
+ indI[cur++]=i;
+ }
+ }
+
+ // Demand is actually negative supply...
+
+ cur=0;
+ for (int i=0; i0) {
+ weights2[ cur ] = -val;
+ indJ[cur++]=i;
+ }
+ }
+
+
+ net.supplyMap(&weights1[0], n, &weights2[0], m);
+
+ // Set the cost of each edge
+ int64_t idarc = 0;
+ for (int i=0; i 1:
- res = parmap(f, [b[:, i].copy() for i in range(nb)], processes)
- else:
- res = list(map(f, [b[:, i].copy() for i in range(nb)]))
+ warnings.warn(
+ "The 'processes' parameter has been deprecated. "
+ "Multiprocessing should be done outside of POT."
+ )
+ res = list(map(f, [b[:, i].copy() for i in range(nb)]))
return res
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
- stopThr=1e-7, verbose=False, log=None):
+ stopThr=1e-7, verbose=False, log=None, numThreads=1):
r"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally:
@@ -512,6 +541,10 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
Print information along iterations
log : bool, optional
record log if True
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
+
Returns
-------
@@ -551,7 +584,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
weights.tolist()):
M_i = dist(X, measure_locations_i)
- T_i = emd(b, measure_weights_i, M_i)
+ T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
displacement_square_norm = np.sum(np.square(T_sum - X))
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index de9a700..42e08f4 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -20,6 +20,7 @@ import warnings
cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
+ int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -38,7 +39,7 @@ def check_result(result_code):
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -109,8 +110,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
# calling the function
with nogil:
- result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter)
-
+ if numThreads == 1:
+ result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter)
+ else:
+ result_code = EMD_wrap_omp(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter, numThreads)
return G, cost, alpha, beta, result_code
diff --git a/ot/lp/full_bipartitegraph.h b/ot/lp/full_bipartitegraph.h
index 87a1bec..713ccb5 100644
--- a/ot/lp/full_bipartitegraph.h
+++ b/ot/lp/full_bipartitegraph.h
@@ -23,10 +23,10 @@
*
*/
-#ifndef LEMON_FULL_BIPARTITE_GRAPH_H
-#define LEMON_FULL_BIPARTITE_GRAPH_H
+#pragma once
#include "core.h"
+#include
///\ingroup graphs
///\file
@@ -44,16 +44,16 @@ namespace lemon {
//class Node;
typedef int Node;
//class Arc;
- typedef long long Arc;
+ typedef int64_t Arc;
protected:
int _node_num;
- long long _arc_num;
+ int64_t _arc_num;
FullBipartiteDigraphBase() {}
- void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = n1 * n2; _n1=n1; _n2=n2;}
+ void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;}
public:
@@ -65,25 +65,25 @@ namespace lemon {
Arc arc(const Node& s, const Node& t) const {
if (s<_n1 && t>=_n1)
- return Arc(s * _n2 + (t-_n1) );
+ return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) );
else
return Arc(-1);
}
int nodeNum() const { return _node_num; }
- long long arcNum() const { return _arc_num; }
+ int64_t arcNum() const { return _arc_num; }
int maxNodeId() const { return _node_num - 1; }
- long long maxArcId() const { return _arc_num - 1; }
+ int64_t maxArcId() const { return _arc_num - 1; }
Node source(Arc arc) const { return arc / _n2; }
Node target(Arc arc) const { return (arc % _n2) + _n1; }
static int id(Node node) { return node; }
- static long long id(Arc arc) { return arc; }
+ static int64_t id(Arc arc) { return arc; }
static Node nodeFromId(int id) { return Node(id);}
- static Arc arcFromId(int id) { return Arc(id);}
+ static Arc arcFromId(int64_t id) { return Arc(id);}
Arc findArc(Node s, Node t, Arc prev = -1) const {
@@ -136,7 +136,7 @@ namespace lemon {
///
/// \brief A directed full graph class.
///
- /// FullBipartiteDigraph is a simple and fast implmenetation of directed full
+ /// FullBipartiteDigraph is a simple and fast implementation of directed full
/// (complete) graphs. It contains an arc from each node to each node
/// (including a loop for each node), therefore the number of arcs
/// is the square of the number of nodes.
@@ -203,13 +203,10 @@ namespace lemon {
/// \brief Number of nodes.
int nodeNum() const { return Parent::nodeNum(); }
/// \brief Number of arcs.
- long long arcNum() const { return Parent::arcNum(); }
+ int64_t arcNum() const { return Parent::arcNum(); }
};
} //namespace lemon
-
-
-#endif //LEMON_FULL_GRAPH_H
diff --git a/ot/lp/full_bipartitegraph_omp.h b/ot/lp/full_bipartitegraph_omp.h
new file mode 100644
index 0000000..8cbed0b
--- /dev/null
+++ b/ot/lp/full_bipartitegraph_omp.h
@@ -0,0 +1,234 @@
+/* -*- mode: C++; indent-tabs-mode: nil; -*-
+ *
+ * This file has been adapted by Nicolas Bonneel (2013),
+ * from full_graph.h from LEMON, a generic C++ optimization library,
+ * to implement a lightweight fully connected bipartite graph. A previous
+ * version of this file is used as part of the Displacement Interpolation
+ * project,
+ * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/
+ *
+ *
+ **** Original file Copyright Notice :
+ * Copyright (C) 2003-2010
+ * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport
+ * (Egervary Research Group on Combinatorial Optimization, EGRES).
+ *
+ * Permission to use, modify and distribute this software is granted
+ * provided that this copyright notice appears in all copies. For
+ * precise terms see the accompanying LICENSE file.
+ *
+ * This software is provided "AS IS" with no warranty of any kind,
+ * express or implied, and with no claim as to its suitability for any
+ * purpose.
+ *
+ */
+
+#pragma once
+
+#include
+
+///\ingroup graphs
+///\file
+///\brief FullBipartiteDigraph and FullBipartiteGraph classes.
+
+
+namespace lemon_omp {
+
+ ///This \c \#define creates convenient type definitions for the following
+ ///types of \c Digraph: \c Node, \c NodeIt, \c Arc, \c ArcIt, \c InArcIt,
+ ///\c OutArcIt, \c BoolNodeMap, \c IntNodeMap, \c DoubleNodeMap,
+ ///\c BoolArcMap, \c IntArcMap, \c DoubleArcMap.
+ ///
+ ///\note If the graph type is a dependent type, ie. the graph type depend
+ ///on a template parameter, then use \c TEMPLATE_DIGRAPH_TYPEDEFS()
+ ///macro.
+#define DIGRAPH_TYPEDEFS(Digraph) \
+ typedef Digraph::Node Node; \
+ typedef Digraph::Arc Arc; \
+
+
+ ///Create convenience typedefs for the digraph types and iterators
+
+ ///\see DIGRAPH_TYPEDEFS
+ ///
+ ///\note Use this macro, if the graph type is a dependent type,
+ ///ie. the graph type depend on a template parameter.
+#define TEMPLATE_DIGRAPH_TYPEDEFS(Digraph) \
+ typedef typename Digraph::Node Node; \
+ typedef typename Digraph::Arc Arc; \
+
+
+ class FullBipartiteDigraphBase {
+ public:
+
+ typedef FullBipartiteDigraphBase Digraph;
+
+ //class Node;
+ typedef int Node;
+ //class Arc;
+ typedef int64_t Arc;
+
+ protected:
+
+ int _node_num;
+ int64_t _arc_num;
+
+ FullBipartiteDigraphBase() {}
+
+ void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;}
+
+ public:
+
+ int _n1, _n2;
+
+
+ Node operator()(int ix) const { return Node(ix); }
+ static int index(const Node& node) { return node; }
+
+ Arc arc(const Node& s, const Node& t) const {
+ if (s<_n1 && t>=_n1)
+ return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) );
+ else
+ return Arc(-1);
+ }
+
+ int nodeNum() const { return _node_num; }
+ int64_t arcNum() const { return _arc_num; }
+
+ int maxNodeId() const { return _node_num - 1; }
+ int64_t maxArcId() const { return _arc_num - 1; }
+
+ Node source(Arc arc) const { return arc / _n2; }
+ Node target(Arc arc) const { return (arc % _n2) + _n1; }
+
+ static int id(Node node) { return node; }
+ static int64_t id(Arc arc) { return arc; }
+
+ static Node nodeFromId(int id) { return Node(id);}
+ static Arc arcFromId(int64_t id) { return Arc(id);}
+
+
+ Arc findArc(Node s, Node t, Arc prev = -1) const {
+ return prev == -1 ? arc(s, t) : -1;
+ }
+
+ void first(Node& node) const {
+ node = _node_num - 1;
+ }
+
+ static void next(Node& node) {
+ --node;
+ }
+
+ void first(Arc& arc) const {
+ arc = _arc_num - 1;
+ }
+
+ static void next(Arc& arc) {
+ --arc;
+ }
+
+ void firstOut(Arc& arc, const Node& node) const {
+ if (node>=_n1)
+ arc = -1;
+ else
+ arc = (node + 1) * _n2 - 1;
+ }
+
+ void nextOut(Arc& arc) const {
+ if (arc % _n2 == 0) arc = 0;
+ --arc;
+ }
+
+ void firstIn(Arc& arc, const Node& node) const {
+ if (node<_n1)
+ arc = -1;
+ else
+ arc = _arc_num + node - _node_num;
+ }
+
+ void nextIn(Arc& arc) const {
+ arc -= _n2;
+ if (arc < 0) arc = -1;
+ }
+
+ };
+
+ /// \ingroup graphs
+ ///
+ /// \brief A directed full graph class.
+ ///
+ /// FullBipartiteDigraph is a simple and fast implmenetation of directed full
+ /// (complete) graphs. It contains an arc from each node to each node
+ /// (including a loop for each node), therefore the number of arcs
+ /// is the square of the number of nodes.
+ /// This class is completely static and it needs constant memory space.
+ /// Thus you can neither add nor delete nodes or arcs, however
+ /// the structure can be resized using resize().
+ ///
+ /// This type fully conforms to the \ref concepts::Digraph "Digraph concept".
+ /// Most of its member functions and nested classes are documented
+ /// only in the concept class.
+ ///
+ /// This class provides constant time counting for nodes and arcs.
+ ///
+ /// \note FullBipartiteDigraph and FullBipartiteGraph classes are very similar,
+ /// but there are two differences. While this class conforms only
+ /// to the \ref concepts::Digraph "Digraph" concept, FullBipartiteGraph
+ /// conforms to the \ref concepts::Graph "Graph" concept,
+ /// moreover FullBipartiteGraph does not contain a loop for each
+ /// node as this class does.
+ ///
+ /// \sa FullBipartiteGraph
+ class FullBipartiteDigraph : public FullBipartiteDigraphBase {
+ typedef FullBipartiteDigraphBase Parent;
+
+ public:
+
+ /// \brief Default constructor.
+ ///
+ /// Default constructor. The number of nodes and arcs will be zero.
+ FullBipartiteDigraph() { construct(0,0); }
+
+ /// \brief Constructor
+ ///
+ /// Constructor.
+ /// \param n The number of the nodes.
+ FullBipartiteDigraph(int n1, int n2) { construct(n1, n2); }
+
+
+ /// \brief Returns the node with the given index.
+ ///
+ /// Returns the node with the given index. Since this structure is
+ /// completely static, the nodes can be indexed with integers from
+ /// the range [0..nodeNum()-1].
+ /// The index of a node is the same as its ID.
+ /// \sa index()
+ Node operator()(int ix) const { return Parent::operator()(ix); }
+
+ /// \brief Returns the index of the given node.
+ ///
+ /// Returns the index of the given node. Since this structure is
+ /// completely static, the nodes can be indexed with integers from
+ /// the range [0..nodeNum()-1].
+ /// The index of a node is the same as its ID.
+ /// \sa operator()()
+ static int index(const Node& node) { return Parent::index(node); }
+
+ /// \brief Returns the arc connecting the given nodes.
+ ///
+ /// Returns the arc connecting the given nodes.
+ /*Arc arc(Node u, Node v) const {
+ return Parent::arc(u, v);
+ }*/
+
+ /// \brief Number of nodes.
+ int nodeNum() const { return Parent::nodeNum(); }
+ /// \brief Number of arcs.
+ int64_t arcNum() const { return Parent::arcNum(); }
+ };
+
+
+
+
+} //namespace lemon_omp
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 630b595..3b46b9b 100644
--- a/ot/lp/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
@@ -25,15 +25,17 @@
*
*/
-#ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H
-#define LEMON_NETWORK_SIMPLEX_SIMPLE_H
+#pragma once
+#undef DEBUG_LVL
#define DEBUG_LVL 0
#if DEBUG_LVL>0
#include
#endif
-
+#undef EPSILON
+#undef _EPSILON
+#undef MAX_DEBUG_ITER
#define EPSILON 2.2204460492503131e-15
#define _EPSILON 1e-8
#define MAX_DEBUG_ITER 100000
@@ -50,6 +52,7 @@
#include
#include
#include
+#include
#include
#ifdef HASHMAP
#include
@@ -63,6 +66,8 @@
//#include "sparse_array_n.h"
#include "full_bipartitegraph.h"
+#undef INVALIDNODE
+#undef INVALID
#define INVALIDNODE -1
#define INVALID (-1)
@@ -76,16 +81,16 @@ namespace lemon {
class SparseValueVector
{
public:
- SparseValueVector(int n=0)
+ SparseValueVector(size_t n=0)
{
}
- void resize(int n=0){};
- T operator[](const int id) const
+ void resize(size_t n=0){};
+ T operator[](const size_t id) const
{
#ifdef HASHMAP
- typename stdext::hash_map::const_iterator it = data.find(id);
+ typename stdext::hash_map::const_iterator it = data.find(id);
#else
- typename std::map::const_iterator it = data.find(id);
+ typename std::map::const_iterator it = data.find(id);
#endif
if (it==data.end())
return 0;
@@ -93,16 +98,16 @@ namespace lemon {
return it->second;
}
- ProxyObject operator[](const int id)
+ ProxyObject operator[](const size_t id)
{
return ProxyObject( this, id );
}
//private:
#ifdef HASHMAP
- stdext::hash_map data;
+ stdext::hash_map data;
#else
- std::map data;
+ std::map data;
#endif
};
@@ -110,7 +115,7 @@ namespace lemon {
template
class ProxyObject {
public:
- ProxyObject( SparseValueVector *v, int idx ){_v=v; _idx=idx;};
+ ProxyObject( SparseValueVector *v, size_t idx ){_v=v; _idx=idx;};
ProxyObject & operator=( const T &v ) {
// If we get here, we know that operator[] was called to perform a write access,
// so we can insert an item in the vector if needed
@@ -123,9 +128,9 @@ namespace lemon {
// If we get here, we know that operator[] was called to perform a read access,
// so we can simply return the existing object
#ifdef HASHMAP
- typename stdext::hash_map::iterator it = _v->data.find(_idx);
+ typename stdext::hash_map::iterator it = _v->data.find(_idx);
#else
- typename std::map::iterator it = _v->data.find(_idx);
+ typename std::map::iterator it = _v->data.find(_idx);
#endif
if (it==_v->data.end())
return 0;
@@ -137,9 +142,9 @@ namespace lemon {
{
if (val==0) return;
#ifdef HASHMAP
- typename stdext::hash_map::iterator it = _v->data.find(_idx);
+ typename stdext::hash_map::iterator it = _v->data.find(_idx);
#else
- typename std::map::iterator it = _v->data.find(_idx);
+ typename std::map::iterator it = _v->data.find(_idx);
#endif
if (it==_v->data.end())
_v->data[_idx] = val;
@@ -156,9 +161,9 @@ namespace lemon {
{
if (val==0) return;
#ifdef HASHMAP
- typename stdext::hash_map::iterator it = _v->data.find(_idx);
+ typename stdext::hash_map::iterator it = _v->data.find(_idx);
#else
- typename std::map::iterator it = _v->data.find(_idx);
+ typename std::map::iterator it = _v->data.find(_idx);
#endif
if (it==_v->data.end())
_v->data[_idx] = -val;
@@ -173,7 +178,7 @@ namespace lemon {
}
SparseValueVector *_v;
- int _idx;
+ size_t _idx;
};
@@ -204,7 +209,7 @@ namespace lemon {
///
/// \tparam GR The digraph type the algorithm runs on.
/// \tparam V The number type used for flow amounts, capacity bounds
- /// and supply values in the algorithm. By default, it is \c int.
+ /// and supply values in the algorithm. By default, it is \c int64_t.
/// \tparam C The number type used for costs and potentials in the
/// algorithm. By default, it is the same as \c V.
///
@@ -214,7 +219,7 @@ namespace lemon {
/// \note %NetworkSimplexSimple provides five different pivot rule
/// implementations, from which the most efficient one is used
/// by default. For more information, see \ref PivotRule.
- template
+ template
class NetworkSimplexSimple
{
public:
@@ -228,7 +233,7 @@ namespace lemon {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits::max()),
@@ -288,11 +293,11 @@ namespace lemon {
private:
- int max_iter;
+ size_t max_iter;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
typedef std::vector IntVector;
- typedef std::vector UHalfIntVector;
+ typedef std::vector ArcVector;
typedef std::vector ValueVector;
typedef std::vector CostVector;
// typedef SparseValueVector CostVector;
@@ -315,9 +320,9 @@ namespace lemon {
// Data related to the underlying digraph
const GR &_graph;
int _node_num;
- int _arc_num;
- int _all_arc_num;
- int _search_arc_num;
+ ArcsType _arc_num;
+ ArcsType _all_arc_num;
+ ArcsType _search_arc_num;
// Parameters of the problem
SupplyType _stype;
@@ -325,9 +330,9 @@ namespace lemon {
inline int _node_id(int n) const {return _node_num-n-1;} ;
- //IntArcMap _arc_id;
- UHalfIntVector _source;
- UHalfIntVector _target;
+// IntArcMap _arc_id;
+ IntVector _source; // keep nodes as integers
+ IntVector _target;
bool _arc_mixing;
public:
// Node and arc data
@@ -341,7 +346,7 @@ namespace lemon {
private:
// Data for storing the spanning tree structure
IntVector _parent;
- IntVector _pred;
+ ArcVector _pred;
IntVector _thread;
IntVector _rev_thread;
IntVector _succ_num;
@@ -349,17 +354,17 @@ namespace lemon {
IntVector _dirty_revs;
BoolVector _forward;
StateVector _state;
- int _root;
+ ArcsType _root;
// Temporary data used in the current pivot iteration
- int in_arc, join, u_in, v_in, u_out, v_out;
- int first, second, right, last;
- int stem, par_stem, new_stem;
+ ArcsType in_arc, join, u_in, v_in, u_out, v_out;
+ ArcsType first, second, right, last;
+ ArcsType stem, par_stem, new_stem;
Value delta;
const Value MAX;
- int mixingCoeff;
+ ArcsType mixingCoeff;
public:
@@ -373,27 +378,27 @@ namespace lemon {
private:
// thank you to DVK and MizardX from StackOverflow for this function!
- inline int sequence(int k) const {
- int smallv = (k > num_total_big_subsequence_numbers) & 1;
+ inline ArcsType sequence(ArcsType k) const {
+ ArcsType smallv = (k > num_total_big_subsequence_numbers) & 1;
k -= num_total_big_subsequence_numbers * smallv;
- int subsequence_length2 = subsequence_length- smallv;
- int subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv;
- int subsequence_offset = (k % subsequence_length2) * mixingCoeff;
+ ArcsType subsequence_length2 = subsequence_length- smallv;
+ ArcsType subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv;
+ ArcsType subsequence_offset = (k % subsequence_length2) * mixingCoeff;
return subsequence_offset + subsequence_num;
}
- int subsequence_length;
- int num_big_subseqiences;
- int num_total_big_subsequence_numbers;
+ ArcsType subsequence_length;
+ ArcsType num_big_subseqiences;
+ ArcsType num_total_big_subsequence_numbers;
- inline int getArcID(const Arc &arc) const
+ inline ArcsType getArcID(const Arc &arc) const
{
//int n = _arc_num-arc._id-1;
- int n = _arc_num-GR::id(arc)-1;
+ ArcsType n = _arc_num-GR::id(arc)-1;
- //int a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
- //int b = _arc_id[arc];
+ //ArcsType a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
+ //ArcsType b = _arc_id[arc];
if (_arc_mixing)
return sequence(n);
else
@@ -401,16 +406,16 @@ namespace lemon {
}
// finally unused because too slow
- inline int getSource(const int arc) const
+ inline ArcsType getSource(const ArcsType arc) const
{
- //int a = _source[arc];
+ //ArcsType a = _source[arc];
//return a;
- int n = _arc_num-arc-1;
+ ArcsType n = _arc_num-arc-1;
if (_arc_mixing)
n = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
- int b;
+ ArcsType b;
if (n>=0)
b = _node_id(_graph.source(GR::arcFromId( n ) ));
else
@@ -436,17 +441,17 @@ namespace lemon {
private:
// References to the NetworkSimplexSimple class
- const UHalfIntVector &_source;
- const UHalfIntVector &_target;
+ const IntVector &_source;
+ const IntVector &_target;
const CostVector &_cost;
const StateVector &_state;
const CostVector &_pi;
- int &_in_arc;
- int _search_arc_num;
+ ArcsType &_in_arc;
+ ArcsType _search_arc_num;
// Pivot rule data
- int _block_size;
- int _next_arc;
+ ArcsType _block_size;
+ ArcsType _next_arc;
NetworkSimplexSimple &_ns;
public:
@@ -460,17 +465,16 @@ namespace lemon {
{
// The main parameters of the pivot rule
const double BLOCK_SIZE_FACTOR = 1.0;
- const int MIN_BLOCK_SIZE = 10;
+ const ArcsType MIN_BLOCK_SIZE = 10;
- _block_size = std::max( int(BLOCK_SIZE_FACTOR *
- std::sqrt(double(_search_arc_num))),
- MIN_BLOCK_SIZE );
+ _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE);
}
+
// Find next entering arc
bool findEnteringArc() {
Cost c, min = 0;
- int e;
- int cnt = _block_size;
+ ArcsType e;
+ ArcsType cnt = _block_size;
double a;
for (e = _next_arc; e != _search_arc_num; ++e) {
c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
@@ -516,7 +520,7 @@ namespace lemon {
int _init_nb_nodes;
- long long _init_nb_arcs;
+ ArcsType _init_nb_arcs;
/// \name Parameters
/// The parameters of the algorithm can be specified using these
@@ -736,7 +740,7 @@ namespace lemon {
for (int i = 0; i != _node_num; ++i) {
_supply[i] = 0;
}
- for (int i = 0; i != _arc_num; ++i) {
+ for (ArcsType i = 0; i != _arc_num; ++i) {
_cost[i] = 1;
}
_stype = GEQ;
@@ -745,7 +749,7 @@ namespace lemon {
- int divid (int x, int y)
+ int64_t divid (int64_t x, int64_t y)
{
return (x-x%y)/y;
}
@@ -775,7 +779,7 @@ namespace lemon {
_node_num = _init_nb_nodes;
_arc_num = _init_nb_arcs;
int all_node_num = _node_num + 1;
- int max_arc_num = _arc_num + 2 * _node_num;
+ ArcsType max_arc_num = _arc_num + 2 * _node_num;
_source.resize(max_arc_num);
_target.resize(max_arc_num);
@@ -798,13 +802,13 @@ namespace lemon {
//_arc_mixing=false;
if (_arc_mixing) {
// Store the arcs in a mixed order
- int k = std::max(int(std::sqrt(double(_arc_num))), 10);
+ const ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10));
mixingCoeff = k;
subsequence_length = _arc_num / mixingCoeff + 1;
num_big_subseqiences = _arc_num % mixingCoeff;
num_total_big_subsequence_numbers = subsequence_length * num_big_subseqiences;
- int i = 0, j = 0;
+ ArcsType i = 0, j = 0;
Arc a; _graph.first(a);
for (; a != INVALID; _graph.next(a)) {
_source[i] = _node_id(_graph.source(a));
@@ -814,7 +818,7 @@ namespace lemon {
}
} else {
// Store the arcs in the original order
- int i = 0;
+ ArcsType i = 0;
Arc a; _graph.first(a);
for (; a != INVALID; _graph.next(a), ++i) {
_source[i] = _node_id(_graph.source(a));
@@ -856,7 +860,7 @@ namespace lemon {
Number totalCost() const {
Number c = 0;
for (ArcIt a(_graph); a != INVALID; ++a) {
- int i = getArcID(a);
+ int64_t i = getArcID(a);
c += Number(_flow[i]) * Number(_cost[i]);
}
return c;
@@ -867,15 +871,15 @@ namespace lemon {
Number c = 0;
/*#ifdef HASHMAP
- typename stdext::hash_map::const_iterator it;
+ typename stdext::hash_map::const_iterator it;
#else
- typename std::map::const_iterator it;
+ typename std::map::const_iterator it;
#endif
for (it = _flow.data.begin(); it!=_flow.data.end(); ++it)
c += Number(it->second) * Number(_cost[it->first]);
return c;*/
- for (unsigned long i=0; i<_flow.size(); i++)
+ for (ArcsType i=0; i<_flow.size(); i++)
c += _flow[i] * Number(_cost[i]);
return c;
@@ -944,14 +948,14 @@ namespace lemon {
// Initialize internal data structures
bool init() {
if (_node_num == 0) return false;
-
+
// Check the sum of supply values
_sum_supply = 0;
for (int i = 0; i != _node_num; ++i) {
_sum_supply += _supply[i];
}
if ( fabs(_sum_supply) > _EPSILON ) return false;
-
+
_sum_supply = 0;
// Initialize artifical cost
@@ -960,14 +964,14 @@ namespace lemon {
ART_COST = std::numeric_limits::max() / 2 + 1;
} else {
ART_COST = 0;
- for (int i = 0; i != _arc_num; ++i) {
+ for (ArcsType i = 0; i != _arc_num; ++i) {
if (_cost[i] > ART_COST) ART_COST = _cost[i];
}
ART_COST = (ART_COST + 1) * _node_num;
}
// Initialize arc maps
- for (int i = 0; i != _arc_num; ++i) {
+ for (ArcsType i = 0; i != _arc_num; ++i) {
//_flow[i] = 0; //by default, the sparse matrix is empty
_state[i] = STATE_LOWER;
}
@@ -988,7 +992,7 @@ namespace lemon {
// EQ supply constraints
_search_arc_num = _arc_num;
_all_arc_num = _arc_num + _node_num;
- for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
_parent[u] = _root;
_pred[u] = e;
_thread[u] = u + 1;
@@ -1016,8 +1020,8 @@ namespace lemon {
else if (_sum_supply > 0) {
// LEQ supply constraints
_search_arc_num = _arc_num + _node_num;
- int f = _arc_num + _node_num;
- for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ ArcsType f = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
_parent[u] = _root;
_thread[u] = u + 1;
_rev_thread[u + 1] = u;
@@ -1054,8 +1058,8 @@ namespace lemon {
else {
// GEQ supply constraints
_search_arc_num = _arc_num + _node_num;
- int f = _arc_num + _node_num;
- for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ ArcsType f = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
_parent[u] = _root;
_thread[u] = u + 1;
_rev_thread[u + 1] = u;
@@ -1120,9 +1124,9 @@ namespace lemon {
second = _source[in_arc];
}
delta = INF;
- int result = 0;
+ char result = 0;
Value d;
- int e;
+ ArcsType e;
// Search the cycle along the path form the first node to the root
for (int u = first; u != join; u = _parent[u]) {
@@ -1239,7 +1243,7 @@ namespace lemon {
// Update _rev_thread using the new _thread values
for (int i = 0; i != int(_dirty_revs.size()); ++i) {
- u = _dirty_revs[i];
+ int u = _dirty_revs[i];
_rev_thread[_thread[u]] = u;
}
@@ -1257,7 +1261,7 @@ namespace lemon {
u = w;
}
_pred[u_in] = in_arc;
- _forward[u_in] = ((unsigned int)u_in == _source[in_arc]);
+ _forward[u_in] = (u_in == _source[in_arc]);
_succ_num[u_in] = old_succ_num;
// Set limits for updating _last_succ form v_in and v_out
@@ -1328,7 +1332,7 @@ namespace lemon {
if (_sum_supply > 0) total -= _sum_supply;
if (total <= 0) return true;
- IntVector arc_vector;
+ ArcVector arc_vector;
if (_sum_supply >= 0) {
if (supply_nodes.size() == 1 && demand_nodes.size() == 1) {
// Perform a reverse graph search from the sink to the source
@@ -1345,7 +1349,7 @@ namespace lemon {
Arc a; _graph.firstIn(a, v);
for (; a != INVALID; _graph.nextIn(a)) {
if (reached[u = _graph.source(a)]) continue;
- int j = getArcID(a);
+ ArcsType j = getArcID(a);
if (INF >= total) {
arc_vector.push_back(j);
reached[u] = true;
@@ -1355,7 +1359,7 @@ namespace lemon {
}
} else {
// Find the min. cost incomming arc for each demand node
- for (int i = 0; i != int(demand_nodes.size()); ++i) {
+ for (int i = 0; i != demand_nodes.size(); ++i) {
Node v = demand_nodes[i];
Cost c, min_cost = std::numeric_limits::max();
Arc min_arc = INVALID;
@@ -1393,7 +1397,7 @@ namespace lemon {
}
// Perform heuristic initial pivots
- for (int i = 0; i != int(arc_vector.size()); ++i) {
+ for (ArcsType i = 0; i != arc_vector.size(); ++i) {
in_arc = arc_vector[i];
// l'erreur est probablement ici...
if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -
@@ -1423,7 +1427,7 @@ namespace lemon {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
- int iter_number=0;
+ size_t iter_number=0;
//pivot.setDantzig(true);
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
@@ -1443,7 +1447,7 @@ namespace lemon {
double a;
a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
- for (int i=0; i<_flow.size(); i++) {
+ for (int64_t i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
}
std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
@@ -1482,12 +1486,12 @@ namespace lemon {
double a;
a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
- for (int i=0; i<_flow.size(); i++) {
+ for (int64_t i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
}
-
+
std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
-
+
std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
@@ -1505,7 +1509,7 @@ namespace lemon {
#endif
// Check feasibility
if( retVal == OPTIMAL){
- for (int e = _search_arc_num; e != _all_arc_num; ++e) {
+ for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) {
if (_flow[e] != 0){
if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126
return INFEASIBLE;
@@ -1521,20 +1525,20 @@ namespace lemon {
if (_sum_supply == 0) {
if (_stype == GEQ) {
Cost max_pot = -std::numeric_limits::max();
- for (int i = 0; i != _node_num; ++i) {
+ for (ArcsType i = 0; i != _node_num; ++i) {
if (_pi[i] > max_pot) max_pot = _pi[i];
}
if (max_pot > 0) {
- for (int i = 0; i != _node_num; ++i)
+ for (ArcsType i = 0; i != _node_num; ++i)
_pi[i] -= max_pot;
}
} else {
Cost min_pot = std::numeric_limits::max();
- for (int i = 0; i != _node_num; ++i) {
+ for (ArcsType i = 0; i != _node_num; ++i) {
if (_pi[i] < min_pot) min_pot = _pi[i];
}
if (min_pot < 0) {
- for (int i = 0; i != _node_num; ++i)
+ for (ArcsType i = 0; i != _node_num; ++i)
_pi[i] -= min_pot;
}
}
@@ -1548,5 +1552,3 @@ namespace lemon {
///@}
} //namespace lemon
-
-#endif //LEMON_NETWORK_SIMPLEX_H
diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h
new file mode 100644
index 0000000..87e4c05
--- /dev/null
+++ b/ot/lp/network_simplex_simple_omp.h
@@ -0,0 +1,1699 @@
+/* -*- mode: C++; indent-tabs-mode: nil; -*-
+*
+*
+* This file has been adapted by Nicolas Bonneel (2013),
+* from network_simplex.h from LEMON, a generic C++ optimization library,
+* to implement a lightweight network simplex for mass transport, more
+* memory efficient than the original file. A previous version of this file
+* is used as part of the Displacement Interpolation project,
+* Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/
+*
+* Revisions:
+* March 2015: added OpenMP parallelization
+* March 2017: included Antoine Rolet's trick to make it more robust
+* April 2018: IMPORTANT bug fix + uses 64bit integers (slightly slower but less risks of overflows), updated to a newer version of the algo by LEMON, sparse flow by default + minor edits.
+*
+*
+**** Original file Copyright Notice :
+*
+* Copyright (C) 2003-2010
+* Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport
+* (Egervary Research Group on Combinatorial Optimization, EGRES).
+*
+* Permission to use, modify and distribute this software is granted
+* provided that this copyright notice appears in all copies. For
+* precise terms see the accompanying LICENSE file.
+*
+* This software is provided "AS IS" with no warranty of any kind,
+* express or implied, and with no claim as to its suitability for any
+* purpose.
+*
+*/
+
+#pragma once
+#undef DEBUG_LVL
+#define DEBUG_LVL 0
+
+#if DEBUG_LVL>0
+#include
+#endif
+
+#undef EPSILON
+#undef _EPSILON
+#undef MAX_DEBUG_ITER
+#define EPSILON std::numeric_limits::epsilon()*10
+#define _EPSILON 1e-8
+#define MAX_DEBUG_ITER 100000
+
+/// \ingroup min_cost_flow_algs
+///
+/// \file
+/// \brief Network Simplex algorithm for finding a minimum cost flow.
+
+// if your compiler has troubles with unorderedmaps, just comment the following line to use a slower std::map instead
+#define HASHMAP // now handled with unorderedmaps instead of stdext::hash_map. Should be better supported.
+
+#define SPARSE_FLOW // a sparse flow vector will be 10-15% slower for small problems but uses less memory and becomes faster for large problems (40k total nodes)
+
+#include
+#include
+#include
+#include
+#ifdef HASHMAP
+#include
+#else
+#include