From a54775103541ea37f54269de1ba1e1396a6d7b30 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 24 Apr 2020 17:32:57 +0200 Subject: exmaples in sections --- examples/gromov/plot_gromov_barycenter.py | 247 ++++++++++++++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100755 examples/gromov/plot_gromov_barycenter.py (limited to 'examples/gromov/plot_gromov_barycenter.py') diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py new file mode 100755 index 0000000..6b29687 --- /dev/null +++ b/examples/gromov/plot_gromov_barycenter.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +""" +===================================== +Gromov-Wasserstein Barycenter example +===================================== + +This example is designed to show how to use the Gromov-Wasserstein distance +computation in POT. +""" + +# Author: Erwan Vautier +# Nicolas Courty +# +# License: MIT License + + +import numpy as np +import scipy as sp + +import matplotlib.pylab as pl +from sklearn import manifold +from sklearn.decomposition import PCA + +import ot + +############################################################################## +# Smacof MDS +# ---------- +# +# This function allows to find an embedding of points given a dissimilarity matrix +# that will be given by the output of the algorithm + + +def smacof_mds(C, dim, max_iter=3000, eps=1e-9): + """ + Returns an interpolated point cloud following the dissimilarity matrix C + using SMACOF multidimensional scaling (MDS) in specific dimensionned + target space + + Parameters + ---------- + C : ndarray, shape (ns, ns) + dissimilarity matrix + dim : int + dimension of the targeted space + max_iter : int + Maximum number of iterations of the SMACOF algorithm for a single run + eps : float + relative tolerance w.r.t stress to declare converge + + Returns + ------- + npos : ndarray, shape (R, dim) + Embedded coordinates of the interpolated point cloud (defined with + one isometry) + """ + + rng = np.random.RandomState(seed=3) + + mds = manifold.MDS( + dim, + max_iter=max_iter, + eps=1e-9, + dissimilarity='precomputed', + n_init=1) + pos = mds.fit(C).embedding_ + + nmds = manifold.MDS( + 2, + max_iter=max_iter, + eps=1e-9, + dissimilarity="precomputed", + random_state=rng, + n_init=1) + npos = nmds.fit_transform(C, init=pos) + + return npos + + +############################################################################## +# Data preparation +# ---------------- +# +# The four distributions are constructed from 4 simple images + + +def im2mat(I): + """Converts and image to matrix (one pixel per line)""" + return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) + + +square = pl.imread('../data/square.png').astype(np.float64)[:, :, 2] +cross = pl.imread('../data/cross.png').astype(np.float64)[:, :, 2] +triangle = pl.imread('../data/triangle.png').astype(np.float64)[:, :, 2] +star = pl.imread('../data/star.png').astype(np.float64)[:, :, 2] + +shapes = [square, cross, triangle, star] + +S = 4 +xs = [[] for i in range(S)] + + +for nb in range(4): + for i in range(8): + for j in range(8): + if shapes[nb][i, j] < 0.95: + xs[nb].append([j, 8 - i]) + +xs = np.array([np.array(xs[0]), np.array(xs[1]), + np.array(xs[2]), np.array(xs[3])]) + +############################################################################## +# Barycenter computation +# ---------------------- + + +ns = [len(xs[s]) for s in range(S)] +n_samples = 30 + +"""Compute all distances matrices for the four shapes""" +Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)] +Cs = [cs / cs.max() for cs in Cs] + +ps = [ot.unif(ns[s]) for s in range(S)] +p = ot.unif(n_samples) + + +lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]] + +Ct01 = [0 for i in range(2)] +for i in range(2): + Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]], + [ps[0], ps[1] + ], p, lambdast[i], 'square_loss', # 5e-4, + max_iter=100, tol=1e-3) + +Ct02 = [0 for i in range(2)] +for i in range(2): + Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]], + [ps[0], ps[2] + ], p, lambdast[i], 'square_loss', # 5e-4, + max_iter=100, tol=1e-3) + +Ct13 = [0 for i in range(2)] +for i in range(2): + Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]], + [ps[1], ps[3] + ], p, lambdast[i], 'square_loss', # 5e-4, + max_iter=100, tol=1e-3) + +Ct23 = [0 for i in range(2)] +for i in range(2): + Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]], + [ps[2], ps[3] + ], p, lambdast[i], 'square_loss', # 5e-4, + max_iter=100, tol=1e-3) + + +############################################################################## +# Visualization +# ------------- +# +# The PCA helps in getting consistency between the rotations + + +clf = PCA(n_components=2) +npos = [0, 0, 0, 0] +npos = [smacof_mds(Cs[s], 2) for s in range(S)] + +npost01 = [0, 0] +npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)] +npost01 = [clf.fit_transform(npost01[s]) for s in range(2)] + +npost02 = [0, 0] +npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)] +npost02 = [clf.fit_transform(npost02[s]) for s in range(2)] + +npost13 = [0, 0] +npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)] +npost13 = [clf.fit_transform(npost13[s]) for s in range(2)] + +npost23 = [0, 0] +npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)] +npost23 = [clf.fit_transform(npost23[s]) for s in range(2)] + + +fig = pl.figure(figsize=(10, 10)) + +ax1 = pl.subplot2grid((4, 4), (0, 0)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r') + +ax2 = pl.subplot2grid((4, 4), (0, 1)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b') + +ax3 = pl.subplot2grid((4, 4), (0, 2)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b') + +ax4 = pl.subplot2grid((4, 4), (0, 3)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r') + +ax5 = pl.subplot2grid((4, 4), (1, 0)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b') + +ax6 = pl.subplot2grid((4, 4), (1, 3)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b') + +ax7 = pl.subplot2grid((4, 4), (2, 0)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b') + +ax8 = pl.subplot2grid((4, 4), (2, 3)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b') + +ax9 = pl.subplot2grid((4, 4), (3, 0)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r') + +ax10 = pl.subplot2grid((4, 4), (3, 1)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b') + +ax11 = pl.subplot2grid((4, 4), (3, 2)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b') + +ax12 = pl.subplot2grid((4, 4), (3, 3)) +pl.xlim((-1, 1)) +pl.ylim((-1, 1)) +ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r') -- cgit v1.2.3 From 4bbabc602678a0227bfe9ffae4bbb4caab8a3767 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 24 Apr 2020 17:38:01 +0200 Subject: relative path exmaples --- examples/domain-adaptation/plot_otda_color_images.py | 4 ++-- examples/domain-adaptation/plot_otda_mapping_colors_images.py | 4 ++-- examples/gromov/plot_gromov_barycenter.py | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) (limited to 'examples/gromov/plot_gromov_barycenter.py') diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index 7e0afee..929365e 100644 --- a/examples/domain-adaptation/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -46,8 +46,8 @@ def minmax(I): # ------------- # Loading images -I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 +I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 +I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index 1276714..9d3a7c7 100644 --- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -47,8 +47,8 @@ def minmax(I): # ------------- # Loading images -I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 +I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 +I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 X1 = im2mat(I1) diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py index 6b29687..f6f031a 100755 --- a/examples/gromov/plot_gromov_barycenter.py +++ b/examples/gromov/plot_gromov_barycenter.py @@ -89,10 +89,10 @@ def im2mat(I): return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) -square = pl.imread('../data/square.png').astype(np.float64)[:, :, 2] -cross = pl.imread('../data/cross.png').astype(np.float64)[:, :, 2] -triangle = pl.imread('../data/triangle.png').astype(np.float64)[:, :, 2] -star = pl.imread('../data/star.png').astype(np.float64)[:, :, 2] +square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2] +cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2] +triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2] +star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2] shapes = [square, cross, triangle, star] -- cgit v1.2.3