# -*- coding: utf-8 -*- """ ===================================== Gromov-Wasserstein Barycenter example ===================================== This example is designed to show how to use the Gromov-Wassertsein distance computation in POT. """ # Author: Erwan Vautier # Nicolas Courty # # License: MIT License import numpy as np import scipy as sp import scipy.ndimage as spi import matplotlib.pylab as pl from sklearn import manifold from sklearn.decomposition import PCA import ot """ Smacof MDS ========== This function allows to find an embedding of points given a dissimilarity matrix that will be given by the output of the algorithm """ def smacof_mds(C, dim, max_iter=3000, eps=1e-9): """ Returns an interpolated point cloud following the dissimilarity matrix C using SMACOF multidimensional scaling (MDS) in specific dimensionned target space Parameters ---------- C : ndarray, shape (ns, ns) dissimilarity matrix dim : int dimension of the targeted space max_iter : int Maximum number of iterations of the SMACOF algorithm for a single run eps : relative tolerance w.r.t stress to declare converge Returns ------- npos : R**dim ndarray Embedded coordinates of the interpolated point cloud (defined with one isometry) """ seed = np.random.RandomState(seed=3) mds = manifold.MDS( dim, max_iter=max_iter, eps=1e-9, dissimilarity='precomputed', n_init=1) pos = mds.fit(C).embedding_ nmds = manifold.MDS( 2, max_iter=max_iter, eps=1e-9, dissimilarity="precomputed", random_state=seed, n_init=1) npos = nmds.fit_transform(C, init=pos) return npos """ Data preparation ================ The four distributions are constructed from 4 simple images """ def im2mat(I): """Converts and image to matrix (one pixel per line)""" return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) carre = spi.imread('../data/carre.png').astype(np.float64) / 256 rond = spi.imread('../data/rond.png').astype(np.float64) / 256 triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256 fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256 shapes = [carre, rond, triangle, fleche] S = 4 xs = [[] for i in range(S)] for nb in range(4): for i in range(8): for j in range(8): if shapes[nb][i, j] < 0.95: xs[nb].append([j, 8 - i]) xs = np.array([np.array(xs[0]), np.array(xs[1]), np.array(xs[2]), np.array(xs[3])]) """ Barycenter computation ====================== The four distributions are constructed from 4 simple images """ ns = [len(xs[s]) for s in range(S)] N = 30 """Compute all distances matrices for the four shapes""" Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)] Cs = [cs / cs.max() for cs in Cs] ps = [ot.unif(ns[s]) for s in range(S)] p = ot.unif(N) lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]] Ct01 = [0 for i in range(2)] for i in range(2): Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [ ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3) Ct02 = [0 for i in range(2)] for i in range(2): Ct02[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[2]], [ ps[0], ps[2]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3) Ct13 = [0 for i in range(2)] for i in range(2): Ct13[i] = ot.gromov.gromov_barycenters(N, [Cs[1], Cs[3]], [ ps[1], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3) Ct23 = [0 for i in range(2)] for i in range(2): Ct23[i] = ot.gromov.gromov_barycenters(N, [Cs[2], Cs[3]], [ ps[2], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3) """ Visualization ============= """ """The PCA helps in getting consistency between the rotations""" clf = PCA(n_components=2) npos = [0, 0, 0, 0] npos = [smacof_mds(Cs[s], 2) for s in range(S)] npost01 = [0, 0] npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)] npost01 = [clf.fit_transform(npost01[s]) for s in range(2)] npost02 = [0, 0] npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)] npost02 = [clf.fit_transform(npost02[s]) for s in range(2)] npost13 = [0, 0] npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)] npost13 = [clf.fit_transform(npost13[s]) for s in range(2)] npost23 = [0, 0] npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)] npost23 = [clf.fit_transform(npost23[s]) for s in range(2)] fig = pl.figure(figsize=(10, 10)) ax1 = pl.subplot2grid((4, 4), (0, 0)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r') ax2 = pl.subplot2grid((4, 4), (0, 1)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b') ax3 = pl.subplot2grid((4, 4), (0, 2)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b') ax4 = pl.subplot2grid((4, 4), (0, 3)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r') ax5 = pl.subplot2grid((4, 4), (1, 0)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b') ax6 = pl.subplot2grid((4, 4), (1, 3)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b') ax7 = pl.subplot2grid((4, 4), (2, 0)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b') ax8 = pl.subplot2grid((4, 4), (2, 3)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b') ax9 = pl.subplot2grid((4, 4), (3, 0)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r') ax10 = pl.subplot2grid((4, 4), (3, 1)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b') ax11 = pl.subplot2grid((4, 4), (3, 2)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b') ax12 = pl.subplot2grid((4, 4), (3, 3)) pl.xlim((-1, 1)) pl.ylim((-1, 1)) ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')