summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorNicolas Courty <Nico@pc-mna-08.univ-ubs.fr>2017-08-28 14:41:09 +0200
committerNicolas Courty <Nico@pc-mna-08.univ-ubs.fr>2017-08-28 14:41:09 +0200
commit7ab9037f1e4a08439083d1bc5705be5ed2e9e10a (patch)
treed8d5be164f7c45d67e68b4b4bb545e1b21b70e17 /examples
parent7638d019b43e52d17600cac653939e7cd807478c (diff)
Gromov-Wasserstein distance
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_gromov.py98
1 files changed, 98 insertions, 0 deletions
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
new file mode 100644
index 0000000..11e5336
--- /dev/null
+++ b/examples/plot_gromov.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+"""
+====================
+Gromov-Wasserstein example
+====================
+
+This example is designed to show how to use the Gromov-Wassertsein distance
+computation in POT.
+
+
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import scipy as sp
+import numpy as np
+
+import ot
+import matplotlib.pylab as pl
+from mpl_toolkits.mplot3d import Axes3D
+
+
+
+"""
+Sample two Gaussian distributions (2D and 3D)
+====================
+
+The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space. For
+demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
+
+"""
+n=30 # nb samples
+
+mu_s=np.array([0,0])
+cov_s=np.array([[1,0],[0,1]])
+
+mu_t=np.array([4,4,4])
+cov_t=np.array([[1,0,0],[0,1,0],[0,0,1]])
+
+
+
+xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)
+P=sp.linalg.sqrtm(cov_t)
+xt= np.random.randn(n,3).dot(P)+mu_t
+
+
+
+"""
+Plotting the distributions
+====================
+"""
+fig=pl.figure()
+ax1=fig.add_subplot(121)
+ax1.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
+ax2=fig.add_subplot(122,projection='3d')
+ax2.scatter(xt[:,0],xt[:,1],xt[:,2],color='r')
+pl.show()
+
+
+"""
+Compute distance kernels, normalize them and then display
+====================
+"""
+
+C1=sp.spatial.distance.cdist(xs,xs)
+C2=sp.spatial.distance.cdist(xt,xt)
+
+C1/=C1.max()
+C2/=C2.max()
+
+pl.figure()
+pl.subplot(121)
+pl.imshow(C1)
+pl.subplot(122)
+pl.imshow(C2)
+pl.show()
+
+"""
+Compute Gromov-Wasserstein plans and distance
+====================
+"""
+
+p=ot.unif(n)
+q=ot.unif(n)
+
+gw=ot.gromov_wasserstein(C1,C2,p,q,'square_loss',epsilon=5e-4)
+gw_dist=ot.gromov_wasserstein2(C1,C2,p,q,'square_loss',epsilon=5e-4)
+
+print('Gromov-Wasserstein distances between the distribution: '+str(gw_dist))
+
+pl.figure()
+pl.imshow(gw,cmap='jet')
+pl.colorbar()
+pl.show()
+