summaryrefslogtreecommitdiff
path: root/examples/plot_gromov.py
blob: 11e53369d55e0106bd6f1c3f0af1b47cc65443c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-
"""
====================
Gromov-Wasserstein example
====================

This example is designed to show how to use the Gromov-Wassertsein distance 
computation in POT. 


"""

# Author: Erwan Vautier <erwan.vautier@gmail.com>
#         Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License

import scipy as sp
import numpy as np

import ot
import matplotlib.pylab as pl
from mpl_toolkits.mplot3d import Axes3D



"""
Sample two Gaussian distributions (2D and 3D)
====================

The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space. For 
demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces. 

"""
n=30 # nb samples

mu_s=np.array([0,0])
cov_s=np.array([[1,0],[0,1]])

mu_t=np.array([4,4,4])
cov_t=np.array([[1,0,0],[0,1,0],[0,0,1]])



xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)
P=sp.linalg.sqrtm(cov_t)
xt= np.random.randn(n,3).dot(P)+mu_t



"""
Plotting the distributions
====================
"""
fig=pl.figure()
ax1=fig.add_subplot(121)
ax1.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
ax2=fig.add_subplot(122,projection='3d')
ax2.scatter(xt[:,0],xt[:,1],xt[:,2],color='r')
pl.show()


"""
Compute distance kernels, normalize them and then display
====================
"""

C1=sp.spatial.distance.cdist(xs,xs)
C2=sp.spatial.distance.cdist(xt,xt)

C1/=C1.max()
C2/=C2.max()

pl.figure()
pl.subplot(121)
pl.imshow(C1)
pl.subplot(122)
pl.imshow(C2)
pl.show()

"""
Compute Gromov-Wasserstein plans and distance
====================
"""

p=ot.unif(n)
q=ot.unif(n)

gw=ot.gromov_wasserstein(C1,C2,p,q,'square_loss',epsilon=5e-4)
gw_dist=ot.gromov_wasserstein2(C1,C2,p,q,'square_loss',epsilon=5e-4)

print('Gromov-Wasserstein distances between the distribution: '+str(gw_dist))

pl.figure()
pl.imshow(gw,cmap='jet')
pl.colorbar()
pl.show()