summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2022-04-05 11:57:10 +0200
committerGitHub <noreply@github.com>2022-04-05 11:57:10 +0200
commitad02112d4288f3efdd5bc6fc6e45444313bba871 (patch)
treef6cd539450c2ed36cf5d7014debfd82e8b9fddfb
parent0afd84d744a472903d427e3c7ae32e55fdd7b9a7 (diff)
[MRG] Update examples in the doc (#359)
* add transparent color logo * add transparent color logo * move screenkhorn * move stochastic and install ffmpeg on circleci * try something * add sudo * install ffmpeg before python * cleanup examples * test svg scrapper * add animation for reg path * better example OT sivergence * update ttles and add plots * update free support * proper figure indexes * have less frame sin animation * update readme and release file * add tests for python 3.10
-rw-r--r--.circleci/config.yml7
-rw-r--r--.github/workflows/build_tests.yml6
-rw-r--r--README.md10
-rw-r--r--RELEASES.md3
-rw-r--r--docs/source/_static/images/logo.pngbin5038 -> 4325 bytes
-rw-r--r--docs/source/_static/images/logo.svg174
-rw-r--r--docs/source/conf.py6
-rw-r--r--examples/backends/plot_sliced_wass_grad_flow_pytorch.py2
-rw-r--r--examples/backends/plot_wass1d_torch.py8
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py55
-rw-r--r--examples/others/plot_logo.py8
-rw-r--r--examples/others/plot_screenkhorn_1D.py (renamed from examples/plot_screenkhorn_1D.py)6
-rw-r--r--examples/others/plot_stochastic.py (renamed from examples/plot_stochastic.py)0
-rw-r--r--examples/plot_OT_1D.py12
-rw-r--r--examples/plot_OT_1D_smooth.py6
-rw-r--r--examples/plot_OT_2D_samples.py2
-rw-r--r--examples/plot_OT_L1_vs_L2.py32
-rw-r--r--examples/plot_compute_emd.py72
-rw-r--r--examples/plot_optim_OTreg.py38
-rw-r--r--examples/sliced-wasserstein/README.txt2
-rw-r--r--examples/sliced-wasserstein/plot_variance.py8
-rw-r--r--examples/unbalanced-partial/plot_UOT_1D.py17
-rw-r--r--examples/unbalanced-partial/plot_regpath.py88
23 files changed, 371 insertions, 191 deletions
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 39c19fb..77ab45c 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -36,6 +36,12 @@ jobs:
- pip-cache
- run:
+ name: Install ffmpeg
+ command: |
+ sudo apt update
+ sudo apt install ffmpeg
+
+ - run:
name: Get Python running
command: |
python -m pip install --user --upgrade --progress-bar off pip
@@ -50,6 +56,7 @@ jobs:
paths:
- ~/.cache/pip
+
# Look at what we have and fail early if there is some library conflict
- run:
name: Check installation
diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index 3c99da8..ce725c6 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -22,7 +22,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
- python-version: ["3.7", "3.8", "3.9"]
+ python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v1
@@ -93,7 +93,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
- python-version: ["3.7", "3.8", "3.9"]
+ python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v1
@@ -120,7 +120,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
- python-version: ["3.7", "3.8", "3.9"]
+ python-version: ["3.7", "3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v1
diff --git a/README.md b/README.md
index ec5d221..0c3bd19 100644
--- a/README.md
+++ b/README.md
@@ -29,8 +29,11 @@ POT provides the following generic OT solvers (links to examples):
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
-* [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
-* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
+* [Stochastic
+ solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and
+ [differentiable losses](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html) for
+ Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
+* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20].
* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
@@ -119,7 +122,7 @@ Note that for easier access the module is named `ot` instead of `pot`.
### Dependencies
-Some sub-modules require additional dependences which are discussed below
+Some sub-modules require additional dependencies which are discussed below
* **ot.dr** (Wasserstein dimensionality reduction) depends on autograd and pymanopt that can be installed with:
@@ -127,7 +130,6 @@ Some sub-modules require additional dependences which are discussed below
pip install pymanopt autograd
```
-* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. Note that this module is deprecated since version 0.8 and will be deleted in the future. GPU is now handled automatically through the backends and several solver already can run on GPU using the Pytorch backend.
## Examples
diff --git a/RELEASES.md b/RELEASES.md
index 45336f7..7d458f3 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -5,6 +5,7 @@
#### New features
+- Update examples in the gallery (PR #359).
- Add stochastic loss and OT plan computation for regularized OT and
backend examples(PR #360).
- Implementation of factored OT with emd and sinkhorn (PR #358).
@@ -254,7 +255,7 @@ are coming for the next versions.
#### Closed issues
-- Add JMLR paper to teh readme ad Mathieu Blondel to the Acknoledgments (PR
+- Add JMLR paper to the readme and Mathieu Blondel to the Acknoledgments (PR
#231, #232)
- Bug in Unbalanced OT example (Issue #127)
- Clean Cython output when calling setup.py clean (Issue #122)
diff --git a/docs/source/_static/images/logo.png b/docs/source/_static/images/logo.png
index 7be5df7..2dd6f65 100644
--- a/docs/source/_static/images/logo.png
+++ b/docs/source/_static/images/logo.png
Binary files differ
diff --git a/docs/source/_static/images/logo.svg b/docs/source/_static/images/logo.svg
index 0bf2cb7..39fe900 100644
--- a/docs/source/_static/images/logo.svg
+++ b/docs/source/_static/images/logo.svg
@@ -1,24 +1,23 @@
<?xml version="1.0" encoding="utf-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
-<!-- Created with matplotlib (https://matplotlib.org/) -->
-<svg height="75.384pt" version="1.1" viewBox="0 0 209.7 75.384" width="209.7pt" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<svg xmlns:xlink="http://www.w3.org/1999/xlink" width="209.7pt" height="75.384pt" viewBox="0 0 209.7 75.384" xmlns="http://www.w3.org/2000/svg" version="1.1">
<metadata>
- <rdf:RDF xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
+ <rdf:RDF xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:cc="http://creativecommons.org/ns#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<cc:Work>
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
- <dc:date>2022-03-17T17:25:30.736761</dc:date>
+ <dc:date>2022-03-30T17:25:32.476826</dc:date>
<dc:format>image/svg+xml</dc:format>
<dc:creator>
<cc:Agent>
- <dc:title>Matplotlib v3.3.3, https://matplotlib.org/</dc:title>
+ <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>
</cc:Agent>
</dc:creator>
</cc:Work>
</rdf:RDF>
</metadata>
<defs>
- <style type="text/css">*{stroke-linecap:butt;stroke-linejoin:round;}</style>
+ <style type="text/css">*{stroke-linejoin: round; stroke-linecap: butt}</style>
</defs>
<g id="figure_1">
<g id="patch_1">
@@ -26,103 +25,104 @@
L 209.7 75.384
L 209.7 0
L 0 0
+L 0 75.384
z
-" style="fill:#ffffff;"/>
+" style="fill: none"/>
</g>
<g id="axes_1">
<g id="line2d_1">
- <path clip-path="url(#pafc522611b)" d="M 16.077273 11.885975
+ <path d="M 16.077273 11.885975
L 47.044503 11.885975
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_2">
- <path clip-path="url(#pafc522611b)" d="M 16.077273 22.208385
+ <path d="M 16.077273 22.208385
L 57.366913 22.208385
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_3">
- <path clip-path="url(#pafc522611b)" d="M 16.077273 32.530795
+ <path d="M 16.077273 32.530795
L 57.366913 32.530795
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_4">
- <path clip-path="url(#pafc522611b)" d="M 16.077273 42.853205
+ <path d="M 16.077273 42.853205
L 47.044503 42.853205
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_5">
- <path clip-path="url(#pafc522611b)" d="M 16.077273 53.175615
+ <path d="M 16.077273 53.175615
L 26.399683 53.175615
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_6">
- <path clip-path="url(#pafc522611b)" d="M 16.077273 63.498025
+ <path d="M 16.077273 63.498025
L 26.399683 63.498025
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_7">
- <path clip-path="url(#pafc522611b)" d="M 95.353383 11.885975
+ <path d="M 95.353383 11.885975
L 107.740275 11.885975
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_8">
- <path clip-path="url(#pafc522611b)" d="M 82.96649 22.208385
+ <path d="M 82.96649 22.208385
L 120.127167 22.208385
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_9">
- <path clip-path="url(#pafc522611b)" d="M 76.773044 32.530795
+ <path d="M 76.773044 32.530795
L 126.320613 32.530795
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_10">
- <path clip-path="url(#pafc522611b)" d="M 76.773044 42.853205
+ <path d="M 76.773044 42.853205
L 126.320613 42.853205
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_11">
- <path clip-path="url(#pafc522611b)" d="M 82.96649 53.175615
+ <path d="M 82.96649 53.175615
L 120.127167 53.175615
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_12">
- <path clip-path="url(#pafc522611b)" d="M 95.353383 63.498025
+ <path d="M 95.353383 63.498025
L 107.740275 63.498025
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_13">
- <path clip-path="url(#pafc522611b)" d="M 142.010677 11.885975
+ <path d="M 142.010677 11.885975
L 193.622727 11.885975
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_14">
- <path clip-path="url(#pafc522611b)" d="M 142.010677 22.208385
+ <path d="M 142.010677 22.208385
L 193.622727 22.208385
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_15">
- <path clip-path="url(#pafc522611b)" d="M 162.655497 32.530795
+ <path d="M 162.655497 32.530795
L 172.977907 32.530795
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_16">
- <path clip-path="url(#pafc522611b)" d="M 162.655497 42.853205
+ <path d="M 162.655497 42.853205
L 172.977907 42.853205
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_17">
- <path clip-path="url(#pafc522611b)" d="M 162.655497 53.175615
+ <path d="M 162.655497 53.175615
L 172.977907 53.175615
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_18">
- <path clip-path="url(#pafc522611b)" d="M 162.655497 63.498025
+ <path d="M 162.655497 63.498025
L 172.977907 63.498025
-" style="fill:none;stroke:#000000;stroke-linecap:square;stroke-opacity:0.6;stroke-width:3;"/>
+" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/>
</g>
<g id="line2d_19">
<defs>
- <path d="M 0 3
+ <path id="m5ead2df136" d="M 0 3
C 0.795609 3 1.55874 2.683901 2.12132 2.12132
C 2.683901 1.55874 3 0.795609 3 0
C 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132
@@ -132,32 +132,32 @@ C -2.683901 -1.55874 -3 -0.795609 -3 0
C -3 0.795609 -2.683901 1.55874 -2.12132 2.12132
C -1.55874 2.683901 -0.795609 3 0 3
z
-" id="m45a57d22c6" style="stroke:#000000;"/>
+" style="stroke: #000000"/>
</defs>
- <g clip-path="url(#pafc522611b)">
- <use style="fill:#d62728;stroke:#000000;" x="16.077273" xlink:href="#m45a57d22c6" y="11.885975"/>
- <use style="fill:#d62728;stroke:#000000;" x="16.077273" xlink:href="#m45a57d22c6" y="22.208385"/>
- <use style="fill:#d62728;stroke:#000000;" x="16.077273" xlink:href="#m45a57d22c6" y="32.530795"/>
- <use style="fill:#d62728;stroke:#000000;" x="16.077273" xlink:href="#m45a57d22c6" y="42.853205"/>
- <use style="fill:#d62728;stroke:#000000;" x="16.077273" xlink:href="#m45a57d22c6" y="53.175615"/>
- <use style="fill:#d62728;stroke:#000000;" x="16.077273" xlink:href="#m45a57d22c6" y="63.498025"/>
- <use style="fill:#d62728;stroke:#000000;" x="95.353383" xlink:href="#m45a57d22c6" y="11.885975"/>
- <use style="fill:#d62728;stroke:#000000;" x="82.96649" xlink:href="#m45a57d22c6" y="22.208385"/>
- <use style="fill:#d62728;stroke:#000000;" x="76.773044" xlink:href="#m45a57d22c6" y="32.530795"/>
- <use style="fill:#d62728;stroke:#000000;" x="76.773044" xlink:href="#m45a57d22c6" y="42.853205"/>
- <use style="fill:#d62728;stroke:#000000;" x="82.96649" xlink:href="#m45a57d22c6" y="53.175615"/>
- <use style="fill:#d62728;stroke:#000000;" x="95.353383" xlink:href="#m45a57d22c6" y="63.498025"/>
- <use style="fill:#d62728;stroke:#000000;" x="142.010677" xlink:href="#m45a57d22c6" y="11.885975"/>
- <use style="fill:#d62728;stroke:#000000;" x="142.010677" xlink:href="#m45a57d22c6" y="22.208385"/>
- <use style="fill:#d62728;stroke:#000000;" x="162.655497" xlink:href="#m45a57d22c6" y="32.530795"/>
- <use style="fill:#d62728;stroke:#000000;" x="162.655497" xlink:href="#m45a57d22c6" y="42.853205"/>
- <use style="fill:#d62728;stroke:#000000;" x="162.655497" xlink:href="#m45a57d22c6" y="53.175615"/>
- <use style="fill:#d62728;stroke:#000000;" x="162.655497" xlink:href="#m45a57d22c6" y="63.498025"/>
+ <g clip-path="url(#p367fff45ba)">
+ <use xlink:href="#m5ead2df136" x="16.077273" y="11.885975" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="16.077273" y="22.208385" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="16.077273" y="32.530795" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="16.077273" y="42.853205" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="16.077273" y="53.175615" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="16.077273" y="63.498025" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="95.353383" y="11.885975" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="82.96649" y="22.208385" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="76.773044" y="32.530795" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="76.773044" y="42.853205" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="82.96649" y="53.175615" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="95.353383" y="63.498025" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="142.010677" y="11.885975" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="142.010677" y="22.208385" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="162.655497" y="32.530795" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="162.655497" y="42.853205" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="162.655497" y="53.175615" style="fill: #d62728; stroke: #000000"/>
+ <use xlink:href="#m5ead2df136" x="162.655497" y="63.498025" style="fill: #d62728; stroke: #000000"/>
</g>
</g>
<g id="line2d_20">
<defs>
- <path d="M 0 3
+ <path id="m39fe4d1791" d="M 0 3
C 0.795609 3 1.55874 2.683901 2.12132 2.12132
C 2.683901 1.55874 3 0.795609 3 0
C 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132
@@ -167,34 +167,34 @@ C -2.683901 -1.55874 -3 -0.795609 -3 0
C -3 0.795609 -2.683901 1.55874 -2.12132 2.12132
C -1.55874 2.683901 -0.795609 3 0 3
z
-" id="mb6ec6fa556" style="stroke:#000000;"/>
+" style="stroke: #000000"/>
</defs>
- <g clip-path="url(#pafc522611b)">
- <use style="fill:#0000ff;stroke:#000000;" x="47.044503" xlink:href="#mb6ec6fa556" y="11.885975"/>
- <use style="fill:#0000ff;stroke:#000000;" x="57.366913" xlink:href="#mb6ec6fa556" y="32.530795"/>
- <use style="fill:#0000ff;stroke:#000000;" x="57.366913" xlink:href="#mb6ec6fa556" y="22.208385"/>
- <use style="fill:#0000ff;stroke:#000000;" x="47.044503" xlink:href="#mb6ec6fa556" y="42.853205"/>
- <use style="fill:#0000ff;stroke:#000000;" x="26.399683" xlink:href="#mb6ec6fa556" y="53.175615"/>
- <use style="fill:#0000ff;stroke:#000000;" x="26.399683" xlink:href="#mb6ec6fa556" y="63.498025"/>
- <use style="fill:#0000ff;stroke:#000000;" x="107.740275" xlink:href="#mb6ec6fa556" y="11.885975"/>
- <use style="fill:#0000ff;stroke:#000000;" x="120.127167" xlink:href="#mb6ec6fa556" y="22.208385"/>
- <use style="fill:#0000ff;stroke:#000000;" x="126.320613" xlink:href="#mb6ec6fa556" y="32.530795"/>
- <use style="fill:#0000ff;stroke:#000000;" x="126.320613" xlink:href="#mb6ec6fa556" y="42.853205"/>
- <use style="fill:#0000ff;stroke:#000000;" x="120.127167" xlink:href="#mb6ec6fa556" y="53.175615"/>
- <use style="fill:#0000ff;stroke:#000000;" x="107.740275" xlink:href="#mb6ec6fa556" y="63.498025"/>
- <use style="fill:#0000ff;stroke:#000000;" x="193.622727" xlink:href="#mb6ec6fa556" y="11.885975"/>
- <use style="fill:#0000ff;stroke:#000000;" x="193.622727" xlink:href="#mb6ec6fa556" y="22.208385"/>
- <use style="fill:#0000ff;stroke:#000000;" x="172.977907" xlink:href="#mb6ec6fa556" y="32.530795"/>
- <use style="fill:#0000ff;stroke:#000000;" x="172.977907" xlink:href="#mb6ec6fa556" y="42.853205"/>
- <use style="fill:#0000ff;stroke:#000000;" x="172.977907" xlink:href="#mb6ec6fa556" y="53.175615"/>
- <use style="fill:#0000ff;stroke:#000000;" x="172.977907" xlink:href="#mb6ec6fa556" y="63.498025"/>
+ <g clip-path="url(#p367fff45ba)">
+ <use xlink:href="#m39fe4d1791" x="47.044503" y="11.885975" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="57.366913" y="32.530795" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="57.366913" y="22.208385" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="47.044503" y="42.853205" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="26.399683" y="53.175615" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="26.399683" y="63.498025" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="107.740275" y="11.885975" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="120.127167" y="22.208385" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="126.320613" y="32.530795" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="126.320613" y="42.853205" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="120.127167" y="53.175615" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="107.740275" y="63.498025" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="193.622727" y="11.885975" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="193.622727" y="22.208385" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="172.977907" y="32.530795" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="172.977907" y="42.853205" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="172.977907" y="53.175615" style="fill: #0000ff; stroke: #000000"/>
+ <use xlink:href="#m39fe4d1791" x="172.977907" y="63.498025" style="fill: #0000ff; stroke: #000000"/>
</g>
</g>
</g>
</g>
<defs>
- <clipPath id="pafc522611b">
- <rect height="60.984" width="195.3" x="7.2" y="7.2"/>
+ <clipPath id="p367fff45ba">
+ <rect x="7.2" y="7.2" width="195.3" height="60.984"/>
</clipPath>
</defs>
</svg>
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 60d0bb7..9526518 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -17,9 +17,15 @@ import os
import re
try:
import sphinx_gallery
+
except ImportError:
print("warning sphinx-gallery not installed")
+
+
+
+
+
# !!!! allow readthedoc compilation
try:
from unittest.mock import MagicMock
diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
index 05b9952..cf5d64d 100644
--- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
+++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py
@@ -27,6 +27,8 @@ Machine Learning (pp. 4104-4113). PMLR.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
# %%
# Loading the data
diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py
index 0abdd6d..cd8e2fd 100644
--- a/examples/backends/plot_wass1d_torch.py
+++ b/examples/backends/plot_wass1d_torch.py
@@ -1,9 +1,9 @@
r"""
-=================================
-Wasserstein 1D with PyTorch
-=================================
+=================================================
+Wasserstein 1D (flow and barycenter) with PyTorch
+=================================================
-In this small example, we consider the following minization problem:
+In this small example, we consider the following minimization problem:
.. math::
\mu^* = \min_\mu W(\mu,\nu)
diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 2d68a39..226dfeb 100644
--- a/examples/barycenters/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -9,61 +9,62 @@ sum of diracs.
"""
-# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import numpy as np
import matplotlib.pylab as pl
import ot
-##############################################################################
+# %%
# Generate data
# -------------
-N = 3
+N = 2
d = 2
-measures_locations = []
-measures_weights = []
-
-for i in range(N):
- n_i = np.random.randint(low=1, high=20) # nb samples
+I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
+I2 = pl.imread('../../data/duck.png').astype(np.float64)[::4, ::4, 2]
- mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
+sz = I2.shape[0]
+XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
- A_i = np.random.rand(d, d)
- cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
+x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0
+x2 = np.stack((XX[I2 == 0] + 80, -YY[I2 == 0] + 32), 1) * 1.0
+x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0
- x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
- b_i = np.random.uniform(0., 1., (n_i,))
- b_i = b_i / np.sum(b_i) # Dirac weights
+measures_locations = [x1, x2]
+measures_weights = [ot.unif(x1.shape[0]), ot.unif(x2.shape[0])]
- measures_locations.append(x_i)
- measures_weights.append(b_i)
+pl.figure(1, (12, 4))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.title('Distributions')
-##############################################################################
+# %%
# Compute free support barycenter
# -------------------------------
-k = 10 # number of Diracs of the barycenter
+k = 200 # number of Diracs of the barycenter
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
-
-##############################################################################
-# Plot data
+# %%
+# Plot the barycenter
# ---------
-pl.figure(1)
-for (x_i, b_i) in zip(measures_locations, measures_weights):
- color = np.random.randint(low=1, high=10 * N)
- pl.scatter(x_i[:, 0], x_i[:, 1], s=b_i * 1000, label='input measure')
-pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
+pl.figure(2, (8, 3))
+pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
+pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
pl.title('Data measures and their barycenter')
-pl.legend(loc=0)
+pl.legend(loc="lower right")
pl.show()
diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py
index afddcad..9414371 100644
--- a/examples/others/plot_logo.py
+++ b/examples/others/plot_logo.py
@@ -7,8 +7,8 @@ Logo of the POT toolbox
In this example we plot the logo of the POT toolbox.
-A specificity of this logo is that it is done 100% in Python and generated using
-matplotlib using the EMD solver from POT.
+This logo is that it is done 100% in Python and generated using
+matplotlib and ploting teh solution of the EMD solver from POT.
"""
@@ -86,8 +86,8 @@ pl.axis('equal')
pl.axis('off')
# Save logo file
-# pl.savefig('logo.svg', dpi=150, bbox_inches='tight')
-# pl.savefig('logo.png', dpi=150, bbox_inches='tight')
+# pl.savefig('logo.svg', dpi=150, transparent=True, bbox_inches='tight')
+# pl.savefig('logo.png', dpi=150, transparent=True, bbox_inches='tight')
# %%
# Plot the logo (dark background)
diff --git a/examples/plot_screenkhorn_1D.py b/examples/others/plot_screenkhorn_1D.py
index 785642a..2023649 100644
--- a/examples/plot_screenkhorn_1D.py
+++ b/examples/others/plot_screenkhorn_1D.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-===============================
-1D Screened optimal transport
-===============================
+========================================
+Screened optimal transport (Screenkhorn)
+========================================
This example illustrates the computation of Screenkhorn [26].
diff --git a/examples/plot_stochastic.py b/examples/others/plot_stochastic.py
index 3a1ef31..3a1ef31 100644
--- a/examples/plot_stochastic.py
+++ b/examples/others/plot_stochastic.py
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py
index 15ead96..62f0b7d 100644
--- a/examples/plot_OT_1D.py
+++ b/examples/plot_OT_1D.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-====================
-1D optimal transport
-====================
+======================================
+Optimal Transport for 1D distributions
+======================================
This example illustrates the computation of EMD and Sinkhorn transport plans
and their visualization.
@@ -64,7 +64,11 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
#%% EMD
-G0 = ot.emd(a, b, M)
+# use fast 1D solver
+G0 = ot.emd_1d(x, x, a, b)
+
+# Equivalent to
+# G0 = ot.emd(a, b, M)
pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
index b07f99f..5415e4f 100644
--- a/examples/plot_OT_1D_smooth.py
+++ b/examples/plot_OT_1D_smooth.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-===========================
-1D smooth optimal transport
-===========================
+================================
+Smooth optimal transport example
+================================
This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
and their visualization.
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index c3a7cd8..1d82fb8 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
====================================================
-2D Optimal transport between empirical distributions
+Optimal Transport between 2D empirical distributions
====================================================
Illustration of 2D optimal transport between discributions that are weighted
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
index cb94574..cce51f8 100644
--- a/examples/plot_OT_L1_vs_L2.py
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""
-==========================================
-2D Optimal transport for different metrics
-==========================================
+================================================
+Optimal Transport with different gournd metrics
+================================================
-2D OT on empirical distributio with different gound metric.
+2D OT on empirical distributio with different ground metric.
Stole the figure idea from Fig. 1 and 2 in
https://arxiv.org/pdf/1706.07650.pdf
@@ -23,7 +23,7 @@ import matplotlib.pylab as pl
import ot
import ot.plot
-##############################################################################
+# %%
# Dataset 1 : uniform sampling
# ----------------------------
@@ -46,7 +46,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean')
M2 /= M2.max()
# loss matrix
-Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp = ot.dist(xs, xt, metric='cityblock')
Mp /= Mp.max()
# Data
@@ -71,7 +71,7 @@ pl.title('Squared Euclidean cost')
pl.subplot(1, 3, 3)
pl.imshow(Mp, interpolation='nearest')
-pl.title('Sqrt Euclidean cost')
+pl.title('L1 (cityblock cost')
pl.tight_layout()
##############################################################################
@@ -109,22 +109,22 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
# pl.legend(loc=0)
-pl.title('OT sqrt Euclidean')
+pl.title('OT L1 (cityblock)')
pl.tight_layout()
pl.show()
-##############################################################################
+# %%
# Dataset 2 : Partial circle
# --------------------------
-n = 50 # nb samples
+n = 20 # nb samples
xtot = np.zeros((n + 1, 2))
xtot[:, 0] = np.cos(
- (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)
xtot[:, 1] = np.sin(
- (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi)
xs = xtot[:n, :]
xt = xtot[1:, :]
@@ -140,7 +140,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean')
M2 /= M2.max()
# loss matrix
-Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp = ot.dist(xs, xt, metric='cityblock')
Mp /= Mp.max()
@@ -166,13 +166,13 @@ pl.title('Squared Euclidean cost')
pl.subplot(1, 3, 3)
pl.imshow(Mp, interpolation='nearest')
-pl.title('Sqrt Euclidean cost')
+pl.title('L1 (cityblock) cost')
pl.tight_layout()
##############################################################################
# Dataset 2 : Plot OT Matrices
# -----------------------------
-
+#
#%% EMD
G1 = ot.emd(a, b, M1)
@@ -204,7 +204,7 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.axis('equal')
# pl.legend(loc=0)
-pl.title('OT sqrt Euclidean')
+pl.title('OT L1 (cityblock)')
pl.tight_layout()
pl.show()
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
index 527a847..36cc7da 100644
--- a/examples/plot_compute_emd.py
+++ b/examples/plot_compute_emd.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
"""
-=================
-Plot multiple EMD
-=================
+==================
+OT distances in 1D
+==================
-Shows how to compute multiple EMD and Sinkhorn with two different
+Shows how to compute multiple Wassersein and Sinkhorn with two different
ground metrics and plot their values for different distributions.
@@ -14,7 +14,7 @@ ground metrics and plot their values for different distributions.
#
# License: MIT License
-# sphinx_gallery_thumbnail_number = 3
+# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
@@ -29,7 +29,7 @@ from ot.datasets import make_1D_gauss as gauss
#%% parameters
n = 100 # nb bins
-n_target = 50 # nb target distributions
+n_target = 20 # nb target distributions
# bin positions
@@ -47,9 +47,9 @@ for i, m in enumerate(lst_m):
# loss matrix and normalization
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
-M /= M.max()
+M /= M.max() * 0.1
M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
-M2 /= M2.max()
+M2 /= M2.max() * 0.1
##############################################################################
# Plot data
@@ -59,10 +59,12 @@ M2 /= M2.max()
pl.figure(1)
pl.subplot(2, 1, 1)
-pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, a, 'r', label='Source distribution')
pl.title('Source distribution')
pl.subplot(2, 1, 2)
-pl.plot(x, B, label='Target distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
pl.title('Target distributions')
pl.tight_layout()
@@ -73,14 +75,27 @@ pl.tight_layout()
#%% Compute and plot distributions and loss matrix
-d_emd = ot.emd2(a, B, M) # direct computation of EMD
-d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
-
+d_emd = ot.emd2(a, B, M) # direct computation of OT loss
+d_emd2 = ot.emd2(a, B, M2) # direct computation of OT loss with metrixc M2
+d_tv = [np.sum(abs(a - B[:, i])) for i in range(n_target)]
pl.figure(2)
-pl.plot(d_emd, label='Euclidean EMD')
-pl.plot(d_emd2, label='Squared Euclidean EMD')
-pl.title('EMD distances')
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'r', label='Source distribution')
+pl.title('Distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
+pl.ylim((-.01, 0.13))
+pl.xticks(())
+pl.legend()
+pl.subplot(2, 1, 2)
+pl.plot(d_emd, label='Euclidean OT')
+pl.plot(d_emd2, label='Squared Euclidean OT')
+pl.plot(d_tv, label='Total Variation (TV)')
+#pl.xlim((-7,23))
+pl.xlabel('Displacement')
+pl.title('Divergences')
pl.legend()
##############################################################################
@@ -88,17 +103,30 @@ pl.legend()
# -----------------------------------------
#%%
-reg = 1e-2
+reg = 1e-1
d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
-pl.figure(2)
+pl.figure(3)
pl.clf()
-pl.plot(d_emd, label='Euclidean EMD')
-pl.plot(d_emd2, label='Squared Euclidean EMD')
+
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'r', label='Source distribution')
+pl.title('Distributions')
+for i in range(n_target):
+ pl.plot(x, B[:, i], 'b', alpha=i / n_target)
+pl.plot(x, B[:, -1], 'b', label='Target distributions')
+pl.ylim((-.01, 0.13))
+pl.xticks(())
+pl.legend()
+pl.subplot(2, 1, 2)
+pl.plot(d_emd, label='Euclidean OT')
+pl.plot(d_emd2, label='Squared Euclidean OT')
pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
-pl.title('EMD distances')
+pl.plot(d_tv, label='Total Variation (TV)')
+#pl.xlim((-7,23))
+pl.xlabel('Displacement')
+pl.title('Divergences')
pl.legend()
-
pl.show()
diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py
index 5eb15bd..7b021d2 100644
--- a/examples/plot_optim_OTreg.py
+++ b/examples/plot_optim_OTreg.py
@@ -24,7 +24,7 @@ arXiv preprint arXiv:1510.06567.
"""
-# sphinx_gallery_thumbnail_number = 4
+# sphinx_gallery_thumbnail_number = 5
import numpy as np
import matplotlib.pylab as pl
@@ -58,7 +58,7 @@ M /= M.max()
G0 = ot.emd(a, b, M)
-pl.figure(3, figsize=(5, 5))
+pl.figure(1, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
##############################################################################
@@ -80,7 +80,7 @@ reg = 1e-1
Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
-pl.figure(3)
+pl.figure(2)
ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg')
##############################################################################
@@ -102,7 +102,7 @@ reg = 1e-3
Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
-pl.figure(4, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg')
##############################################################################
@@ -125,6 +125,34 @@ reg2 = 1e-1
Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True)
-pl.figure(5, figsize=(5, 5))
+pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg')
pl.show()
+
+
+# %%
+# Comparison of the OT matrices
+
+nvisu = 40
+
+pl.figure(5, figsize=(10, 4))
+
+pl.subplot(2, 2, 1)
+pl.imshow(G0[:nvisu, :])
+pl.axis('off')
+pl.title('Exact OT')
+
+pl.subplot(2, 2, 2)
+pl.imshow(Gl2[:nvisu, :])
+pl.axis('off')
+pl.title('Frobenius reg.')
+
+pl.subplot(2, 2, 3)
+pl.imshow(Ge[:nvisu, :])
+pl.axis('off')
+pl.title('Entropic reg.')
+
+pl.subplot(2, 2, 4)
+pl.imshow(Gel2[:nvisu, :])
+pl.axis('off')
+pl.title('Entropic + Frobenius reg.')
diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt
index a575345..73e6122 100644
--- a/examples/sliced-wasserstein/README.txt
+++ b/examples/sliced-wasserstein/README.txt
@@ -1,4 +1,4 @@
Sliced Wasserstein Distance
---------------------------- \ No newline at end of file
+---------------------------
diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py
index 7d73907..f12b522 100644
--- a/examples/sliced-wasserstein/plot_variance.py
+++ b/examples/sliced-wasserstein/plot_variance.py
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
"""
-==============================
-2D Sliced Wasserstein Distance
-==============================
+===============================================
+Sliced Wasserstein Distance on 2D distributions
+===============================================
This example illustrates the computation of the sliced Wasserstein Distance as
proposed in [31].
@@ -16,6 +16,8 @@ measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import matplotlib.pylab as pl
import numpy as np
diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py
index 183849c..06dd02d 100644
--- a/examples/unbalanced-partial/plot_UOT_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_1D.py
@@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
import numpy as np
import matplotlib.pylab as pl
import ot
@@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter
alpha = 1. # Unbalanced KL relaxation parameter
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
-pl.figure(4, figsize=(5, 5))
+pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
pl.show()
+
+
+# %%
+# plot the transported mass
+# -------------------------
+
+pl.figure(4, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source')
+pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target')
+pl.legend(loc='upper right')
+pl.title('Distributions and transported mass for UOT')
diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py
index 4a51c2d..782e8c2 100644
--- a/examples/unbalanced-partial/plot_regpath.py
+++ b/examples/unbalanced-partial/plot_regpath.py
@@ -15,11 +15,12 @@ penalized linear regression.
# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
import numpy as np
import matplotlib.pylab as pl
import ot
-
+import matplotlib.animation as animation
##############################################################################
# Generate data
# -------------
@@ -72,6 +73,9 @@ t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
##############################################################################
# Plot the regularization path
# ----------------
+#
+# The OT plan is ploted as a function of $\gamma$ that is the inverse of the
+# weight on the marginal relaxations.
#%% fully relaxed l2-penalized UOT
@@ -103,13 +107,53 @@ for p in range(4):
pl.show()
+# %%
+# Animation of the regpath for UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(-.5, -2.5, nv)
+
+pl.figure(3)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list,
+ t_list)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
+
+
##############################################################################
# Plot the semi-relaxed regularization path
# -------------------
#%% semi-relaxed l2-penalized UOT
-pl.figure(3)
+pl.figure(4)
selected_gamma = [10, 1, 1e-1, 1e-2]
for p in range(4):
tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2,
@@ -133,3 +177,43 @@ for p in range(4):
if p < 2:
pl.xticks(())
pl.show()
+
+
+# %%
+# Animation of the regpath for semi-relaxed UOT l2
+# ------------------------
+
+nv = 100
+g_list_v = np.logspace(2.5, -2, nv)
+
+pl.figure(5)
+
+
+def _update_plot(iv):
+ pl.clf()
+ tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2,
+ t_list2)
+ P = tp.reshape((n, n))
+ if P.sum() > 0:
+ P = P / P.max()
+ for i in range(n):
+ for j in range(n):
+ if P[i, j] > 0:
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
+ alpha=P[i, j] * 0.5)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
+ pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4,
+ label='Re-weighted source', alpha=1)
+ pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4,
+ label='Re-weighted target', alpha=1)
+ pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
+ pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]),
+ fontsize=11)
+ return 1
+
+
+i = 0
+_update_plot(i)
+
+ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)