diff options
Diffstat (limited to 'docs')
29 files changed, 1372 insertions, 24 deletions
diff --git a/docs/cache_nbrun b/docs/cache_nbrun index 6f10375..04f6fce 100644 --- a/docs/cache_nbrun +++ b/docs/cache_nbrun @@ -1 +1 @@ -{"plot_otda_mapping_colors_images.ipynb": "cc8bf9a857f52e4a159fe71dfda19018", "plot_optim_OTreg.ipynb": "481801bb0d133ef350a65179cf8f739a", "plot_otda_color_images.ipynb": "f804d5806c7ac1a0901e4542b1eaa77b", "plot_stochastic.ipynb": "e18253354c8c1d72567a4259eb1094f7", "plot_WDA.ipynb": "27f8de4c6d7db46497076523673eedfb", "plot_otda_linear_mapping.ipynb": "a472c767abe82020e0a58125a528785c", "plot_OT_1D_smooth.ipynb": "3a059103652225a0c78ea53895cf79e5", "plot_OT_L1_vs_L2.ipynb": "5d565b8aaf03be4309eba731127851dc", "plot_barycenter_1D.ipynb": "5f6fb8aebd8e2e91ebc77c923cb112b3", "plot_otda_classes.ipynb": "39087b6e98217851575f2271c22853a4", "plot_otda_d2.ipynb": "e6feae588103f2a8fab942e5f4eff483", "plot_otda_mapping.ipynb": "2f1ebbdc0f855d9e2b7adf9edec24d25", "plot_gromov.ipynb": "24f2aea489714d34779521f46d5e2c47", "plot_compute_emd.ipynb": "f5cd71cad882ec157dc8222721e9820c", "plot_OT_1D.ipynb": "b5348bdc561c07ec168a1622e5af4b93", "plot_gromov_barycenter.ipynb": "953e5047b886ec69ec621ec52f5e21d1", "plot_free_support_barycenter.ipynb": "246dd2feff4b233a4f1a553c5a202fdc", "plot_convolutional_barycenter.ipynb": "a72bb3716a1baaffd81ae267a673f9b6", "plot_otda_semi_supervised.ipynb": "f6dfb02ba2bbd939408ffcd22a3b007c", "plot_OT_2D_samples.ipynb": "07dbc14859fa019a966caa79fa0825bd", "plot_barycenter_lp_vs_entropic.ipynb": "51833e8c76aaedeba9599ac7a30eb357"}
\ No newline at end of file +{"plot_otda_color_images.ipynb": "f804d5806c7ac1a0901e4542b1eaa77b", "plot_WDA.ipynb": "27f8de4c6d7db46497076523673eedfb", "plot_OT_L1_vs_L2.ipynb": "5d565b8aaf03be4309eba731127851dc", "plot_otda_semi_supervised.ipynb": "f6dfb02ba2bbd939408ffcd22a3b007c", "plot_fgw.ipynb": "2ba3e100e92ecf4dfbeb605de20b40ab", "plot_otda_d2.ipynb": "e6feae588103f2a8fab942e5f4eff483", "plot_compute_emd.ipynb": "f5cd71cad882ec157dc8222721e9820c", "plot_barycenter_fgw.ipynb": "e14100dd276bff3ffdfdf176f1b6b070", "plot_convolutional_barycenter.ipynb": "a72bb3716a1baaffd81ae267a673f9b6", "plot_optim_OTreg.ipynb": "481801bb0d133ef350a65179cf8f739a", "plot_barycenter_lp_vs_entropic.ipynb": "51833e8c76aaedeba9599ac7a30eb357", "plot_OT_1D_smooth.ipynb": "3a059103652225a0c78ea53895cf79e5", "plot_barycenter_1D.ipynb": "5f6fb8aebd8e2e91ebc77c923cb112b3", "plot_otda_mapping.ipynb": "2f1ebbdc0f855d9e2b7adf9edec24d25", "plot_OT_1D.ipynb": "b5348bdc561c07ec168a1622e5af4b93", "plot_gromov_barycenter.ipynb": "953e5047b886ec69ec621ec52f5e21d1", "plot_otda_mapping_colors_images.ipynb": "cc8bf9a857f52e4a159fe71dfda19018", "plot_stochastic.ipynb": "e18253354c8c1d72567a4259eb1094f7", "plot_otda_linear_mapping.ipynb": "a472c767abe82020e0a58125a528785c", "plot_otda_classes.ipynb": "39087b6e98217851575f2271c22853a4", "plot_free_support_barycenter.ipynb": "246dd2feff4b233a4f1a553c5a202fdc", "plot_gromov.ipynb": "24f2aea489714d34779521f46d5e2c47", "plot_OT_2D_samples.ipynb": "912a77c5dd0fc0fafa03fac3d86f1502"}
\ No newline at end of file diff --git a/docs/source/auto_examples/auto_examples_jupyter.zip b/docs/source/auto_examples/auto_examples_jupyter.zip Binary files differindex 88e1e9b..a3a7c29 100644 --- a/docs/source/auto_examples/auto_examples_jupyter.zip +++ b/docs/source/auto_examples/auto_examples_jupyter.zip diff --git a/docs/source/auto_examples/auto_examples_python.zip b/docs/source/auto_examples/auto_examples_python.zip Binary files differindex 120a586..86a6841 100644 --- a/docs/source/auto_examples/auto_examples_python.zip +++ b/docs/source/auto_examples/auto_examples_python.zip diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png Binary files differindex 2e93ed1..a5bded7 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png Binary files differindex d6db0ed..1d90c2d 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png Binary files differindex 9a215ab..ea6a405 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png Binary files differindex 81c4ddb..8bc46dc 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png Binary files differindex 892b2a2..56d18ef 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png Binary files differindex c53717f..5aef7d2 100644 --- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png Binary files differnew file mode 100644 index 0000000..bb8bd7c --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png Binary files differnew file mode 100644 index 0000000..30cec7b --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png Binary files differnew file mode 100644 index 0000000..77e1282 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png Binary files differnew file mode 100644 index 0000000..ca6d7f8 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_fgw_004.png b/docs/source/auto_examples/images/sphx_glr_plot_fgw_004.png Binary files differnew file mode 100644 index 0000000..4e0df9f --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_fgw_004.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_fgw_010.png b/docs/source/auto_examples/images/sphx_glr_plot_fgw_010.png Binary files differnew file mode 100644 index 0000000..d0e36e8 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_fgw_010.png diff --git a/docs/source/auto_examples/images/sphx_glr_plot_fgw_011.png b/docs/source/auto_examples/images/sphx_glr_plot_fgw_011.png Binary files differnew file mode 100644 index 0000000..6d7e630 --- /dev/null +++ b/docs/source/auto_examples/images/sphx_glr_plot_fgw_011.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png Binary files differindex b9135dd..ae33588 100644 --- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png Binary files differnew file mode 100644 index 0000000..9c3244e --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png Binary files differnew file mode 100644 index 0000000..609339d --- /dev/null +++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png diff --git a/docs/source/auto_examples/index.rst b/docs/source/auto_examples/index.rst index 17a9710..9f02da4 100644 --- a/docs/source/auto_examples/index.rst +++ b/docs/source/auto_examples/index.rst @@ -109,26 +109,6 @@ This is a gallery of all the POT example files. .. raw:: html - <div class="sphx-glr-thumbcontainer" tooltip="Illustration of 2D optimal transport between discributions that are weighted sum of diracs. The..."> - -.. only:: html - - .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png - - :ref:`sphx_glr_auto_examples_plot_OT_2D_samples.py` - -.. raw:: html - - </div> - - -.. toctree:: - :hidden: - - /auto_examples/plot_OT_2D_samples - -.. raw:: html - <div class="sphx-glr-thumbcontainer" tooltip="Shows how to compute multiple EMD and Sinkhorn with two differnt ground metrics and plot their ..."> .. only:: html @@ -209,6 +189,26 @@ This is a gallery of all the POT example files. .. raw:: html + <div class="sphx-glr-thumbcontainer" tooltip="Illustration of 2D optimal transport between discributions that are weighted sum of diracs. The..."> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png + + :ref:`sphx_glr_auto_examples_plot_OT_2D_samples.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_OT_2D_samples + +.. raw:: html + <div class="sphx-glr-thumbcontainer" tooltip="This example is designed to show how to use the stochatic optimization algorithms for descrete ..."> .. only:: html @@ -329,6 +329,26 @@ This is a gallery of all the POT example files. .. raw:: html + <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of FGW for 1D measures[18]."> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png + + :ref:`sphx_glr_auto_examples_plot_fgw.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_fgw + +.. raw:: html + <div class="sphx-glr-thumbcontainer" tooltip="This example introduces a domain adaptation in a 2D setting and the 4 OTDA approaches currently..."> .. only:: html @@ -409,6 +429,26 @@ This is a gallery of all the POT example files. .. raw:: html + <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation barycenter of labeled graphs using FGW"> + +.. only:: html + + .. figure:: /auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png + + :ref:`sphx_glr_auto_examples_plot_barycenter_fgw.py` + +.. raw:: html + + </div> + + +.. toctree:: + :hidden: + + /auto_examples/plot_barycenter_fgw + +.. raw:: html + <div class="sphx-glr-thumbcontainer" tooltip="This example is designed to show how to use the Gromov-Wasserstein distance computation in POT...."> .. only:: html diff --git a/docs/source/auto_examples/plot_OT_2D_samples.ipynb b/docs/source/auto_examples/plot_OT_2D_samples.ipynb index 26831f9..dad138b 100644 --- a/docs/source/auto_examples/plot_OT_2D_samples.ipynb +++ b/docs/source/auto_examples/plot_OT_2D_samples.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot" + "# Author: Remi Flamary <remi.flamary@unice.fr>\n# Kilian Fatras <kilian.fatras@irisa.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot" ] }, { @@ -100,6 +100,24 @@ "source": [ "#%% sinkhorn\n\n# reg term\nlambd = 1e-3\n\nGs = ot.sinkhorn(a, b, M, lambd)\n\npl.figure(5)\npl.imshow(Gs, interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')\n\npl.show()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Emprirical Sinkhorn\n----------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% sinkhorn\n\n# reg term\nlambd = 1e-3\n\nGes = ot.bregman.empirical_sinkhorn(xs, xt, lambd)\n\npl.figure(7)\npl.imshow(Ges, interpolation='nearest')\npl.title('OT matrix empirical sinkhorn')\n\npl.figure(8)\not.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn from samples')\n\npl.show()" + ] } ], "metadata": { @@ -118,7 +136,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.5" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/docs/source/auto_examples/plot_OT_2D_samples.py b/docs/source/auto_examples/plot_OT_2D_samples.py index bb952a0..63126ba 100644 --- a/docs/source/auto_examples/plot_OT_2D_samples.py +++ b/docs/source/auto_examples/plot_OT_2D_samples.py @@ -10,6 +10,7 @@ sum of diracs. The OT matrix is plotted with the samples. """ # Author: Remi Flamary <remi.flamary@unice.fr> +# Kilian Fatras <kilian.fatras@irisa.fr> # # License: MIT License @@ -100,3 +101,28 @@ pl.legend(loc=0) pl.title('OT matrix Sinkhorn with samples') pl.show() + + +############################################################################## +# Emprirical Sinkhorn +# ---------------- + +#%% sinkhorn + +# reg term +lambd = 1e-3 + +Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd) + +pl.figure(7) +pl.imshow(Ges, interpolation='nearest') +pl.title('OT matrix empirical sinkhorn') + +pl.figure(8) +ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('OT matrix Sinkhorn from samples') + +pl.show() diff --git a/docs/source/auto_examples/plot_OT_2D_samples.rst b/docs/source/auto_examples/plot_OT_2D_samples.rst index 624ae3e..1f1d713 100644 --- a/docs/source/auto_examples/plot_OT_2D_samples.rst +++ b/docs/source/auto_examples/plot_OT_2D_samples.rst @@ -17,6 +17,7 @@ sum of diracs. The OT matrix is plotted with the samples. # Author: Remi Flamary <remi.flamary@unice.fr> + # Kilian Fatras <kilian.fatras@irisa.fr> # # License: MIT License @@ -176,6 +177,8 @@ Compute Sinkhorn + + .. rst-class:: sphx-glr-horizontal @@ -192,7 +195,58 @@ Compute Sinkhorn -**Total running time of the script:** ( 0 minutes 3.027 seconds) +Emprirical Sinkhorn +---------------- + + + +.. code-block:: python + + + #%% sinkhorn + + # reg term + lambd = 1e-3 + + Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd) + + pl.figure(7) + pl.imshow(Ges, interpolation='nearest') + pl.title('OT matrix empirical sinkhorn') + + pl.figure(8) + ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1]) + pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') + pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') + pl.legend(loc=0) + pl.title('OT matrix Sinkhorn from samples') + + pl.show() + + + +.. rst-class:: sphx-glr-horizontal + + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png + :scale: 47 + + * + + .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png + :scale: 47 + + +.. rst-class:: sphx-glr-script-out + + Out:: + + Warning: numerical errors at iteration 0 + + +**Total running time of the script:** ( 0 minutes 2.616 seconds) diff --git a/docs/source/auto_examples/plot_barycenter_fgw.ipynb b/docs/source/auto_examples/plot_barycenter_fgw.ipynb new file mode 100644 index 0000000..28229b2 --- /dev/null +++ b/docs/source/auto_examples/plot_barycenter_fgw.ipynb @@ -0,0 +1,126 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n=================================\nPlot graphs' barycenter using FGW\n=================================\n\nThis example illustrates the computation barycenter of labeled graphs using FGW\n\nRequires networkx >=2\n\n.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain\n and Courty Nicolas\n \"Optimal Transport for structured data with application on graphs\"\n International Conference on Machine Learning (ICML). 2019.\n\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Author: Titouan Vayer <titouan.vayer@irisa.fr>\n#\n# License: MIT License\n\n#%% load libraries\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport networkx as nx\nimport math\nfrom scipy.sparse.csgraph import shortest_path\nimport matplotlib.colors as mcol\nfrom matplotlib import cm\nfrom ot.gromov import fgw_barycenters\n#%% Graph functions\n\n\ndef find_thresh(C, inf=0.5, sup=3, step=10):\n \"\"\" Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected\n Tthe threshold is found by a linesearch between values \"inf\" and \"sup\" with \"step\" thresholds tested.\n The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix\n and the original matrix.\n Parameters\n ----------\n C : ndarray, shape (n_nodes,n_nodes)\n The structure matrix to threshold\n inf : float\n The beginning of the linesearch\n sup : float\n The end of the linesearch\n step : integer\n Number of thresholds tested\n \"\"\"\n dist = []\n search = np.linspace(inf, sup, step)\n for thresh in search:\n Cprime = sp_to_adjency(C, 0, thresh)\n SC = shortest_path(Cprime, method='D')\n SC[SC == float('inf')] = 100\n dist.append(np.linalg.norm(SC - C))\n return search[np.argmin(dist)], dist\n\n\ndef sp_to_adjency(C, threshinf=0.2, threshsup=1.8):\n \"\"\" Thresholds the structure matrix in order to compute an adjency matrix.\n All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0\n Parameters\n ----------\n C : ndarray, shape (n_nodes,n_nodes)\n The structure matrix to threshold\n threshinf : float\n The minimum value of distance from which the new value is set to 1\n threshsup : float\n The maximum value of distance from which the new value is set to 1\n Returns\n -------\n C : ndarray, shape (n_nodes,n_nodes)\n The threshold matrix. Each element is in {0,1}\n \"\"\"\n H = np.zeros_like(C)\n np.fill_diagonal(H, np.diagonal(C))\n C = C - H\n C = np.minimum(np.maximum(C, threshinf), threshsup)\n C[C == threshsup] = 0\n C[C != 0] = 1\n\n return C\n\n\ndef build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):\n \"\"\" Create a noisy circular graph\n \"\"\"\n g = nx.Graph()\n g.add_nodes_from(list(range(N)))\n for i in range(N):\n noise = float(np.random.normal(mu, sigma, 1))\n if with_noise:\n g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)\n else:\n g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))\n g.add_edge(i, i + 1)\n if structure_noise:\n randomint = np.random.randint(0, p)\n if randomint == 0:\n if i <= N - 3:\n g.add_edge(i, i + 2)\n if i == N - 2:\n g.add_edge(i, 0)\n if i == N - 1:\n g.add_edge(i, 1)\n g.add_edge(N, 0)\n noise = float(np.random.normal(mu, sigma, 1))\n if with_noise:\n g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)\n else:\n g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))\n return g\n\n\ndef graph_colors(nx_graph, vmin=0, vmax=7):\n cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)\n cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')\n cpick.set_array([])\n val_map = {}\n for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():\n val_map[k] = cpick.to_rgba(v)\n colors = []\n for node in nx_graph.nodes():\n colors.append(val_map[node])\n return colors" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate data\n-------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% circular dataset\n# We build a dataset of noisy circular graphs.\n# Noise is added on the structures by random connections and on the features by gaussian noise.\n\n\nnp.random.seed(30)\nX0 = []\nfor k in range(9):\n X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot data\n---------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% Plot graphs\n\nplt.figure(figsize=(8, 10))\nfor i in range(len(X0)):\n plt.subplot(3, 3, i + 1)\n g = X0[i]\n pos = nx.kamada_kawai_layout(g)\n nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)\nplt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)\nplt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Barycenter computation\n----------------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph\n# Features distances are the euclidean distances\nCs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]\nps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]\nYs = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]\nlambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()\nsizebary = 15 # we choose a barycenter with 15 nodes\n\nA, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot Barycenter\n-------------------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% Create the barycenter\nbary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))\nfor i, v in enumerate(A.ravel()):\n bary.add_node(i, attr_name=v)\n\n#%%\npos = nx.kamada_kawai_layout(bary)\nnx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)\nplt.suptitle('Barycenter', fontsize=20)\nplt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_barycenter_fgw.py b/docs/source/auto_examples/plot_barycenter_fgw.py new file mode 100644 index 0000000..77b0370 --- /dev/null +++ b/docs/source/auto_examples/plot_barycenter_fgw.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +""" +================================= +Plot graphs' barycenter using FGW +================================= + +This example illustrates the computation barycenter of labeled graphs using FGW + +Requires networkx >=2 + +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + +""" + +# Author: Titouan Vayer <titouan.vayer@irisa.fr> +# +# License: MIT License + +#%% load libraries +import numpy as np +import matplotlib.pyplot as plt +import networkx as nx +import math +from scipy.sparse.csgraph import shortest_path +import matplotlib.colors as mcol +from matplotlib import cm +from ot.gromov import fgw_barycenters +#%% Graph functions + + +def find_thresh(C, inf=0.5, sup=3, step=10): + """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected + Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. + The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix + and the original matrix. + Parameters + ---------- + C : ndarray, shape (n_nodes,n_nodes) + The structure matrix to threshold + inf : float + The beginning of the linesearch + sup : float + The end of the linesearch + step : integer + Number of thresholds tested + """ + dist = [] + search = np.linspace(inf, sup, step) + for thresh in search: + Cprime = sp_to_adjency(C, 0, thresh) + SC = shortest_path(Cprime, method='D') + SC[SC == float('inf')] = 100 + dist.append(np.linalg.norm(SC - C)) + return search[np.argmin(dist)], dist + + +def sp_to_adjency(C, threshinf=0.2, threshsup=1.8): + """ Thresholds the structure matrix in order to compute an adjency matrix. + All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 + Parameters + ---------- + C : ndarray, shape (n_nodes,n_nodes) + The structure matrix to threshold + threshinf : float + The minimum value of distance from which the new value is set to 1 + threshsup : float + The maximum value of distance from which the new value is set to 1 + Returns + ------- + C : ndarray, shape (n_nodes,n_nodes) + The threshold matrix. Each element is in {0,1} + """ + H = np.zeros_like(C) + np.fill_diagonal(H, np.diagonal(C)) + C = C - H + C = np.minimum(np.maximum(C, threshinf), threshsup) + C[C == threshsup] = 0 + C[C != 0] = 1 + + return C + + +def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None): + """ Create a noisy circular graph + """ + g = nx.Graph() + g.add_nodes_from(list(range(N))) + for i in range(N): + noise = float(np.random.normal(mu, sigma, 1)) + if with_noise: + g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise) + else: + g.add_node(i, attr_name=math.sin(2 * i * math.pi / N)) + g.add_edge(i, i + 1) + if structure_noise: + randomint = np.random.randint(0, p) + if randomint == 0: + if i <= N - 3: + g.add_edge(i, i + 2) + if i == N - 2: + g.add_edge(i, 0) + if i == N - 1: + g.add_edge(i, 1) + g.add_edge(N, 0) + noise = float(np.random.normal(mu, sigma, 1)) + if with_noise: + g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise) + else: + g.add_node(N, attr_name=math.sin(2 * N * math.pi / N)) + return g + + +def graph_colors(nx_graph, vmin=0, vmax=7): + cnorm = mcol.Normalize(vmin=vmin, vmax=vmax) + cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis') + cpick.set_array([]) + val_map = {} + for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items(): + val_map[k] = cpick.to_rgba(v) + colors = [] + for node in nx_graph.nodes(): + colors.append(val_map[node]) + return colors + +############################################################################## +# Generate data +# ------------- + +#%% circular dataset +# We build a dataset of noisy circular graphs. +# Noise is added on the structures by random connections and on the features by gaussian noise. + + +np.random.seed(30) +X0 = [] +for k in range(9): + X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3)) + +############################################################################## +# Plot data +# --------- + +#%% Plot graphs + +plt.figure(figsize=(8, 10)) +for i in range(len(X0)): + plt.subplot(3, 3, i + 1) + g = X0[i] + pos = nx.kamada_kawai_layout(g) + nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100) +plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20) +plt.show() + +############################################################################## +# Barycenter computation +# ---------------------- + +#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph +# Features distances are the euclidean distances +Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0] +ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] +Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0] +lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel() +sizebary = 15 # we choose a barycenter with 15 nodes + +A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True) + +############################################################################## +# Plot Barycenter +# ------------------------- + +#%% Create the barycenter +bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) +for i, v in enumerate(A.ravel()): + bary.add_node(i, attr_name=v) + +#%% +pos = nx.kamada_kawai_layout(bary) +nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False) +plt.suptitle('Barycenter', fontsize=20) +plt.show() diff --git a/docs/source/auto_examples/plot_barycenter_fgw.rst b/docs/source/auto_examples/plot_barycenter_fgw.rst new file mode 100644 index 0000000..2c44a65 --- /dev/null +++ b/docs/source/auto_examples/plot_barycenter_fgw.rst @@ -0,0 +1,268 @@ + + +.. _sphx_glr_auto_examples_plot_barycenter_fgw.py: + + +================================= +Plot graphs' barycenter using FGW +================================= + +This example illustrates the computation barycenter of labeled graphs using FGW + +Requires networkx >=2 + +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + + + + +.. code-block:: python + + + # Author: Titouan Vayer <titouan.vayer@irisa.fr> + # + # License: MIT License + + #%% load libraries + import numpy as np + import matplotlib.pyplot as plt + import networkx as nx + import math + from scipy.sparse.csgraph import shortest_path + import matplotlib.colors as mcol + from matplotlib import cm + from ot.gromov import fgw_barycenters + #%% Graph functions + + + def find_thresh(C, inf=0.5, sup=3, step=10): + """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected + Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. + The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix + and the original matrix. + Parameters + ---------- + C : ndarray, shape (n_nodes,n_nodes) + The structure matrix to threshold + inf : float + The beginning of the linesearch + sup : float + The end of the linesearch + step : integer + Number of thresholds tested + """ + dist = [] + search = np.linspace(inf, sup, step) + for thresh in search: + Cprime = sp_to_adjency(C, 0, thresh) + SC = shortest_path(Cprime, method='D') + SC[SC == float('inf')] = 100 + dist.append(np.linalg.norm(SC - C)) + return search[np.argmin(dist)], dist + + + def sp_to_adjency(C, threshinf=0.2, threshsup=1.8): + """ Thresholds the structure matrix in order to compute an adjency matrix. + All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 + Parameters + ---------- + C : ndarray, shape (n_nodes,n_nodes) + The structure matrix to threshold + threshinf : float + The minimum value of distance from which the new value is set to 1 + threshsup : float + The maximum value of distance from which the new value is set to 1 + Returns + ------- + C : ndarray, shape (n_nodes,n_nodes) + The threshold matrix. Each element is in {0,1} + """ + H = np.zeros_like(C) + np.fill_diagonal(H, np.diagonal(C)) + C = C - H + C = np.minimum(np.maximum(C, threshinf), threshsup) + C[C == threshsup] = 0 + C[C != 0] = 1 + + return C + + + def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None): + """ Create a noisy circular graph + """ + g = nx.Graph() + g.add_nodes_from(list(range(N))) + for i in range(N): + noise = float(np.random.normal(mu, sigma, 1)) + if with_noise: + g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise) + else: + g.add_node(i, attr_name=math.sin(2 * i * math.pi / N)) + g.add_edge(i, i + 1) + if structure_noise: + randomint = np.random.randint(0, p) + if randomint == 0: + if i <= N - 3: + g.add_edge(i, i + 2) + if i == N - 2: + g.add_edge(i, 0) + if i == N - 1: + g.add_edge(i, 1) + g.add_edge(N, 0) + noise = float(np.random.normal(mu, sigma, 1)) + if with_noise: + g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise) + else: + g.add_node(N, attr_name=math.sin(2 * N * math.pi / N)) + return g + + + def graph_colors(nx_graph, vmin=0, vmax=7): + cnorm = mcol.Normalize(vmin=vmin, vmax=vmax) + cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis') + cpick.set_array([]) + val_map = {} + for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items(): + val_map[k] = cpick.to_rgba(v) + colors = [] + for node in nx_graph.nodes(): + colors.append(val_map[node]) + return colors + + + + + + + +Generate data +------------- + + + +.. code-block:: python + + + #%% circular dataset + # We build a dataset of noisy circular graphs. + # Noise is added on the structures by random connections and on the features by gaussian noise. + + + np.random.seed(30) + X0 = [] + for k in range(9): + X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3)) + + + + + + + +Plot data +--------- + + + +.. code-block:: python + + + #%% Plot graphs + + plt.figure(figsize=(8, 10)) + for i in range(len(X0)): + plt.subplot(3, 3, i + 1) + g = X0[i] + pos = nx.kamada_kawai_layout(g) + nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100) + plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20) + plt.show() + + + + +.. image:: /auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png + :align: center + + + + +Barycenter computation +---------------------- + + + +.. code-block:: python + + + #%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph + # Features distances are the euclidean distances + Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0] + ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0] + Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0] + lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel() + sizebary = 15 # we choose a barycenter with 15 nodes + + A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True) + + + + + + + +Plot Barycenter +------------------------- + + + +.. code-block:: python + + + #%% Create the barycenter + bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0])) + for i, v in enumerate(A.ravel()): + bary.add_node(i, attr_name=v) + + #%% + pos = nx.kamada_kawai_layout(bary) + nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False) + plt.suptitle('Barycenter', fontsize=20) + plt.show() + + + +.. image:: /auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png + :align: center + + + + +**Total running time of the script:** ( 0 minutes 2.065 seconds) + + + +.. only :: html + + .. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_barycenter_fgw.py <plot_barycenter_fgw.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_barycenter_fgw.ipynb <plot_barycenter_fgw.ipynb>` + + +.. only:: html + + .. rst-class:: sphx-glr-signature + + `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_ diff --git a/docs/source/auto_examples/plot_fgw.ipynb b/docs/source/auto_examples/plot_fgw.ipynb new file mode 100644 index 0000000..1b150bd --- /dev/null +++ b/docs/source/auto_examples/plot_fgw.ipynb @@ -0,0 +1,162 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Plot Fused-gromov-Wasserstein\n\n\nThis example illustrates the computation of FGW for 1D measures[18].\n\n.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain\n and Courty Nicolas\n \"Optimal Transport for structured data with application on graphs\"\n International Conference on Machine Learning (ICML). 2019.\n\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Author: Titouan Vayer <titouan.vayer@irisa.fr>\n#\n# License: MIT License\n\nimport matplotlib.pyplot as pl\nimport numpy as np\nimport ot\nfrom ot.gromov import gromov_wasserstein, fused_gromov_wasserstein" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate data\n---------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% parameters\n# We create two 1D random measures\nn = 20 # number of points in the first distribution\nn2 = 30 # number of points in the second distribution\nsig = 1 # std of first distribution\nsig2 = 0.1 # std of second distribution\n\nnp.random.seed(0)\n\nphi = np.arange(n)[:, None]\nxs = phi + sig * np.random.randn(n, 1)\nys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1)\n\nphi2 = np.arange(n2)[:, None]\nxt = phi2 + sig * np.random.randn(n2, 1)\nyt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1)\nyt = yt[::-1, :]\n\np = ot.unif(n)\nq = ot.unif(n2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot data\n---------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% plot the distributions\n\npl.close(10)\npl.figure(10, (7, 7))\n\npl.subplot(2, 1, 1)\n\npl.scatter(ys, xs, c=phi, s=70)\npl.ylabel('Feature value a', fontsize=20)\npl.title('$\\mu=\\sum_i \\delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)\npl.xticks(())\npl.yticks(())\npl.subplot(2, 1, 2)\npl.scatter(yt, xt, c=phi2, s=70)\npl.xlabel('coordinates x/y', fontsize=25)\npl.ylabel('Feature value b', fontsize=20)\npl.title('$\\\\nu=\\sum_j \\delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)\npl.yticks(())\npl.tight_layout()\npl.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create structure matrices and across-feature distance matrix\n---------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% Structure matrices and across-features distance matrix\nC1 = ot.dist(xs)\nC2 = ot.dist(xt)\nM = ot.dist(ys, yt)\nw1 = ot.unif(C1.shape[0])\nw2 = ot.unif(C2.shape[0])\nGot = ot.emd([], [], M)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot matrices\n---------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%%\ncmap = 'Reds'\npl.close(10)\npl.figure(10, (5, 5))\nfs = 15\nl_x = [0, 5, 10, 15]\nl_y = [0, 5, 10, 15, 20, 25]\ngs = pl.GridSpec(5, 5)\n\nax1 = pl.subplot(gs[3:, :2])\n\npl.imshow(C1, cmap=cmap, interpolation='nearest')\npl.title(\"$C_1$\", fontsize=fs)\npl.xlabel(\"$k$\", fontsize=fs)\npl.ylabel(\"$i$\", fontsize=fs)\npl.xticks(l_x)\npl.yticks(l_x)\n\nax2 = pl.subplot(gs[:3, 2:])\n\npl.imshow(C2, cmap=cmap, interpolation='nearest')\npl.title(\"$C_2$\", fontsize=fs)\npl.ylabel(\"$l$\", fontsize=fs)\n#pl.ylabel(\"$l$\",fontsize=fs)\npl.xticks(())\npl.yticks(l_y)\nax2.set_aspect('auto')\n\nax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1)\npl.imshow(M, cmap=cmap, interpolation='nearest')\npl.yticks(l_x)\npl.xticks(l_y)\npl.ylabel(\"$i$\", fontsize=fs)\npl.title(\"$M_{AB}$\", fontsize=fs)\npl.xlabel(\"$j$\", fontsize=fs)\npl.tight_layout()\nax3.set_aspect('auto')\npl.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compute FGW/GW\n---------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% Computing FGW and GW\nalpha = 1e-3\n\not.tic()\nGwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)\not.toc()\n\n#%reload_ext WGW\nGg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Visualize transport matrices\n---------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "#%% visu OT matrix\ncmap = 'Blues'\nfs = 15\npl.figure(2, (13, 5))\npl.clf()\npl.subplot(1, 3, 1)\npl.imshow(Got, cmap=cmap, interpolation='nearest')\n#pl.xlabel(\"$y$\",fontsize=fs)\npl.ylabel(\"$i$\", fontsize=fs)\npl.xticks(())\n\npl.title('Wasserstein ($M$ only)')\n\npl.subplot(1, 3, 2)\npl.imshow(Gg, cmap=cmap, interpolation='nearest')\npl.title('Gromov ($C_1,C_2$ only)')\npl.xticks(())\npl.subplot(1, 3, 3)\npl.imshow(Gwg, cmap=cmap, interpolation='nearest')\npl.title('FGW ($M+C_1,C_2$)')\n\npl.xlabel(\"$j$\", fontsize=fs)\npl.ylabel(\"$i$\", fontsize=fs)\n\npl.tight_layout()\npl.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}
\ No newline at end of file diff --git a/docs/source/auto_examples/plot_fgw.py b/docs/source/auto_examples/plot_fgw.py new file mode 100644 index 0000000..43efc94 --- /dev/null +++ b/docs/source/auto_examples/plot_fgw.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +""" +============================== +Plot Fused-gromov-Wasserstein +============================== + +This example illustrates the computation of FGW for 1D measures[18]. + +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + +""" + +# Author: Titouan Vayer <titouan.vayer@irisa.fr> +# +# License: MIT License + +import matplotlib.pyplot as pl +import numpy as np +import ot +from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein + +############################################################################## +# Generate data +# --------- + +#%% parameters +# We create two 1D random measures +n = 20 # number of points in the first distribution +n2 = 30 # number of points in the second distribution +sig = 1 # std of first distribution +sig2 = 0.1 # std of second distribution + +np.random.seed(0) + +phi = np.arange(n)[:, None] +xs = phi + sig * np.random.randn(n, 1) +ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1) + +phi2 = np.arange(n2)[:, None] +xt = phi2 + sig * np.random.randn(n2, 1) +yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1) +yt = yt[::-1, :] + +p = ot.unif(n) +q = ot.unif(n2) + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + +pl.close(10) +pl.figure(10, (7, 7)) + +pl.subplot(2, 1, 1) + +pl.scatter(ys, xs, c=phi, s=70) +pl.ylabel('Feature value a', fontsize=20) +pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1) +pl.xticks(()) +pl.yticks(()) +pl.subplot(2, 1, 2) +pl.scatter(yt, xt, c=phi2, s=70) +pl.xlabel('coordinates x/y', fontsize=25) +pl.ylabel('Feature value b', fontsize=20) +pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1) +pl.yticks(()) +pl.tight_layout() +pl.show() + +############################################################################## +# Create structure matrices and across-feature distance matrix +# --------- + +#%% Structure matrices and across-features distance matrix +C1 = ot.dist(xs) +C2 = ot.dist(xt) +M = ot.dist(ys, yt) +w1 = ot.unif(C1.shape[0]) +w2 = ot.unif(C2.shape[0]) +Got = ot.emd([], [], M) + +############################################################################## +# Plot matrices +# --------- + +#%% +cmap = 'Reds' +pl.close(10) +pl.figure(10, (5, 5)) +fs = 15 +l_x = [0, 5, 10, 15] +l_y = [0, 5, 10, 15, 20, 25] +gs = pl.GridSpec(5, 5) + +ax1 = pl.subplot(gs[3:, :2]) + +pl.imshow(C1, cmap=cmap, interpolation='nearest') +pl.title("$C_1$", fontsize=fs) +pl.xlabel("$k$", fontsize=fs) +pl.ylabel("$i$", fontsize=fs) +pl.xticks(l_x) +pl.yticks(l_x) + +ax2 = pl.subplot(gs[:3, 2:]) + +pl.imshow(C2, cmap=cmap, interpolation='nearest') +pl.title("$C_2$", fontsize=fs) +pl.ylabel("$l$", fontsize=fs) +#pl.ylabel("$l$",fontsize=fs) +pl.xticks(()) +pl.yticks(l_y) +ax2.set_aspect('auto') + +ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1) +pl.imshow(M, cmap=cmap, interpolation='nearest') +pl.yticks(l_x) +pl.xticks(l_y) +pl.ylabel("$i$", fontsize=fs) +pl.title("$M_{AB}$", fontsize=fs) +pl.xlabel("$j$", fontsize=fs) +pl.tight_layout() +ax3.set_aspect('auto') +pl.show() + +############################################################################## +# Compute FGW/GW +# --------- + +#%% Computing FGW and GW +alpha = 1e-3 + +ot.tic() +Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True) +ot.toc() + +#%reload_ext WGW +Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) + +############################################################################## +# Visualize transport matrices +# --------- + +#%% visu OT matrix +cmap = 'Blues' +fs = 15 +pl.figure(2, (13, 5)) +pl.clf() +pl.subplot(1, 3, 1) +pl.imshow(Got, cmap=cmap, interpolation='nearest') +#pl.xlabel("$y$",fontsize=fs) +pl.ylabel("$i$", fontsize=fs) +pl.xticks(()) + +pl.title('Wasserstein ($M$ only)') + +pl.subplot(1, 3, 2) +pl.imshow(Gg, cmap=cmap, interpolation='nearest') +pl.title('Gromov ($C_1,C_2$ only)') +pl.xticks(()) +pl.subplot(1, 3, 3) +pl.imshow(Gwg, cmap=cmap, interpolation='nearest') +pl.title('FGW ($M+C_1,C_2$)') + +pl.xlabel("$j$", fontsize=fs) +pl.ylabel("$i$", fontsize=fs) + +pl.tight_layout() +pl.show() diff --git a/docs/source/auto_examples/plot_fgw.rst b/docs/source/auto_examples/plot_fgw.rst new file mode 100644 index 0000000..aec725d --- /dev/null +++ b/docs/source/auto_examples/plot_fgw.rst @@ -0,0 +1,297 @@ + + +.. _sphx_glr_auto_examples_plot_fgw.py: + + +============================== +Plot Fused-gromov-Wasserstein +============================== + +This example illustrates the computation of FGW for 1D measures[18]. + +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + + + + +.. code-block:: python + + + # Author: Titouan Vayer <titouan.vayer@irisa.fr> + # + # License: MIT License + + import matplotlib.pyplot as pl + import numpy as np + import ot + from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein + + + + + + + +Generate data +--------- + + + +.. code-block:: python + + + #%% parameters + # We create two 1D random measures + n = 20 # number of points in the first distribution + n2 = 30 # number of points in the second distribution + sig = 1 # std of first distribution + sig2 = 0.1 # std of second distribution + + np.random.seed(0) + + phi = np.arange(n)[:, None] + xs = phi + sig * np.random.randn(n, 1) + ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1) + + phi2 = np.arange(n2)[:, None] + xt = phi2 + sig * np.random.randn(n2, 1) + yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1) + yt = yt[::-1, :] + + p = ot.unif(n) + q = ot.unif(n2) + + + + + + + +Plot data +--------- + + + +.. code-block:: python + + + #%% plot the distributions + + pl.close(10) + pl.figure(10, (7, 7)) + + pl.subplot(2, 1, 1) + + pl.scatter(ys, xs, c=phi, s=70) + pl.ylabel('Feature value a', fontsize=20) + pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1) + pl.xticks(()) + pl.yticks(()) + pl.subplot(2, 1, 2) + pl.scatter(yt, xt, c=phi2, s=70) + pl.xlabel('coordinates x/y', fontsize=25) + pl.ylabel('Feature value b', fontsize=20) + pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1) + pl.yticks(()) + pl.tight_layout() + pl.show() + + + + +.. image:: /auto_examples/images/sphx_glr_plot_fgw_010.png + :align: center + + + + +Create structure matrices and across-feature distance matrix +--------- + + + +.. code-block:: python + + + #%% Structure matrices and across-features distance matrix + C1 = ot.dist(xs) + C2 = ot.dist(xt) + M = ot.dist(ys, yt) + w1 = ot.unif(C1.shape[0]) + w2 = ot.unif(C2.shape[0]) + Got = ot.emd([], [], M) + + + + + + + +Plot matrices +--------- + + + +.. code-block:: python + + + #%% + cmap = 'Reds' + pl.close(10) + pl.figure(10, (5, 5)) + fs = 15 + l_x = [0, 5, 10, 15] + l_y = [0, 5, 10, 15, 20, 25] + gs = pl.GridSpec(5, 5) + + ax1 = pl.subplot(gs[3:, :2]) + + pl.imshow(C1, cmap=cmap, interpolation='nearest') + pl.title("$C_1$", fontsize=fs) + pl.xlabel("$k$", fontsize=fs) + pl.ylabel("$i$", fontsize=fs) + pl.xticks(l_x) + pl.yticks(l_x) + + ax2 = pl.subplot(gs[:3, 2:]) + + pl.imshow(C2, cmap=cmap, interpolation='nearest') + pl.title("$C_2$", fontsize=fs) + pl.ylabel("$l$", fontsize=fs) + #pl.ylabel("$l$",fontsize=fs) + pl.xticks(()) + pl.yticks(l_y) + ax2.set_aspect('auto') + + ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1) + pl.imshow(M, cmap=cmap, interpolation='nearest') + pl.yticks(l_x) + pl.xticks(l_y) + pl.ylabel("$i$", fontsize=fs) + pl.title("$M_{AB}$", fontsize=fs) + pl.xlabel("$j$", fontsize=fs) + pl.tight_layout() + ax3.set_aspect('auto') + pl.show() + + + + +.. image:: /auto_examples/images/sphx_glr_plot_fgw_011.png + :align: center + + + + +Compute FGW/GW +--------- + + + +.. code-block:: python + + + #%% Computing FGW and GW + alpha = 1e-3 + + ot.tic() + Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True) + ot.toc() + + #%reload_ext WGW + Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) + + + + + +.. rst-class:: sphx-glr-script-out + + Out:: + + It. |Loss |Relative loss|Absolute loss + ------------------------------------------------ + 0|4.734462e+01|0.000000e+00|0.000000e+00 + 1|2.508258e+01|8.875498e-01|2.226204e+01 + 2|2.189329e+01|1.456747e-01|3.189297e+00 + 3|2.189329e+01|0.000000e+00|0.000000e+00 + Elapsed time : 0.0016989707946777344 s + It. |Loss |Relative loss|Absolute loss + ------------------------------------------------ + 0|4.683978e+04|0.000000e+00|0.000000e+00 + 1|3.860061e+04|2.134468e-01|8.239175e+03 + 2|2.182948e+04|7.682787e-01|1.677113e+04 + 3|2.182948e+04|0.000000e+00|0.000000e+00 + + +Visualize transport matrices +--------- + + + +.. code-block:: python + + + #%% visu OT matrix + cmap = 'Blues' + fs = 15 + pl.figure(2, (13, 5)) + pl.clf() + pl.subplot(1, 3, 1) + pl.imshow(Got, cmap=cmap, interpolation='nearest') + #pl.xlabel("$y$",fontsize=fs) + pl.ylabel("$i$", fontsize=fs) + pl.xticks(()) + + pl.title('Wasserstein ($M$ only)') + + pl.subplot(1, 3, 2) + pl.imshow(Gg, cmap=cmap, interpolation='nearest') + pl.title('Gromov ($C_1,C_2$ only)') + pl.xticks(()) + pl.subplot(1, 3, 3) + pl.imshow(Gwg, cmap=cmap, interpolation='nearest') + pl.title('FGW ($M+C_1,C_2$)') + + pl.xlabel("$j$", fontsize=fs) + pl.ylabel("$i$", fontsize=fs) + + pl.tight_layout() + pl.show() + + + +.. image:: /auto_examples/images/sphx_glr_plot_fgw_004.png + :align: center + + + + +**Total running time of the script:** ( 0 minutes 1.468 seconds) + + + +.. only :: html + + .. container:: sphx-glr-footer + + + .. container:: sphx-glr-download + + :download:`Download Python source code: plot_fgw.py <plot_fgw.py>` + + + + .. container:: sphx-glr-download + + :download:`Download Jupyter notebook: plot_fgw.ipynb <plot_fgw.ipynb>` + + +.. only:: html + + .. rst-class:: sphx-glr-signature + + `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_ |