summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples')
-rw-r--r--docs/source/auto_examples/auto_examples_jupyter.zipbin122957 -> 148147 bytes
-rw-r--r--docs/source/auto_examples/auto_examples_python.zipbin81905 -> 99229 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.pngbin22281 -> 20785 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.pngbin20743 -> 21134 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.pngbin9695 -> 9704 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.pngbin90088 -> 79153 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.pngbin15036 -> 14611 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.pngbin103143 -> 97487 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_013.pngbin0 -> 10846 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_014.pngbin0 -> 20361 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.pngbin0 -> 21239 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.pngbin0 -> 22051 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.pngbin0 -> 21288 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.pngbin0 -> 22177 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.pngbin0 -> 42539 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.pngbin0 -> 105997 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.pngbin0 -> 103234 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_001.pngbin0 -> 131827 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_002.pngbin0 -> 29423 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_fgw_004.pngbin0 -> 19490 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_fgw_010.pngbin0 -> 44747 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_fgw_011.pngbin0 -> 21337 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.pngbin19155 -> 17987 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.pngbin0 -> 14761 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.pngbin0 -> 15099 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.pngbin0 -> 28694 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.pngbin0 -> 17541 bytes
-rw-r--r--docs/source/auto_examples/index.rst120
-rw-r--r--docs/source/auto_examples/plot_OT_2D_samples.ipynb22
-rw-r--r--docs/source/auto_examples/plot_OT_2D_samples.py26
-rw-r--r--docs/source/auto_examples/plot_OT_2D_samples.rst56
-rw-r--r--docs/source/auto_examples/plot_UOT_1D.ipynb108
-rw-r--r--docs/source/auto_examples/plot_UOT_1D.py76
-rw-r--r--docs/source/auto_examples/plot_UOT_1D.rst173
-rw-r--r--docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb126
-rw-r--r--docs/source/auto_examples/plot_UOT_barycenter_1D.py164
-rw-r--r--docs/source/auto_examples/plot_UOT_barycenter_1D.rst261
-rw-r--r--docs/source/auto_examples/plot_barycenter_fgw.ipynb126
-rw-r--r--docs/source/auto_examples/plot_barycenter_fgw.py184
-rw-r--r--docs/source/auto_examples/plot_barycenter_fgw.rst268
-rw-r--r--docs/source/auto_examples/plot_fgw.ipynb162
-rw-r--r--docs/source/auto_examples/plot_fgw.py173
-rw-r--r--docs/source/auto_examples/plot_fgw.rst297
43 files changed, 2319 insertions, 23 deletions
diff --git a/docs/source/auto_examples/auto_examples_jupyter.zip b/docs/source/auto_examples/auto_examples_jupyter.zip
index 88e1e9b..901195a 100644
--- a/docs/source/auto_examples/auto_examples_jupyter.zip
+++ b/docs/source/auto_examples/auto_examples_jupyter.zip
Binary files differ
diff --git a/docs/source/auto_examples/auto_examples_python.zip b/docs/source/auto_examples/auto_examples_python.zip
index 120a586..ded2613 100644
--- a/docs/source/auto_examples/auto_examples_python.zip
+++ b/docs/source/auto_examples/auto_examples_python.zip
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png
index 2e93ed1..a5bded7 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png
index d6db0ed..1d90c2d 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png
index 9a215ab..ea6a405 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png
index 81c4ddb..8bc46dc 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png
index 892b2a2..56d18ef 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png
index c53717f..5aef7d2 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png
new file mode 100644
index 0000000..bb8bd7c
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png
new file mode 100644
index 0000000..30cec7b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.png
new file mode 100644
index 0000000..69ef5b7
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.png
new file mode 100644
index 0000000..0407e44
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.png
new file mode 100644
index 0000000..f58d383
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png
new file mode 100644
index 0000000..ec8c51e
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png
new file mode 100644
index 0000000..89ab265
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png
new file mode 100644
index 0000000..c6c49cb
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png
new file mode 100644
index 0000000..8870b10
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png
new file mode 100644
index 0000000..77e1282
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png
new file mode 100644
index 0000000..ca6d7f8
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_fgw_004.png b/docs/source/auto_examples/images/sphx_glr_plot_fgw_004.png
new file mode 100644
index 0000000..4e0df9f
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_fgw_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_fgw_010.png b/docs/source/auto_examples/images/sphx_glr_plot_fgw_010.png
new file mode 100644
index 0000000..d0e36e8
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_fgw_010.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_fgw_011.png b/docs/source/auto_examples/images/sphx_glr_plot_fgw_011.png
new file mode 100644
index 0000000..6d7e630
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_fgw_011.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png
index b9135dd..ae33588 100644
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png
new file mode 100644
index 0000000..1d048f2
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png
new file mode 100644
index 0000000..999f175
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png
new file mode 100644
index 0000000..9c3244e
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png
new file mode 100644
index 0000000..609339d
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/index.rst b/docs/source/auto_examples/index.rst
index 17a9710..fe6702d 100644
--- a/docs/source/auto_examples/index.rst
+++ b/docs/source/auto_examples/index.rst
@@ -29,13 +29,13 @@ This is a gallery of all the POT example files.
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="Illustrates the use of the generic solver for regularized OT with user-designed regularization ...">
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of Unbalanced Optimal transport using a Kullback-Leibl...">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png
- :ref:`sphx_glr_auto_examples_plot_optim_OTreg.py`
+ :ref:`sphx_glr_auto_examples_plot_UOT_1D.py`
.. raw:: html
@@ -45,17 +45,17 @@ This is a gallery of all the POT example files.
.. toctree::
:hidden:
- /auto_examples/plot_optim_OTreg
+ /auto_examples/plot_UOT_1D
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="Illustration of 2D Wasserstein barycenters if discributions that are weighted sum of diracs.">
+ <div class="sphx-glr-thumbcontainer" tooltip="Illustrates the use of the generic solver for regularized OT with user-designed regularization ...">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_free_support_barycenter_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png
- :ref:`sphx_glr_auto_examples_plot_free_support_barycenter.py`
+ :ref:`sphx_glr_auto_examples_plot_optim_OTreg.py`
.. raw:: html
@@ -65,17 +65,17 @@ This is a gallery of all the POT example files.
.. toctree::
:hidden:
- /auto_examples/plot_free_support_barycenter
+ /auto_examples/plot_optim_OTreg
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of EMD, Sinkhorn and smooth OT plans and their visuali...">
+ <div class="sphx-glr-thumbcontainer" tooltip="Illustration of 2D Wasserstein barycenters if discributions that are weighted sum of diracs.">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_1D_smooth_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_free_support_barycenter_thumb.png
- :ref:`sphx_glr_auto_examples_plot_OT_1D_smooth.py`
+ :ref:`sphx_glr_auto_examples_plot_free_support_barycenter.py`
.. raw:: html
@@ -85,17 +85,17 @@ This is a gallery of all the POT example files.
.. toctree::
:hidden:
- /auto_examples/plot_OT_1D_smooth
+ /auto_examples/plot_free_support_barycenter
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="This example is designed to show how to use the Gromov-Wassertsein distance computation in POT....">
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of EMD, Sinkhorn and smooth OT plans and their visuali...">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_1D_smooth_thumb.png
- :ref:`sphx_glr_auto_examples_plot_gromov.py`
+ :ref:`sphx_glr_auto_examples_plot_OT_1D_smooth.py`
.. raw:: html
@@ -105,17 +105,17 @@ This is a gallery of all the POT example files.
.. toctree::
:hidden:
- /auto_examples/plot_gromov
+ /auto_examples/plot_OT_1D_smooth
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="Illustration of 2D optimal transport between discributions that are weighted sum of diracs. The...">
+ <div class="sphx-glr-thumbcontainer" tooltip="This example is designed to show how to use the Gromov-Wassertsein distance computation in POT....">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.png
- :ref:`sphx_glr_auto_examples_plot_OT_2D_samples.py`
+ :ref:`sphx_glr_auto_examples_plot_gromov.py`
.. raw:: html
@@ -125,7 +125,7 @@ This is a gallery of all the POT example files.
.. toctree::
:hidden:
- /auto_examples/plot_OT_2D_samples
+ /auto_examples/plot_gromov
.. raw:: html
@@ -209,6 +209,26 @@ This is a gallery of all the POT example files.
.. raw:: html
+ <div class="sphx-glr-thumbcontainer" tooltip="Illustration of 2D optimal transport between discributions that are weighted sum of diracs. The...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_OT_2D_samples.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_OT_2D_samples
+
+.. raw:: html
+
<div class="sphx-glr-thumbcontainer" tooltip="This example is designed to show how to use the stochatic optimization algorithms for descrete ...">
.. only:: html
@@ -289,6 +309,26 @@ This is a gallery of all the POT example files.
.. raw:: html
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of regularized Wassersyein Barycenter as proposed in [...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_UOT_barycenter_1D.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_UOT_barycenter_1D
+
+.. raw:: html
+
<div class="sphx-glr-thumbcontainer" tooltip="This example presents how to use MappingTransport to estimate at the same time both the couplin...">
.. only:: html
@@ -329,6 +369,26 @@ This is a gallery of all the POT example files.
.. raw:: html
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of FGW for 1D measures[18].">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_fgw.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_fgw
+
+.. raw:: html
+
<div class="sphx-glr-thumbcontainer" tooltip="This example introduces a domain adaptation in a 2D setting and the 4 OTDA approaches currently...">
.. only:: html
@@ -409,6 +469,26 @@ This is a gallery of all the POT example files.
.. raw:: html
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation barycenter of labeled graphs using FGW">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_barycenter_fgw.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_barycenter_fgw
+
+.. raw:: html
+
<div class="sphx-glr-thumbcontainer" tooltip="This example is designed to show how to use the Gromov-Wasserstein distance computation in POT....">
.. only:: html
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.ipynb b/docs/source/auto_examples/plot_OT_2D_samples.ipynb
index 26831f9..dad138b 100644
--- a/docs/source/auto_examples/plot_OT_2D_samples.ipynb
+++ b/docs/source/auto_examples/plot_OT_2D_samples.ipynb
@@ -26,7 +26,7 @@
},
"outputs": [],
"source": [
- "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot"
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n# Kilian Fatras <kilian.fatras@irisa.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot"
]
},
{
@@ -100,6 +100,24 @@
"source": [
"#%% sinkhorn\n\n# reg term\nlambd = 1e-3\n\nGs = ot.sinkhorn(a, b, M, lambd)\n\npl.figure(5)\npl.imshow(Gs, interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')\n\npl.show()"
]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Emprirical Sinkhorn\n----------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% sinkhorn\n\n# reg term\nlambd = 1e-3\n\nGes = ot.bregman.empirical_sinkhorn(xs, xt, lambd)\n\npl.figure(7)\npl.imshow(Ges, interpolation='nearest')\npl.title('OT matrix empirical sinkhorn')\n\npl.figure(8)\not.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn from samples')\n\npl.show()"
+ ]
}
],
"metadata": {
@@ -118,7 +136,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.5"
+ "version": "3.6.8"
}
},
"nbformat": 4,
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.py b/docs/source/auto_examples/plot_OT_2D_samples.py
index bb952a0..63126ba 100644
--- a/docs/source/auto_examples/plot_OT_2D_samples.py
+++ b/docs/source/auto_examples/plot_OT_2D_samples.py
@@ -10,6 +10,7 @@ sum of diracs. The OT matrix is plotted with the samples.
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
#
# License: MIT License
@@ -100,3 +101,28 @@ pl.legend(loc=0)
pl.title('OT matrix Sinkhorn with samples')
pl.show()
+
+
+##############################################################################
+# Emprirical Sinkhorn
+# ----------------
+
+#%% sinkhorn
+
+# reg term
+lambd = 1e-3
+
+Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
+
+pl.figure(7)
+pl.imshow(Ges, interpolation='nearest')
+pl.title('OT matrix empirical sinkhorn')
+
+pl.figure(8)
+ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('OT matrix Sinkhorn from samples')
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.rst b/docs/source/auto_examples/plot_OT_2D_samples.rst
index 624ae3e..1f1d713 100644
--- a/docs/source/auto_examples/plot_OT_2D_samples.rst
+++ b/docs/source/auto_examples/plot_OT_2D_samples.rst
@@ -17,6 +17,7 @@ sum of diracs. The OT matrix is plotted with the samples.
# Author: Remi Flamary <remi.flamary@unice.fr>
+ # Kilian Fatras <kilian.fatras@irisa.fr>
#
# License: MIT License
@@ -176,6 +177,8 @@ Compute Sinkhorn
+
+
.. rst-class:: sphx-glr-horizontal
@@ -192,7 +195,58 @@ Compute Sinkhorn
-**Total running time of the script:** ( 0 minutes 3.027 seconds)
+Emprirical Sinkhorn
+----------------
+
+
+
+.. code-block:: python
+
+
+ #%% sinkhorn
+
+ # reg term
+ lambd = 1e-3
+
+ Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
+
+ pl.figure(7)
+ pl.imshow(Ges, interpolation='nearest')
+ pl.title('OT matrix empirical sinkhorn')
+
+ pl.figure(8)
+ ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.legend(loc=0)
+ pl.title('OT matrix Sinkhorn from samples')
+
+ pl.show()
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png
+ :scale: 47
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ Warning: numerical errors at iteration 0
+
+
+**Total running time of the script:** ( 0 minutes 2.616 seconds)
diff --git a/docs/source/auto_examples/plot_UOT_1D.ipynb b/docs/source/auto_examples/plot_UOT_1D.ipynb
new file mode 100644
index 0000000..c695306
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_1D.ipynb
@@ -0,0 +1,108 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# 1D Unbalanced optimal transport\n\n\nThis example illustrates the computation of Unbalanced Optimal transport\nusing a Kullback-Leibler relaxation.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Hicham Janati <hicham.janati@inria.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot\nfrom ot.datasets import make_1D_gauss as gauss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\nb = gauss(n, m=60, s=10)\n\n# make distributions unbalanced\nb *= 5.\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot distributions and loss matrix\n----------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\npl.plot(x, a, 'b', label='Source distribution')\npl.plot(x, b, 'r', label='Target distribution')\npl.legend()\n\n# plot distributions and loss matrix\n\npl.figure(2, figsize=(5, 5))\not.plot.plot1D_mat(a, b, M, 'Cost matrix M')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Solve Unbalanced Sinkhorn\n--------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Sinkhorn\n\nepsilon = 0.1 # entropy parameter\nalpha = 1. # Unbalanced KL relaxation parameter\nGs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)\n\npl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_UOT_1D.py b/docs/source/auto_examples/plot_UOT_1D.py
new file mode 100644
index 0000000..2ea8b05
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_1D.py
@@ -0,0 +1,76 @@
+# -*- coding: utf-8 -*-
+"""
+===============================
+1D Unbalanced optimal transport
+===============================
+
+This example illustrates the computation of Unbalanced Optimal transport
+using a Kullback-Leibler relaxation.
+"""
+
+# Author: Hicham Janati <hicham.janati@inria.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+
+##############################################################################
+# Generate data
+# -------------
+
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# make distributions unbalanced
+b *= 5.
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+# plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+
+##############################################################################
+# Solve Unbalanced Sinkhorn
+# --------------
+
+
+# Sinkhorn
+
+epsilon = 0.1 # entropy parameter
+alpha = 1. # Unbalanced KL relaxation parameter
+Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_UOT_1D.rst b/docs/source/auto_examples/plot_UOT_1D.rst
new file mode 100644
index 0000000..8e618b4
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_1D.rst
@@ -0,0 +1,173 @@
+
+
+.. _sphx_glr_auto_examples_plot_UOT_1D.py:
+
+
+===============================
+1D Unbalanced optimal transport
+===============================
+
+This example illustrates the computation of Unbalanced Optimal transport
+using a Kullback-Leibler relaxation.
+
+
+
+.. code-block:: python
+
+
+ # Author: Hicham Janati <hicham.janati@inria.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ import ot.plot
+ from ot.datasets import make_1D_gauss as gauss
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+
+ #%% parameters
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=20, s=5) # m= mean, s= std
+ b = gauss(n, m=60, s=10)
+
+ # make distributions unbalanced
+ b *= 5.
+
+ # loss matrix
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+ M /= M.max()
+
+
+
+
+
+
+
+
+Plot distributions and loss matrix
+----------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% plot the distributions
+
+ pl.figure(1, figsize=(6.4, 3))
+ pl.plot(x, a, 'b', label='Source distribution')
+ pl.plot(x, b, 'r', label='Target distribution')
+ pl.legend()
+
+ # plot distributions and loss matrix
+
+ pl.figure(2, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_UOT_1D_001.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_UOT_1D_002.png
+ :scale: 47
+
+
+
+
+Solve Unbalanced Sinkhorn
+--------------
+
+
+
+.. code-block:: python
+
+
+
+ # Sinkhorn
+
+ epsilon = 0.1 # entropy parameter
+ alpha = 1. # Unbalanced KL relaxation parameter
+ Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
+
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_UOT_1D_006.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Err
+ -------------------
+ 0|1.838786e+00|
+ 10|1.242379e-01|
+ 20|2.581314e-03|
+ 30|5.674552e-05|
+ 40|1.252959e-06|
+ 50|2.768136e-08|
+ 60|6.116090e-10|
+
+
+**Total running time of the script:** ( 0 minutes 0.259 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_UOT_1D.py <plot_UOT_1D.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_UOT_1D.ipynb <plot_UOT_1D.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb b/docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb
new file mode 100644
index 0000000..e59cdc2
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# 1D Wasserstein barycenter demo for Unbalanced distributions\n\n\nThis example illustrates the computation of regularized Wassersyein Barycenter\nas proposed in [10] for Unbalanced inputs.\n\n\n[10] Chizat, L., Peyr\u00e9, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Hicham Janati <hicham.janati@inria.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\n# necessary for 3d plot even if not used\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\nfrom matplotlib.collections import PolyCollection"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std\na2 = ot.datasets.make_1D_gauss(n, m=60, s=8)\n\n# make unbalanced dists\na2 *= 3.\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Barycenter computation\n----------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# non weighted barycenter computation\n\nweight = 0.5 # 0<=weight<=1\nweights = np.array([1 - weight, weight])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\nalpha = 1.\n\nbary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Barycentric interpolation\n-------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# barycenter interpolation\n\nn_weight = 11\nweight_list = np.linspace(0, 1, n_weight)\n\n\nB_l2 = np.zeros((n, n_weight))\n\nB_wass = np.copy(B_l2)\n\nfor i in range(0, n_weight):\n weight = weight_list[i]\n weights = np.array([1 - weight, weight])\n B_l2[:, i] = A.dot(weights)\n B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)\n\n\n# plot interpolation\n\npl.figure(3)\n\ncmap = pl.cm.get_cmap('viridis')\nverts = []\nzs = weight_list\nfor i, z in enumerate(zs):\n ys = B_l2[:, i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel(r'$\\alpha$')\nax.set_ylim3d(0, 1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max() * 1.01)\npl.title('Barycenter interpolation with l2')\npl.tight_layout()\n\npl.figure(4)\ncmap = pl.cm.get_cmap('viridis')\nverts = []\nzs = weight_list\nfor i, z in enumerate(zs):\n ys = B_wass[:, i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel(r'$\\alpha$')\nax.set_ylim3d(0, 1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max() * 1.01)\npl.title('Barycenter interpolation with Wasserstein')\npl.tight_layout()\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_UOT_barycenter_1D.py b/docs/source/auto_examples/plot_UOT_barycenter_1D.py
new file mode 100644
index 0000000..c8d9d3b
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_barycenter_1D.py
@@ -0,0 +1,164 @@
+# -*- coding: utf-8 -*-
+"""
+===========================================================
+1D Wasserstein barycenter demo for Unbalanced distributions
+===========================================================
+
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [10] for Unbalanced inputs.
+
+
+[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+"""
+
+# Author: Hicham Janati <hicham.janati@inria.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+from matplotlib.collections import PolyCollection
+
+##############################################################################
+# Generate data
+# -------------
+
+# parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# make unbalanced dists
+a2 *= 3.
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+# plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+# non weighted barycenter computation
+
+weight = 0.5 # 0<=weight<=1
+weights = np.array([1 - weight, weight])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+alpha = 1.
+
+bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+##############################################################################
+# Barycentric interpolation
+# -------------------------
+
+# barycenter interpolation
+
+n_weight = 11
+weight_list = np.linspace(0, 1, n_weight)
+
+
+B_l2 = np.zeros((n, n_weight))
+
+B_wass = np.copy(B_l2)
+
+for i in range(0, n_weight):
+ weight = weight_list[i]
+ weights = np.array([1 - weight, weight])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+
+# plot interpolation
+
+pl.figure(3)
+
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = weight_list
+for i, z in enumerate(zs):
+ ys = B_l2[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel(r'$\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with l2')
+pl.tight_layout()
+
+pl.figure(4)
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = weight_list
+for i, z in enumerate(zs):
+ ys = B_wass[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel(r'$\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with Wasserstein')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_UOT_barycenter_1D.rst b/docs/source/auto_examples/plot_UOT_barycenter_1D.rst
new file mode 100644
index 0000000..ac17587
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_barycenter_1D.rst
@@ -0,0 +1,261 @@
+
+
+.. _sphx_glr_auto_examples_plot_UOT_barycenter_1D.py:
+
+
+===========================================================
+1D Wasserstein barycenter demo for Unbalanced distributions
+===========================================================
+
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [10] for Unbalanced inputs.
+
+
+[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Hicham Janati <hicham.janati@inria.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ # necessary for 3d plot even if not used
+ from mpl_toolkits.mplot3d import Axes3D # noqa
+ from matplotlib.collections import PolyCollection
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ # parameters
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+ # make unbalanced dists
+ a2 *= 3.
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ n_distributions = A.shape[1]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n)
+ M /= M.max()
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ # plot the distributions
+
+ pl.figure(1, figsize=(6.4, 3))
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+ pl.tight_layout()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png
+ :align: center
+
+
+
+
+Barycenter computation
+----------------------
+
+
+
+.. code-block:: python
+
+
+ # non weighted barycenter computation
+
+ weight = 0.5 # 0<=weight<=1
+ weights = np.array([1 - weight, weight])
+
+ # l2bary
+ bary_l2 = A.dot(weights)
+
+ # wasserstein
+ reg = 1e-3
+ alpha = 1.
+
+ bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+ pl.figure(2)
+ pl.clf()
+ pl.subplot(2, 1, 1)
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+
+ pl.subplot(2, 1, 2)
+ pl.plot(x, bary_l2, 'r', label='l2')
+ pl.plot(x, bary_wass, 'g', label='Wasserstein')
+ pl.legend()
+ pl.title('Barycenters')
+ pl.tight_layout()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png
+ :align: center
+
+
+
+
+Barycentric interpolation
+-------------------------
+
+
+
+.. code-block:: python
+
+
+ # barycenter interpolation
+
+ n_weight = 11
+ weight_list = np.linspace(0, 1, n_weight)
+
+
+ B_l2 = np.zeros((n, n_weight))
+
+ B_wass = np.copy(B_l2)
+
+ for i in range(0, n_weight):
+ weight = weight_list[i]
+ weights = np.array([1 - weight, weight])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+
+ # plot interpolation
+
+ pl.figure(3)
+
+ cmap = pl.cm.get_cmap('viridis')
+ verts = []
+ zs = weight_list
+ for i, z in enumerate(zs):
+ ys = B_l2[:, i]
+ verts.append(list(zip(x, ys)))
+
+ ax = pl.gcf().gca(projection='3d')
+
+ poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+ poly.set_alpha(0.7)
+ ax.add_collection3d(poly, zs=zs, zdir='y')
+ ax.set_xlabel('x')
+ ax.set_xlim3d(0, n)
+ ax.set_ylabel(r'$\alpha$')
+ ax.set_ylim3d(0, 1)
+ ax.set_zlabel('')
+ ax.set_zlim3d(0, B_l2.max() * 1.01)
+ pl.title('Barycenter interpolation with l2')
+ pl.tight_layout()
+
+ pl.figure(4)
+ cmap = pl.cm.get_cmap('viridis')
+ verts = []
+ zs = weight_list
+ for i, z in enumerate(zs):
+ ys = B_wass[:, i]
+ verts.append(list(zip(x, ys)))
+
+ ax = pl.gcf().gca(projection='3d')
+
+ poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+ poly.set_alpha(0.7)
+ ax.add_collection3d(poly, zs=zs, zdir='y')
+ ax.set_xlabel('x')
+ ax.set_xlim3d(0, n)
+ ax.set_ylabel(r'$\alpha$')
+ ax.set_ylim3d(0, 1)
+ ax.set_zlabel('')
+ ax.set_zlim3d(0, B_l2.max() * 1.01)
+ pl.title('Barycenter interpolation with Wasserstein')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png
+ :scale: 47
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.344 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_UOT_barycenter_1D.py <plot_UOT_barycenter_1D.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_UOT_barycenter_1D.ipynb <plot_UOT_barycenter_1D.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_barycenter_fgw.ipynb b/docs/source/auto_examples/plot_barycenter_fgw.ipynb
new file mode 100644
index 0000000..28229b2
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_fgw.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n=================================\nPlot graphs' barycenter using FGW\n=================================\n\nThis example illustrates the computation barycenter of labeled graphs using FGW\n\nRequires networkx >=2\n\n.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain\n and Courty Nicolas\n \"Optimal Transport for structured data with application on graphs\"\n International Conference on Machine Learning (ICML). 2019.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Titouan Vayer <titouan.vayer@irisa.fr>\n#\n# License: MIT License\n\n#%% load libraries\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport networkx as nx\nimport math\nfrom scipy.sparse.csgraph import shortest_path\nimport matplotlib.colors as mcol\nfrom matplotlib import cm\nfrom ot.gromov import fgw_barycenters\n#%% Graph functions\n\n\ndef find_thresh(C, inf=0.5, sup=3, step=10):\n \"\"\" Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected\n Tthe threshold is found by a linesearch between values \"inf\" and \"sup\" with \"step\" thresholds tested.\n The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix\n and the original matrix.\n Parameters\n ----------\n C : ndarray, shape (n_nodes,n_nodes)\n The structure matrix to threshold\n inf : float\n The beginning of the linesearch\n sup : float\n The end of the linesearch\n step : integer\n Number of thresholds tested\n \"\"\"\n dist = []\n search = np.linspace(inf, sup, step)\n for thresh in search:\n Cprime = sp_to_adjency(C, 0, thresh)\n SC = shortest_path(Cprime, method='D')\n SC[SC == float('inf')] = 100\n dist.append(np.linalg.norm(SC - C))\n return search[np.argmin(dist)], dist\n\n\ndef sp_to_adjency(C, threshinf=0.2, threshsup=1.8):\n \"\"\" Thresholds the structure matrix in order to compute an adjency matrix.\n All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0\n Parameters\n ----------\n C : ndarray, shape (n_nodes,n_nodes)\n The structure matrix to threshold\n threshinf : float\n The minimum value of distance from which the new value is set to 1\n threshsup : float\n The maximum value of distance from which the new value is set to 1\n Returns\n -------\n C : ndarray, shape (n_nodes,n_nodes)\n The threshold matrix. Each element is in {0,1}\n \"\"\"\n H = np.zeros_like(C)\n np.fill_diagonal(H, np.diagonal(C))\n C = C - H\n C = np.minimum(np.maximum(C, threshinf), threshsup)\n C[C == threshsup] = 0\n C[C != 0] = 1\n\n return C\n\n\ndef build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):\n \"\"\" Create a noisy circular graph\n \"\"\"\n g = nx.Graph()\n g.add_nodes_from(list(range(N)))\n for i in range(N):\n noise = float(np.random.normal(mu, sigma, 1))\n if with_noise:\n g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)\n else:\n g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))\n g.add_edge(i, i + 1)\n if structure_noise:\n randomint = np.random.randint(0, p)\n if randomint == 0:\n if i <= N - 3:\n g.add_edge(i, i + 2)\n if i == N - 2:\n g.add_edge(i, 0)\n if i == N - 1:\n g.add_edge(i, 1)\n g.add_edge(N, 0)\n noise = float(np.random.normal(mu, sigma, 1))\n if with_noise:\n g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)\n else:\n g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))\n return g\n\n\ndef graph_colors(nx_graph, vmin=0, vmax=7):\n cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)\n cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')\n cpick.set_array([])\n val_map = {}\n for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():\n val_map[k] = cpick.to_rgba(v)\n colors = []\n for node in nx_graph.nodes():\n colors.append(val_map[node])\n return colors"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% circular dataset\n# We build a dataset of noisy circular graphs.\n# Noise is added on the structures by random connections and on the features by gaussian noise.\n\n\nnp.random.seed(30)\nX0 = []\nfor k in range(9):\n X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Plot graphs\n\nplt.figure(figsize=(8, 10))\nfor i in range(len(X0)):\n plt.subplot(3, 3, i + 1)\n g = X0[i]\n pos = nx.kamada_kawai_layout(g)\n nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)\nplt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)\nplt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Barycenter computation\n----------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph\n# Features distances are the euclidean distances\nCs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]\nps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]\nYs = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]\nlambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()\nsizebary = 15 # we choose a barycenter with 15 nodes\n\nA, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot Barycenter\n-------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Create the barycenter\nbary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))\nfor i, v in enumerate(A.ravel()):\n bary.add_node(i, attr_name=v)\n\n#%%\npos = nx.kamada_kawai_layout(bary)\nnx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)\nplt.suptitle('Barycenter', fontsize=20)\nplt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_barycenter_fgw.py b/docs/source/auto_examples/plot_barycenter_fgw.py
new file mode 100644
index 0000000..77b0370
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_fgw.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Plot graphs' barycenter using FGW
+=================================
+
+This example illustrates the computation barycenter of labeled graphs using FGW
+
+Requires networkx >=2
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+"""
+
+# Author: Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+#%% load libraries
+import numpy as np
+import matplotlib.pyplot as plt
+import networkx as nx
+import math
+from scipy.sparse.csgraph import shortest_path
+import matplotlib.colors as mcol
+from matplotlib import cm
+from ot.gromov import fgw_barycenters
+#%% Graph functions
+
+
+def find_thresh(C, inf=0.5, sup=3, step=10):
+ """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
+ Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
+ The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
+ and the original matrix.
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ inf : float
+ The beginning of the linesearch
+ sup : float
+ The end of the linesearch
+ step : integer
+ Number of thresholds tested
+ """
+ dist = []
+ search = np.linspace(inf, sup, step)
+ for thresh in search:
+ Cprime = sp_to_adjency(C, 0, thresh)
+ SC = shortest_path(Cprime, method='D')
+ SC[SC == float('inf')] = 100
+ dist.append(np.linalg.norm(SC - C))
+ return search[np.argmin(dist)], dist
+
+
+def sp_to_adjency(C, threshinf=0.2, threshsup=1.8):
+ """ Thresholds the structure matrix in order to compute an adjency matrix.
+ All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ threshinf : float
+ The minimum value of distance from which the new value is set to 1
+ threshsup : float
+ The maximum value of distance from which the new value is set to 1
+ Returns
+ -------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The threshold matrix. Each element is in {0,1}
+ """
+ H = np.zeros_like(C)
+ np.fill_diagonal(H, np.diagonal(C))
+ C = C - H
+ C = np.minimum(np.maximum(C, threshinf), threshsup)
+ C[C == threshsup] = 0
+ C[C != 0] = 1
+
+ return C
+
+
+def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):
+ """ Create a noisy circular graph
+ """
+ g = nx.Graph()
+ g.add_nodes_from(list(range(N)))
+ for i in range(N):
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)
+ else:
+ g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))
+ g.add_edge(i, i + 1)
+ if structure_noise:
+ randomint = np.random.randint(0, p)
+ if randomint == 0:
+ if i <= N - 3:
+ g.add_edge(i, i + 2)
+ if i == N - 2:
+ g.add_edge(i, 0)
+ if i == N - 1:
+ g.add_edge(i, 1)
+ g.add_edge(N, 0)
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)
+ else:
+ g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))
+ return g
+
+
+def graph_colors(nx_graph, vmin=0, vmax=7):
+ cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)
+ cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')
+ cpick.set_array([])
+ val_map = {}
+ for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():
+ val_map[k] = cpick.to_rgba(v)
+ colors = []
+ for node in nx_graph.nodes():
+ colors.append(val_map[node])
+ return colors
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% circular dataset
+# We build a dataset of noisy circular graphs.
+# Noise is added on the structures by random connections and on the features by gaussian noise.
+
+
+np.random.seed(30)
+X0 = []
+for k in range(9):
+ X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% Plot graphs
+
+plt.figure(figsize=(8, 10))
+for i in range(len(X0)):
+ plt.subplot(3, 3, i + 1)
+ g = X0[i]
+ pos = nx.kamada_kawai_layout(g)
+ nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)
+plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
+plt.show()
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
+# Features distances are the euclidean distances
+Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
+ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
+Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]
+lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
+sizebary = 15 # we choose a barycenter with 15 nodes
+
+A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True)
+
+##############################################################################
+# Plot Barycenter
+# -------------------------
+
+#%% Create the barycenter
+bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
+for i, v in enumerate(A.ravel()):
+ bary.add_node(i, attr_name=v)
+
+#%%
+pos = nx.kamada_kawai_layout(bary)
+nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)
+plt.suptitle('Barycenter', fontsize=20)
+plt.show()
diff --git a/docs/source/auto_examples/plot_barycenter_fgw.rst b/docs/source/auto_examples/plot_barycenter_fgw.rst
new file mode 100644
index 0000000..2c44a65
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_fgw.rst
@@ -0,0 +1,268 @@
+
+
+.. _sphx_glr_auto_examples_plot_barycenter_fgw.py:
+
+
+=================================
+Plot graphs' barycenter using FGW
+=================================
+
+This example illustrates the computation barycenter of labeled graphs using FGW
+
+Requires networkx >=2
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Titouan Vayer <titouan.vayer@irisa.fr>
+ #
+ # License: MIT License
+
+ #%% load libraries
+ import numpy as np
+ import matplotlib.pyplot as plt
+ import networkx as nx
+ import math
+ from scipy.sparse.csgraph import shortest_path
+ import matplotlib.colors as mcol
+ from matplotlib import cm
+ from ot.gromov import fgw_barycenters
+ #%% Graph functions
+
+
+ def find_thresh(C, inf=0.5, sup=3, step=10):
+ """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
+ Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
+ The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
+ and the original matrix.
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ inf : float
+ The beginning of the linesearch
+ sup : float
+ The end of the linesearch
+ step : integer
+ Number of thresholds tested
+ """
+ dist = []
+ search = np.linspace(inf, sup, step)
+ for thresh in search:
+ Cprime = sp_to_adjency(C, 0, thresh)
+ SC = shortest_path(Cprime, method='D')
+ SC[SC == float('inf')] = 100
+ dist.append(np.linalg.norm(SC - C))
+ return search[np.argmin(dist)], dist
+
+
+ def sp_to_adjency(C, threshinf=0.2, threshsup=1.8):
+ """ Thresholds the structure matrix in order to compute an adjency matrix.
+ All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ threshinf : float
+ The minimum value of distance from which the new value is set to 1
+ threshsup : float
+ The maximum value of distance from which the new value is set to 1
+ Returns
+ -------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The threshold matrix. Each element is in {0,1}
+ """
+ H = np.zeros_like(C)
+ np.fill_diagonal(H, np.diagonal(C))
+ C = C - H
+ C = np.minimum(np.maximum(C, threshinf), threshsup)
+ C[C == threshsup] = 0
+ C[C != 0] = 1
+
+ return C
+
+
+ def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):
+ """ Create a noisy circular graph
+ """
+ g = nx.Graph()
+ g.add_nodes_from(list(range(N)))
+ for i in range(N):
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)
+ else:
+ g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))
+ g.add_edge(i, i + 1)
+ if structure_noise:
+ randomint = np.random.randint(0, p)
+ if randomint == 0:
+ if i <= N - 3:
+ g.add_edge(i, i + 2)
+ if i == N - 2:
+ g.add_edge(i, 0)
+ if i == N - 1:
+ g.add_edge(i, 1)
+ g.add_edge(N, 0)
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)
+ else:
+ g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))
+ return g
+
+
+ def graph_colors(nx_graph, vmin=0, vmax=7):
+ cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)
+ cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')
+ cpick.set_array([])
+ val_map = {}
+ for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():
+ val_map[k] = cpick.to_rgba(v)
+ colors = []
+ for node in nx_graph.nodes():
+ colors.append(val_map[node])
+ return colors
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ #%% circular dataset
+ # We build a dataset of noisy circular graphs.
+ # Noise is added on the structures by random connections and on the features by gaussian noise.
+
+
+ np.random.seed(30)
+ X0 = []
+ for k in range(9):
+ X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% Plot graphs
+
+ plt.figure(figsize=(8, 10))
+ for i in range(len(X0)):
+ plt.subplot(3, 3, i + 1)
+ g = X0[i]
+ pos = nx.kamada_kawai_layout(g)
+ nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)
+ plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
+ plt.show()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png
+ :align: center
+
+
+
+
+Barycenter computation
+----------------------
+
+
+
+.. code-block:: python
+
+
+ #%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
+ # Features distances are the euclidean distances
+ Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
+ ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
+ Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]
+ lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
+ sizebary = 15 # we choose a barycenter with 15 nodes
+
+ A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True)
+
+
+
+
+
+
+
+Plot Barycenter
+-------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Create the barycenter
+ bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
+ for i, v in enumerate(A.ravel()):
+ bary.add_node(i, attr_name=v)
+
+ #%%
+ pos = nx.kamada_kawai_layout(bary)
+ nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)
+ plt.suptitle('Barycenter', fontsize=20)
+ plt.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 2.065 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_barycenter_fgw.py <plot_barycenter_fgw.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_barycenter_fgw.ipynb <plot_barycenter_fgw.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_fgw.ipynb b/docs/source/auto_examples/plot_fgw.ipynb
new file mode 100644
index 0000000..1b150bd
--- /dev/null
+++ b/docs/source/auto_examples/plot_fgw.ipynb
@@ -0,0 +1,162 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# Plot Fused-gromov-Wasserstein\n\n\nThis example illustrates the computation of FGW for 1D measures[18].\n\n.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain\n and Courty Nicolas\n \"Optimal Transport for structured data with application on graphs\"\n International Conference on Machine Learning (ICML). 2019.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Titouan Vayer <titouan.vayer@irisa.fr>\n#\n# License: MIT License\n\nimport matplotlib.pyplot as pl\nimport numpy as np\nimport ot\nfrom ot.gromov import gromov_wasserstein, fused_gromov_wasserstein"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters\n# We create two 1D random measures\nn = 20 # number of points in the first distribution\nn2 = 30 # number of points in the second distribution\nsig = 1 # std of first distribution\nsig2 = 0.1 # std of second distribution\n\nnp.random.seed(0)\n\nphi = np.arange(n)[:, None]\nxs = phi + sig * np.random.randn(n, 1)\nys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1)\n\nphi2 = np.arange(n2)[:, None]\nxt = phi2 + sig * np.random.randn(n2, 1)\nyt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1)\nyt = yt[::-1, :]\n\np = ot.unif(n)\nq = ot.unif(n2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% plot the distributions\n\npl.close(10)\npl.figure(10, (7, 7))\n\npl.subplot(2, 1, 1)\n\npl.scatter(ys, xs, c=phi, s=70)\npl.ylabel('Feature value a', fontsize=20)\npl.title('$\\mu=\\sum_i \\delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)\npl.xticks(())\npl.yticks(())\npl.subplot(2, 1, 2)\npl.scatter(yt, xt, c=phi2, s=70)\npl.xlabel('coordinates x/y', fontsize=25)\npl.ylabel('Feature value b', fontsize=20)\npl.title('$\\\\nu=\\sum_j \\delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)\npl.yticks(())\npl.tight_layout()\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create structure matrices and across-feature distance matrix\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Structure matrices and across-features distance matrix\nC1 = ot.dist(xs)\nC2 = ot.dist(xt)\nM = ot.dist(ys, yt)\nw1 = ot.unif(C1.shape[0])\nw2 = ot.unif(C2.shape[0])\nGot = ot.emd([], [], M)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot matrices\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%%\ncmap = 'Reds'\npl.close(10)\npl.figure(10, (5, 5))\nfs = 15\nl_x = [0, 5, 10, 15]\nl_y = [0, 5, 10, 15, 20, 25]\ngs = pl.GridSpec(5, 5)\n\nax1 = pl.subplot(gs[3:, :2])\n\npl.imshow(C1, cmap=cmap, interpolation='nearest')\npl.title(\"$C_1$\", fontsize=fs)\npl.xlabel(\"$k$\", fontsize=fs)\npl.ylabel(\"$i$\", fontsize=fs)\npl.xticks(l_x)\npl.yticks(l_x)\n\nax2 = pl.subplot(gs[:3, 2:])\n\npl.imshow(C2, cmap=cmap, interpolation='nearest')\npl.title(\"$C_2$\", fontsize=fs)\npl.ylabel(\"$l$\", fontsize=fs)\n#pl.ylabel(\"$l$\",fontsize=fs)\npl.xticks(())\npl.yticks(l_y)\nax2.set_aspect('auto')\n\nax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1)\npl.imshow(M, cmap=cmap, interpolation='nearest')\npl.yticks(l_x)\npl.xticks(l_y)\npl.ylabel(\"$i$\", fontsize=fs)\npl.title(\"$M_{AB}$\", fontsize=fs)\npl.xlabel(\"$j$\", fontsize=fs)\npl.tight_layout()\nax3.set_aspect('auto')\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute FGW/GW\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Computing FGW and GW\nalpha = 1e-3\n\not.tic()\nGwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)\not.toc()\n\n#%reload_ext WGW\nGg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Visualize transport matrices\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% visu OT matrix\ncmap = 'Blues'\nfs = 15\npl.figure(2, (13, 5))\npl.clf()\npl.subplot(1, 3, 1)\npl.imshow(Got, cmap=cmap, interpolation='nearest')\n#pl.xlabel(\"$y$\",fontsize=fs)\npl.ylabel(\"$i$\", fontsize=fs)\npl.xticks(())\n\npl.title('Wasserstein ($M$ only)')\n\npl.subplot(1, 3, 2)\npl.imshow(Gg, cmap=cmap, interpolation='nearest')\npl.title('Gromov ($C_1,C_2$ only)')\npl.xticks(())\npl.subplot(1, 3, 3)\npl.imshow(Gwg, cmap=cmap, interpolation='nearest')\npl.title('FGW ($M+C_1,C_2$)')\n\npl.xlabel(\"$j$\", fontsize=fs)\npl.ylabel(\"$i$\", fontsize=fs)\n\npl.tight_layout()\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_fgw.py b/docs/source/auto_examples/plot_fgw.py
new file mode 100644
index 0000000..43efc94
--- /dev/null
+++ b/docs/source/auto_examples/plot_fgw.py
@@ -0,0 +1,173 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+Plot Fused-gromov-Wasserstein
+==============================
+
+This example illustrates the computation of FGW for 1D measures[18].
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+"""
+
+# Author: Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+import matplotlib.pyplot as pl
+import numpy as np
+import ot
+from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+
+##############################################################################
+# Generate data
+# ---------
+
+#%% parameters
+# We create two 1D random measures
+n = 20 # number of points in the first distribution
+n2 = 30 # number of points in the second distribution
+sig = 1 # std of first distribution
+sig2 = 0.1 # std of second distribution
+
+np.random.seed(0)
+
+phi = np.arange(n)[:, None]
+xs = phi + sig * np.random.randn(n, 1)
+ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1)
+
+phi2 = np.arange(n2)[:, None]
+xt = phi2 + sig * np.random.randn(n2, 1)
+yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1)
+yt = yt[::-1, :]
+
+p = ot.unif(n)
+q = ot.unif(n2)
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.close(10)
+pl.figure(10, (7, 7))
+
+pl.subplot(2, 1, 1)
+
+pl.scatter(ys, xs, c=phi, s=70)
+pl.ylabel('Feature value a', fontsize=20)
+pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)
+pl.xticks(())
+pl.yticks(())
+pl.subplot(2, 1, 2)
+pl.scatter(yt, xt, c=phi2, s=70)
+pl.xlabel('coordinates x/y', fontsize=25)
+pl.ylabel('Feature value b', fontsize=20)
+pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)
+pl.yticks(())
+pl.tight_layout()
+pl.show()
+
+##############################################################################
+# Create structure matrices and across-feature distance matrix
+# ---------
+
+#%% Structure matrices and across-features distance matrix
+C1 = ot.dist(xs)
+C2 = ot.dist(xt)
+M = ot.dist(ys, yt)
+w1 = ot.unif(C1.shape[0])
+w2 = ot.unif(C2.shape[0])
+Got = ot.emd([], [], M)
+
+##############################################################################
+# Plot matrices
+# ---------
+
+#%%
+cmap = 'Reds'
+pl.close(10)
+pl.figure(10, (5, 5))
+fs = 15
+l_x = [0, 5, 10, 15]
+l_y = [0, 5, 10, 15, 20, 25]
+gs = pl.GridSpec(5, 5)
+
+ax1 = pl.subplot(gs[3:, :2])
+
+pl.imshow(C1, cmap=cmap, interpolation='nearest')
+pl.title("$C_1$", fontsize=fs)
+pl.xlabel("$k$", fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+pl.xticks(l_x)
+pl.yticks(l_x)
+
+ax2 = pl.subplot(gs[:3, 2:])
+
+pl.imshow(C2, cmap=cmap, interpolation='nearest')
+pl.title("$C_2$", fontsize=fs)
+pl.ylabel("$l$", fontsize=fs)
+#pl.ylabel("$l$",fontsize=fs)
+pl.xticks(())
+pl.yticks(l_y)
+ax2.set_aspect('auto')
+
+ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1)
+pl.imshow(M, cmap=cmap, interpolation='nearest')
+pl.yticks(l_x)
+pl.xticks(l_y)
+pl.ylabel("$i$", fontsize=fs)
+pl.title("$M_{AB}$", fontsize=fs)
+pl.xlabel("$j$", fontsize=fs)
+pl.tight_layout()
+ax3.set_aspect('auto')
+pl.show()
+
+##############################################################################
+# Compute FGW/GW
+# ---------
+
+#%% Computing FGW and GW
+alpha = 1e-3
+
+ot.tic()
+Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
+ot.toc()
+
+#%reload_ext WGW
+Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
+
+##############################################################################
+# Visualize transport matrices
+# ---------
+
+#%% visu OT matrix
+cmap = 'Blues'
+fs = 15
+pl.figure(2, (13, 5))
+pl.clf()
+pl.subplot(1, 3, 1)
+pl.imshow(Got, cmap=cmap, interpolation='nearest')
+#pl.xlabel("$y$",fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+pl.xticks(())
+
+pl.title('Wasserstein ($M$ only)')
+
+pl.subplot(1, 3, 2)
+pl.imshow(Gg, cmap=cmap, interpolation='nearest')
+pl.title('Gromov ($C_1,C_2$ only)')
+pl.xticks(())
+pl.subplot(1, 3, 3)
+pl.imshow(Gwg, cmap=cmap, interpolation='nearest')
+pl.title('FGW ($M+C_1,C_2$)')
+
+pl.xlabel("$j$", fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+
+pl.tight_layout()
+pl.show()
diff --git a/docs/source/auto_examples/plot_fgw.rst b/docs/source/auto_examples/plot_fgw.rst
new file mode 100644
index 0000000..aec725d
--- /dev/null
+++ b/docs/source/auto_examples/plot_fgw.rst
@@ -0,0 +1,297 @@
+
+
+.. _sphx_glr_auto_examples_plot_fgw.py:
+
+
+==============================
+Plot Fused-gromov-Wasserstein
+==============================
+
+This example illustrates the computation of FGW for 1D measures[18].
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Titouan Vayer <titouan.vayer@irisa.fr>
+ #
+ # License: MIT License
+
+ import matplotlib.pyplot as pl
+ import numpy as np
+ import ot
+ from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+
+
+
+
+
+
+
+Generate data
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% parameters
+ # We create two 1D random measures
+ n = 20 # number of points in the first distribution
+ n2 = 30 # number of points in the second distribution
+ sig = 1 # std of first distribution
+ sig2 = 0.1 # std of second distribution
+
+ np.random.seed(0)
+
+ phi = np.arange(n)[:, None]
+ xs = phi + sig * np.random.randn(n, 1)
+ ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1)
+
+ phi2 = np.arange(n2)[:, None]
+ xt = phi2 + sig * np.random.randn(n2, 1)
+ yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1)
+ yt = yt[::-1, :]
+
+ p = ot.unif(n)
+ q = ot.unif(n2)
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% plot the distributions
+
+ pl.close(10)
+ pl.figure(10, (7, 7))
+
+ pl.subplot(2, 1, 1)
+
+ pl.scatter(ys, xs, c=phi, s=70)
+ pl.ylabel('Feature value a', fontsize=20)
+ pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)
+ pl.xticks(())
+ pl.yticks(())
+ pl.subplot(2, 1, 2)
+ pl.scatter(yt, xt, c=phi2, s=70)
+ pl.xlabel('coordinates x/y', fontsize=25)
+ pl.ylabel('Feature value b', fontsize=20)
+ pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)
+ pl.yticks(())
+ pl.tight_layout()
+ pl.show()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_fgw_010.png
+ :align: center
+
+
+
+
+Create structure matrices and across-feature distance matrix
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% Structure matrices and across-features distance matrix
+ C1 = ot.dist(xs)
+ C2 = ot.dist(xt)
+ M = ot.dist(ys, yt)
+ w1 = ot.unif(C1.shape[0])
+ w2 = ot.unif(C2.shape[0])
+ Got = ot.emd([], [], M)
+
+
+
+
+
+
+
+Plot matrices
+---------
+
+
+
+.. code-block:: python
+
+
+ #%%
+ cmap = 'Reds'
+ pl.close(10)
+ pl.figure(10, (5, 5))
+ fs = 15
+ l_x = [0, 5, 10, 15]
+ l_y = [0, 5, 10, 15, 20, 25]
+ gs = pl.GridSpec(5, 5)
+
+ ax1 = pl.subplot(gs[3:, :2])
+
+ pl.imshow(C1, cmap=cmap, interpolation='nearest')
+ pl.title("$C_1$", fontsize=fs)
+ pl.xlabel("$k$", fontsize=fs)
+ pl.ylabel("$i$", fontsize=fs)
+ pl.xticks(l_x)
+ pl.yticks(l_x)
+
+ ax2 = pl.subplot(gs[:3, 2:])
+
+ pl.imshow(C2, cmap=cmap, interpolation='nearest')
+ pl.title("$C_2$", fontsize=fs)
+ pl.ylabel("$l$", fontsize=fs)
+ #pl.ylabel("$l$",fontsize=fs)
+ pl.xticks(())
+ pl.yticks(l_y)
+ ax2.set_aspect('auto')
+
+ ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1)
+ pl.imshow(M, cmap=cmap, interpolation='nearest')
+ pl.yticks(l_x)
+ pl.xticks(l_y)
+ pl.ylabel("$i$", fontsize=fs)
+ pl.title("$M_{AB}$", fontsize=fs)
+ pl.xlabel("$j$", fontsize=fs)
+ pl.tight_layout()
+ ax3.set_aspect('auto')
+ pl.show()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_fgw_011.png
+ :align: center
+
+
+
+
+Compute FGW/GW
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% Computing FGW and GW
+ alpha = 1e-3
+
+ ot.tic()
+ Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
+ ot.toc()
+
+ #%reload_ext WGW
+ Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Relative loss|Absolute loss
+ ------------------------------------------------
+ 0|4.734462e+01|0.000000e+00|0.000000e+00
+ 1|2.508258e+01|8.875498e-01|2.226204e+01
+ 2|2.189329e+01|1.456747e-01|3.189297e+00
+ 3|2.189329e+01|0.000000e+00|0.000000e+00
+ Elapsed time : 0.0016989707946777344 s
+ It. |Loss |Relative loss|Absolute loss
+ ------------------------------------------------
+ 0|4.683978e+04|0.000000e+00|0.000000e+00
+ 1|3.860061e+04|2.134468e-01|8.239175e+03
+ 2|2.182948e+04|7.682787e-01|1.677113e+04
+ 3|2.182948e+04|0.000000e+00|0.000000e+00
+
+
+Visualize transport matrices
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% visu OT matrix
+ cmap = 'Blues'
+ fs = 15
+ pl.figure(2, (13, 5))
+ pl.clf()
+ pl.subplot(1, 3, 1)
+ pl.imshow(Got, cmap=cmap, interpolation='nearest')
+ #pl.xlabel("$y$",fontsize=fs)
+ pl.ylabel("$i$", fontsize=fs)
+ pl.xticks(())
+
+ pl.title('Wasserstein ($M$ only)')
+
+ pl.subplot(1, 3, 2)
+ pl.imshow(Gg, cmap=cmap, interpolation='nearest')
+ pl.title('Gromov ($C_1,C_2$ only)')
+ pl.xticks(())
+ pl.subplot(1, 3, 3)
+ pl.imshow(Gwg, cmap=cmap, interpolation='nearest')
+ pl.title('FGW ($M+C_1,C_2$)')
+
+ pl.xlabel("$j$", fontsize=fs)
+ pl.ylabel("$i$", fontsize=fs)
+
+ pl.tight_layout()
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_fgw_004.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 1.468 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_fgw.py <plot_fgw.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_fgw.ipynb <plot_fgw.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_