summaryrefslogtreecommitdiff
path: root/examples/gromov
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2021-11-09 17:05:13 +0100
committerGard Spreemann <gspr@nonempty.org>2021-11-09 17:05:13 +0100
commita9fdc844907decddf54bed3ebeea8d8b2cf0fc5c (patch)
tree449a03fce8fafb78b6badd12b6e633f1e5d73a64 /examples/gromov
parenta16b9471d7114ec08977479b7249efe747702b97 (diff)
parentf1628794d521a8dfa00af383b5e06cd6d34af619 (diff)
Merge tag '0.8.0' into dfsg/latest
Diffstat (limited to 'examples/gromov')
-rw-r--r--examples/gromov/plot_barycenter_fgw.py2
-rw-r--r--examples/gromov/plot_fgw.py10
-rw-r--r--examples/gromov/plot_gromov.py34
-rwxr-xr-xexamples/gromov/plot_gromov_barycenter.py94
4 files changed, 89 insertions, 51 deletions
diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py
index 3f81765..556e08f 100644
--- a/examples/gromov/plot_barycenter_fgw.py
+++ b/examples/gromov/plot_barycenter_fgw.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
=================================
-Plot graphs' barycenter using FGW
+Plot graphs barycenter using FGW
=================================
This example illustrates the computation barycenter of labeled graphs using
diff --git a/examples/gromov/plot_fgw.py b/examples/gromov/plot_fgw.py
index 97fe619..5475fb3 100644
--- a/examples/gromov/plot_fgw.py
+++ b/examples/gromov/plot_fgw.py
@@ -26,7 +26,7 @@ from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
##############################################################################
# Generate data
-# ---------
+# -------------
#%% parameters
# We create two 1D random measures
@@ -76,7 +76,7 @@ pl.show()
##############################################################################
# Create structure matrices and across-feature distance matrix
-# ---------
+# ------------------------------------------------------------
#%% Structure matrices and across-features distance matrix
C1 = ot.dist(xs)
@@ -88,7 +88,7 @@ Got = ot.emd([], [], M)
##############################################################################
# Plot matrices
-# ---------
+# -------------
#%%
cmap = 'Reds'
@@ -131,7 +131,7 @@ pl.show()
##############################################################################
# Compute FGW/GW
-# ---------
+# --------------
#%% Computing FGW and GW
alpha = 1e-3
@@ -145,7 +145,7 @@ Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True,
##############################################################################
# Visualize transport matrices
-# ---------
+# ----------------------------
#%% visu OT matrix
cmap = 'Blues'
diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py
index deb2f86..5a362cf 100644
--- a/examples/gromov/plot_gromov.py
+++ b/examples/gromov/plot_gromov.py
@@ -104,3 +104,37 @@ pl.imshow(gw, cmap='jet')
pl.title('Entropic Gromov Wasserstein')
pl.show()
+
+#############################################################################
+#
+# Compute GW with a scalable stochastic method with any loss function
+# ----------------------------------------------------------------------
+
+
+def loss(x, y):
+ return np.abs(x - y)
+
+
+pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100,
+ log=True)
+
+sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100,
+ log=True)
+
+print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated']))
+print('Variance estimated: ' + str(plog['gw_dist_std']))
+print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated']))
+print('Variance estimated: ' + str(slog['gw_dist_std']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(pgw.toarray(), cmap='jet')
+pl.title('Pointwise Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
+pl.imshow(sgw, cmap='jet')
+pl.title('Sampled Gromov Wasserstein')
+
+pl.show()
diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py
index f6f031a..7fe081f 100755
--- a/examples/gromov/plot_gromov_barycenter.py
+++ b/examples/gromov/plot_gromov_barycenter.py
@@ -13,11 +13,13 @@ computation in POT.
#
# License: MIT License
+import os
+from pathlib import Path
import numpy as np
import scipy as sp
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
from sklearn import manifold
from sklearn.decomposition import PCA
@@ -84,22 +86,24 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
# The four distributions are constructed from 4 simple images
-def im2mat(I):
+def im2mat(img):
"""Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+ return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
-square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2]
-cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2]
-triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2]
-star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2]
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2]
+cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2]
+triangle = plt.imread(os.path.join(data_path, 'triangle.png')).astype(np.float64)[:, :, 2]
+star = plt.imread(os.path.join(data_path, 'star.png')).astype(np.float64)[:, :, 2]
shapes = [square, cross, triangle, star]
S = 4
xs = [[] for i in range(S)]
-
for nb in range(4):
for i in range(8):
for j in range(8):
@@ -184,64 +188,64 @@ npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
-fig = pl.figure(figsize=(10, 10))
+fig = plt.figure(figsize=(10, 10))
-ax1 = pl.subplot2grid((4, 4), (0, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax1 = plt.subplot2grid((4, 4), (0, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
-ax2 = pl.subplot2grid((4, 4), (0, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax2 = plt.subplot2grid((4, 4), (0, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
-ax3 = pl.subplot2grid((4, 4), (0, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax3 = plt.subplot2grid((4, 4), (0, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
-ax4 = pl.subplot2grid((4, 4), (0, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax4 = plt.subplot2grid((4, 4), (0, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
-ax5 = pl.subplot2grid((4, 4), (1, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax5 = plt.subplot2grid((4, 4), (1, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
-ax6 = pl.subplot2grid((4, 4), (1, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax6 = plt.subplot2grid((4, 4), (1, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
-ax7 = pl.subplot2grid((4, 4), (2, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax7 = plt.subplot2grid((4, 4), (2, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
-ax8 = pl.subplot2grid((4, 4), (2, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax8 = plt.subplot2grid((4, 4), (2, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
-ax9 = pl.subplot2grid((4, 4), (3, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax9 = plt.subplot2grid((4, 4), (3, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
-ax10 = pl.subplot2grid((4, 4), (3, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax10 = plt.subplot2grid((4, 4), (3, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
-ax11 = pl.subplot2grid((4, 4), (3, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax11 = plt.subplot2grid((4, 4), (3, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
-ax12 = pl.subplot2grid((4, 4), (3, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax12 = plt.subplot2grid((4, 4), (3, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')