summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_gromov_barycenter.ipynb
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-09-15 13:57:01 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-09-15 13:57:01 +0200
commitdd3546baf9c59733b2109a971293eba48d2eaed3 (patch)
treedbc9c5dd126eecf537acbe7d205b91250f2bdc9b /docs/source/auto_examples/plot_gromov_barycenter.ipynb
parentbad3d95523d005a4fbf64dd009c716b9dd560fe3 (diff)
add all files for doc
Diffstat (limited to 'docs/source/auto_examples/plot_gromov_barycenter.ipynb')
-rw-r--r--docs/source/auto_examples/plot_gromov_barycenter.ipynb126
1 files changed, 126 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_gromov_barycenter.ipynb b/docs/source/auto_examples/plot_gromov_barycenter.ipynb
new file mode 100644
index 0000000..d38dfbb
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov_barycenter.ipynb
@@ -0,0 +1,126 @@
+{
+ "nbformat_minor": 0,
+ "nbformat": 4,
+ "cells": [
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "%matplotlib inline"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "\n# Gromov-Wasserstein Barycenter example\n\n\nThis example is designed to show how to use the Gromov-Wasserstein distance\ncomputation in POT.\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Author: Erwan Vautier <erwan.vautier@gmail.com>\r\n# Nicolas Courty <ncourty@irisa.fr>\r\n#\r\n# License: MIT License\r\n\r\n\r\nimport numpy as np\r\nimport scipy as sp\r\n\r\nimport scipy.ndimage as spi\r\nimport matplotlib.pylab as pl\r\nfrom sklearn import manifold\r\nfrom sklearn.decomposition import PCA\r\n\r\nimport ot"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Smacof MDS\r\n ----------\r\n\r\n This function allows to find an embedding of points given a dissimilarity matrix\r\n that will be given by the output of the algorithm\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "def smacof_mds(C, dim, max_iter=3000, eps=1e-9):\r\n \"\"\"\r\n Returns an interpolated point cloud following the dissimilarity matrix C\r\n using SMACOF multidimensional scaling (MDS) in specific dimensionned\r\n target space\r\n\r\n Parameters\r\n ----------\r\n C : ndarray, shape (ns, ns)\r\n dissimilarity matrix\r\n dim : int\r\n dimension of the targeted space\r\n max_iter : int\r\n Maximum number of iterations of the SMACOF algorithm for a single run\r\n eps : float\r\n relative tolerance w.r.t stress to declare converge\r\n\r\n Returns\r\n -------\r\n npos : ndarray, shape (R, dim)\r\n Embedded coordinates of the interpolated point cloud (defined with\r\n one isometry)\r\n \"\"\"\r\n\r\n rng = np.random.RandomState(seed=3)\r\n\r\n mds = manifold.MDS(\r\n dim,\r\n max_iter=max_iter,\r\n eps=1e-9,\r\n dissimilarity='precomputed',\r\n n_init=1)\r\n pos = mds.fit(C).embedding_\r\n\r\n nmds = manifold.MDS(\r\n 2,\r\n max_iter=max_iter,\r\n eps=1e-9,\r\n dissimilarity=\"precomputed\",\r\n random_state=rng,\r\n n_init=1)\r\n npos = nmds.fit_transform(C, init=pos)\r\n\r\n return npos"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Data preparation\r\n ----------------\r\n\r\n The four distributions are constructed from 4 simple images\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "def im2mat(I):\r\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\r\n return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\r\n\r\n\r\nsquare = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256\r\ncross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256\r\ntriangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256\r\nstar = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256\r\n\r\nshapes = [square, cross, triangle, star]\r\n\r\nS = 4\r\nxs = [[] for i in range(S)]\r\n\r\n\r\nfor nb in range(4):\r\n for i in range(8):\r\n for j in range(8):\r\n if shapes[nb][i, j] < 0.95:\r\n xs[nb].append([j, 8 - i])\r\n\r\nxs = np.array([np.array(xs[0]), np.array(xs[1]),\r\n np.array(xs[2]), np.array(xs[3])])"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Barycenter computation\r\n----------------------\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "ns = [len(xs[s]) for s in range(S)]\r\nn_samples = 30\r\n\r\n\"\"\"Compute all distances matrices for the four shapes\"\"\"\r\nCs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]\r\nCs = [cs / cs.max() for cs in Cs]\r\n\r\nps = [ot.unif(ns[s]) for s in range(S)]\r\np = ot.unif(n_samples)\r\n\r\n\r\nlambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]\r\n\r\nCt01 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],\r\n [ps[0], ps[1]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)\r\n\r\nCt02 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],\r\n [ps[0], ps[2]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)\r\n\r\nCt13 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],\r\n [ps[1], ps[3]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)\r\n\r\nCt23 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],\r\n [ps[2], ps[3]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Visualization\r\n -------------\r\n\r\n The PCA helps in getting consistency between the rotations\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "clf = PCA(n_components=2)\r\nnpos = [0, 0, 0, 0]\r\nnpos = [smacof_mds(Cs[s], 2) for s in range(S)]\r\n\r\nnpost01 = [0, 0]\r\nnpost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]\r\nnpost01 = [clf.fit_transform(npost01[s]) for s in range(2)]\r\n\r\nnpost02 = [0, 0]\r\nnpost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]\r\nnpost02 = [clf.fit_transform(npost02[s]) for s in range(2)]\r\n\r\nnpost13 = [0, 0]\r\nnpost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]\r\nnpost13 = [clf.fit_transform(npost13[s]) for s in range(2)]\r\n\r\nnpost23 = [0, 0]\r\nnpost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]\r\nnpost23 = [clf.fit_transform(npost23[s]) for s in range(2)]\r\n\r\n\r\nfig = pl.figure(figsize=(10, 10))\r\n\r\nax1 = pl.subplot2grid((4, 4), (0, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')\r\n\r\nax2 = pl.subplot2grid((4, 4), (0, 1))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')\r\n\r\nax3 = pl.subplot2grid((4, 4), (0, 2))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')\r\n\r\nax4 = pl.subplot2grid((4, 4), (0, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')\r\n\r\nax5 = pl.subplot2grid((4, 4), (1, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')\r\n\r\nax6 = pl.subplot2grid((4, 4), (1, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')\r\n\r\nax7 = pl.subplot2grid((4, 4), (2, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')\r\n\r\nax8 = pl.subplot2grid((4, 4), (2, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')\r\n\r\nax9 = pl.subplot2grid((4, 4), (3, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')\r\n\r\nax10 = pl.subplot2grid((4, 4), (3, 1))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')\r\n\r\nax11 = pl.subplot2grid((4, 4), (3, 2))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')\r\n\r\nax12 = pl.subplot2grid((4, 4), (3, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2",
+ "language": "python"
+ },
+ "language_info": {
+ "mimetype": "text/x-python",
+ "nbconvert_exporter": "python",
+ "name": "python",
+ "file_extension": ".py",
+ "version": "2.7.12",
+ "pygments_lexer": "ipython2",
+ "codemirror_mode": {
+ "version": 2,
+ "name": "ipython"
+ }
+ }
+ }
+} \ No newline at end of file