summaryrefslogtreecommitdiff
path: root/examples/plot_gromov.py
blob: 99aaf8100ddd2000dddf85615dcce6587238d676 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# -*- coding: utf-8 -*-
"""
==========================
Gromov-Wasserstein example
==========================
This example is designed to show how to use the Gromov-Wassertsein distance
computation in POT.
"""

# Author: Erwan Vautier <erwan.vautier@gmail.com>
#         Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License

import scipy as sp
import numpy as np
import matplotlib.pylab as pl

import ot


"""
Sample two Gaussian distributions (2D and 3D)
=============================================
The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space.
For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
"""

<<<<<<< HEAD
n_samples = 30  # nb samples
=======
n = 30  # nb samples
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4, 4])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])


<<<<<<< HEAD
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
=======
xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n, 3).dot(P) + mu_t
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d


"""
Plotting the distributions
==========================
"""
fig = pl.figure()
ax1 = fig.add_subplot(121)
ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
pl.show()


"""
Compute distance kernels, normalize them and then display
=========================================================
"""

C1 = sp.spatial.distance.cdist(xs, xs)
C2 = sp.spatial.distance.cdist(xt, xt)

C1 /= C1.max()
C2 /= C2.max()

pl.figure()
pl.subplot(121)
pl.imshow(C1)
pl.subplot(122)
pl.imshow(C2)
pl.show()

"""
Compute Gromov-Wasserstein plans and distance
=============================================
"""

<<<<<<< HEAD
p = ot.unif(n_samples)
q = ot.unif(n_samples)
=======
p = ot.unif(n)
q = ot.unif(n)
>>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d

gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)

print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))

pl.figure()
pl.imshow(gw, cmap='jet')
pl.colorbar()
pl.show()