diff options
author | Gard Spreemann <gspr@nonempty.org> | 2022-04-27 11:49:23 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2022-04-27 11:49:23 +0200 |
commit | 35bd2c98b642df78638d7d733bc1a89d873db1de (patch) | |
tree | 6bc637624004713808d3097b95acdccbb9608e52 /examples | |
parent | c4753bd3f74139af8380127b66b484bc09b50661 (diff) | |
parent | eccb1386eea52b94b82456d126bd20cbe3198e05 (diff) |
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'examples')
23 files changed, 1387 insertions, 98 deletions
diff --git a/examples/backends/plot_dual_ot_pytorch.py b/examples/backends/plot_dual_ot_pytorch.py new file mode 100644 index 0000000..d3f7a66 --- /dev/null +++ b/examples/backends/plot_dual_ot_pytorch.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +r""" +====================================================================== +Dual OT solvers for entropic and quadratic regularized OT with Pytorch +====================================================================== + + +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import torch +import ot +import ot.plot + +# %% +# Data generation +# --------------- + +torch.manual_seed(1) + +n_source_samples = 100 +n_target_samples = 100 +theta = 2 * np.pi / 20 +noise_level = 0.1 + +Xs, ys = ot.datasets.make_data_classif( + 'gaussrot', n_source_samples, nz=noise_level) +Xt, yt = ot.datasets.make_data_classif( + 'gaussrot', n_target_samples, theta=theta, nz=noise_level) + +# one of the target mode changes its variance (no linear mapping) +Xt[yt == 2] *= 3 +Xt = Xt + 4 + + +# %% +# Plot data +# --------- + +pl.figure(1, (10, 5)) +pl.clf() +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples') +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +# %% +# Convert data to torch tensors +# ----------------------------- + +xs = torch.tensor(Xs) +xt = torch.tensor(Xt) + +# %% +# Estimating dual variables for entropic OT +# ----------------------------------------- + +u = torch.randn(n_source_samples, requires_grad=True) +v = torch.randn(n_source_samples, requires_grad=True) + +reg = 0.5 + +optimizer = torch.optim.Adam([u, v], lr=1) + +# number of iteration +n_iter = 200 + + +losses = [] + +for i in range(n_iter): + + # generate noise samples + + # minus because we maximize te dual loss + loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + +pl.figure(2) +pl.plot(losses) +pl.grid() +pl.title('Dual objective (negative)') +pl.xlabel("Iterations") + +Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg) + +# %% +# Plot teh estimated entropic OT plan +# ----------------------------------- + +pl.figure(3, (10, 5)) +pl.clf() +ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1) +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) +pl.legend(loc=0) +pl.title('Source and target distributions') + + +# %% +# Estimating dual variables for quadratic OT +# ----------------------------------------- + +u = torch.randn(n_source_samples, requires_grad=True) +v = torch.randn(n_source_samples, requires_grad=True) + +reg = 0.01 + +optimizer = torch.optim.Adam([u, v], lr=1) + +# number of iteration +n_iter = 200 + + +losses = [] + + +for i in range(n_iter): + + # generate noise samples + + # minus because we maximize te dual loss + loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + +pl.figure(4) +pl.plot(losses) +pl.grid() +pl.title('Dual objective (negative)') +pl.xlabel("Iterations") + +Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg) + + +# %% +# Plot the estimated quadratic OT plan +# ----------------------------------- + +pl.figure(5, (10, 5)) +pl.clf() +ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1) +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) +pl.legend(loc=0) +pl.title('OT plan with quadratic regularization') diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py index 05b9952..cf5d64d 100644 --- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -27,6 +27,8 @@ Machine Learning (pp. 4104-4113). PMLR. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + # %% # Loading the data diff --git a/examples/backends/plot_stoch_continuous_ot_pytorch.py b/examples/backends/plot_stoch_continuous_ot_pytorch.py new file mode 100644 index 0000000..6d9b916 --- /dev/null +++ b/examples/backends/plot_stoch_continuous_ot_pytorch.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +r""" +====================================================================== +Continuous OT plan estimation with Pytorch +====================================================================== + + +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import torch +from torch import nn +import ot +import ot.plot + +# %% +# Data generation +# --------------- + +torch.manual_seed(42) +np.random.seed(42) + +n_source_samples = 10000 +n_target_samples = 10000 +theta = 2 * np.pi / 20 +noise_level = 0.1 + +Xs = np.random.randn(n_source_samples, 2) * 0.5 +Xt = np.random.randn(n_target_samples, 2) * 2 + +# one of the target mode changes its variance (no linear mapping) +Xt = Xt + 4 + + +# %% +# Plot data +# --------- +nvisu = 300 +pl.figure(1, (5, 5)) +pl.clf() +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', label='Source samples', alpha=0.5) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', label='Target samples', alpha=0.5) +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Source and target distributions') + +# %% +# Convert data to torch tensors +# ----------------------------- + +xs = torch.tensor(Xs) +xt = torch.tensor(Xt) + +# %% +# Estimating deep dual variables for entropic OT +# ---------------------------------------------- + +torch.manual_seed(42) + +# define the MLP model + + +class Potential(torch.nn.Module): + def __init__(self): + super(Potential, self).__init__() + self.fc1 = nn.Linear(2, 200) + self.fc2 = nn.Linear(200, 1) + self.relu = torch.nn.ReLU() # instead of Heaviside step fn + + def forward(self, x): + output = self.fc1(x) + output = self.relu(output) # instead of Heaviside step fn + output = self.fc2(output) + return output.ravel() + + +u = Potential().double() +v = Potential().double() + +reg = 1 + +optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005) + +# number of iteration +n_iter = 1000 +n_batch = 500 + + +losses = [] + +for i in range(n_iter): + + # generate noise samples + + iperms = torch.randint(0, n_source_samples, (n_batch,)) + ipermt = torch.randint(0, n_target_samples, (n_batch,)) + + xsi = xs[iperms] + xti = xt[ipermt] + + # minus because we maximize te dual loss + loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + +pl.figure(2) +pl.plot(losses) +pl.grid() +pl.title('Dual objective (negative)') +pl.xlabel("Iterations") + + +# %% +# Plot the density on arget for a given source sample +# --------------------------------------------------- + + +nv = 100 +xl = np.linspace(ax_bounds[0], ax_bounds[1], nv) +yl = np.linspace(ax_bounds[2], ax_bounds[3], nv) + +XX, YY = np.meshgrid(xl, yl) + +xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1) + +wxg = np.exp(-((xg[:, 0] - 4)**2 + (xg[:, 1] - 4)**2) / (2 * 2)) +wxg = wxg / np.sum(wxg) + +xg = torch.tensor(xg) +wxg = torch.tensor(wxg) + + +pl.figure(4, (12, 4)) +pl.clf() +pl.subplot(1, 3, 1) + +iv = 2 +Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = Gg.reshape((nv, nv)).detach().numpy() + +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) +pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') +pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample') +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Density of transported source sample') + +pl.subplot(1, 3, 2) + +iv = 3 +Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = Gg.reshape((nv, nv)).detach().numpy() + +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) +pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') +pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample') +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Density of transported source sample') + +pl.subplot(1, 3, 3) + +iv = 6 +Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = Gg.reshape((nv, nv)).detach().numpy() + +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) +pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') +pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample') +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Density of transported source sample') diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py index 0abdd6d..cd8e2fd 100644 --- a/examples/backends/plot_wass1d_torch.py +++ b/examples/backends/plot_wass1d_torch.py @@ -1,9 +1,9 @@ r""" -================================= -Wasserstein 1D with PyTorch -================================= +================================================= +Wasserstein 1D (flow and barycenter) with PyTorch +================================================= -In this small example, we consider the following minization problem: +In this small example, we consider the following minimization problem: .. math:: \mu^* = \min_\mu W(\mu,\nu) diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py index 2d68a39..226dfeb 100644 --- a/examples/barycenters/plot_free_support_barycenter.py +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -9,61 +9,62 @@ sum of diracs. """ -# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp> +# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp> +# Rémi Flamary <remi.flamary@polytechnique.edu> # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import numpy as np import matplotlib.pylab as pl import ot -############################################################################## +# %% # Generate data # ------------- -N = 3 +N = 2 d = 2 -measures_locations = [] -measures_weights = [] - -for i in range(N): - n_i = np.random.randint(low=1, high=20) # nb samples +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] +I2 = pl.imread('../../data/duck.png').astype(np.float64)[::4, ::4, 2] - mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean +sz = I2.shape[0] +XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) - A_i = np.random.rand(d, d) - cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix +x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0 +x2 = np.stack((XX[I2 == 0] + 80, -YY[I2 == 0] + 32), 1) * 1.0 +x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0 - x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations - b_i = np.random.uniform(0., 1., (n_i,)) - b_i = b_i / np.sum(b_i) # Dirac weights +measures_locations = [x1, x2] +measures_weights = [ot.unif(x1.shape[0]), ot.unif(x2.shape[0])] - measures_locations.append(x_i) - measures_weights.append(b_i) +pl.figure(1, (12, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) +pl.title('Distributions') -############################################################################## +# %% # Compute free support barycenter # ------------------------------- -k = 10 # number of Diracs of the barycenter +k = 200 # number of Diracs of the barycenter X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) - -############################################################################## -# Plot data +# %% +# Plot the barycenter # --------- -pl.figure(1) -for (x_i, b_i) in zip(measures_locations, measures_weights): - color = np.random.randint(low=1, high=10 * N) - pl.scatter(x_i[:, 0], x_i[:, 1], s=b_i * 1000, label='input measure') -pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter') +pl.figure(2, (8, 3)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter') pl.title('Data measures and their barycenter') -pl.legend(loc=0) +pl.legend(loc="lower right") pl.show() diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py new file mode 100755 index 0000000..1fdc3b9 --- /dev/null +++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- + +r""" +================================= +(Fused) Gromov-Wasserstein Linear Dictionary Learning +================================= + +In this exemple, we illustrate how to learn a Gromov-Wasserstein dictionary on +a dataset of structured data such as graphs, denoted +:math:`\{ \mathbf{C_s} \}_{s \in [S]}` where every nodes have uniform weights. +Given a dictionary :math:`\mathbf{C_{dict}}` composed of D structures of a fixed +size nt, each graph :math:`(\mathbf{C_s}, \mathbf{p_s})` +is modeled as a convex combination :math:`\mathbf{w_s} \in \Sigma_D` of these +dictionary atoms as :math:`\sum_d w_{s,d} \mathbf{C_{dict}[d]}`. + + +First, we consider a dataset composed of graphs generated by Stochastic Block models +with variable sizes taken in :math:`\{30, ... , 50\}` and quantities of clusters +varying in :math:`\{ 1, 2, 3\}`. We learn a dictionary of 3 atoms, by minimizing +the Gromov-Wasserstein distance from all samples to its model in the dictionary +with respect to the dictionary atoms. + +Second, we illustrate the extension of this dictionary learning framework to +structured data endowed with node features by using the Fused Gromov-Wasserstein +distance. Starting from the aforementioned dataset of unattributed graphs, we +add discrete labels uniformly depending on the number of clusters. Then we learn +and visualize attributed graph atoms where each sample is modeled as a joint convex +combination between atom structures and features. + + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph +Dictionary Learning, International Conference on Machine Learning (ICML), 2021. + +""" +# Author: Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +from sklearn.manifold import MDS +from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning +import ot +import networkx +from networkx.generators.community import stochastic_block_model as sbm +# %% +# ============================================================================= +# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters. +# ============================================================================= + +np.random.seed(42) + +N = 60 # number of graphs in the dataset +# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability. +clusters = [1, 2, 3] +Nc = N // len(clusters) # number of graphs by cluster +nlabels = len(clusters) +dataset = [] +labels = [] + +p_inter = 0.1 +p_intra = 0.9 +for n_cluster in clusters: + for i in range(Nc): + n_nodes = int(np.random.uniform(low=30, high=50)) + + if n_cluster > 1: + P = p_inter * np.ones((n_cluster, n_cluster)) + np.fill_diagonal(P, p_intra) + else: + P = p_intra * np.eye(1) + sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32) + G = sbm(sizes, P, seed=i, directed=False) + C = networkx.to_numpy_array(G) + dataset.append(C) + labels.append(n_cluster) + + +# Visualize samples + +def plot_graph(x, C, binary=True, color='C0', s=None): + for j in range(C.shape[0]): + for i in range(j): + if binary: + if C[i, j] > 0: + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + else: # connection intensity proportional to C[i,j] + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k') + + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + + +pl.figure(1, (12, 8)) +pl.clf() +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color='C0', s=50.) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Estimate the gromov-wasserstein dictionary from the dataset +# ============================================================================= + + +np.random.seed(0) +ps = [ot.unif(C.shape[0]) for C in dataset] + +D = 3 # 3 atoms in the dictionary +nt = 6 # of 6 nodes each + +q = ot.unif(nt) +reg = 0. # regularization coefficient to promote sparsity of unmixings {w_s} + +Cdict_GW, log = gromov_wasserstein_dictionary_learning( + Cs=dataset, D=D, nt=nt, ps=ps, q=q, epochs=10, batch_size=16, + learning_rate=0.1, reg=reg, projection='nonnegative_symmetric', + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, + use_log=True, use_adam_optimizer=True, verbose=True +) +# visualize loss evolution over epochs +pl.figure(2, (4, 3)) +pl.clf() +pl.title('loss evolution by epoch', fontsize=14) +pl.plot(log['loss_epochs']) +pl.xlabel('epochs', fontsize=12) +pl.ylabel('loss', fontsize=12) +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the estimated dictionary atoms +# ============================================================================= + + +# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white) + +pl.figure(3, (12, 8)) +pl.clf() +for idx_atom, atom in enumerate(Cdict_GW): + scaled_atom = (atom - atom.min()) / (atom.max() - atom.min()) + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) + pl.subplot(2, D, idx_atom + 1) + pl.title('(graph) atom ' + str(idx_atom + 1), fontsize=14) + plot_graph(x, atom / atom.max(), binary=False, color='C0', s=100.) + pl.axis("off") + pl.subplot(2, D, D + idx_atom + 1) + pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) + pl.imshow(scaled_atom, interpolation='nearest') + pl.colorbar() + pl.axis("off") +pl.tight_layout() +pl.show() +#%% +# ============================================================================= +# Visualization of the embedding space +# ============================================================================= + +unmixings = [] +reconstruction_errors = [] +for C in dataset: + p = ot.unif(C.shape[0]) + unmixing, Cembedded, OT, reconstruction_error = gromov_wasserstein_linear_unmixing( + C, Cdict_GW, p=p, q=q, reg=reg, + tol_outer=10**(-5), tol_inner=10**(-5), + max_iter_outer=30, max_iter_inner=300 + ) + unmixings.append(unmixing) + reconstruction_errors.append(reconstruction_error) +unmixings = np.array(unmixings) +print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) + + +# Compute the 2D representation of the unmixing living in the 2-simplex of probability +unmixings2D = np.zeros(shape=(N, 2)) +for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(4, (4, 4)) +pl.clf() +pl.title('Embedding space', fontsize=14) +for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') + else: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) +pl.axis('off') +pl.legend(fontsize=11) +pl.tight_layout() +pl.show() +# %% +# ============================================================================= +# Endow the dataset with node features +# ============================================================================= + +# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters +# 1 cluster --> 0 as nodes feature +# 2 clusters --> 1 as nodes feature +# 3 clusters --> 2 as nodes feature +# features are one-hot encoded following these assignments +dataset_features = [] +for i in range(len(dataset)): + n = dataset[i].shape[0] + F = np.zeros((n, 3)) + if i < Nc: # graph with 1 cluster + F[:, 0] = 1. + elif i < 2 * Nc: # graph with 2 clusters + F[:, 1] = 1. + else: # graph with 3 clusters + F[:, 2] = 1. + dataset_features.append(F) + +pl.figure(5, (12, 8)) +pl.clf() +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + F = dataset_features[(c - 1) * Nc] + colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])] + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color=colors, s=50) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() +# %% +# ============================================================================= +# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs +# ============================================================================= +np.random.seed(0) +ps = [ot.unif(C.shape[0]) for C in dataset] +D = 3 # 6 atoms instead of 3 +nt = 6 +q = ot.unif(nt) +reg = 0.001 +alpha = 0.5 # trade-off parameter between structure and feature information of Fused Gromov-Wasserstein + + +Cdict_FGW, Ydict_FGW, log = fused_gromov_wasserstein_dictionary_learning( + Cs=dataset, Ys=dataset_features, D=D, nt=nt, ps=ps, q=q, alpha=alpha, + epochs=10, batch_size=16, learning_rate_C=0.1, learning_rate_Y=0.1, reg=reg, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, + projection='nonnegative_symmetric', use_log=True, use_adam_optimizer=True, verbose=True +) +# visualize loss evolution +pl.figure(6, (4, 3)) +pl.clf() +pl.title('loss evolution by epoch', fontsize=14) +pl.plot(log['loss_epochs']) +pl.xlabel('epochs', fontsize=12) +pl.ylabel('loss', fontsize=12) +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the estimated dictionary atoms +# ============================================================================= + +pl.figure(7, (12, 8)) +pl.clf() +max_features = Ydict_FGW.max() +min_features = Ydict_FGW.min() + +for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)): + scaled_atom = (Catom - Catom.min()) / (Catom.max() - Catom.min()) + #scaled_F = 2 * (Fatom - min_features) / (max_features - min_features) + colors = ['C%s' % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])] + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) + pl.subplot(2, D, idx_atom + 1) + pl.title('(attributed graph) atom ' + str(idx_atom + 1), fontsize=14) + plot_graph(x, Catom / Catom.max(), binary=False, color=colors, s=100) + pl.axis("off") + pl.subplot(2, D, D + idx_atom + 1) + pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) + pl.imshow(scaled_atom, interpolation='nearest') + pl.colorbar() + pl.axis("off") +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the embedding space +# ============================================================================= + +unmixings = [] +reconstruction_errors = [] +for i in range(len(dataset)): + C = dataset[i] + Y = dataset_features[i] + p = ot.unif(C.shape[0]) + unmixing, Cembedded, Yembedded, OT, reconstruction_error = fused_gromov_wasserstein_linear_unmixing( + C, Y, Cdict_FGW, Ydict_FGW, p=p, q=q, alpha=alpha, + reg=reg, tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=30, max_iter_inner=300 + ) + unmixings.append(unmixing) + reconstruction_errors.append(reconstruction_error) +unmixings = np.array(unmixings) +print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) + +# Visualize unmixings in the 2-simplex of probability +unmixings2D = np.zeros(shape=(N, 2)) +for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(8, (4, 4)) +pl.clf() +pl.title('Embedding space', fontsize=14) +for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') + else: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) + +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) +pl.axis('off') +pl.legend(fontsize=11) +pl.tight_layout() +pl.show() diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py new file mode 100644 index 0000000..a29c875 --- /dev/null +++ b/examples/others/plot_WeakOT_VS_OT.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Weak Optimal Transport VS exact Optimal Transport +==================================================== + +Illustration of 2D optimal transport between distributions that are weighted +sum of diracs. The OT matrix is plotted with the samples. + +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +############################################################################## +# Generate data an plot it +# ------------------------ + +#%% parameters and data generation + +n = 50 # nb samples + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +pl.figure(2) +pl.imshow(M, interpolation='nearest') +pl.title('Cost matrix M') + + +############################################################################## +# Compute Weak OT and exact OT solutions +# -------------------------------------- + +#%% EMD + +G0 = ot.emd(a, b, M) + +#%% Weak OT + +Gweak = ot.weak_optimal_transport(xs, xt, a, b) + + +############################################################################## +# Plot weak OT and exact OT solutions +# -------------------------------------- + +pl.figure(3, (8, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(G0, interpolation='nearest') +pl.title('OT matrix') + +pl.subplot(1, 2, 2) +pl.imshow(Gweak, interpolation='nearest') +pl.title('Weak OT matrix') + +pl.figure(4, (8, 5)) + +pl.subplot(1, 2, 1) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('OT matrix with samples') + +pl.subplot(1, 2, 2) +ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Weak OT matrix with samples') diff --git a/examples/others/plot_factored_coupling.py b/examples/others/plot_factored_coupling.py new file mode 100644 index 0000000..b5b1c9f --- /dev/null +++ b/examples/others/plot_factored_coupling.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +========================================== +Optimal transport with factored couplings +========================================== + +Illustration of the factored coupling OT between 2D empirical distributions + +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +# %% +# Generate data an plot it +# ------------------------ + +# parameters and data generation + +np.random.seed(42) + +n = 100 # nb samples + +xs = np.random.rand(n, 2) - .5 + +xs = xs + np.sign(xs) + +xt = np.random.rand(n, 2) - .5 + +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + + +# %% +# Compute Factore OT and exact OT solutions +# -------------------------------------- + +#%% EMD +M = ot.dist(xs, xt) +G0 = ot.emd(a, b, M) + +#%% factored OT OT + +Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4) + + +# %% +# Plot factored OT and exact OT solutions +# -------------------------------------- + +pl.figure(2, (14, 4)) + +pl.subplot(1, 3, 1) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Exact OT with samples') + +pl.subplot(1, 3, 2) +ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5) +ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples') +pl.title('Factored OT with template samples') + +pl.subplot(1, 3, 3) +ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Factored OT low rank OT plan') diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py new file mode 100644 index 0000000..bb4f640 --- /dev/null +++ b/examples/others/plot_logo.py @@ -0,0 +1,112 @@ + +# -*- coding: utf-8 -*- +r""" +======================= +Logo of the POT toolbox +======================= + +In this example we plot the logo of the POT toolbox. + +This logo is that it is done 100% in Python and generated using +matplotlib and ploting teh solution of the EMD solver from POT. + +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +# %% Load modules +import numpy as np +import matplotlib.pyplot as pl +import ot + +# %% +# Data for logo +# ------------- + + +# Letter P +p1 = np.array([[0, 6.], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], ]) +p2 = np.array([[1.5, 6], [2, 4], [2, 5], [1.5, 3], [0.5, 2], [.5, 1], ]) + +# Letter O +o1 = np.array([[0, 6.], [-1, 5], [-1.5, 4], [-1.5, 3], [-1, 2], [0, 1], ]) +o2 = np.array([[1, 6.], [2, 5], [2.5, 4], [2.5, 3], [2, 2], [1, 1], ]) + +# Scaling and translation for letter O +o1[:, 0] += 6.4 +o2[:, 0] += 6.4 +o1[:, 0] *= 0.6 +o2[:, 0] *= 0.6 + +# Letter T +t1 = np.array([[-1, 6.], [-1, 5], [0, 4], [0, 3], [0, 2], [0, 1], ]) +t2 = np.array([[1.5, 6.], [1.5, 5], [0.5, 4], [0.5, 3], [0.5, 2], [0.5, 1], ]) + +# Translating the T +t1[:, 0] += 7.1 +t2[:, 0] += 7.1 + +# Concatenate all letters +x1 = np.concatenate((p1, o1, t1), axis=0) +x2 = np.concatenate((p2, o2, t2), axis=0) + +# Horizontal and vertical scaling +sx = 1.0 +sy = .5 +x1[:, 0] *= sx +x1[:, 1] *= sy +x2[:, 0] *= sx +x2[:, 1] *= sy + +# %% +# Plot the logo (clear background) +# -------------------------------- + +# Solve OT problem between the points +M = ot.dist(x1, x2, metric='euclidean') +T = ot.emd([], [], M) + +pl.figure(1, (3.5, 1.1)) +pl.clf() +# plot the OT plan +for i in range(M.shape[0]): + for j in range(M.shape[1]): + if T[i, j] > 1e-8: + pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='k', alpha=0.6, linewidth=3, zorder=1) +# plot the samples +pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='C3', markeredgecolor='k') +pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='b', markeredgecolor='k') + + +pl.axis('equal') +pl.axis('off') + +# Save logo file +# pl.savefig('logo.svg', dpi=150, transparent=True, bbox_inches='tight') +# pl.savefig('logo.png', dpi=150, transparent=True, bbox_inches='tight') + +# %% +# Plot the logo (dark background) +# -------------------------------- + +pl.figure(2, (3.5, 1.1), facecolor='darkgray') +pl.clf() +# plot the OT plan +for i in range(M.shape[0]): + for j in range(M.shape[1]): + if T[i, j] > 1e-8: + pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='w', alpha=0.8, linewidth=3, zorder=1) +# plot the samples +pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='w', markeredgecolor='w') +pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='w', markeredgecolor='w') + +pl.axis('equal') +pl.axis('off') + +# Save logo file +# pl.savefig('logo_dark.svg', dpi=150, transparent=True, bbox_inches='tight') +# pl.savefig('logo_dark.png', dpi=150, transparent=True, bbox_inches='tight') diff --git a/examples/plot_screenkhorn_1D.py b/examples/others/plot_screenkhorn_1D.py index 785642a..2023649 100644 --- a/examples/plot_screenkhorn_1D.py +++ b/examples/others/plot_screenkhorn_1D.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -=============================== -1D Screened optimal transport -=============================== +======================================== +Screened optimal transport (Screenkhorn) +======================================== This example illustrates the computation of Screenkhorn [26]. diff --git a/examples/plot_stochastic.py b/examples/others/plot_stochastic.py index 3a1ef31..3a1ef31 100644 --- a/examples/plot_stochastic.py +++ b/examples/others/plot_stochastic.py diff --git a/examples/plot_Intro_OT.py b/examples/plot_Intro_OT.py index f282950..219aa51 100644 --- a/examples/plot_Intro_OT.py +++ b/examples/plot_Intro_OT.py @@ -58,7 +58,7 @@ help(ot.dist) # number of Bakeries to Cafés in a City (in this case Manhattan). We did a # quick google map search in Manhattan for bakeries and Cafés: # -# .. image:: images/bak.png +# .. image:: ../_static/images/bak.png # :align: center # :alt: bakery-cafe-manhattan # :width: 600px @@ -233,7 +233,7 @@ print('Wasserstein loss (EMD) = {0:.2f}'.format(W)) # The Sinkhorn algorithm is very simple to code. You can implement it directly # using the following pseudo-code # -# .. image:: images/sinkhorn.png +# .. image:: ../_static/images/sinkhorn.png # :align: center # :alt: Sinkhorn algorithm # :width: 440px diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py index 15ead96..62f0b7d 100644 --- a/examples/plot_OT_1D.py +++ b/examples/plot_OT_1D.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -==================== -1D optimal transport -==================== +====================================== +Optimal Transport for 1D distributions +====================================== This example illustrates the computation of EMD and Sinkhorn transport plans and their visualization. @@ -64,7 +64,11 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') #%% EMD -G0 = ot.emd(a, b, M) +# use fast 1D solver +G0 = ot.emd_1d(x, x, a, b) + +# Equivalent to +# G0 = ot.emd(a, b, M) pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index b07f99f..5415e4f 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -=========================== -1D smooth optimal transport -=========================== +================================ +Smooth optimal transport example +================================ This example illustrates the computation of EMD, Sinkhorn and smooth OT plans and their visualization. diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index af1bc12..1d82fb8 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ ==================================================== -2D Optimal transport between empirical distributions +Optimal Transport between 2D empirical distributions ==================================================== Illustration of 2D optimal transport between discributions that are weighted @@ -42,7 +42,6 @@ a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples # loss matrix M = ot.dist(xs, xt) -M /= M.max() ############################################################################## # Plot data @@ -87,7 +86,7 @@ pl.title('OT matrix with samples') #%% sinkhorn # reg term -lambd = 1e-3 +lambd = 1e-1 Gs = ot.sinkhorn(a, b, M, lambd) @@ -112,7 +111,7 @@ pl.show() #%% sinkhorn # reg term -lambd = 1e-3 +lambd = 1e-1 Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd) diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py index 60353ab..cce51f8 100644 --- a/examples/plot_OT_L1_vs_L2.py +++ b/examples/plot_OT_L1_vs_L2.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """ -========================================== -2D Optimal transport for different metrics -========================================== +================================================ +Optimal Transport with different gournd metrics +================================================ -2D OT on empirical distributio with different gound metric. +2D OT on empirical distributio with different ground metric. Stole the figure idea from Fig. 1 and 2 in https://arxiv.org/pdf/1706.07650.pdf @@ -23,7 +23,7 @@ import matplotlib.pylab as pl import ot import ot.plot -############################################################################## +# %% # Dataset 1 : uniform sampling # ---------------------------- @@ -46,7 +46,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean') M2 /= M2.max() # loss matrix -Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean')) +Mp = ot.dist(xs, xt, metric='cityblock') Mp /= Mp.max() # Data @@ -71,7 +71,7 @@ pl.title('Squared Euclidean cost') pl.subplot(1, 3, 3) pl.imshow(Mp, interpolation='nearest') -pl.title('Sqrt Euclidean cost') +pl.title('L1 (cityblock cost') pl.tight_layout() ############################################################################## @@ -109,22 +109,22 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.axis('equal') # pl.legend(loc=0) -pl.title('OT sqrt Euclidean') +pl.title('OT L1 (cityblock)') pl.tight_layout() pl.show() -############################################################################## +# %% # Dataset 2 : Partial circle # -------------------------- -n = 50 # nb samples +n = 20 # nb samples xtot = np.zeros((n + 1, 2)) xtot[:, 0] = np.cos( - (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi) + (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) xtot[:, 1] = np.sin( - (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi) + (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) xs = xtot[:n, :] xt = xtot[1:, :] @@ -140,7 +140,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean') M2 /= M2.max() # loss matrix -Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean')) +Mp = ot.dist(xs, xt, metric='cityblock') Mp /= Mp.max() @@ -150,7 +150,7 @@ pl.clf() pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.axis('equal') -pl.title('Source and traget distributions') +pl.title('Source and target distributions') # Cost matrices @@ -166,13 +166,13 @@ pl.title('Squared Euclidean cost') pl.subplot(1, 3, 3) pl.imshow(Mp, interpolation='nearest') -pl.title('Sqrt Euclidean cost') +pl.title('L1 (cityblock) cost') pl.tight_layout() ############################################################################## # Dataset 2 : Plot OT Matrices # ----------------------------- - +# #%% EMD G1 = ot.emd(a, b, M1) @@ -204,7 +204,7 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.axis('equal') # pl.legend(loc=0) -pl.title('OT sqrt Euclidean') +pl.title('OT L1 (cityblock)') pl.tight_layout() pl.show() diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index 527a847..36cc7da 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """ -================= -Plot multiple EMD -================= +================== +OT distances in 1D +================== -Shows how to compute multiple EMD and Sinkhorn with two different +Shows how to compute multiple Wassersein and Sinkhorn with two different ground metrics and plot their values for different distributions. @@ -14,7 +14,7 @@ ground metrics and plot their values for different distributions. # # License: MIT License -# sphinx_gallery_thumbnail_number = 3 +# sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl @@ -29,7 +29,7 @@ from ot.datasets import make_1D_gauss as gauss #%% parameters n = 100 # nb bins -n_target = 50 # nb target distributions +n_target = 20 # nb target distributions # bin positions @@ -47,9 +47,9 @@ for i, m in enumerate(lst_m): # loss matrix and normalization M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean') -M /= M.max() +M /= M.max() * 0.1 M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean') -M2 /= M2.max() +M2 /= M2.max() * 0.1 ############################################################################## # Plot data @@ -59,10 +59,12 @@ M2 /= M2.max() pl.figure(1) pl.subplot(2, 1, 1) -pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, a, 'r', label='Source distribution') pl.title('Source distribution') pl.subplot(2, 1, 2) -pl.plot(x, B, label='Target distributions') +for i in range(n_target): + pl.plot(x, B[:, i], 'b', alpha=i / n_target) +pl.plot(x, B[:, -1], 'b', label='Target distributions') pl.title('Target distributions') pl.tight_layout() @@ -73,14 +75,27 @@ pl.tight_layout() #%% Compute and plot distributions and loss matrix -d_emd = ot.emd2(a, B, M) # direct computation of EMD -d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2 - +d_emd = ot.emd2(a, B, M) # direct computation of OT loss +d_emd2 = ot.emd2(a, B, M2) # direct computation of OT loss with metrixc M2 +d_tv = [np.sum(abs(a - B[:, i])) for i in range(n_target)] pl.figure(2) -pl.plot(d_emd, label='Euclidean EMD') -pl.plot(d_emd2, label='Squared Euclidean EMD') -pl.title('EMD distances') +pl.subplot(2, 1, 1) +pl.plot(x, a, 'r', label='Source distribution') +pl.title('Distributions') +for i in range(n_target): + pl.plot(x, B[:, i], 'b', alpha=i / n_target) +pl.plot(x, B[:, -1], 'b', label='Target distributions') +pl.ylim((-.01, 0.13)) +pl.xticks(()) +pl.legend() +pl.subplot(2, 1, 2) +pl.plot(d_emd, label='Euclidean OT') +pl.plot(d_emd2, label='Squared Euclidean OT') +pl.plot(d_tv, label='Total Variation (TV)') +#pl.xlim((-7,23)) +pl.xlabel('Displacement') +pl.title('Divergences') pl.legend() ############################################################################## @@ -88,17 +103,30 @@ pl.legend() # ----------------------------------------- #%% -reg = 1e-2 +reg = 1e-1 d_sinkhorn = ot.sinkhorn2(a, B, M, reg) d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg) -pl.figure(2) +pl.figure(3) pl.clf() -pl.plot(d_emd, label='Euclidean EMD') -pl.plot(d_emd2, label='Squared Euclidean EMD') + +pl.subplot(2, 1, 1) +pl.plot(x, a, 'r', label='Source distribution') +pl.title('Distributions') +for i in range(n_target): + pl.plot(x, B[:, i], 'b', alpha=i / n_target) +pl.plot(x, B[:, -1], 'b', label='Target distributions') +pl.ylim((-.01, 0.13)) +pl.xticks(()) +pl.legend() +pl.subplot(2, 1, 2) +pl.plot(d_emd, label='Euclidean OT') +pl.plot(d_emd2, label='Squared Euclidean OT') pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn') pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn') -pl.title('EMD distances') +pl.plot(d_tv, label='Total Variation (TV)') +#pl.xlim((-7,23)) +pl.xlabel('Displacement') +pl.title('Divergences') pl.legend() - pl.show() diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py index 5eb15bd..7b021d2 100644 --- a/examples/plot_optim_OTreg.py +++ b/examples/plot_optim_OTreg.py @@ -24,7 +24,7 @@ arXiv preprint arXiv:1510.06567. """ -# sphinx_gallery_thumbnail_number = 4 +# sphinx_gallery_thumbnail_number = 5 import numpy as np import matplotlib.pylab as pl @@ -58,7 +58,7 @@ M /= M.max() G0 = ot.emd(a, b, M) -pl.figure(3, figsize=(5, 5)) +pl.figure(1, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') ############################################################################## @@ -80,7 +80,7 @@ reg = 1e-1 Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True) -pl.figure(3) +pl.figure(2) ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg') ############################################################################## @@ -102,7 +102,7 @@ reg = 1e-3 Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True) -pl.figure(4, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg') ############################################################################## @@ -125,6 +125,34 @@ reg2 = 1e-1 Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True) -pl.figure(5, figsize=(5, 5)) +pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg') pl.show() + + +# %% +# Comparison of the OT matrices + +nvisu = 40 + +pl.figure(5, figsize=(10, 4)) + +pl.subplot(2, 2, 1) +pl.imshow(G0[:nvisu, :]) +pl.axis('off') +pl.title('Exact OT') + +pl.subplot(2, 2, 2) +pl.imshow(Gl2[:nvisu, :]) +pl.axis('off') +pl.title('Frobenius reg.') + +pl.subplot(2, 2, 3) +pl.imshow(Ge[:nvisu, :]) +pl.axis('off') +pl.title('Entropic reg.') + +pl.subplot(2, 2, 4) +pl.imshow(Gel2[:nvisu, :]) +pl.axis('off') +pl.title('Entropic + Frobenius reg.') diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt index a575345..73e6122 100644 --- a/examples/sliced-wasserstein/README.txt +++ b/examples/sliced-wasserstein/README.txt @@ -1,4 +1,4 @@ Sliced Wasserstein Distance ----------------------------
\ No newline at end of file +--------------------------- diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 7d73907..f12b522 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -============================== -2D Sliced Wasserstein Distance -============================== +=============================================== +Sliced Wasserstein Distance on 2D distributions +=============================================== This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31]. @@ -16,6 +16,8 @@ measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import matplotlib.pylab as pl import numpy as np diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 183849c..06dd02d 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + import numpy as np import matplotlib.pylab as pl import ot @@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter alpha = 1. # Unbalanced KL relaxation parameter Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) -pl.figure(4, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') pl.show() + + +# %% +# plot the transported mass +# ------------------------- + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source') +pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target') +pl.legend(loc='upper right') +pl.title('Distributions and transported mass for UOT') diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py index 4a51c2d..782e8c2 100644 --- a/examples/unbalanced-partial/plot_regpath.py +++ b/examples/unbalanced-partial/plot_regpath.py @@ -15,11 +15,12 @@ penalized linear regression. # Author: Haoran Wu <haoran.wu@univ-ubs.fr> # License: MIT License +# sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl import ot - +import matplotlib.animation as animation ############################################################################## # Generate data # ------------- @@ -72,6 +73,9 @@ t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma, ############################################################################## # Plot the regularization path # ---------------- +# +# The OT plan is ploted as a function of $\gamma$ that is the inverse of the +# weight on the marginal relaxations. #%% fully relaxed l2-penalized UOT @@ -103,13 +107,53 @@ for p in range(4): pl.show() +# %% +# Animation of the regpath for UOT l2 +# ------------------------ + +nv = 100 +g_list_v = np.logspace(-.5, -2.5, nv) + +pl.figure(3) + + +def _update_plot(iv): + pl.clf() + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list, + t_list) + P = tp.reshape((n, n)) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.5) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), + fontsize=11) + return 1 + + +i = 0 +_update_plot(i) + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) + + ############################################################################## # Plot the semi-relaxed regularization path # ------------------- #%% semi-relaxed l2-penalized UOT -pl.figure(3) +pl.figure(4) selected_gamma = [10, 1, 1e-1, 1e-2] for p in range(4): tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2, @@ -133,3 +177,43 @@ for p in range(4): if p < 2: pl.xticks(()) pl.show() + + +# %% +# Animation of the regpath for semi-relaxed UOT l2 +# ------------------------ + +nv = 100 +g_list_v = np.logspace(2.5, -2, nv) + +pl.figure(5) + + +def _update_plot(iv): + pl.clf() + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2, + t_list2) + P = tp.reshape((n, n)) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.5) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), + fontsize=11) + return 1 + + +i = 0 +_update_plot(i) + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py new file mode 100644 index 0000000..03487e7 --- /dev/null +++ b/examples/unbalanced-partial/plot_unbalanced_OT.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +""" +============================================================== +2D examples of exact and entropic unbalanced optimal transport +============================================================== +This example is designed to show how to compute unbalanced and +partial OT in POT. + +UOT aims at solving the following optimization problem: + + .. math:: + W = \min_{\gamma} <\gamma, \mathbf{M}>_F + + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + +where :math:`\mathrm{div}` is a divergence. +When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}` +should be the Kullback-Leibler divergence. +When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}` +can be either the Kullback-Leibler or the quadratic divergence. +Using :math:`\ell_1` norm gives the so-called partial OT. +""" + +# Author: Laetitia Chapel <laetitia.chapel@univ-ubs.fr> +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 40 # nb samples + +mu_s = np.array([-1, -1]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +np.random.seed(0) +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +n_noise = 10 + +xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0) +xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0) + +n = n + n_noise + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + + +############################################################################## +# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT +# ----------- + +reg = 0.005 +reg_m_kl = 0.05 +reg_m_l2 = 5 +mass = 0.7 + +entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl) +kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl') +l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2') +partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass) + +############################################################################## +# Plot the results +# ---------------- + +pl.figure(2) +transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot] +title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" + + str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl), + "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)] + +for p in range(4): + pl.subplot(2, 4, p + 1) + P = transp[p] + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.3) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2) + pl.title(title[p]) + pl.yticks(()) + pl.xticks(()) + if p < 1: + pl.ylabel("mappings") + pl.subplot(2, 4, p + 5) + pl.imshow(P, cmap='jet') + pl.yticks(()) + pl.xticks(()) + if p < 1: + pl.ylabel("transport plans") +pl.show() |