From a5e0f0d40d5046a6639924347ef97e2ac80ad0c9 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 2 Feb 2022 11:53:12 +0100 Subject: [MRG] Add weak OT solver (#341) * add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation --- ot/gromov.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'ot/gromov.py') diff --git a/ot/gromov.py b/ot/gromov.py index 6544260..b7e7949 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F - :math:`\mathbf{q}`: distribution in the target space - `L`: loss function to account for the misfit between the similarity matrices + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Parameters ---------- C1 : array-like, shape (ns, ns) @@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss using the gradients from [38]_. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Parameters ---------- C1 : array-like, shape (ns, ns) @@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - `L` is a loss function to account for the misfit between the similarity matrices + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` Parameters @@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss using the gradients from [38]_. -- cgit v1.2.3 From 50c0f17d00e3492c4d56a356af30cf00d6d07913 Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Fri, 11 Feb 2022 10:53:38 +0100 Subject: [MRG] GW dictionary learning (#319) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add fgw dictionary learning feature * add fgw dictionary learning feature * plot gromov wasserstein dictionary learning * Update __init__.py * fix pep8 errors exact E501 line too long * fix last pep8 issues * add unitary tests for (F)GW dictionary learning without using autodifferentiable functions * correct tests for (F)GW dictionary learning without using autodiff * correct tests for (F)GW dictionary learning without using autodiff * fix docs and notations * answer to review: improve tests, docs, examples + make node weights optional * fix pep8 and examples * improve docs + tests + thumbnail * make example faster * improve ex * update README.md * make GDL tests faster Co-authored-by: Rémi Flamary --- README.md | 2 + RELEASES.md | 2 +- .../plot_gromov_wasserstein_dictionary_learning.py | 357 +++++++ ot/__init__.py | 4 - ot/gromov.py | 1074 +++++++++++++++++++- test/test_gromov.py | 554 +++++++++- 6 files changed, 1954 insertions(+), 39 deletions(-) create mode 100755 examples/gromov/plot_gromov_wasserstein_dictionary_learning.py (limited to 'ot/gromov.py') diff --git a/README.md b/README.md index a7627df..c6bfd9c 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ POT provides the following generic OT solvers (links to examples): * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. +* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. POT provides the following Machine Learning related solvers: @@ -198,6 +199,7 @@ The contributors to this library are * [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) +* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): diff --git a/RELEASES.md b/RELEASES.md index 4d05582..925920a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,7 +10,7 @@ of the regularization parameter (PR #336). - Backend implementation for `ot.lp.free_support_barycenter` (PR #340). - Add weak OT solver + example (PR #341). - +- Add (F)GW linear dictionary learning solvers + example (PR #319) #### Closed issues diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py new file mode 100755 index 0000000..1fdc3b9 --- /dev/null +++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- + +r""" +================================= +(Fused) Gromov-Wasserstein Linear Dictionary Learning +================================= + +In this exemple, we illustrate how to learn a Gromov-Wasserstein dictionary on +a dataset of structured data such as graphs, denoted +:math:`\{ \mathbf{C_s} \}_{s \in [S]}` where every nodes have uniform weights. +Given a dictionary :math:`\mathbf{C_{dict}}` composed of D structures of a fixed +size nt, each graph :math:`(\mathbf{C_s}, \mathbf{p_s})` +is modeled as a convex combination :math:`\mathbf{w_s} \in \Sigma_D` of these +dictionary atoms as :math:`\sum_d w_{s,d} \mathbf{C_{dict}[d]}`. + + +First, we consider a dataset composed of graphs generated by Stochastic Block models +with variable sizes taken in :math:`\{30, ... , 50\}` and quantities of clusters +varying in :math:`\{ 1, 2, 3\}`. We learn a dictionary of 3 atoms, by minimizing +the Gromov-Wasserstein distance from all samples to its model in the dictionary +with respect to the dictionary atoms. + +Second, we illustrate the extension of this dictionary learning framework to +structured data endowed with node features by using the Fused Gromov-Wasserstein +distance. Starting from the aforementioned dataset of unattributed graphs, we +add discrete labels uniformly depending on the number of clusters. Then we learn +and visualize attributed graph atoms where each sample is modeled as a joint convex +combination between atom structures and features. + + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph +Dictionary Learning, International Conference on Machine Learning (ICML), 2021. + +""" +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +from sklearn.manifold import MDS +from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning +import ot +import networkx +from networkx.generators.community import stochastic_block_model as sbm +# %% +# ============================================================================= +# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters. +# ============================================================================= + +np.random.seed(42) + +N = 60 # number of graphs in the dataset +# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability. +clusters = [1, 2, 3] +Nc = N // len(clusters) # number of graphs by cluster +nlabels = len(clusters) +dataset = [] +labels = [] + +p_inter = 0.1 +p_intra = 0.9 +for n_cluster in clusters: + for i in range(Nc): + n_nodes = int(np.random.uniform(low=30, high=50)) + + if n_cluster > 1: + P = p_inter * np.ones((n_cluster, n_cluster)) + np.fill_diagonal(P, p_intra) + else: + P = p_intra * np.eye(1) + sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32) + G = sbm(sizes, P, seed=i, directed=False) + C = networkx.to_numpy_array(G) + dataset.append(C) + labels.append(n_cluster) + + +# Visualize samples + +def plot_graph(x, C, binary=True, color='C0', s=None): + for j in range(C.shape[0]): + for i in range(j): + if binary: + if C[i, j] > 0: + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + else: # connection intensity proportional to C[i,j] + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k') + + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + + +pl.figure(1, (12, 8)) +pl.clf() +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color='C0', s=50.) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Estimate the gromov-wasserstein dictionary from the dataset +# ============================================================================= + + +np.random.seed(0) +ps = [ot.unif(C.shape[0]) for C in dataset] + +D = 3 # 3 atoms in the dictionary +nt = 6 # of 6 nodes each + +q = ot.unif(nt) +reg = 0. # regularization coefficient to promote sparsity of unmixings {w_s} + +Cdict_GW, log = gromov_wasserstein_dictionary_learning( + Cs=dataset, D=D, nt=nt, ps=ps, q=q, epochs=10, batch_size=16, + learning_rate=0.1, reg=reg, projection='nonnegative_symmetric', + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, + use_log=True, use_adam_optimizer=True, verbose=True +) +# visualize loss evolution over epochs +pl.figure(2, (4, 3)) +pl.clf() +pl.title('loss evolution by epoch', fontsize=14) +pl.plot(log['loss_epochs']) +pl.xlabel('epochs', fontsize=12) +pl.ylabel('loss', fontsize=12) +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the estimated dictionary atoms +# ============================================================================= + + +# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white) + +pl.figure(3, (12, 8)) +pl.clf() +for idx_atom, atom in enumerate(Cdict_GW): + scaled_atom = (atom - atom.min()) / (atom.max() - atom.min()) + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) + pl.subplot(2, D, idx_atom + 1) + pl.title('(graph) atom ' + str(idx_atom + 1), fontsize=14) + plot_graph(x, atom / atom.max(), binary=False, color='C0', s=100.) + pl.axis("off") + pl.subplot(2, D, D + idx_atom + 1) + pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) + pl.imshow(scaled_atom, interpolation='nearest') + pl.colorbar() + pl.axis("off") +pl.tight_layout() +pl.show() +#%% +# ============================================================================= +# Visualization of the embedding space +# ============================================================================= + +unmixings = [] +reconstruction_errors = [] +for C in dataset: + p = ot.unif(C.shape[0]) + unmixing, Cembedded, OT, reconstruction_error = gromov_wasserstein_linear_unmixing( + C, Cdict_GW, p=p, q=q, reg=reg, + tol_outer=10**(-5), tol_inner=10**(-5), + max_iter_outer=30, max_iter_inner=300 + ) + unmixings.append(unmixing) + reconstruction_errors.append(reconstruction_error) +unmixings = np.array(unmixings) +print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) + + +# Compute the 2D representation of the unmixing living in the 2-simplex of probability +unmixings2D = np.zeros(shape=(N, 2)) +for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(4, (4, 4)) +pl.clf() +pl.title('Embedding space', fontsize=14) +for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') + else: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) +pl.axis('off') +pl.legend(fontsize=11) +pl.tight_layout() +pl.show() +# %% +# ============================================================================= +# Endow the dataset with node features +# ============================================================================= + +# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters +# 1 cluster --> 0 as nodes feature +# 2 clusters --> 1 as nodes feature +# 3 clusters --> 2 as nodes feature +# features are one-hot encoded following these assignments +dataset_features = [] +for i in range(len(dataset)): + n = dataset[i].shape[0] + F = np.zeros((n, 3)) + if i < Nc: # graph with 1 cluster + F[:, 0] = 1. + elif i < 2 * Nc: # graph with 2 clusters + F[:, 1] = 1. + else: # graph with 3 clusters + F[:, 2] = 1. + dataset_features.append(F) + +pl.figure(5, (12, 8)) +pl.clf() +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + F = dataset_features[(c - 1) * Nc] + colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])] + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color=colors, s=50) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() +# %% +# ============================================================================= +# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs +# ============================================================================= +np.random.seed(0) +ps = [ot.unif(C.shape[0]) for C in dataset] +D = 3 # 6 atoms instead of 3 +nt = 6 +q = ot.unif(nt) +reg = 0.001 +alpha = 0.5 # trade-off parameter between structure and feature information of Fused Gromov-Wasserstein + + +Cdict_FGW, Ydict_FGW, log = fused_gromov_wasserstein_dictionary_learning( + Cs=dataset, Ys=dataset_features, D=D, nt=nt, ps=ps, q=q, alpha=alpha, + epochs=10, batch_size=16, learning_rate_C=0.1, learning_rate_Y=0.1, reg=reg, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, + projection='nonnegative_symmetric', use_log=True, use_adam_optimizer=True, verbose=True +) +# visualize loss evolution +pl.figure(6, (4, 3)) +pl.clf() +pl.title('loss evolution by epoch', fontsize=14) +pl.plot(log['loss_epochs']) +pl.xlabel('epochs', fontsize=12) +pl.ylabel('loss', fontsize=12) +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the estimated dictionary atoms +# ============================================================================= + +pl.figure(7, (12, 8)) +pl.clf() +max_features = Ydict_FGW.max() +min_features = Ydict_FGW.min() + +for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)): + scaled_atom = (Catom - Catom.min()) / (Catom.max() - Catom.min()) + #scaled_F = 2 * (Fatom - min_features) / (max_features - min_features) + colors = ['C%s' % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])] + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) + pl.subplot(2, D, idx_atom + 1) + pl.title('(attributed graph) atom ' + str(idx_atom + 1), fontsize=14) + plot_graph(x, Catom / Catom.max(), binary=False, color=colors, s=100) + pl.axis("off") + pl.subplot(2, D, D + idx_atom + 1) + pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) + pl.imshow(scaled_atom, interpolation='nearest') + pl.colorbar() + pl.axis("off") +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the embedding space +# ============================================================================= + +unmixings = [] +reconstruction_errors = [] +for i in range(len(dataset)): + C = dataset[i] + Y = dataset_features[i] + p = ot.unif(C.shape[0]) + unmixing, Cembedded, Yembedded, OT, reconstruction_error = fused_gromov_wasserstein_linear_unmixing( + C, Y, Cdict_FGW, Ydict_FGW, p=p, q=q, alpha=alpha, + reg=reg, tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=30, max_iter_inner=300 + ) + unmixings.append(unmixing) + reconstruction_errors.append(reconstruction_error) +unmixings = np.array(unmixings) +print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) + +# Visualize unmixings in the 2-simplex of probability +unmixings2D = np.zeros(shape=(N, 2)) +for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(8, (4, 4)) +pl.clf() +pl.title('Embedding space', fontsize=14) +for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') + else: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) + +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) +pl.axis('off') +pl.legend(fontsize=11) +pl.tight_layout() +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 7253318..bda7a35 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -1,5 +1,4 @@ """ - .. warning:: The list of automatically imported sub-modules is as follows: :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim` @@ -7,13 +6,10 @@ :py:mod:`ot.gromov`, :py:mod:`ot.smooth` :py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath` , :py:mod:`ot.unbalanced`. - The following sub-modules are not imported due to additional dependencies: - - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU. - :any:`ot.plot` : depends on :code:`matplotlib` - """ # Author: Remi Flamary diff --git a/ot/gromov.py b/ot/gromov.py index b7e7949..f5a1f91 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -7,6 +7,7 @@ Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers # Nicolas Courty # Rémi Flamary # Titouan Vayer +# Cédric Vincent-Cuaz # # License: MIT License @@ -17,7 +18,7 @@ from .bregman import sinkhorn from .utils import dist, UndefinedParameter, list_to_array from .optim import cg from .lp import emd_1d, emd -from .utils import check_random_state +from .utils import check_random_state, unif from .backend import get_backend @@ -320,7 +321,7 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt) -def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs): r""" Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -365,6 +366,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -389,18 +393,26 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F """ p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - nx = get_backend(p0, q0, C10, C20) - + if G0 is None: + nx = get_backend(p0, q0, C10, C20) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) - G0 = p[:, None] * q[None, :] + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) def f(G): return gwloss(constC, hC1, hC2, G) @@ -418,7 +430,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10) -def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs): r""" Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -467,6 +479,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. Returns ------- @@ -491,9 +506,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= """ p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - nx = get_backend(p0, q0, C10, C20) + if G0 is None: + nx = get_backend(p0, q0, C10, C20) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) @@ -502,7 +520,13 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - G0 = p[:, None] * q[None, :] + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) def f(G): return gwloss(constC, hC1, hC2, G) @@ -533,7 +557,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= return gw -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs): r""" Computes the FGW transport between two graphs (see :ref:`[24] `) @@ -578,6 +602,9 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. log : bool, optional record log if True **kwargs : dict @@ -600,20 +627,28 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, (ICML). 2019. """ p, q = list_to_array(p, q) - p0, q0, C10, C20, M0 = p, q, C1, C2, M - nx = get_backend(p0, q0, C10, C20, M0) + if G0 is None: + nx = get_backend(p0, q0, C10, C20, M0) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, M0, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) M = nx.to_numpy(M0) + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - G0 = p[:, None] * q[None, :] - def f(G): return gwloss(constC, hC1, hC2, G) @@ -622,19 +657,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, if log: res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10) - log['fgw_dist'] = fgw_dist log['u'] = nx.from_numpy(log['u'], type_as=C10) log['v'] = nx.from_numpy(log['v'], type_as=C10) return nx.from_numpy(res, type_as=C10), log - else: return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10) -def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): +def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs): r""" Computes the FGW distance between two graphs see (see :ref:`[24] `) @@ -683,6 +715,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. log : bool, optional Record log if True. **kwargs : dict @@ -711,7 +746,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 p, q = list_to_array(p, q) p0, q0, C10, C20, M0 = p, q, C1, C2, M - nx = get_backend(p0, q0, C10, C20, M0) + if G0 is None: + nx = get_backend(p0, q0, C10, C20, M0) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, M0, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) @@ -721,7 +760,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - G0 = p[:, None] * q[None, :] + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) def f(G): return gwloss(constC, hC1, hC2, G) @@ -1796,3 +1841,988 @@ def update_feature_matrix(lambdas, Ys, Ts, p): for s in range(len(Ts)) ]) return tmpsum + + +def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): + r""" + Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s` + + .. math:: + \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2 + + such that, :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - reg is the regularization coefficient. + + The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38] + + Parameters + ---------- + Cs : list of S symmetric array-like, shape (ns, ns) + List of Metric/Graph cost matrices of variable size (ns, ns). + D: int + Number of dictionary atoms to learn + nt: int + Number of samples within each dictionary atoms + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + ps : list of S array-like, shape (ns,), optional + Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + q : array-like, shape (nt,), optional + Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. + epochs: int, optional + Number of epochs used to learn the dictionary. Default is 32. + batch_size: int, optional + Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. + learning_rate: float, optional + Learning rate used for the stochastic gradient descent. Default is 1. + Cdict_init: list of D array-like with shape (nt, nt), optional + Used to initialize the dictionary. + If set to None (Default), the dictionary will be initialized randomly. + Else Cdict must have shape (D, nt, nt) i.e match provided shape features. + projection: str , optional + If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary + Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' + log: bool, optional + If set to True, losses evolution by batches and epochs are tracked. Default is False. + use_adam_optimizer: bool, optional + If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. + Else perform SGD with fixed learning rate. Default is True. + tol_outer : float, optional + Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + verbose : bool, optional + Print the reconstruction loss every epoch. Default is False. + + Returns + ------- + + Cdict_best_state : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + log: dict + If use_log is True, contains loss evolutions by batches and epochs. + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + # Handle backend of non-optional arguments + Cs0 = Cs + nx = get_backend(*Cs0) + Cs = [nx.to_numpy(C) for C in Cs0] + dataset_size = len(Cs) + # Handle backend of optional arguments + if ps is None: + ps = [unif(C.shape[0]) for C in Cs] + else: + ps = [nx.to_numpy(p) for p in ps] + if q is None: + q = unif(nt) + else: + q = nx.to_numpy(q) + if Cdict_init is None: + # Initialize randomly structures of dictionary atoms based on samples + dataset_means = [C.mean() for C in Cs] + Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + else: + Cdict = nx.to_numpy(Cdict_init).copy() + assert Cdict.shape == (D, nt, nt) + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0 + if use_adam_optimizer: + adam_moments = _initialize_adam_optimizer(Cdict) + + log = {'loss_batches': [], 'loss_epochs': []} + const_q = q[:, None] * q[None, :] + Cdict_best_state = Cdict.copy() + loss_best_state = np.inf + if batch_size > dataset_size: + batch_size = dataset_size + iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) + + for epoch in range(epochs): + cumulated_loss_over_epoch = 0. + + for _ in range(iter_by_epoch): + # batch sampling + batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) + cumulated_loss_over_batch = 0. + unmixings = np.zeros((batch_size, D)) + Cs_embedded = np.zeros((batch_size, nt, nt)) + Ts = [None] * batch_size + + for batch_idx, C_idx in enumerate(batch): + # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch + unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing( + Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner, + max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner + ) + cumulated_loss_over_batch += current_loss + cumulated_loss_over_epoch += cumulated_loss_over_batch + + if use_log: + log['loss_batches'].append(cumulated_loss_over_batch) + + # Stochastic projected gradient step over dictionary atoms + grad_Cdict = np.zeros_like(Cdict) + for batch_idx, C_idx in enumerate(batch): + shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) + grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] + grad_Cdict *= 2 / batch_size + if use_adam_optimizer: + Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments) + else: + Cdict -= learning_rate * grad_Cdict + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_log: + log['loss_epochs'].append(cumulated_loss_over_epoch) + if loss_best_state > cumulated_loss_over_epoch: + loss_best_state = cumulated_loss_over_epoch + Cdict_best_state = Cdict.copy() + if verbose: + print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + + return nx.from_numpy(Cdict_best_state), log + + +def _initialize_adam_optimizer(variable): + + # Initialization for our numpy implementation of adam optimizer + atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor + atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor + atoms_adam_count = 1 + + return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count} + + +def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09): + + adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad + adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2) + unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count']) + unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count']) + variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps) + adam_moments['count'] += 1 + + return variable, adam_moments + + +def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs): + r""" + Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`. + + .. math:: + \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 + + such that: + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt. + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 1. + + Parameters + ---------- + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Cdict : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. + p : array-like, shape (ns,), optional + Distribution in the source space C. Default is None and corresponds to uniform distribution. + q : array-like, shape (nt,), optional + Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. + tol_outer : float, optional + Solver precision for the BCD algorithm. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + + Returns + ------- + w: array-like, shape (D,) + gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary. + Cembedded: array-like, shape (nt,nt) + embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. + T: array-like (ns, nt) + Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})` + current_loss: float + reconstruction error + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + C0, Cdict0 = C, Cdict + nx = get_backend(C0, Cdict0) + C = nx.to_numpy(C0) + Cdict = nx.to_numpy(Cdict0) + if p is None: + p = unif(C.shape[0]) + else: + p = nx.to_numpy(p) + + if q is None: + q = unif(Cdict.shape[-1]) + else: + q = nx.to_numpy(q) + + T = p[:, None] * q[None, :] + D = len(Cdict) + + w = unif(D) # Initialize uniformly the unmixing w + Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) + + const_q = q[:, None] * q[None, :] + # Trackers for BCD convergence + convergence_criterion = np.inf + current_loss = 10**15 + outer_count = 0 + + while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): + previous_loss = current_loss + # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w + T, log = gromov_wasserstein(C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, log=True, armijo=False, **kwargs) + current_loss = log['gw_dist'] + if reg != 0: + current_loss -= reg * np.sum(w**2) + + # 2. Solve linear unmixing problem over w with a fixed transport plan T + w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing( + C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T, + starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs + ) + + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: # handle numerical issues around 0 + convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + outer_count += 1 + + return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss) + + +def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs): + r""" + Returns for a fixed admissible transport plan, + the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})` + + .. math:: + \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2 + + + Such that: + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points. + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. + - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`. + - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`. + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38] + + Parameters + ---------- + + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + Each matrix in the dictionary must have the same size (nt,nt). + Cembedded: array-like, shape (nt,nt) + Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. + w: array-like, shape (D,) + Linear unmixing of the input structure onto the dictionary + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + T: array-like, shape (ns,nt) + fixed transport plan between the input structure and its representation in the dictionary. + p : array-like, shape (ns,) + Distribution in the source space. + q : array-like, shape (nt,) + Distribution in the embedding space depicted by the dictionary. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. + + Returns + ------- + w: ndarray (D,) + optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing. + """ + convergence_criterion = np.inf + current_loss = starting_loss + count = 0 + const_TCT = np.transpose(C.dot(T)).dot(T) + + while (convergence_criterion > tol) and (count < max_iter): + + previous_loss = current_loss + # 1) Compute gradient at current point w + grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + grad_w -= 2 * reg * w + + # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w + min_ = np.min(grad_w) + x = (grad_w == min_).astype(np.float64) + x /= np.sum(x) + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg) + + # 4) Updates: w <-- (1-gamma)*w + gamma*x + w += gamma * (x - w) + Cembedded += gamma * Cembedded_diff + current_loss += a * (gamma**2) + b * gamma + + if previous_loss != 0: # not that the loss can be negative if reg >0 + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: # handle numerical issues around 0 + convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + count += 1 + + return w, Cembedded, current_loss + + +def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs): + r""" + Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing + .. math:: + \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2 + + + Such that: + + - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` + + Parameters + ---------- + + w : array-like, shape (D,) + Unmixing. + grad_w : array-like, shape (D, D) + Gradient of the reconstruction loss with respect to w. + x: array-like, shape (D,) + Conditional gradient direction. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + Each matrix in the dictionary must have the same size (nt,nt). + Cembedded: array-like, shape (nt,nt) + Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + const_TCT: array-like, shape (nt, nt) + :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. + Returns + ------- + gamma: float + Optimal value for the line-search step + a: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + b: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + Cembedded_diff: numpy array, shape (nt, nt) + Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`. + """ + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) + Cembedded_diff = Cembedded_x - Cembedded + trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) + trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) + a = trace_diffx - trace_diffw + b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) + if reg != 0: + a -= reg * np.sum((x - w)**2) + b -= 2 * reg * np.sum(w * (x - w)) + + if a > 0: + gamma = min(1, max(0, - b / (2 * a))) + elif a + b < 0: + gamma = 1 + else: + gamma = 0 + + return gamma, a, b, Cembedded_diff + + +def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1., + Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): + r""" + Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s` + + .. math:: + \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2 + + + Such that :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + + The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38] + + Parameters + ---------- + Cs : list of S symmetric array-like, shape (ns, ns) + List of Metric/Graph cost matrices of variable size (ns,ns). + Ys : list of S array-like, shape (ns, d) + List of feature matrix of variable size (ns,d) with d fixed. + D: int + Number of dictionary atoms to learn + nt: int + Number of samples within each dictionary atoms + alpha : float + Trade-off parameter of Fused Gromov-Wasserstein + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + ps : list of S array-like, shape (ns,), optional + Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + q : array-like, shape (nt,), optional + Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. + epochs: int, optional + Number of epochs used to learn the dictionary. Default is 32. + batch_size: int, optional + Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. + learning_rate_C: float, optional + Learning rate used for the stochastic gradient descent on Cdict. Default is 1. + learning_rate_Y: float, optional + Learning rate used for the stochastic gradient descent on Ydict. Default is 1. + Cdict_init: list of D array-like with shape (nt, nt), optional + Used to initialize the dictionary structures Cdict. + If set to None (Default), the dictionary will be initialized randomly. + Else Cdict must have shape (D, nt, nt) i.e match provided shape features. + Ydict_init: list of D array-like with shape (nt, d), optional + Used to initialize the dictionary features Ydict. + If set to None, the dictionary features will be initialized randomly. + Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features. + projection: str, optional + If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary + Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' + log: bool, optional + If set to True, losses evolution by batches and epochs are tracked. Default is False. + use_adam_optimizer: bool, optional + If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. + Else perform SGD with fixed learning rate. Default is True. + tol_outer : float, optional + Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + verbose : bool, optional + Print the reconstruction loss every epoch. Default is False. + + Returns + ------- + + Cdict_best_state : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + Ydict_best_state : D array-like, shape (D,nt,d) + Feature matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + log: dict + If use_log is True, contains loss evolutions by batches and epoches. + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + Cs0, Ys0 = Cs, Ys + nx = get_backend(*Cs0, *Ys0) + Cs = [nx.to_numpy(C) for C in Cs0] + Ys = [nx.to_numpy(Y) for Y in Ys0] + + d = Ys[0].shape[-1] + dataset_size = len(Cs) + + if ps is None: + ps = [unif(C.shape[0]) for C in Cs] + else: + ps = [nx.to_numpy(p) for p in ps] + if q is None: + q = unif(nt) + else: + q = nx.to_numpy(q) + + if Cdict_init is None: + # Initialize randomly structures of dictionary atoms based on samples + dataset_means = [C.mean() for C in Cs] + Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + else: + Cdict = nx.to_numpy(Cdict_init).copy() + assert Cdict.shape == (D, nt, nt) + if Ydict_init is None: + # Initialize randomly features of dictionary atoms based on samples distribution by feature component + dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys]) + Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d)) + else: + Ydict = nx.to_numpy(Ydict_init).copy() + assert Ydict.shape == (D, nt, d) + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_adam_optimizer: + adam_moments_C = _initialize_adam_optimizer(Cdict) + adam_moments_Y = _initialize_adam_optimizer(Ydict) + + log = {'loss_batches': [], 'loss_epochs': []} + const_q = q[:, None] * q[None, :] + diag_q = np.diag(q) + Cdict_best_state = Cdict.copy() + Ydict_best_state = Ydict.copy() + loss_best_state = np.inf + if batch_size > dataset_size: + batch_size = dataset_size + iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) + + for epoch in range(epochs): + cumulated_loss_over_epoch = 0. + + for _ in range(iter_by_epoch): + + # Batch iterations + batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) + cumulated_loss_over_batch = 0. + unmixings = np.zeros((batch_size, D)) + Cs_embedded = np.zeros((batch_size, nt, nt)) + Ys_embedded = np.zeros((batch_size, nt, d)) + Ts = [None] * batch_size + + for batch_idx, C_idx in enumerate(batch): + # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch + unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing( + Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q, + tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner + ) + cumulated_loss_over_batch += current_loss + cumulated_loss_over_epoch += cumulated_loss_over_batch + if use_log: + log['loss_batches'].append(cumulated_loss_over_batch) + + # Stochastic projected gradient step over dictionary atoms + grad_Cdict = np.zeros_like(Cdict) + grad_Ydict = np.zeros_like(Ydict) + + for batch_idx, C_idx in enumerate(batch): + shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) + shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx]) + grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] + grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :] + grad_Cdict *= 2 / batch_size + grad_Ydict *= 2 / batch_size + + if use_adam_optimizer: + Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C) + Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y) + else: + Cdict -= learning_rate_C * grad_Cdict + Ydict -= learning_rate_Y * grad_Ydict + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_log: + log['loss_epochs'].append(cumulated_loss_over_epoch) + if loss_best_state > cumulated_loss_over_epoch: + loss_best_state = cumulated_loss_over_epoch + Cdict_best_state = Cdict.copy() + Ydict_best_state = Ydict.copy() + if verbose: + print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + + return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log + + +def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs): + r""" + Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` + + .. math:: + \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 + + such that, :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 6. + + Parameters + ---------- + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Y : array-like, shape (ns, d) + Feature matrix. + Cdict : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Ydict : D array-like, shape (D,nt,d) + Feature matrices composing the dictionary on which to embed (C,Y). + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + p : array-like, shape (ns,), optional + Distribution in the source space C. Default is None and corresponds to uniform distribution. + q : array-like, shape (nt,), optional + Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. + tol_outer : float, optional + Solver precision for the BCD algorithm. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + + Returns + ------- + w: array-like, shape (D,) + fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary. + Cembedded: array-like, shape (nt,nt) + embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. + Yembedded: array-like, shape (nt,d) + embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`. + T: array-like (ns,nt) + Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`. + current_loss: float + reconstruction error + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict + nx = get_backend(C0, Y0, Cdict0, Ydict0) + C = nx.to_numpy(C0) + Y = nx.to_numpy(Y0) + Cdict = nx.to_numpy(Cdict0) + Ydict = nx.to_numpy(Ydict0) + + if p is None: + p = unif(C.shape[0]) + else: + p = nx.to_numpy(p) + if q is None: + q = unif(Cdict.shape[-1]) + else: + q = nx.to_numpy(q) + + T = p[:, None] * q[None, :] + D = len(Cdict) + d = Y.shape[-1] + w = unif(D) # Initialize with uniform weights + ns = C.shape[-1] + nt = Cdict.shape[-1] + + # modeling (C,Y) + Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) + Yembedded = np.sum(w[:, None, None] * Ydict, axis=0) + + # constants depending on q + const_q = q[:, None] * q[None, :] + diag_q = np.diag(q) + # Trackers for BCD convergence + convergence_criterion = np.inf + current_loss = 10**15 + outer_count = 0 + Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix + + while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): + previous_loss = current_loss + + # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w + Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T) + M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features + T, log = fused_gromov_wasserstein(M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, armijo=False, G0=T, log=True) + current_loss = log['fgw_dist'] + if reg != 0: + current_loss -= reg * np.sum(w**2) + + # 2. Solve linear unmixing problem over w with a fixed transport plan T + w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, + T, p, q, const_q, diag_q, current_loss, alpha, reg, + tol=tol_inner, max_iter=max_iter_inner, **kwargs) + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: + convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + outer_count += 1 + + return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss) + + +def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs): + r""" + Returns for a fixed admissible transport plan, + the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})` + + .. math:: + \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2 + + Such that : + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}` + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38], algorithm 7. + + Parameters + ---------- + + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Y : array-like, shape (ns, d) + Feature matrix. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,nt). + Ydict : list of D array-like, shape (nt,d) + Feature matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,d). + Cembedded: array-like, shape (nt,nt) + Embedded structure of (C,Y) onto the dictionary + Yembedded: array-like, shape (nt,d) + Embedded features of (C,Y) onto the dictionary + w: array-like, shape (n_D,) + Linear unmixing of (C,Y) onto (Cdict,Ydict) + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution. + diag_q: array-like, shape (nt,nt) + diagonal matrix with values of q on the diagonal. + T: array-like, shape (ns,nt) + fixed transport plan between (C,Y) and its model + p : array-like, shape (ns,) + Distribution in the source space (C,Y). + q : array-like, shape (nt,) + Distribution in the embedding space depicted by the dictionary. + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. + + Returns + ------- + w: ndarray (D,) + linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing. + """ + convergence_criterion = np.inf + current_loss = starting_loss + count = 0 + const_TCT = np.transpose(C.dot(T)).dot(T) + ones_ns_d = np.ones(Y.shape) + + while (convergence_criterion > tol) and (count < max_iter): + previous_loss = current_loss + + # 1) Compute gradient at current point w + # structure + grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + # feature + grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2)) + grad_w -= reg * w + grad_w *= 2 + + # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w + min_ = np.min(grad_w) + x = (grad_w == min_).astype(np.float64) + x /= np.sum(x) + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg) + + # 4) Updates: w <-- (1-gamma)*w + gamma*x + w += gamma * (x - w) + Cembedded += gamma * Cembedded_diff + Yembedded += gamma * Yembedded_diff + current_loss += a * (gamma**2) + b * gamma + + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: + convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + count += 1 + + return w, Cembedded, Yembedded, current_loss + + +def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs): + r""" + Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing + .. math:: + \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2 + + + Such that : + + - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` + + Parameters + ---------- + + w : array-like, shape (D,) + Unmixing. + grad_w : array-like, shape (D, D) + Gradient of the reconstruction loss with respect to w. + x: array-like, shape (D,) + Conditional gradient direction. + Y: arrat-like, shape (ns,d) + Feature matrix of the input space + Cdict : list of D array-like, shape (nt, nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,nt). + Ydict : list of D array-like, shape (nt, d) + Feature matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,d). + Cembedded: array-like, shape (nt, nt) + Embedded structure of (C,Y) onto the dictionary + Yembedded: array-like, shape (nt, d) + Embedded features of (C,Y) onto the dictionary + T: array-like, shape (ns, nt) + Fixed transport plan between (C,Y) and its current model. + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + const_TCT: array-like, shape (nt, nt) + :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. + ones_ns_d: array-like, shape (ns, d) + :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations. + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. + + Returns + ------- + gamma: float + Optimal value for the line-search step + a: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + b: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + Cembedded_diff: numpy array, shape (nt, nt) + Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + Yembedded_diff: numpy array, shape (nt, nt) + Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + """ + # polynomial coefficients from quadratic objective (with respect to w) on structures + Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) + Cembedded_diff = Cembedded_x - Cembedded + trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) + trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) + # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss + a_gw = trace_diffx - trace_diffw + b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) + + # polynomial coefficient from quadratic objective (with respect to w) on features + Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0) + Yembedded_diff = Yembedded_x - Yembedded + # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss + a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T) + b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T))) + + a = alpha * a_gw + (1 - alpha) * a_w + b = alpha * b_gw + (1 - alpha) * b_w + if reg != 0: + a -= reg * np.sum((x - w)**2) + b -= 2 * reg * np.sum(w * (x - w)) + if a > 0: + gamma = min(1, max(0, -b / (2 * a))) + elif a + b < 0: + gamma = 1 + else: + gamma = 0 + + return gamma, a, b, Cembedded_diff, Yembedded_diff diff --git a/test/test_gromov.py b/test/test_gromov.py index 4b995d5..329f99c 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -3,6 +3,7 @@ # Author: Erwan Vautier # Nicolas Courty # Titouan Vayer +# Cédric Vincent-Cuaz # # License: MIT License @@ -26,6 +27,7 @@ def test_gromov(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -37,9 +39,10 @@ def test_gromov(nx): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0) - G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -56,9 +59,9 @@ def test_gromov(nx): gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True) gwb = nx.to_numpy(gwb) - gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False) + gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False) gw_valb = nx.to_numpy( - ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) ) G = log['T'] @@ -91,6 +94,7 @@ def test_gromov_dtype_device(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -105,9 +109,10 @@ def test_gromov_dtype_device(nx): C2b = nx.from_numpy(C2, type_as=tp) pb = nx.from_numpy(p, type_as=tp) qb = nx.from_numpy(q, type_as=tp) + G0b = nx.from_numpy(G0, type_as=tp) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -123,6 +128,7 @@ def test_gromov_device_tf(): xt = xs[::-1].copy() p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) C1 /= C1.max() @@ -134,8 +140,9 @@ def test_gromov_device_tf(): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + G0b = nx.from_numpy(G0) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -145,6 +152,7 @@ def test_gromov_device_tf(): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0b) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) nx.assert_same_dtype_device(C1b, Gb) @@ -554,6 +562,7 @@ def test_fgw(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -569,9 +578,10 @@ def test_fgw(nx): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0) - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True) Gb = nx.to_numpy(Gb) # check constraints @@ -586,8 +596,8 @@ def test_fgw(nx): np.testing.assert_allclose( Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True) fgwb = nx.to_numpy(fgwb) G = log['T'] @@ -698,3 +708,523 @@ def test_fgw_barycenter(nx): Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + + +def test_gromov_wasserstein_linear_unmixing(nx): + n = 10 + + X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cdict = np.stack([C1, C2]) + p = ot.unif(n) + + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + Cdictb = nx.from_numpy(Cdict) + pb = nx.from_numpy(p) + tol = 10**(-5) + # Tests without regularization + reg = 0. + unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( + C1, Cdict, reg=reg, p=p, q=p, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, Cdictb, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( + C2, Cdict, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, Cdictb, reg=reg, p=pb, q=pb, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + # Tests with regularization + + reg = 0.001 + unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( + C1, Cdict, reg=reg, p=p, q=p, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, Cdictb, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( + C2, Cdict, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, Cdictb, reg=reg, p=pb, q=pb, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + +def test_gromov_wasserstein_dictionary_learning(nx): + + # create dataset composed from 2 structures which are repeated 5 times + shape = 10 + n_samples = 2 + n_atoms = 2 + projection = 'nonnegative_symmetric' + X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + ps = [ot.unif(shape) for _ in range(n_samples)] + q = ot.unif(shape) + + # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) + # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. + dataset_means = [C.mean() for C in Cs] + np.random.seed(0) + Cdict_init = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape)) + if projection == 'nonnegative_symmetric': + Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) + Cdict_init[Cdict_init < 0.] = 0. + Csb = [nx.from_numpy(C) for C in Cs] + psb = [nx.from_numpy(p) for p in ps] + qb = nx.from_numpy(q) + Cdict_initb = nx.from_numpy(Cdict_init) + + # Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization + # > Compute initial reconstruction of samples on this random dictionary without backend + use_adam_optimizer = True + verbose = False + tol = 10**(-5) + epochs = 1 + + initial_total_reconstruction = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_init, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + initial_total_reconstruction += reconstruction + + # > Learn the dictionary using this init + Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, + epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary without backend + total_reconstruction = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict, p=None, q=None, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction += reconstruction + + np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) + + # Test: Perform same experiments after going through backend + + Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # Compute reconstruction of samples on learned dictionary + total_reconstruction_b = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb, p=psb[i], q=qb, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b += reconstruction + + np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) + + # Test: Perform same comparison without providing the initial dictionary being an optional input + # knowing than the initialization scheme is the same than implemented to set the benchmarked initialization. + np.random.seed(0) + Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_bis, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + + # Test: Same after going through backend + np.random.seed(0) + Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb_bis, p=None, q=None, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03) + + # Test: Perform same comparison without providing the initial dictionary being an optional input + # and testing other optimization settings untested until now. + # We pass previously estimated dictionaries to speed up the process. + use_adam_optimizer = False + verbose = True + use_log = True + + np.random.seed(0) + Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, + epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis2 = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis2 += reconstruction + + np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) + + # Test: Same after going through backend + np.random.seed(0) + Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb, + epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis2 = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis2 += reconstruction + + np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05) + + +def test_fused_gromov_wasserstein_linear_unmixing(nx): + + n = 10 + X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) + F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cdict = np.stack([C1, C2]) + Ydict = np.stack([F, F]) + p = ot.unif(n) + + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + Fb = nx.from_numpy(F) + Cdictb = nx.from_numpy(Cdict) + Ydictb = nx.from_numpy(Ydict) + pb = nx.from_numpy(p) + # Tests without regularization + reg = 0. + + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) + np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) + np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + # Tests with regularization + reg = 0.001 + + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) + np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) + np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + +def test_fused_gromov_wasserstein_dictionary_learning(nx): + + # create dataset composed from 2 structures which are repeated 5 times + shape = 10 + n_samples = 2 + n_atoms = 2 + projection = 'nonnegative_symmetric' + X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) + F, y = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + Ys = [F.copy() for _ in range(n_samples)] + ps = [ot.unif(shape) for _ in range(n_samples)] + q = ot.unif(shape) + + # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) + # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. + dataset_structure_means = [C.mean() for C in Cs] + np.random.seed(0) + Cdict_init = np.random.normal(loc=np.mean(dataset_structure_means), scale=np.std(dataset_structure_means), size=(n_atoms, shape, shape)) + if projection == 'nonnegative_symmetric': + Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) + Cdict_init[Cdict_init < 0.] = 0. + dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys]) + Ydict_init = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2)) + + Csb = [nx.from_numpy(C) for C in Cs] + Ysb = [nx.from_numpy(Y) for Y in Ys] + psb = [nx.from_numpy(p) for p in ps] + qb = nx.from_numpy(q) + Cdict_initb = nx.from_numpy(Cdict_init) + Ydict_initb = nx.from_numpy(Ydict_init) + + # Test: Compute initial reconstruction of samples on this random dictionary + alpha = 0.5 + use_adam_optimizer = True + verbose = False + tol = 1e-05 + epochs = 1 + + initial_total_reconstruction = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q, + alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + initial_total_reconstruction += reconstruction + + # > Learn a dictionary using this given initialization and check that the reconstruction loss + # on the learned dictionary is lower than the one using its initialization. + Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction += reconstruction + # Compare both + np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) + + # Test: Perform same experiments after going through backend + + Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb, + epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b += reconstruction + + np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) + np.testing.assert_allclose(Ydict, nx.to_numpy(Ydictb), atol=1e-03) + + # Test: Perform similar experiment without providing the initial dictionary being an optional input + np.random.seed(0) + Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + + # > Same after going through backend + np.random.seed(0) + Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis += reconstruction + np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + + # Test: without using adam optimizer, with log and verbose set to True + use_adam_optimizer = False + verbose = True + use_log = True + + # > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init. + np.random.seed(0) + Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict, + epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis2 = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis2 += reconstruction + + np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) + + # > Same after going through backend + np.random.seed(0) + Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb, + epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis2 = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis2 += reconstruction + + # > Compare results with/without backend + np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05) -- cgit v1.2.3 From 9412f0ad1c0003e659b7d779bf8b6728e0e5e60f Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Wed, 2 Mar 2022 11:35:47 +0100 Subject: [MRG] Gromov_Wasserstein2 not performing backward properly on GPU (#352) * Resolves gromov wasserstein backward bug * release file updated --- RELEASES.md | 3 +++ ot/gromov.py | 12 +++++++---- test/test_gromov.py | 60 +++++++++++++++++++++++++++++++---------------------- 3 files changed, 46 insertions(+), 29 deletions(-) (limited to 'ot/gromov.py') diff --git a/RELEASES.md b/RELEASES.md index c1068f3..18562e7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -18,6 +18,9 @@ - Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, PR #338) - Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349) +- Fix bug where gromov_wasserstein2 does not perform backpropagation with CUDA + tensors (Issue #351, PR #352) + ## 0.8.1.0 *December 2021* diff --git a/ot/gromov.py b/ot/gromov.py index f5a1f91..c5a82d1 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -546,8 +546,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= gw = log_gw['gw_dist'] if loss_fun == 'square_loss': - gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) - gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) + gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) + gC1 = nx.from_numpy(gC1, type_as=C10) + gC2 = nx.from_numpy(gC2, type_as=C10) gw = nx.set_gradients(gw, (p0, q0, C10, C20), (log_gw['u'], log_gw['v'], gC1, gC2)) @@ -786,8 +788,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 log_fgw['T'] = T0 if loss_fun == 'square_loss': - gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) - gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) + gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) + gC1 = nx.from_numpy(gC1, type_as=C10) + gC2 = nx.from_numpy(gC2, type_as=C10) fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) diff --git a/test/test_gromov.py b/test/test_gromov.py index 329f99c..0dcf2da 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -181,19 +181,24 @@ def test_gromov2_gradients(): if torch: - p1 = torch.tensor(p, requires_grad=True) - q1 = torch.tensor(q, requires_grad=True) - C11 = torch.tensor(C1, requires_grad=True) - C12 = torch.tensor(C2, requires_grad=True) + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) - val = ot.gromov_wasserstein2(C11, C12, p1, q1) + val = ot.gromov_wasserstein2(C11, C12, p1, q1) - val.backward() + val.backward() - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape @pytest.skip_backend("jax", reason="test very slow with jax backend") @@ -636,21 +641,26 @@ def test_fgw2_gradients(): if torch: - p1 = torch.tensor(p, requires_grad=True) - q1 = torch.tensor(q, requires_grad=True) - C11 = torch.tensor(C1, requires_grad=True) - C12 = torch.tensor(C2, requires_grad=True) - M1 = torch.tensor(M, requires_grad=True) - - val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) - - val.backward() - - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert M1.shape == M1.grad.shape + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape def test_fgw_barycenter(nx): -- cgit v1.2.3 From 486b0d6397182a57cd53651dca87fcea89747490 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 11 Apr 2022 16:26:30 +0200 Subject: [MRG] Center gradients for mass of emd2 and gw2 (#363) * center gradients for mass of emd2 and gw2 * debug fgw gradient * debug fgw --- RELEASES.md | 4 +++- ot/gromov.py | 7 +++++-- ot/lp/__init__.py | 7 ++++--- test/test_ot.py | 8 +++++++- 4 files changed, 19 insertions(+), 7 deletions(-) (limited to 'ot/gromov.py') diff --git a/RELEASES.md b/RELEASES.md index 7942a15..33d1ab6 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,7 +5,7 @@ #### New features -- remode deprecated `ot.gpu` submodule (PR #361) +- Remove deprecated `ot.gpu` submodule (PR #361) - Update examples in the gallery (PR #359). - Add stochastic loss and OT plan computation for regularized OT and backend examples(PR #360). @@ -23,6 +23,8 @@ #### Closed issues +- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are + centered (Issue #364, PR #363) - Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, PR #338) - Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349) diff --git a/ot/gromov.py b/ot/gromov.py index c5a82d1..55ab0bd 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -551,7 +551,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= gC1 = nx.from_numpy(gC1, type_as=C10) gC2 = nx.from_numpy(gC2, type_as=C10) gw = nx.set_gradients(gw, (p0, q0, C10, C20), - (log_gw['u'], log_gw['v'], gC1, gC2)) + (log_gw['u'] - nx.mean(log_gw['u']), + log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2)) if log: return gw, log_gw @@ -793,7 +794,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 gC1 = nx.from_numpy(gC1, type_as=C10) gC2 = nx.from_numpy(gC2, type_as=C10) fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), - (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T0)) if log: return fgw_dist, log_fgw diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index abf7fe0..390c32d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -517,7 +517,8 @@ def emd2(a, b, M, processes=1, log['warning'] = result_code_string log['result_code'] = result_code cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (log['u'], log['v'], G)) + (a0, b0, M0), (log['u'] - nx.mean(log['u']), + log['v'] - nx.mean(log['v']), G)) return [cost, log] else: def f(b): @@ -540,8 +541,8 @@ def emd2(a, b, M, processes=1, ) G = nx.from_numpy(G, type_as=type_as) cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (nx.from_numpy(u, type_as=type_as), - nx.from_numpy(v, type_as=type_as), G)) + (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), G)) check_result(result_code) return cost diff --git a/test/test_ot.py b/test/test_ot.py index bb258e2..bf832f6 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -147,7 +147,7 @@ def test_emd2_gradients(): b1 = torch.tensor(a, requires_grad=True) M1 = torch.tensor(M, requires_grad=True) - val = ot.emd2(a1, b1, M1) + val, log = ot.emd2(a1, b1, M1, log=True) val.backward() @@ -155,6 +155,12 @@ def test_emd2_gradients(): assert b1.shape == b1.grad.shape assert M1.shape == M1.grad.shape + assert np.allclose(a1.grad.cpu().detach().numpy(), + log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean()) + + assert np.allclose(b1.grad.cpu().detach().numpy(), + log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean()) + # Testing for bug #309, checking for scaling of gradient a2 = torch.tensor(a, requires_grad=True) b2 = torch.tensor(a, requires_grad=True) -- cgit v1.2.3