From 3302cd48cdcc5d4832997bae921952cc3917fb59 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 23 Feb 2022 09:53:13 +0100 Subject: [MRG] Build POT against oldest-supported-numpy (local PR) (#349) * Configure setup to compile against oldest supported numpy version using the meta-package: https://pypi.org/project/oldest-supported-numpy/ - * Set minimum Python requirement to `>=3.7` in setup.py since !328 removed Python 3.6 support * Fix typo in pyproject.toml - * Update setup.py * Update setup.py and * build wheels * remove install dependencies for wheels building and build wheels * Apply suggestions from code review Co-authored-by: David M. Ghiurco <9147386+davidghiurco@users.noreply.github.com> * correct timing test add info in release file and build wheels * pep8 and Co-authored-by: David Ghiurco <9147386+davidghiurco@users.noreply.github.com> --- .github/workflows/build_wheels.yml | 8 +------- .github/workflows/build_wheels_weekly.yml | 2 -- 2 files changed, 1 insertion(+), 9 deletions(-) (limited to '.github') diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index c746eb8..475058c 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -27,8 +27,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install -U "cython" - name: Install cibuildwheel run: | @@ -37,7 +35,6 @@ jobs: - name: Build wheels env: CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp36*" # remove pypy on mac and win (wrong version) - CIBW_BEFORE_BUILD: "pip install numpy cython" run: | python -m cibuildwheel --output-dir wheelhouse @@ -65,8 +62,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install -U "cython" - name: Install cibuildwheel run: | @@ -80,8 +75,7 @@ jobs: - name: Build wheels env: - CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl* cp36*" # remove pypy on mac and win (wrong version) - CIBW_BEFORE_BUILD: "pip install numpy cython" + CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl*" # remove pypy on mac and win (wrong version) CIBW_ARCHS_LINUX: auto aarch64 # force aarch64 with QEMU CIBW_ARCHS_MACOS: x86_64 universal2 arm64 run: | diff --git a/.github/workflows/build_wheels_weekly.yml b/.github/workflows/build_wheels_weekly.yml index dbf342f..b9154c5 100644 --- a/.github/workflows/build_wheels_weekly.yml +++ b/.github/workflows/build_wheels_weekly.yml @@ -26,8 +26,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install -U "cython" - name: Install cibuildwheel run: | -- cgit v1.2.3 From ad02112d4288f3efdd5bc6fc6e45444313bba871 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 5 Apr 2022 11:57:10 +0200 Subject: [MRG] Update examples in the doc (#359) * add transparent color logo * add transparent color logo * move screenkhorn * move stochastic and install ffmpeg on circleci * try something * add sudo * install ffmpeg before python * cleanup examples * test svg scrapper * add animation for reg path * better example OT sivergence * update ttles and add plots * update free support * proper figure indexes * have less frame sin animation * update readme and release file * add tests for python 3.10 --- .circleci/config.yml | 7 + .github/workflows/build_tests.yml | 6 +- README.md | 10 +- RELEASES.md | 3 +- docs/source/_static/images/logo.png | Bin 5038 -> 4325 bytes docs/source/_static/images/logo.svg | 174 +++++++++---------- docs/source/conf.py | 6 + .../backends/plot_sliced_wass_grad_flow_pytorch.py | 2 + examples/backends/plot_wass1d_torch.py | 8 +- .../barycenters/plot_free_support_barycenter.py | 55 +++--- examples/others/plot_logo.py | 8 +- examples/others/plot_screenkhorn_1D.py | 71 ++++++++ examples/others/plot_stochastic.py | 189 +++++++++++++++++++++ examples/plot_OT_1D.py | 12 +- examples/plot_OT_1D_smooth.py | 6 +- examples/plot_OT_2D_samples.py | 2 +- examples/plot_OT_L1_vs_L2.py | 32 ++-- examples/plot_compute_emd.py | 72 +++++--- examples/plot_optim_OTreg.py | 38 ++++- examples/plot_screenkhorn_1D.py | 71 -------- examples/plot_stochastic.py | 189 --------------------- examples/sliced-wasserstein/README.txt | 2 +- examples/sliced-wasserstein/plot_variance.py | 8 +- examples/unbalanced-partial/plot_UOT_1D.py | 17 +- examples/unbalanced-partial/plot_regpath.py | 88 +++++++++- 25 files changed, 628 insertions(+), 448 deletions(-) create mode 100644 examples/others/plot_screenkhorn_1D.py create mode 100644 examples/others/plot_stochastic.py delete mode 100644 examples/plot_screenkhorn_1D.py delete mode 100644 examples/plot_stochastic.py (limited to '.github') diff --git a/.circleci/config.yml b/.circleci/config.yml index 39c19fb..77ab45c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -35,6 +35,12 @@ jobs: - data-cache-0 - pip-cache + - run: + name: Install ffmpeg + command: | + sudo apt update + sudo apt install ffmpeg + - run: name: Get Python running command: | @@ -50,6 +56,7 @@ jobs: paths: - ~/.cache/pip + # Look at what we have and fail early if there is some library conflict - run: name: Check installation diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 3c99da8..ce725c6 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -22,7 +22,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v1 @@ -93,7 +93,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v1 @@ -120,7 +120,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v1 diff --git a/README.md b/README.md index ec5d221..0c3bd19 100644 --- a/README.md +++ b/README.md @@ -29,8 +29,11 @@ POT provides the following generic OT solvers (links to examples): * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] -* [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) -* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] +* [Stochastic + solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and + [differentiable losses](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html) for + Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] @@ -119,7 +122,7 @@ Note that for easier access the module is named `ot` instead of `pot`. ### Dependencies -Some sub-modules require additional dependences which are discussed below +Some sub-modules require additional dependencies which are discussed below * **ot.dr** (Wasserstein dimensionality reduction) depends on autograd and pymanopt that can be installed with: @@ -127,7 +130,6 @@ Some sub-modules require additional dependences which are discussed below pip install pymanopt autograd ``` -* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. Note that this module is deprecated since version 0.8 and will be deleted in the future. GPU is now handled automatically through the backends and several solver already can run on GPU using the Pytorch backend. ## Examples diff --git a/RELEASES.md b/RELEASES.md index 45336f7..7d458f3 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features +- Update examples in the gallery (PR #359). - Add stochastic loss and OT plan computation for regularized OT and backend examples(PR #360). - Implementation of factored OT with emd and sinkhorn (PR #358). @@ -254,7 +255,7 @@ are coming for the next versions. #### Closed issues -- Add JMLR paper to teh readme ad Mathieu Blondel to the Acknoledgments (PR +- Add JMLR paper to the readme and Mathieu Blondel to the Acknoledgments (PR #231, #232) - Bug in Unbalanced OT example (Issue #127) - Clean Cython output when calling setup.py clean (Issue #122) diff --git a/docs/source/_static/images/logo.png b/docs/source/_static/images/logo.png index 7be5df7..2dd6f65 100644 Binary files a/docs/source/_static/images/logo.png and b/docs/source/_static/images/logo.png differ diff --git a/docs/source/_static/images/logo.svg b/docs/source/_static/images/logo.svg index 0bf2cb7..39fe900 100644 --- a/docs/source/_static/images/logo.svg +++ b/docs/source/_static/images/logo.svg @@ -1,24 +1,23 @@ - - + - + - 2022-03-17T17:25:30.736761 + 2022-03-30T17:25:32.476826 image/svg+xml - Matplotlib v3.3.3, https://matplotlib.org/ + Matplotlib v3.5.1, https://matplotlib.org/ - + @@ -26,103 +25,104 @@ L 209.7 75.384 L 209.7 0 L 0 0 +L 0 75.384 z -" style="fill:#ffffff;"/> +" style="fill: none"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" style="stroke: #000000"/> - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + - +" style="stroke: #000000"/> - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + - - + + diff --git a/docs/source/conf.py b/docs/source/conf.py index 60d0bb7..9526518 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -17,9 +17,15 @@ import os import re try: import sphinx_gallery + except ImportError: print("warning sphinx-gallery not installed") + + + + + # !!!! allow readthedoc compilation try: from unittest.mock import MagicMock diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py index 05b9952..cf5d64d 100644 --- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -27,6 +27,8 @@ Machine Learning (pp. 4104-4113). PMLR. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + # %% # Loading the data diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py index 0abdd6d..cd8e2fd 100644 --- a/examples/backends/plot_wass1d_torch.py +++ b/examples/backends/plot_wass1d_torch.py @@ -1,9 +1,9 @@ r""" -================================= -Wasserstein 1D with PyTorch -================================= +================================================= +Wasserstein 1D (flow and barycenter) with PyTorch +================================================= -In this small example, we consider the following minization problem: +In this small example, we consider the following minimization problem: .. math:: \mu^* = \min_\mu W(\mu,\nu) diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py index 2d68a39..226dfeb 100644 --- a/examples/barycenters/plot_free_support_barycenter.py +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -9,61 +9,62 @@ sum of diracs. """ -# Author: Vivien Seguy +# Authors: Vivien Seguy +# Rémi Flamary # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import numpy as np import matplotlib.pylab as pl import ot -############################################################################## +# %% # Generate data # ------------- -N = 3 +N = 2 d = 2 -measures_locations = [] -measures_weights = [] - -for i in range(N): - n_i = np.random.randint(low=1, high=20) # nb samples +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] +I2 = pl.imread('../../data/duck.png').astype(np.float64)[::4, ::4, 2] - mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean +sz = I2.shape[0] +XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) - A_i = np.random.rand(d, d) - cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix +x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0 +x2 = np.stack((XX[I2 == 0] + 80, -YY[I2 == 0] + 32), 1) * 1.0 +x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0 - x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations - b_i = np.random.uniform(0., 1., (n_i,)) - b_i = b_i / np.sum(b_i) # Dirac weights +measures_locations = [x1, x2] +measures_weights = [ot.unif(x1.shape[0]), ot.unif(x2.shape[0])] - measures_locations.append(x_i) - measures_weights.append(b_i) +pl.figure(1, (12, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) +pl.title('Distributions') -############################################################################## +# %% # Compute free support barycenter # ------------------------------- -k = 10 # number of Diracs of the barycenter +k = 200 # number of Diracs of the barycenter X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) - -############################################################################## -# Plot data +# %% +# Plot the barycenter # --------- -pl.figure(1) -for (x_i, b_i) in zip(measures_locations, measures_weights): - color = np.random.randint(low=1, high=10 * N) - pl.scatter(x_i[:, 0], x_i[:, 1], s=b_i * 1000, label='input measure') -pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter') +pl.figure(2, (8, 3)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter') pl.title('Data measures and their barycenter') -pl.legend(loc=0) +pl.legend(loc="lower right") pl.show() diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py index afddcad..9414371 100644 --- a/examples/others/plot_logo.py +++ b/examples/others/plot_logo.py @@ -7,8 +7,8 @@ Logo of the POT toolbox In this example we plot the logo of the POT toolbox. -A specificity of this logo is that it is done 100% in Python and generated using -matplotlib using the EMD solver from POT. +This logo is that it is done 100% in Python and generated using +matplotlib and ploting teh solution of the EMD solver from POT. """ @@ -86,8 +86,8 @@ pl.axis('equal') pl.axis('off') # Save logo file -# pl.savefig('logo.svg', dpi=150, bbox_inches='tight') -# pl.savefig('logo.png', dpi=150, bbox_inches='tight') +# pl.savefig('logo.svg', dpi=150, transparent=True, bbox_inches='tight') +# pl.savefig('logo.png', dpi=150, transparent=True, bbox_inches='tight') # %% # Plot the logo (dark background) diff --git a/examples/others/plot_screenkhorn_1D.py b/examples/others/plot_screenkhorn_1D.py new file mode 100644 index 0000000..2023649 --- /dev/null +++ b/examples/others/plot_screenkhorn_1D.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +""" +======================================== +Screened optimal transport (Screenkhorn) +======================================== + +This example illustrates the computation of Screenkhorn [26]. + +[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). +Screening Sinkhorn Algorithm for Regularized Optimal Transport, +Advances in Neural Information Processing Systems 33 (NeurIPS). +""" + +# Author: Mokhtar Z. Alaya +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot.plot +from ot.datasets import make_1D_gauss as gauss +from ot.bregman import screenkhorn + +############################################################################## +# Generate data +# ------------- + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +# plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + +############################################################################## +# Solve Screenkhorn +# ----------------------- + +# Screenkhorn +lambd = 2e-03 # entropy parameter +ns_budget = 30 # budget number of points to be keeped in the source distribution +nt_budget = 30 # budget number of points to be keeped in the target distribution + +G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True) +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn') +pl.show() diff --git a/examples/others/plot_stochastic.py b/examples/others/plot_stochastic.py new file mode 100644 index 0000000..3a1ef31 --- /dev/null +++ b/examples/others/plot_stochastic.py @@ -0,0 +1,189 @@ +""" +=================== +Stochastic examples +=================== + +This example is designed to show how to use the stochatic optimization +algorithms for discrete and semi-continuous measures from the POT library. + +[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. +Stochastic Optimization for Large-scale Optimal Transport. +Advances in Neural Information Processing Systems (2016). + +[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & +Blondel, M. Large-scale Optimal Transport and Mapping Estimation. +International Conference on Learning Representation (2018) + +""" + +# Author: Kilian Fatras +# +# License: MIT License + +import matplotlib.pylab as pl +import numpy as np +import ot +import ot.plot + + +############################################################################# +# Compute the Transportation Matrix for the Semi-Dual Problem +# ----------------------------------------------------------- +# +# Discrete case +# ````````````` +# +# Sample two discrete measures for the discrete case and compute their cost +# matrix c. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 1000 + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# Call the "SAG" method to find the transportation matrix in the discrete case + +method = "SAG" +sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, + numItermax) +print(sag_pi) + +############################################################################# +# Semi-Continuous Case +# ```````````````````` +# +# Sample one general measure a, one discrete measures b for the semicontinous +# case, the points where source and target measures are defined and compute the +# cost matrix. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 1000 +log = True + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# Call the "ASGD" method to find the transportation matrix in the semicontinous +# case. + +method = "ASGD" +asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, + numItermax, log=log) +print(log_asgd['alpha'], log_asgd['beta']) +print(asgd_pi) + +############################################################################# +# Compare the results with the Sinkhorn algorithm + +sinkhorn_pi = ot.sinkhorn(a, b, M, reg) +print(sinkhorn_pi) + + +############################################################################## +# Plot Transportation Matrices +# ```````````````````````````` +# +# For SAG + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG') +pl.show() + + +############################################################################## +# For ASGD + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD') +pl.show() + + +############################################################################## +# For Sinkhorn + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') +pl.show() + + +############################################################################# +# Compute the Transportation Matrix for the Dual Problem +# ------------------------------------------------------ +# +# Semi-continuous case +# ```````````````````` +# +# Sample one general measure a, one discrete measures b for the semi-continuous +# case and compute the cost matrix c. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 100000 +lr = 0.1 +batch_size = 3 +log = True + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# +# Call the "SGD" dual method to find the transportation matrix in the +# semi-continuous case + +sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, + batch_size, numItermax, + lr, log=log) +print(log_sgd['alpha'], log_sgd['beta']) +print(sgd_dual_pi) + +############################################################################# +# +# Compare the results with the Sinkhorn algorithm +# ``````````````````````````````````````````````` +# +# Call the Sinkhorn algorithm from POT + +sinkhorn_pi = ot.sinkhorn(a, b, M, reg) +print(sinkhorn_pi) + +############################################################################## +# Plot Transportation Matrices +# ```````````````````````````` +# +# For SGD + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD') +pl.show() + + +############################################################################## +# For Sinkhorn + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') +pl.show() diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py index 15ead96..62f0b7d 100644 --- a/examples/plot_OT_1D.py +++ b/examples/plot_OT_1D.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -==================== -1D optimal transport -==================== +====================================== +Optimal Transport for 1D distributions +====================================== This example illustrates the computation of EMD and Sinkhorn transport plans and their visualization. @@ -64,7 +64,11 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') #%% EMD -G0 = ot.emd(a, b, M) +# use fast 1D solver +G0 = ot.emd_1d(x, x, a, b) + +# Equivalent to +# G0 = ot.emd(a, b, M) pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index b07f99f..5415e4f 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -=========================== -1D smooth optimal transport -=========================== +================================ +Smooth optimal transport example +================================ This example illustrates the computation of EMD, Sinkhorn and smooth OT plans and their visualization. diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index c3a7cd8..1d82fb8 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ ==================================================== -2D Optimal transport between empirical distributions +Optimal Transport between 2D empirical distributions ==================================================== Illustration of 2D optimal transport between discributions that are weighted diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py index cb94574..cce51f8 100644 --- a/examples/plot_OT_L1_vs_L2.py +++ b/examples/plot_OT_L1_vs_L2.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """ -========================================== -2D Optimal transport for different metrics -========================================== +================================================ +Optimal Transport with different gournd metrics +================================================ -2D OT on empirical distributio with different gound metric. +2D OT on empirical distributio with different ground metric. Stole the figure idea from Fig. 1 and 2 in https://arxiv.org/pdf/1706.07650.pdf @@ -23,7 +23,7 @@ import matplotlib.pylab as pl import ot import ot.plot -############################################################################## +# %% # Dataset 1 : uniform sampling # ---------------------------- @@ -46,7 +46,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean') M2 /= M2.max() # loss matrix -Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean')) +Mp = ot.dist(xs, xt, metric='cityblock') Mp /= Mp.max() # Data @@ -71,7 +71,7 @@ pl.title('Squared Euclidean cost') pl.subplot(1, 3, 3) pl.imshow(Mp, interpolation='nearest') -pl.title('Sqrt Euclidean cost') +pl.title('L1 (cityblock cost') pl.tight_layout() ############################################################################## @@ -109,22 +109,22 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.axis('equal') # pl.legend(loc=0) -pl.title('OT sqrt Euclidean') +pl.title('OT L1 (cityblock)') pl.tight_layout() pl.show() -############################################################################## +# %% # Dataset 2 : Partial circle # -------------------------- -n = 50 # nb samples +n = 20 # nb samples xtot = np.zeros((n + 1, 2)) xtot[:, 0] = np.cos( - (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi) + (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) xtot[:, 1] = np.sin( - (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi) + (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) xs = xtot[:n, :] xt = xtot[1:, :] @@ -140,7 +140,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean') M2 /= M2.max() # loss matrix -Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean')) +Mp = ot.dist(xs, xt, metric='cityblock') Mp /= Mp.max() @@ -166,13 +166,13 @@ pl.title('Squared Euclidean cost') pl.subplot(1, 3, 3) pl.imshow(Mp, interpolation='nearest') -pl.title('Sqrt Euclidean cost') +pl.title('L1 (cityblock) cost') pl.tight_layout() ############################################################################## # Dataset 2 : Plot OT Matrices # ----------------------------- - +# #%% EMD G1 = ot.emd(a, b, M1) @@ -204,7 +204,7 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.axis('equal') # pl.legend(loc=0) -pl.title('OT sqrt Euclidean') +pl.title('OT L1 (cityblock)') pl.tight_layout() pl.show() diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index 527a847..36cc7da 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """ -================= -Plot multiple EMD -================= +================== +OT distances in 1D +================== -Shows how to compute multiple EMD and Sinkhorn with two different +Shows how to compute multiple Wassersein and Sinkhorn with two different ground metrics and plot their values for different distributions. @@ -14,7 +14,7 @@ ground metrics and plot their values for different distributions. # # License: MIT License -# sphinx_gallery_thumbnail_number = 3 +# sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl @@ -29,7 +29,7 @@ from ot.datasets import make_1D_gauss as gauss #%% parameters n = 100 # nb bins -n_target = 50 # nb target distributions +n_target = 20 # nb target distributions # bin positions @@ -47,9 +47,9 @@ for i, m in enumerate(lst_m): # loss matrix and normalization M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean') -M /= M.max() +M /= M.max() * 0.1 M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean') -M2 /= M2.max() +M2 /= M2.max() * 0.1 ############################################################################## # Plot data @@ -59,10 +59,12 @@ M2 /= M2.max() pl.figure(1) pl.subplot(2, 1, 1) -pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, a, 'r', label='Source distribution') pl.title('Source distribution') pl.subplot(2, 1, 2) -pl.plot(x, B, label='Target distributions') +for i in range(n_target): + pl.plot(x, B[:, i], 'b', alpha=i / n_target) +pl.plot(x, B[:, -1], 'b', label='Target distributions') pl.title('Target distributions') pl.tight_layout() @@ -73,14 +75,27 @@ pl.tight_layout() #%% Compute and plot distributions and loss matrix -d_emd = ot.emd2(a, B, M) # direct computation of EMD -d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2 - +d_emd = ot.emd2(a, B, M) # direct computation of OT loss +d_emd2 = ot.emd2(a, B, M2) # direct computation of OT loss with metrixc M2 +d_tv = [np.sum(abs(a - B[:, i])) for i in range(n_target)] pl.figure(2) -pl.plot(d_emd, label='Euclidean EMD') -pl.plot(d_emd2, label='Squared Euclidean EMD') -pl.title('EMD distances') +pl.subplot(2, 1, 1) +pl.plot(x, a, 'r', label='Source distribution') +pl.title('Distributions') +for i in range(n_target): + pl.plot(x, B[:, i], 'b', alpha=i / n_target) +pl.plot(x, B[:, -1], 'b', label='Target distributions') +pl.ylim((-.01, 0.13)) +pl.xticks(()) +pl.legend() +pl.subplot(2, 1, 2) +pl.plot(d_emd, label='Euclidean OT') +pl.plot(d_emd2, label='Squared Euclidean OT') +pl.plot(d_tv, label='Total Variation (TV)') +#pl.xlim((-7,23)) +pl.xlabel('Displacement') +pl.title('Divergences') pl.legend() ############################################################################## @@ -88,17 +103,30 @@ pl.legend() # ----------------------------------------- #%% -reg = 1e-2 +reg = 1e-1 d_sinkhorn = ot.sinkhorn2(a, B, M, reg) d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg) -pl.figure(2) +pl.figure(3) pl.clf() -pl.plot(d_emd, label='Euclidean EMD') -pl.plot(d_emd2, label='Squared Euclidean EMD') + +pl.subplot(2, 1, 1) +pl.plot(x, a, 'r', label='Source distribution') +pl.title('Distributions') +for i in range(n_target): + pl.plot(x, B[:, i], 'b', alpha=i / n_target) +pl.plot(x, B[:, -1], 'b', label='Target distributions') +pl.ylim((-.01, 0.13)) +pl.xticks(()) +pl.legend() +pl.subplot(2, 1, 2) +pl.plot(d_emd, label='Euclidean OT') +pl.plot(d_emd2, label='Squared Euclidean OT') pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn') pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn') -pl.title('EMD distances') +pl.plot(d_tv, label='Total Variation (TV)') +#pl.xlim((-7,23)) +pl.xlabel('Displacement') +pl.title('Divergences') pl.legend() - pl.show() diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py index 5eb15bd..7b021d2 100644 --- a/examples/plot_optim_OTreg.py +++ b/examples/plot_optim_OTreg.py @@ -24,7 +24,7 @@ arXiv preprint arXiv:1510.06567. """ -# sphinx_gallery_thumbnail_number = 4 +# sphinx_gallery_thumbnail_number = 5 import numpy as np import matplotlib.pylab as pl @@ -58,7 +58,7 @@ M /= M.max() G0 = ot.emd(a, b, M) -pl.figure(3, figsize=(5, 5)) +pl.figure(1, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') ############################################################################## @@ -80,7 +80,7 @@ reg = 1e-1 Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True) -pl.figure(3) +pl.figure(2) ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg') ############################################################################## @@ -102,7 +102,7 @@ reg = 1e-3 Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True) -pl.figure(4, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg') ############################################################################## @@ -125,6 +125,34 @@ reg2 = 1e-1 Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True) -pl.figure(5, figsize=(5, 5)) +pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg') pl.show() + + +# %% +# Comparison of the OT matrices + +nvisu = 40 + +pl.figure(5, figsize=(10, 4)) + +pl.subplot(2, 2, 1) +pl.imshow(G0[:nvisu, :]) +pl.axis('off') +pl.title('Exact OT') + +pl.subplot(2, 2, 2) +pl.imshow(Gl2[:nvisu, :]) +pl.axis('off') +pl.title('Frobenius reg.') + +pl.subplot(2, 2, 3) +pl.imshow(Ge[:nvisu, :]) +pl.axis('off') +pl.title('Entropic reg.') + +pl.subplot(2, 2, 4) +pl.imshow(Gel2[:nvisu, :]) +pl.axis('off') +pl.title('Entropic + Frobenius reg.') diff --git a/examples/plot_screenkhorn_1D.py b/examples/plot_screenkhorn_1D.py deleted file mode 100644 index 785642a..0000000 --- a/examples/plot_screenkhorn_1D.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- -""" -=============================== -1D Screened optimal transport -=============================== - -This example illustrates the computation of Screenkhorn [26]. - -[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). -Screening Sinkhorn Algorithm for Regularized Optimal Transport, -Advances in Neural Information Processing Systems 33 (NeurIPS). -""" - -# Author: Mokhtar Z. Alaya -# -# License: MIT License - -import numpy as np -import matplotlib.pylab as pl -import ot.plot -from ot.datasets import make_1D_gauss as gauss -from ot.bregman import screenkhorn - -############################################################################## -# Generate data -# ------------- - -#%% parameters - -n = 100 # nb bins - -# bin positions -x = np.arange(n, dtype=np.float64) - -# Gaussian distributions -a = gauss(n, m=20, s=5) # m= mean, s= std -b = gauss(n, m=60, s=10) - -# loss matrix -M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) -M /= M.max() - -############################################################################## -# Plot distributions and loss matrix -# ---------------------------------- - -#%% plot the distributions - -pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, a, 'b', label='Source distribution') -pl.plot(x, b, 'r', label='Target distribution') -pl.legend() - -# plot distributions and loss matrix - -pl.figure(2, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') - -############################################################################## -# Solve Screenkhorn -# ----------------------- - -# Screenkhorn -lambd = 2e-03 # entropy parameter -ns_budget = 30 # budget number of points to be keeped in the source distribution -nt_budget = 30 # budget number of points to be keeped in the target distribution - -G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True) -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn') -pl.show() diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py deleted file mode 100644 index 3a1ef31..0000000 --- a/examples/plot_stochastic.py +++ /dev/null @@ -1,189 +0,0 @@ -""" -=================== -Stochastic examples -=================== - -This example is designed to show how to use the stochatic optimization -algorithms for discrete and semi-continuous measures from the POT library. - -[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. -Stochastic Optimization for Large-scale Optimal Transport. -Advances in Neural Information Processing Systems (2016). - -[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & -Blondel, M. Large-scale Optimal Transport and Mapping Estimation. -International Conference on Learning Representation (2018) - -""" - -# Author: Kilian Fatras -# -# License: MIT License - -import matplotlib.pylab as pl -import numpy as np -import ot -import ot.plot - - -############################################################################# -# Compute the Transportation Matrix for the Semi-Dual Problem -# ----------------------------------------------------------- -# -# Discrete case -# ````````````` -# -# Sample two discrete measures for the discrete case and compute their cost -# matrix c. - -n_source = 7 -n_target = 4 -reg = 1 -numItermax = 1000 - -a = ot.utils.unif(n_source) -b = ot.utils.unif(n_target) - -rng = np.random.RandomState(0) -X_source = rng.randn(n_source, 2) -Y_target = rng.randn(n_target, 2) -M = ot.dist(X_source, Y_target) - -############################################################################# -# Call the "SAG" method to find the transportation matrix in the discrete case - -method = "SAG" -sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, - numItermax) -print(sag_pi) - -############################################################################# -# Semi-Continuous Case -# ```````````````````` -# -# Sample one general measure a, one discrete measures b for the semicontinous -# case, the points where source and target measures are defined and compute the -# cost matrix. - -n_source = 7 -n_target = 4 -reg = 1 -numItermax = 1000 -log = True - -a = ot.utils.unif(n_source) -b = ot.utils.unif(n_target) - -rng = np.random.RandomState(0) -X_source = rng.randn(n_source, 2) -Y_target = rng.randn(n_target, 2) -M = ot.dist(X_source, Y_target) - -############################################################################# -# Call the "ASGD" method to find the transportation matrix in the semicontinous -# case. - -method = "ASGD" -asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, - numItermax, log=log) -print(log_asgd['alpha'], log_asgd['beta']) -print(asgd_pi) - -############################################################################# -# Compare the results with the Sinkhorn algorithm - -sinkhorn_pi = ot.sinkhorn(a, b, M, reg) -print(sinkhorn_pi) - - -############################################################################## -# Plot Transportation Matrices -# ```````````````````````````` -# -# For SAG - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG') -pl.show() - - -############################################################################## -# For ASGD - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD') -pl.show() - - -############################################################################## -# For Sinkhorn - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') -pl.show() - - -############################################################################# -# Compute the Transportation Matrix for the Dual Problem -# ------------------------------------------------------ -# -# Semi-continuous case -# ```````````````````` -# -# Sample one general measure a, one discrete measures b for the semi-continuous -# case and compute the cost matrix c. - -n_source = 7 -n_target = 4 -reg = 1 -numItermax = 100000 -lr = 0.1 -batch_size = 3 -log = True - -a = ot.utils.unif(n_source) -b = ot.utils.unif(n_target) - -rng = np.random.RandomState(0) -X_source = rng.randn(n_source, 2) -Y_target = rng.randn(n_target, 2) -M = ot.dist(X_source, Y_target) - -############################################################################# -# -# Call the "SGD" dual method to find the transportation matrix in the -# semi-continuous case - -sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, - batch_size, numItermax, - lr, log=log) -print(log_sgd['alpha'], log_sgd['beta']) -print(sgd_dual_pi) - -############################################################################# -# -# Compare the results with the Sinkhorn algorithm -# ``````````````````````````````````````````````` -# -# Call the Sinkhorn algorithm from POT - -sinkhorn_pi = ot.sinkhorn(a, b, M, reg) -print(sinkhorn_pi) - -############################################################################## -# Plot Transportation Matrices -# ```````````````````````````` -# -# For SGD - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD') -pl.show() - - -############################################################################## -# For Sinkhorn - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') -pl.show() diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt index a575345..73e6122 100644 --- a/examples/sliced-wasserstein/README.txt +++ b/examples/sliced-wasserstein/README.txt @@ -1,4 +1,4 @@ Sliced Wasserstein Distance ---------------------------- \ No newline at end of file +--------------------------- diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 7d73907..f12b522 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -============================== -2D Sliced Wasserstein Distance -============================== +=============================================== +Sliced Wasserstein Distance on 2D distributions +=============================================== This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31]. @@ -16,6 +16,8 @@ measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import matplotlib.pylab as pl import numpy as np diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 183849c..06dd02d 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + import numpy as np import matplotlib.pylab as pl import ot @@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter alpha = 1. # Unbalanced KL relaxation parameter Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) -pl.figure(4, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') pl.show() + + +# %% +# plot the transported mass +# ------------------------- + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source') +pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target') +pl.legend(loc='upper right') +pl.title('Distributions and transported mass for UOT') diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py index 4a51c2d..782e8c2 100644 --- a/examples/unbalanced-partial/plot_regpath.py +++ b/examples/unbalanced-partial/plot_regpath.py @@ -15,11 +15,12 @@ penalized linear regression. # Author: Haoran Wu # License: MIT License +# sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl import ot - +import matplotlib.animation as animation ############################################################################## # Generate data # ------------- @@ -72,6 +73,9 @@ t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma, ############################################################################## # Plot the regularization path # ---------------- +# +# The OT plan is ploted as a function of $\gamma$ that is the inverse of the +# weight on the marginal relaxations. #%% fully relaxed l2-penalized UOT @@ -103,13 +107,53 @@ for p in range(4): pl.show() +# %% +# Animation of the regpath for UOT l2 +# ------------------------ + +nv = 100 +g_list_v = np.logspace(-.5, -2.5, nv) + +pl.figure(3) + + +def _update_plot(iv): + pl.clf() + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list, + t_list) + P = tp.reshape((n, n)) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.5) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), + fontsize=11) + return 1 + + +i = 0 +_update_plot(i) + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) + + ############################################################################## # Plot the semi-relaxed regularization path # ------------------- #%% semi-relaxed l2-penalized UOT -pl.figure(3) +pl.figure(4) selected_gamma = [10, 1, 1e-1, 1e-2] for p in range(4): tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2, @@ -133,3 +177,43 @@ for p in range(4): if p < 2: pl.xticks(()) pl.show() + + +# %% +# Animation of the regpath for semi-relaxed UOT l2 +# ------------------------ + +nv = 100 +g_list_v = np.logspace(2.5, -2, nv) + +pl.figure(5) + + +def _update_plot(iv): + pl.clf() + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2, + t_list2) + P = tp.reshape((n, n)) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.5) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), + fontsize=11) + return 1 + + +i = 0 +_update_plot(i) + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) -- cgit v1.2.3