1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
|
{
"nbformat_minor": 0,
"nbformat": 4,
"cells": [
{
"execution_count": null,
"cell_type": "code",
"source": [
"%matplotlib inline"
],
"outputs": [],
"metadata": {
"collapsed": false
}
},
{
"source": [
"\n# Gromov-Wasserstein Barycenter example\n\n\nThis example is designed to show how to use the Gromov-Wasserstein distance\ncomputation in POT.\n\n"
],
"cell_type": "markdown",
"metadata": {}
},
{
"execution_count": null,
"cell_type": "code",
"source": [
"# Author: Erwan Vautier <erwan.vautier@gmail.com>\r\n# Nicolas Courty <ncourty@irisa.fr>\r\n#\r\n# License: MIT License\r\n\r\n\r\nimport numpy as np\r\nimport scipy as sp\r\n\r\nimport scipy.ndimage as spi\r\nimport matplotlib.pylab as pl\r\nfrom sklearn import manifold\r\nfrom sklearn.decomposition import PCA\r\n\r\nimport ot"
],
"outputs": [],
"metadata": {
"collapsed": false
}
},
{
"source": [
"Smacof MDS\r\n ----------\r\n\r\n This function allows to find an embedding of points given a dissimilarity matrix\r\n that will be given by the output of the algorithm\r\n\n"
],
"cell_type": "markdown",
"metadata": {}
},
{
"execution_count": null,
"cell_type": "code",
"source": [
"def smacof_mds(C, dim, max_iter=3000, eps=1e-9):\r\n \"\"\"\r\n Returns an interpolated point cloud following the dissimilarity matrix C\r\n using SMACOF multidimensional scaling (MDS) in specific dimensionned\r\n target space\r\n\r\n Parameters\r\n ----------\r\n C : ndarray, shape (ns, ns)\r\n dissimilarity matrix\r\n dim : int\r\n dimension of the targeted space\r\n max_iter : int\r\n Maximum number of iterations of the SMACOF algorithm for a single run\r\n eps : float\r\n relative tolerance w.r.t stress to declare converge\r\n\r\n Returns\r\n -------\r\n npos : ndarray, shape (R, dim)\r\n Embedded coordinates of the interpolated point cloud (defined with\r\n one isometry)\r\n \"\"\"\r\n\r\n rng = np.random.RandomState(seed=3)\r\n\r\n mds = manifold.MDS(\r\n dim,\r\n max_iter=max_iter,\r\n eps=1e-9,\r\n dissimilarity='precomputed',\r\n n_init=1)\r\n pos = mds.fit(C).embedding_\r\n\r\n nmds = manifold.MDS(\r\n 2,\r\n max_iter=max_iter,\r\n eps=1e-9,\r\n dissimilarity=\"precomputed\",\r\n random_state=rng,\r\n n_init=1)\r\n npos = nmds.fit_transform(C, init=pos)\r\n\r\n return npos"
],
"outputs": [],
"metadata": {
"collapsed": false
}
},
{
"source": [
"Data preparation\r\n ----------------\r\n\r\n The four distributions are constructed from 4 simple images\r\n\n"
],
"cell_type": "markdown",
"metadata": {}
},
{
"execution_count": null,
"cell_type": "code",
"source": [
"def im2mat(I):\r\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\r\n return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\r\n\r\n\r\nsquare = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256\r\ncross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256\r\ntriangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256\r\nstar = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256\r\n\r\nshapes = [square, cross, triangle, star]\r\n\r\nS = 4\r\nxs = [[] for i in range(S)]\r\n\r\n\r\nfor nb in range(4):\r\n for i in range(8):\r\n for j in range(8):\r\n if shapes[nb][i, j] < 0.95:\r\n xs[nb].append([j, 8 - i])\r\n\r\nxs = np.array([np.array(xs[0]), np.array(xs[1]),\r\n np.array(xs[2]), np.array(xs[3])])"
],
"outputs": [],
"metadata": {
"collapsed": false
}
},
{
"source": [
"Barycenter computation\r\n----------------------\r\n\n"
],
"cell_type": "markdown",
"metadata": {}
},
{
"execution_count": null,
"cell_type": "code",
"source": [
"ns = [len(xs[s]) for s in range(S)]\r\nn_samples = 30\r\n\r\n\"\"\"Compute all distances matrices for the four shapes\"\"\"\r\nCs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]\r\nCs = [cs / cs.max() for cs in Cs]\r\n\r\nps = [ot.unif(ns[s]) for s in range(S)]\r\np = ot.unif(n_samples)\r\n\r\n\r\nlambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]\r\n\r\nCt01 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],\r\n [ps[0], ps[1]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)\r\n\r\nCt02 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],\r\n [ps[0], ps[2]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)\r\n\r\nCt13 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],\r\n [ps[1], ps[3]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)\r\n\r\nCt23 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],\r\n [ps[2], ps[3]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)"
],
"outputs": [],
"metadata": {
"collapsed": false
}
},
{
"source": [
"Visualization\r\n -------------\r\n\r\n The PCA helps in getting consistency between the rotations\r\n\n"
],
"cell_type": "markdown",
"metadata": {}
},
{
"execution_count": null,
"cell_type": "code",
"source": [
"clf = PCA(n_components=2)\r\nnpos = [0, 0, 0, 0]\r\nnpos = [smacof_mds(Cs[s], 2) for s in range(S)]\r\n\r\nnpost01 = [0, 0]\r\nnpost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]\r\nnpost01 = [clf.fit_transform(npost01[s]) for s in range(2)]\r\n\r\nnpost02 = [0, 0]\r\nnpost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]\r\nnpost02 = [clf.fit_transform(npost02[s]) for s in range(2)]\r\n\r\nnpost13 = [0, 0]\r\nnpost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]\r\nnpost13 = [clf.fit_transform(npost13[s]) for s in range(2)]\r\n\r\nnpost23 = [0, 0]\r\nnpost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]\r\nnpost23 = [clf.fit_transform(npost23[s]) for s in range(2)]\r\n\r\n\r\nfig = pl.figure(figsize=(10, 10))\r\n\r\nax1 = pl.subplot2grid((4, 4), (0, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')\r\n\r\nax2 = pl.subplot2grid((4, 4), (0, 1))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')\r\n\r\nax3 = pl.subplot2grid((4, 4), (0, 2))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')\r\n\r\nax4 = pl.subplot2grid((4, 4), (0, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')\r\n\r\nax5 = pl.subplot2grid((4, 4), (1, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')\r\n\r\nax6 = pl.subplot2grid((4, 4), (1, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')\r\n\r\nax7 = pl.subplot2grid((4, 4), (2, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')\r\n\r\nax8 = pl.subplot2grid((4, 4), (2, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')\r\n\r\nax9 = pl.subplot2grid((4, 4), (3, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')\r\n\r\nax10 = pl.subplot2grid((4, 4), (3, 1))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')\r\n\r\nax11 = pl.subplot2grid((4, 4), (3, 2))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')\r\n\r\nax12 = pl.subplot2grid((4, 4), (3, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')"
],
"outputs": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"name": "python2",
"language": "python"
},
"language_info": {
"mimetype": "text/x-python",
"nbconvert_exporter": "python",
"name": "python",
"file_extension": ".py",
"version": "2.7.12",
"pygments_lexer": "ipython2",
"codemirror_mode": {
"version": 2,
"name": "ipython"
}
}
}
}
|