summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_gromov_barycenter.ipynb
blob: d38dfbb449daccfa2252119346929a3afdc43eec (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
{
  "nbformat_minor": 0, 
  "nbformat": 4, 
  "cells": [
    {
      "execution_count": null, 
      "cell_type": "code", 
      "source": [
        "%matplotlib inline"
      ], 
      "outputs": [], 
      "metadata": {
        "collapsed": false
      }
    }, 
    {
      "source": [
        "\n# Gromov-Wasserstein Barycenter example\n\n\nThis example is designed to show how to use the Gromov-Wasserstein distance\ncomputation in POT.\n\n"
      ], 
      "cell_type": "markdown", 
      "metadata": {}
    }, 
    {
      "execution_count": null, 
      "cell_type": "code", 
      "source": [
        "# Author: Erwan Vautier <erwan.vautier@gmail.com>\r\n#         Nicolas Courty <ncourty@irisa.fr>\r\n#\r\n# License: MIT License\r\n\r\n\r\nimport numpy as np\r\nimport scipy as sp\r\n\r\nimport scipy.ndimage as spi\r\nimport matplotlib.pylab as pl\r\nfrom sklearn import manifold\r\nfrom sklearn.decomposition import PCA\r\n\r\nimport ot"
      ], 
      "outputs": [], 
      "metadata": {
        "collapsed": false
      }
    }, 
    {
      "source": [
        "Smacof MDS\r\n ----------\r\n\r\n This function allows to find an embedding of points given a dissimilarity matrix\r\n that will be given by the output of the algorithm\r\n\n"
      ], 
      "cell_type": "markdown", 
      "metadata": {}
    }, 
    {
      "execution_count": null, 
      "cell_type": "code", 
      "source": [
        "def smacof_mds(C, dim, max_iter=3000, eps=1e-9):\r\n    \"\"\"\r\n    Returns an interpolated point cloud following the dissimilarity matrix C\r\n    using SMACOF multidimensional scaling (MDS) in specific dimensionned\r\n    target space\r\n\r\n    Parameters\r\n    ----------\r\n    C : ndarray, shape (ns, ns)\r\n        dissimilarity matrix\r\n    dim : int\r\n          dimension of the targeted space\r\n    max_iter :  int\r\n        Maximum number of iterations of the SMACOF algorithm for a single run\r\n    eps : float\r\n        relative tolerance w.r.t stress to declare converge\r\n\r\n    Returns\r\n    -------\r\n    npos : ndarray, shape (R, dim)\r\n           Embedded coordinates of the interpolated point cloud (defined with\r\n           one isometry)\r\n    \"\"\"\r\n\r\n    rng = np.random.RandomState(seed=3)\r\n\r\n    mds = manifold.MDS(\r\n        dim,\r\n        max_iter=max_iter,\r\n        eps=1e-9,\r\n        dissimilarity='precomputed',\r\n        n_init=1)\r\n    pos = mds.fit(C).embedding_\r\n\r\n    nmds = manifold.MDS(\r\n        2,\r\n        max_iter=max_iter,\r\n        eps=1e-9,\r\n        dissimilarity=\"precomputed\",\r\n        random_state=rng,\r\n        n_init=1)\r\n    npos = nmds.fit_transform(C, init=pos)\r\n\r\n    return npos"
      ], 
      "outputs": [], 
      "metadata": {
        "collapsed": false
      }
    }, 
    {
      "source": [
        "Data preparation\r\n ----------------\r\n\r\n The four distributions are constructed from 4 simple images\r\n\n"
      ], 
      "cell_type": "markdown", 
      "metadata": {}
    }, 
    {
      "execution_count": null, 
      "cell_type": "code", 
      "source": [
        "def im2mat(I):\r\n    \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\r\n    return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\r\n\r\n\r\nsquare = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256\r\ncross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256\r\ntriangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256\r\nstar = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256\r\n\r\nshapes = [square, cross, triangle, star]\r\n\r\nS = 4\r\nxs = [[] for i in range(S)]\r\n\r\n\r\nfor nb in range(4):\r\n    for i in range(8):\r\n        for j in range(8):\r\n            if shapes[nb][i, j] < 0.95:\r\n                xs[nb].append([j, 8 - i])\r\n\r\nxs = np.array([np.array(xs[0]), np.array(xs[1]),\r\n               np.array(xs[2]), np.array(xs[3])])"
      ], 
      "outputs": [], 
      "metadata": {
        "collapsed": false
      }
    }, 
    {
      "source": [
        "Barycenter computation\r\n----------------------\r\n\n"
      ], 
      "cell_type": "markdown", 
      "metadata": {}
    }, 
    {
      "execution_count": null, 
      "cell_type": "code", 
      "source": [
        "ns = [len(xs[s]) for s in range(S)]\r\nn_samples = 30\r\n\r\n\"\"\"Compute all distances matrices for the four shapes\"\"\"\r\nCs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]\r\nCs = [cs / cs.max() for cs in Cs]\r\n\r\nps = [ot.unif(ns[s]) for s in range(S)]\r\np = ot.unif(n_samples)\r\n\r\n\r\nlambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]\r\n\r\nCt01 = [0 for i in range(2)]\r\nfor i in range(2):\r\n    Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],\r\n                                           [ps[0], ps[1]\r\n                                            ], p, lambdast[i], 'square_loss', 5e-4,\r\n                                           max_iter=100, tol=1e-3)\r\n\r\nCt02 = [0 for i in range(2)]\r\nfor i in range(2):\r\n    Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],\r\n                                           [ps[0], ps[2]\r\n                                            ], p, lambdast[i], 'square_loss', 5e-4,\r\n                                           max_iter=100, tol=1e-3)\r\n\r\nCt13 = [0 for i in range(2)]\r\nfor i in range(2):\r\n    Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],\r\n                                           [ps[1], ps[3]\r\n                                            ], p, lambdast[i], 'square_loss', 5e-4,\r\n                                           max_iter=100, tol=1e-3)\r\n\r\nCt23 = [0 for i in range(2)]\r\nfor i in range(2):\r\n    Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],\r\n                                           [ps[2], ps[3]\r\n                                            ], p, lambdast[i], 'square_loss', 5e-4,\r\n                                           max_iter=100, tol=1e-3)"
      ], 
      "outputs": [], 
      "metadata": {
        "collapsed": false
      }
    }, 
    {
      "source": [
        "Visualization\r\n -------------\r\n\r\n The PCA helps in getting consistency between the rotations\r\n\n"
      ], 
      "cell_type": "markdown", 
      "metadata": {}
    }, 
    {
      "execution_count": null, 
      "cell_type": "code", 
      "source": [
        "clf = PCA(n_components=2)\r\nnpos = [0, 0, 0, 0]\r\nnpos = [smacof_mds(Cs[s], 2) for s in range(S)]\r\n\r\nnpost01 = [0, 0]\r\nnpost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]\r\nnpost01 = [clf.fit_transform(npost01[s]) for s in range(2)]\r\n\r\nnpost02 = [0, 0]\r\nnpost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]\r\nnpost02 = [clf.fit_transform(npost02[s]) for s in range(2)]\r\n\r\nnpost13 = [0, 0]\r\nnpost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]\r\nnpost13 = [clf.fit_transform(npost13[s]) for s in range(2)]\r\n\r\nnpost23 = [0, 0]\r\nnpost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]\r\nnpost23 = [clf.fit_transform(npost23[s]) for s in range(2)]\r\n\r\n\r\nfig = pl.figure(figsize=(10, 10))\r\n\r\nax1 = pl.subplot2grid((4, 4), (0, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')\r\n\r\nax2 = pl.subplot2grid((4, 4), (0, 1))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')\r\n\r\nax3 = pl.subplot2grid((4, 4), (0, 2))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')\r\n\r\nax4 = pl.subplot2grid((4, 4), (0, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')\r\n\r\nax5 = pl.subplot2grid((4, 4), (1, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')\r\n\r\nax6 = pl.subplot2grid((4, 4), (1, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')\r\n\r\nax7 = pl.subplot2grid((4, 4), (2, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')\r\n\r\nax8 = pl.subplot2grid((4, 4), (2, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')\r\n\r\nax9 = pl.subplot2grid((4, 4), (3, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')\r\n\r\nax10 = pl.subplot2grid((4, 4), (3, 1))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')\r\n\r\nax11 = pl.subplot2grid((4, 4), (3, 2))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')\r\n\r\nax12 = pl.subplot2grid((4, 4), (3, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')"
      ], 
      "outputs": [], 
      "metadata": {
        "collapsed": false
      }
    }
  ], 
  "metadata": {
    "kernelspec": {
      "display_name": "Python 2", 
      "name": "python2", 
      "language": "python"
    }, 
    "language_info": {
      "mimetype": "text/x-python", 
      "nbconvert_exporter": "python", 
      "name": "python", 
      "file_extension": ".py", 
      "version": "2.7.12", 
      "pygments_lexer": "ipython2", 
      "codemirror_mode": {
        "version": 2, 
        "name": "ipython"
      }
    }
  }
}