From 9076f02903ba2fb9ea9fe704764a755cad8dcd63 Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Mon, 12 Jun 2023 12:01:48 +0200 Subject: [FEAT] Entropic gw/fgw/srgw/srfgw solvers (#455) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add entropic fgw + fgw bary + srgw + srfgw with tests * add exemples for entropic srgw - srfgw solvers * add PPA solvers for GW/FGW + complete previous commits * update readme * add tests * add examples + tests + warning in entropic solvers + releases * reduce testing runtimes for test_gromov * fix conflicts * optional marginals * improve coverage * gromov doc harmonization * fix pep8 * complete optional marginal for entropic srfgw --------- Co-authored-by: Rémi Flamary --- CONTRIBUTORS.md | 2 +- README.md | 8 +- RELEASES.md | 10 +- examples/gromov/plot_entropic_semirelaxed_fgw.py | 304 +++++++++ examples/gromov/plot_fgw.py | 32 +- examples/gromov/plot_fgw_solvers.py | 288 ++++++++ examples/gromov/plot_gromov.py | 112 +++- ot/gromov/__init__.py | 37 +- ot/gromov/_bregman.py | 782 ++++++++++++++++++++-- ot/gromov/_gw.py | 318 ++++----- ot/gromov/_semirelaxed.py | 591 ++++++++++++++++- ot/gromov/_utils.py | 43 ++ test/test_gromov.py | 796 +++++++++++++++++++++-- 13 files changed, 2948 insertions(+), 375 deletions(-) create mode 100644 examples/gromov/plot_entropic_semirelaxed_fgw.py create mode 100644 examples/gromov/plot_fgw_solvers.py diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 6b35653..2d25f3e 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -36,7 +36,7 @@ The contributors to this library are: * [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) -* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, semi-relaxed FGW) +* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) diff --git a/README.md b/README.md index c16b328..25c0401 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,8 @@ POT provides the following generic OT solvers (links to examples): * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. * Weak OT solver between empirical distributions [39] * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) with LP solver (only small scale). -* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from Graph Dictionary Learning [38] - * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] +* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12,51]), differentiable using gradients from Graph Dictionary Learning [38] + * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) (exact [24] and regularized [12,51]). * [Stochastic solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and [differentiable losses](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html) for @@ -42,7 +42,7 @@ POT provides the following generic OT solvers (links to examples): * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] * [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. -* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) [48]. +* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]). * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. POT provides the following Machine Learning related solvers: @@ -310,3 +310,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR). + +[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019. diff --git a/RELEASES.md b/RELEASES.md index 61ad2ca..2f47f40 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,13 +5,18 @@ #### New features - Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483) - Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463) -- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459) +- Added the sparsity-constrained OT solver to `ot.smooth` and added `projection_sparse_simplex` to `ot.utils` (PR #459) - Add tests on GPU for master branch and approved PR (PR #473) - Add `median` method to all inherited classes of `backend.Backend` (PR #472) - Update tests for macOS and Windows, speedup documentation (PR #484) +- Added Proximal Point algorithm to solve GW problems via a new parameter `solver="PPA"` in `ot.gromov.entropic_gromov_wasserstein` + examples (PR #455) +- Added features `warmstart` and `kwargs` in `ot.gromov.entropic_gromov_wasserstein` to respectively perform warmstart on dual potentials and pass parameters to `ot.sinkhorn` (PR #455) +- Added sinkhorn projection based solvers for FGW `ot.gromov.entropic_fused_gromov_wasserstein` and entropic FGW barycenters + examples (PR #455) +- Added features `warmstartT` and `kwargs` to all CG and entropic (F)GW barycenter solvers (PR #455) +- Added entropic semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #455) +- Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455) #### Closed issues - - Fix circleci-redirector action and codecov (PR #460) - Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457) - Major documentation cleanup (PR #462, #467, #475) @@ -22,6 +27,7 @@ - Fix `utils.cost_normalization` function issue to work with multiple backends (PR #472) ## 0.9.0 +*April 2023* This new release contains so many new features and bug fixes since 0.8.2 that we decided to make it a new minor release at 0.9.0. diff --git a/examples/gromov/plot_entropic_semirelaxed_fgw.py b/examples/gromov/plot_entropic_semirelaxed_fgw.py new file mode 100644 index 0000000..642baea --- /dev/null +++ b/examples/gromov/plot_entropic_semirelaxed_fgw.py @@ -0,0 +1,304 @@ +# -*- coding: utf-8 -*- +""" +========================== +Entropic-regularized semi-relaxed (Fused) Gromov-Wasserstein example +========================== + +This example is designed to show how to use the entropic semi-relaxed Gromov-Wasserstein +and the entropic semi-relaxed Fused Gromov-Wasserstein divergences. + +Entropic-regularized sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of +G2 at a minimal entropic-regularized (F)GW distance from G1. + +First, we generate two graphs following Stochastic Block Models, then show +how to compute their srGW matchings and illustrate them. These graphs are then +endowed with node features and we follow the same process with srFGW. + +[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. +"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" +International Conference on Learning Representations (ICLR), 2021. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +import numpy as np +import matplotlib.pylab as pl +from ot.gromov import entropic_semirelaxed_gromov_wasserstein, entropic_semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein +import networkx +from networkx.generators.community import stochastic_block_model as sbm + +############################################################################# +# +# Generate two graphs following Stochastic Block models of 2 and 3 clusters. +# --------------------------------------------- + + +N2 = 20 # 2 communities +N3 = 30 # 3 communities +p2 = [[1., 0.1], + [0.1, 0.9]] +p3 = [[1., 0.1, 0.], + [0.1, 0.95, 0.1], + [0., 0.1, 0.9]] +G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2) +G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3) + + +C2 = networkx.to_numpy_array(G2) +C3 = networkx.to_numpy_array(G3) + +h2 = np.ones(C2.shape[0]) / C2.shape[0] +h3 = np.ones(C3.shape[0]) / C3.shape[0] + +# Add weights on the edges for visualization later on +weight_intra_G2 = 5 +weight_inter_G2 = 0.5 +weight_intra_G3 = 1. +weight_inter_G3 = 1.5 + +weightedG2 = networkx.Graph() +part_G2 = [G2.nodes[i]['block'] for i in range(N2)] + +for node in G2.nodes(): + weightedG2.add_node(node) +for i, j in G2.edges(): + if part_G2[i] == part_G2[j]: + weightedG2.add_edge(i, j, weight=weight_intra_G2) + else: + weightedG2.add_edge(i, j, weight=weight_inter_G2) + +weightedG3 = networkx.Graph() +part_G3 = [G3.nodes[i]['block'] for i in range(N3)] + +for node in G3.nodes(): + weightedG3.add_node(node) +for i, j in G3.edges(): + if part_G3[i] == part_G3[j]: + weightedG3.add_edge(i, j, weight=weight_intra_G3) + else: + weightedG3.add_edge(i, j, weight=weight_inter_G3) + +############################################################################# +# +# Compute their entropic-regularized semi-relaxed Gromov-Wasserstein divergences +# --------------------------------------------- + +# 0) GW(C2, h2, C3, h3) for reference +OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True) +gw = log['gw_dist'] + +# 1) srGW_e(C2, h2, C3) +OT_23, log_23 = entropic_semirelaxed_gromov_wasserstein( + C2, C3, h2, symmetric=True, epsilon=1., G0=None, log=True) +srgw_23 = log_23['srgw_dist'] + +# 2) srGW_e(C3, h3, C2) + +OT_32, log_32 = entropic_semirelaxed_gromov_wasserstein( + C3, C2, h3, symmetric=None, epsilon=1., G0=None, log=True) +srgw_32 = log_32['srgw_dist'] + +print('GW(C2, C3) = ', gw) +print('srGW_e(C2, h2, C3) = ', srgw_23) +print('srGW_e(C3, h3, C2) = ', srgw_32) + + +############################################################################# +# +# Visualization of the entropic-regularized semi-relaxed Gromov-Wasserstein matchings +# --------------------------------------------- +# +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the entropic srGW matching. +# We adjust the intensity of links across domains proportionaly to the mass +# sent, adding a minimal intensity of 0.1 if mass sent is not zero. + + +def draw_graph(G, C, nodes_color_part, Gweights=None, + pos=None, edge_color='black', node_size=None, + shiftx=0, seed=0): + + if (pos is None): + pos = networkx.spring_layout(G, scale=1., seed=seed) + + if shiftx != 0: + for k, v in pos.items(): + v[0] = v[0] + shiftx + + alpha_edge = 0.7 + width_edge = 1.8 + if Gweights is None: + networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + else: + # We make more visible connections between activated nodes + n = len(Gweights) + edgelist_activated = [] + edgelist_deactivated = [] + for i in range(n): + for j in range(n): + if Gweights[i] * Gweights[j] * C[i, j] > 0: + edgelist_activated.append((i, j)) + elif C[i, j] > 0: + edgelist_deactivated.append((i, j)) + + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated, + width=width_edge, alpha=alpha_edge, + edge_color=edge_color) + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated, + width=width_edge, alpha=0.1, + edge_color=edge_color) + + if Gweights is None: + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=node_size, alpha=1, + node_color=node_color) + else: + scaled_Gweights = Gweights / (0.5 * Gweights.max()) + nodes_size = node_size * scaled_Gweights + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=nodes_size[node], alpha=1, + node_color=node_color) + return pos + + +def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1, + p1, p2, T, pos1=None, pos2=None, + shiftx=4, switchx=False, node_size=70, + seed_G1=0, seed_G2=0): + starting_color = 0 + # get graphs partition and their coloring + part1 = part_G1.copy() + unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)] + nodes_color_part1 = [] + for cluster in part1: + nodes_color_part1.append(unique_colors[cluster]) + + nodes_color_part2 = [] + # T: getting colors assignment from argmin of columns + for i in range(len(G2.nodes())): + j = np.argmax(T[:, i]) + nodes_color_part2.append(nodes_color_part1[j]) + pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1, + pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1) + pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2, + node_size=node_size, shiftx=shiftx, seed=seed_G2) + for k1, v1 in pos1.items(): + max_Tk1 = np.max(T[k1, :]) + for k2, v2 in pos2.items(): + if (T[k1, k2] > 0): + pl.plot([pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + '-', lw=0.6, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.), + color=nodes_color_part1[k1]) + return pos1, pos2 + + +node_size = 40 +fontsize = 10 +seed_G2 = 0 +seed_G3 = 4 + +pl.figure(1, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.axis('off') +pl.axis +pl.title(r'$srGW_e(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize) + +hbar2 = OT_23.sum(axis=0) +pos1, pos2 = draw_transp_colored_srGW( + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) +pl.subplot(122) +pl.axis('off') +hbar3 = OT_32.sum(axis=0) +pl.title(r'$srGW_e(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize) +pos1, pos2 = draw_transp_colored_srGW( + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) +pl.tight_layout() + +pl.show() + +############################################################################# +# +# Add node features +# --------------------------------------------- + +# We add node features with given mean - by clusters +# and inversely proportional to clusters' intra-connectivity + +F2 = np.zeros((N2, 1)) +for i, c in enumerate(part_G2): + F2[i, 0] = np.random.normal(loc=c, scale=0.01) + +F3 = np.zeros((N3, 1)) +for i, c in enumerate(part_G3): + F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) + +############################################################################# +# +# Compute their semi-relaxed Fused Gromov-Wasserstein divergences +# --------------------------------------------- + +alpha = 0.5 +# Compute pairwise euclidean distance between node features +M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T) + +# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference + +OT, log = fused_gromov_wasserstein( + M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True) +fgw = log['fgw_dist'] + +# 1) srFGW_e(C2, F2, h2, C3, F3) +OT_23, log_23 = entropic_semirelaxed_fused_gromov_wasserstein( + M, C2, C3, h2, symmetric=True, epsilon=1., alpha=0.5, log=True, G0=None) +srfgw_23 = log_23['srfgw_dist'] + +# 2) srFGW(C3, F3, h3, C2, F2) + +OT_32, log_32 = entropic_semirelaxed_fused_gromov_wasserstein( + M.T, C3, C2, h3, symmetric=None, epsilon=1., alpha=alpha, log=True, G0=None) +srfgw_32 = log_32['srfgw_dist'] + +print('FGW(C2, F2, C3, F3) = ', fgw) +print(r'$srGW_e$(C2, F2, h2, C3, F3) = ', srfgw_23) +print(r'$srGW_e$(C3, F3, h3, C2, F2) = ', srfgw_32) + +############################################################################# +# +# Visualization of the entropic semi-relaxed Fused Gromov-Wasserstein matchings +# --------------------------------------------- +# +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the srFGW matching +# NB: colors refer to clusters - not to node features + +pl.figure(2, figsize=(8, 2.5)) +pl.clf() +pl.subplot(121) +pl.axis('off') +pl.axis +pl.title(r'$srFGW_e(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize) + +hbar2 = OT_23.sum(axis=0) +pos1, pos2 = draw_transp_colored_srGW( + weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) +pl.subplot(122) +pl.axis('off') +hbar3 = OT_32.sum(axis=0) +pl.title(r'$srFGW_e(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize) +pos1, pos2 = draw_transp_colored_srGW( + weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32, + pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0) +pl.tight_layout() + +pl.show() diff --git a/examples/gromov/plot_fgw.py b/examples/gromov/plot_fgw.py index bf10de6..68ecb13 100644 --- a/examples/gromov/plot_fgw.py +++ b/examples/gromov/plot_fgw.py @@ -4,13 +4,13 @@ Plot Fused-Gromov-Wasserstein ============================== -This example illustrates the computation of FGW for 1D measures [18]. +This example first illustrates the computation of FGW for 1D measures estimated +using a Conditional Gradient solver [24]. -[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. - """ # Author: Titouan Vayer @@ -24,11 +24,13 @@ import numpy as np import ot from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein + ############################################################################## # Generate data # ------------- -#%% parameters +# parameters + # We create two 1D random measures n = 20 # number of points in the first distribution n2 = 30 # number of points in the second distribution @@ -53,10 +55,9 @@ q = ot.unif(n2) # Plot data # --------- -#%% plot the distributions +# plot the distributions -pl.close(10) -pl.figure(10, (7, 7)) +pl.figure(1, (7, 7)) pl.subplot(2, 1, 1) @@ -78,7 +79,7 @@ pl.show() # Create structure matrices and across-feature distance matrix # ------------------------------------------------------------ -#%% Structure matrices and across-features distance matrix +# Structure matrices and across-features distance matrix C1 = ot.dist(xs) C2 = ot.dist(xt) M = ot.dist(ys, yt) @@ -90,10 +91,9 @@ Got = ot.emd([], [], M) # Plot matrices # ------------- -#%% cmap = 'Reds' -pl.close(10) -pl.figure(10, (5, 5)) + +pl.figure(2, (5, 5)) fs = 15 l_x = [0, 5, 10, 15] l_y = [0, 5, 10, 15, 20, 25] @@ -113,7 +113,6 @@ ax2 = pl.subplot(gs[:3, 2:]) pl.imshow(C2, cmap=cmap, interpolation='nearest') pl.title("$C_2$", fontsize=fs) pl.ylabel("$l$", fontsize=fs) -#pl.ylabel("$l$",fontsize=fs) pl.xticks(()) pl.yticks(l_y) ax2.set_aspect('auto') @@ -133,28 +132,27 @@ pl.show() # Compute FGW/GW # -------------- -#%% Computing FGW and GW +# Computing FGW and GW alpha = 1e-3 ot.tic() Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True) ot.toc() -#%reload_ext WGW +# reload_ext WGW Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True) ############################################################################## # Visualize transport matrices # ---------------------------- -#%% visu OT matrix +# visu OT matrix cmap = 'Blues' fs = 15 -pl.figure(2, (13, 5)) +pl.figure(3, (13, 5)) pl.clf() pl.subplot(1, 3, 1) pl.imshow(Got, cmap=cmap, interpolation='nearest') -#pl.xlabel("$y$",fontsize=fs) pl.ylabel("$i$", fontsize=fs) pl.xticks(()) diff --git a/examples/gromov/plot_fgw_solvers.py b/examples/gromov/plot_fgw_solvers.py new file mode 100644 index 0000000..5f8a885 --- /dev/null +++ b/examples/gromov/plot_fgw_solvers.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +""" +============================== +Comparison of Fused Gromov-Wasserstein solvers +============================== + +This example illustrates the computation of FGW for attributed graphs +using 3 different solvers to estimate the distance based on Conditional +Gradient [24] or Sinkhorn projections [12, 51]. + +We generate two graphs following Stochastic Block Models further endowed with +node features and compute their FGW matchings. + +[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), +"Gromov-Wasserstein averaging of kernel and distance matrices". +International Conference on Machine Learning (ICML). + +[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +and Courty Nicolas +"Optimal Transport for structured data with application on graphs" +International Conference on Machine Learning (ICML). 2019. + +[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). +"Gromov-wasserstein learning for graph matching and node embedding". +In International Conference on Machine Learning (ICML), 2019. +""" + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +import numpy as np +import matplotlib.pylab as pl +from ot.gromov import fused_gromov_wasserstein, entropic_fused_gromov_wasserstein +import networkx +from networkx.generators.community import stochastic_block_model as sbm + +############################################################################# +# +# Generate two graphs following Stochastic Block models of 2 and 3 clusters. +# --------------------------------------------- +np.random.seed(0) + +N2 = 20 # 2 communities +N3 = 30 # 3 communities +p2 = [[1., 0.1], + [0.1, 0.9]] +p3 = [[1., 0.1, 0.], + [0.1, 0.95, 0.1], + [0., 0.1, 0.9]] +G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2) +G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3) +part_G2 = [G2.nodes[i]['block'] for i in range(N2)] +part_G3 = [G3.nodes[i]['block'] for i in range(N3)] + +C2 = networkx.to_numpy_array(G2) +C3 = networkx.to_numpy_array(G3) + + +# We add node features with given mean - by clusters +# and inversely proportional to clusters' intra-connectivity + +F2 = np.zeros((N2, 1)) +for i, c in enumerate(part_G2): + F2[i, 0] = np.random.normal(loc=c, scale=0.01) + +F3 = np.zeros((N3, 1)) +for i, c in enumerate(part_G3): + F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01) + +# Compute pairwise euclidean distance between node features +M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T) + +h2 = np.ones(C2.shape[0]) / C2.shape[0] +h3 = np.ones(C3.shape[0]) / C3.shape[0] + +############################################################################# +# +# Compute their Fused Gromov-Wasserstein distances +# --------------------------------------------- + +alpha = 0.5 + + +# Conditional Gradient algorithm +fgw0, log0 = fused_gromov_wasserstein( + M, C2, C3, h2, h3, 'square_loss', alpha=alpha, verbose=True, log=True) + +# Proximal Point algorithm with Kullback-Leibler as proximal operator +fgw, log = entropic_fused_gromov_wasserstein( + M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1., solver='PPA', + log=True, verbose=True, warmstart=False, numItermax=10) + +# Projected Gradient algorithm with entropic regularization +fgwe, loge = entropic_fused_gromov_wasserstein( + M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=0.01, solver='PGD', + log=True, verbose=True, warmstart=False, numItermax=10) + +print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['fgw_dist'])) +print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['fgw_dist'])) +print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['fgw_dist'])) + +# compute OT sparsity level +fgw0_sparsity = 100 * (fgw0 == 0.).astype(np.float64).sum() / (N2 * N3) +fgw_sparsity = 100 * (fgw == 0.).astype(np.float64).sum() / (N2 * N3) +fgwe_sparsity = 100 * (fgwe == 0.).astype(np.float64).sum() / (N2 * N3) + +# Methods using Sinkhorn projections tend to produce feasibility errors on the +# marginal constraints + +err0 = np.linalg.norm(fgw0.sum(1) - h2) + np.linalg.norm(fgw0.sum(0) - h3) +err = np.linalg.norm(fgw.sum(1) - h2) + np.linalg.norm(fgw.sum(0) - h3) +erre = np.linalg.norm(fgwe.sum(1) - h2) + np.linalg.norm(fgwe.sum(0) - h3) + +############################################################################# +# +# Visualization of the Fused Gromov-Wasserstein matchings +# --------------------------------------------- +# +# We color nodes of the graph on the right - then project its node colors +# based on the optimal transport plan from the FGW matchings +# We adjust the intensity of links across domains proportionaly to the mass +# sent, adding a minimal intensity of 0.1 if mass sent is not zero. +# For each matching, all node sizes are proportionnal to their mass computed +# from marginals of the OT plan to illustrate potential feasibility errors. +# NB: colors refer to clusters - not to node features + +# Add weights on the edges for visualization later on +weight_intra_G2 = 5 +weight_inter_G2 = 0.5 +weight_intra_G3 = 1. +weight_inter_G3 = 1.5 + +weightedG2 = networkx.Graph() +part_G2 = [G2.nodes[i]['block'] for i in range(N2)] + +for node in G2.nodes(): + weightedG2.add_node(node) +for i, j in G2.edges(): + if part_G2[i] == part_G2[j]: + weightedG2.add_edge(i, j, weight=weight_intra_G2) + else: + weightedG2.add_edge(i, j, weight=weight_inter_G2) + +weightedG3 = networkx.Graph() +part_G3 = [G3.nodes[i]['block'] for i in range(N3)] + +for node in G3.nodes(): + weightedG3.add_node(node) +for i, j in G3.edges(): + if part_G3[i] == part_G3[j]: + weightedG3.add_edge(i, j, weight=weight_intra_G3) + else: + weightedG3.add_edge(i, j, weight=weight_inter_G3) + + +def draw_graph(G, C, nodes_color_part, Gweights=None, + pos=None, edge_color='black', node_size=None, + shiftx=0, seed=0): + + if (pos is None): + pos = networkx.spring_layout(G, scale=1., seed=seed) + + if shiftx != 0: + for k, v in pos.items(): + v[0] = v[0] + shiftx + + alpha_edge = 0.7 + width_edge = 1.8 + if Gweights is None: + networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color) + else: + # We make more visible connections between activated nodes + n = len(Gweights) + edgelist_activated = [] + edgelist_deactivated = [] + for i in range(n): + for j in range(n): + if Gweights[i] * Gweights[j] * C[i, j] > 0: + edgelist_activated.append((i, j)) + elif C[i, j] > 0: + edgelist_deactivated.append((i, j)) + + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated, + width=width_edge, alpha=alpha_edge, + edge_color=edge_color) + networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated, + width=width_edge, alpha=0.1, + edge_color=edge_color) + + if Gweights is None: + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=node_size, alpha=1, + node_color=node_color) + else: + scaled_Gweights = Gweights / (0.5 * Gweights.max()) + nodes_size = node_size * scaled_Gweights + for node, node_color in enumerate(nodes_color_part): + networkx.draw_networkx_nodes(G, pos, nodelist=[node], + node_size=nodes_size[node], alpha=1, + node_color=node_color) + return pos + + +def draw_transp_colored_GW(G1, C1, G2, C2, part_G1, p1, p2, T, + pos1=None, pos2=None, shiftx=4, switchx=False, + node_size=70, seed_G1=0, seed_G2=0): + starting_color = 0 + # get graphs partition and their coloring + part1 = part_G1.copy() + unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)] + nodes_color_part1 = [] + for cluster in part1: + nodes_color_part1.append(unique_colors[cluster]) + + nodes_color_part2 = [] + # T: getting colors assignment from argmin of columns + for i in range(len(G2.nodes())): + j = np.argmax(T[:, i]) + nodes_color_part2.append(nodes_color_part1[j]) + pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1, + pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1) + pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2, + node_size=node_size, shiftx=shiftx, seed=seed_G2) + + for k1, v1 in pos1.items(): + max_Tk1 = np.max(T[k1, :]) + for k2, v2 in pos2.items(): + if (T[k1, k2] > 0): + pl.plot([pos1[k1][0], pos2[k2][0]], + [pos1[k1][1], pos2[k2][1]], + '-', lw=0.7, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.), + color=nodes_color_part1[k1]) + return pos1, pos2 + + +node_size = 40 +fontsize = 13 +seed_G2 = 0 +seed_G3 = 4 + +pl.figure(2, figsize=(12, 3.5)) +pl.clf() +pl.subplot(131) +pl.axis('off') +pl.axis +pl.title('(CG algo) FGW=%s \n \n OT sparsity = %s \n feasibility error = %s' % ( + np.round(log0['fgw_dist'], 3), str(np.round(fgw0_sparsity, 2)) + ' %', + np.round(err0, 4)), fontsize=fontsize) + +p0, q0 = fgw0.sum(1), fgw0.sum(0) # check marginals + +pos1, pos2 = draw_transp_colored_GW( + weightedG2, C2, weightedG3, C3, part_G2, p1=p0, p2=q0, T=fgw0, + shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3) + +pl.subplot(132) +pl.axis('off') + +p, q = fgw.sum(1), fgw.sum(0) # check marginals + +pl.title('(PP algo) FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % ( + np.round(log['fgw_dist'], 3), str(np.round(fgw_sparsity, 2)) + ' %', + np.round(err, 4)), fontsize=fontsize) + +pos1, pos2 = draw_transp_colored_GW( + weightedG2, C2, weightedG3, C3, part_G2, p1=p, p2=q, T=fgw, + pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) + +pl.subplot(133) +pl.axis('off') + +pe, qe = fgwe.sum(1), fgwe.sum(0) # check marginals + +pl.title('Entropic FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % ( + np.round(loge['fgw_dist'], 3), str(np.round(fgwe_sparsity, 2)) + ' %', + np.round(erre, 4)), fontsize=fontsize) + +pos1, pos2 = draw_transp_colored_GW( + weightedG2, C2, weightedG3, C3, part_G2, p1=pe, p2=qe, T=fgwe, + pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0) + +pl.tight_layout() + +pl.show() diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index afb5bdc..252267f 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -5,13 +5,38 @@ Gromov-Wasserstein example ========================== This example is designed to show how to use the Gromov-Wasserstein distance computation in POT. +We first compare 3 solvers to estimate the distance based on +Conditional Gradient [24] or Sinkhorn projections [12, 51]. +Then we compare 2 stochastic solvers to estimate the distance with a lower +numerical cost [33]. + +[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), +"Gromov-Wasserstein averaging of kernel and distance matrices". +International Conference on Machine Learning (ICML). + +[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +and Courty Nicolas +"Optimal Transport for structured data with application on graphs" +International Conference on Machine Learning (ICML). 2019. + +[33] Kerdoncuff T., Emonet R., Marc S. "Sampled Gromov Wasserstein", +Machine Learning Journal (MJL), 2021. + +[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). +"Gromov-wasserstein learning for graph matching and node embedding". +In International Conference on Machine Learning (ICML), 2019. + """ # Author: Erwan Vautier # Nicolas Courty +# Cédric Vincent-Cuaz +# Tanguy Kerdoncuff # # License: MIT License +# sphinx_gallery_thumbnail_number = 1 + import scipy as sp import numpy as np import matplotlib.pylab as pl @@ -36,7 +61,7 @@ cov_s = np.array([[1, 0], [0, 1]]) mu_t = np.array([4, 4, 4]) cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - +np.random.seed(0) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s) P = sp.linalg.sqrtm(cov_t) xt = np.random.randn(n_samples, 3).dot(P) + mu_t @@ -47,7 +72,7 @@ xt = np.random.randn(n_samples, 3).dot(P) + mu_t # -------------------------- -fig = pl.figure() +fig = pl.figure(1) ax1 = fig.add_subplot(121) ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') ax2 = fig.add_subplot(122, projection='3d') @@ -66,11 +91,15 @@ C2 = sp.spatial.distance.cdist(xt, xt) C1 /= C1.max() C2 /= C2.max() -pl.figure() +pl.figure(2) pl.subplot(121) pl.imshow(C1) +pl.title('C1') + pl.subplot(122) pl.imshow(C2) +pl.title('C2') + pl.show() ############################################################################# @@ -81,32 +110,63 @@ pl.show() p = ot.unif(n_samples) q = ot.unif(n_samples) +# Conditional Gradient algorithm gw0, log0 = ot.gromov.gromov_wasserstein( C1, C2, p, q, 'square_loss', verbose=True, log=True) +# Proximal Point algorithm with Kullback-Leibler as proximal operator gw, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True) - - -print('Gromov-Wasserstein distances: ' + str(log0['gw_dist'])) -print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist'])) - - -pl.figure(1, (10, 5)) - -pl.subplot(1, 2, 1) -pl.imshow(gw0, cmap='jet') -pl.title('Gromov Wasserstein') - -pl.subplot(1, 2, 2) -pl.imshow(gw, cmap='jet') -pl.title('Entropic Gromov Wasserstein') - + C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PPA', + log=True, verbose=True) + +# Projected Gradient algorithm with entropic regularization +gwe, loge = ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PGD', + log=True, verbose=True) + +print('Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['gw_dist'])) +print('Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['gw_dist'])) +print('Entropic Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['gw_dist'])) + +# compute OT sparsity level +gw0_sparsity = 100 * (gw0 == 0.).astype(np.float64).sum() / (n_samples ** 2) +gw_sparsity = 100 * (gw == 0.).astype(np.float64).sum() / (n_samples ** 2) +gwe_sparsity = 100 * (gwe == 0.).astype(np.float64).sum() / (n_samples ** 2) + +# Methods using Sinkhorn projections tend to produce feasibility errors on the +# marginal constraints + +err0 = np.linalg.norm(gw0.sum(1) - p) + np.linalg.norm(gw0.sum(0) - q) +err = np.linalg.norm(gw.sum(1) - p) + np.linalg.norm(gw.sum(0) - q) +erre = np.linalg.norm(gwe.sum(1) - p) + np.linalg.norm(gwe.sum(0) - q) + +pl.figure(3, (10, 6)) +cmap = 'Blues' +fontsize = 12 +pl.subplot(131) +pl.imshow(gw0, cmap=cmap) +pl.title('(CG algo) GW=%s \n \n OT sparsity=%s \n feasibility error=%s' % ( + np.round(log0['gw_dist'], 4), str(np.round(gw0_sparsity, 2)) + ' %', np.round(np.round(err0, 4))), + fontsize=fontsize) + +pl.subplot(132) +pl.imshow(gw, cmap=cmap) +pl.title('(PP algo) GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % ( + np.round(log['gw_dist'], 4), str(np.round(gw_sparsity, 2)) + ' %', np.round(err, 4)), + fontsize=fontsize) + +pl.subplot(133) +pl.imshow(gwe, cmap=cmap) +pl.title('Entropic GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % ( + np.round(loge['gw_dist'], 4), str(np.round(gwe_sparsity, 2)) + ' %', np.round(erre, 4)), + fontsize=fontsize) + +pl.tight_layout() pl.show() ############################################################################# # -# Compute GW with a scalable stochastic method with any loss function +# Compute GW with scalable stochastic methods with any loss function # ---------------------------------------------------------------------- @@ -126,14 +186,14 @@ print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated'])) print('Variance estimated: ' + str(slog['gw_dist_std'])) -pl.figure(1, (10, 5)) +pl.figure(4, (10, 5)) -pl.subplot(1, 2, 1) -pl.imshow(pgw.toarray(), cmap='jet') +pl.subplot(121) +pl.imshow(pgw.toarray(), cmap=cmap) pl.title('Pointwise Gromov Wasserstein') -pl.subplot(1, 2, 2) -pl.imshow(sgw, cmap='jet') +pl.subplot(122) +pl.imshow(sgw, cmap=cmap) pl.title('Sampled Gromov Wasserstein') pl.show() diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index 6184edf..e39d906 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -11,38 +11,51 @@ Solvers related to Gromov-Wasserstein problems. # All submodules and packages from ._utils import (init_matrix, tensor_product, gwloss, gwggrad, - update_square_loss, update_kl_loss, + update_square_loss, update_kl_loss, update_feature_matrix, init_matrix_semirelaxed) + from ._gw import (gromov_wasserstein, gromov_wasserstein2, fused_gromov_wasserstein, fused_gromov_wasserstein2, - solve_gromov_linesearch, gromov_barycenters, fgw_barycenters, - update_structure_matrix, update_feature_matrix) + solve_gromov_linesearch, gromov_barycenters, fgw_barycenters) + from ._bregman import (entropic_gromov_wasserstein, entropic_gromov_wasserstein2, - entropic_gromov_barycenters) + entropic_gromov_barycenters, + entropic_fused_gromov_wasserstein, + entropic_fused_gromov_wasserstein2, + entropic_fused_gromov_barycenters) + from ._estimators import (GW_distance_estimation, pointwise_gromov_wasserstein, sampled_gromov_wasserstein) + from ._semirelaxed import (semirelaxed_gromov_wasserstein, semirelaxed_gromov_wasserstein2, semirelaxed_fused_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein2, - solve_semirelaxed_gromov_linesearch) + solve_semirelaxed_gromov_linesearch, + entropic_semirelaxed_gromov_wasserstein, + entropic_semirelaxed_gromov_wasserstein2, + entropic_semirelaxed_fused_gromov_wasserstein, + entropic_semirelaxed_fused_gromov_wasserstein2) + from ._dictionary import (gromov_wasserstein_dictionary_learning, gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing) -__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', - 'update_square_loss', 'update_kl_loss', 'init_matrix_semirelaxed', +__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss', + 'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed', 'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters', - 'fgw_barycenters', 'update_structure_matrix', 'update_feature_matrix', - 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2', - 'entropic_gromov_barycenters', 'GW_distance_estimation', - 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein', + 'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2', + 'entropic_gromov_barycenters', 'entropic_fused_gromov_wasserstein', + 'entropic_fused_gromov_wasserstein2', 'entropic_fused_gromov_barycenters', + 'GW_distance_estimation', 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein', 'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2', 'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2', - 'solve_semirelaxed_gromov_linesearch', 'gromov_wasserstein_dictionary_learning', + 'solve_semirelaxed_gromov_linesearch', 'entropic_semirelaxed_gromov_wasserstein', + 'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein', + 'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning', 'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning', 'fused_gromov_wasserstein_linear_unmixing'] diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index aa25f1f..18cef56 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -11,23 +11,29 @@ Bregman projections solvers for entropic Gromov-Wasserstein # # License: MIT License +import numpy as np +import warnings + from ..bregman import sinkhorn -from ..utils import dist, list_to_array, check_random_state +from ..utils import dist, list_to_array, check_random_state, unif from ..backend import get_backend from ._utils import init_matrix, gwloss, gwggrad -from ._utils import update_square_loss, update_kl_loss +from ._utils import update_square_loss, update_kl_loss, update_feature_matrix -def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None, - max_iter=1000, tol=1e-9, verbose=False, log=False): +def entropic_gromov_wasserstein( + C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, + tol=1e-9, solver='PGD', warmstart=False, verbose=False, log=False, **kwargs): r""" - Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + Returns the Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + estimated using Sinkhorn projections. - The function solves the following optimization problem: + If `solver="PGD"`, the function solves the following entropic-regularized + Gromov-Wasserstein optimization problem using Projected Gradient Descent [12]: .. math:: - \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) + \mathbf{T}^* \in \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon H(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} @@ -35,6 +41,17 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, \mathbf{T} &\geq 0 + Else if `solver="PPA"`, the function solves the following Gromov-Wasserstein + optimization problem using Proximal Point Algorithm [51]: + + .. math:: + \mathbf{T}^* \in \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 Where : - :math:`\mathbf{C_1}`: Metric cost matrix in the source space @@ -58,13 +75,15 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : string + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + loss_fun : string, optional Loss function used for the solver either 'square_loss' or 'kl_loss' - epsilon : float + epsilon : float, optional Regularization term >0 symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. @@ -72,16 +91,28 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). G0: array-like, shape (ns,nt), optional If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + Otherwise G0 will be used as initial transport of the solver. G0 is not + required to satisfy marginal constraints but we strongly recommand it + to correcly estimate the GW distance. max_iter : int, optional Max number of iterations tol : float, optional Stop threshold on error (>0) + solver: string, optional + Solver to use either 'PGD' for Projected Gradient Descent or 'PPA' + for Proximal Point Algorithm. + Default value is 'PGD'. + warmstart: bool, optional + Either to perform warmstart of dual potentials in the successive + Sinkhorn projections. verbose : bool, optional Print information along iterations log : bool, optional Record log if True. - + **kwargs: dict + parameters can be directly passed to the ot.sinkhorn solver. + Such as `numItermax` and `stopThr` to control its estimation precision, + e.g [51] suggests to use `numItermax=1`. Returns ------- T : array-like, shape (`ns`, `nt`) @@ -96,22 +127,50 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein distance between networks and stable network invariants. Information and Inference: A Journal of the IMA, 8(4), 757-787. + + .. [51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). Gromov-wasserstein + learning for graph matching and node embedding. In International + Conference on Machine Learning (ICML), 2019. """ - C1, C2, p, q = list_to_array(C1, C2, p, q) + if solver not in ['PGD', 'PPA']: + raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver) + + C1, C2 = list_to_array(C1, C2) + arr = [C1, C2] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + if q is not None: + arr.append(list_to_array(q)) + else: + q = unif(C2.shape[0], type_as=C2) + + if G0 is not None: + arr.append(G0) + + nx = get_backend(*arr) + if G0 is None: - nx = get_backend(p, q, C1, C2) G0 = nx.outer(p, q) - else: - nx = get_backend(p, q, C1, C2, G0) + T = G0 constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx) + if symmetric is None: symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) if not symmetric: constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx) + cpt = 0 err = 1 + if warmstart: + # initialize potentials to cope with ot.sinkhorn initialization + N1, N2 = C1.shape[0], C2.shape[0] + mu = nx.zeros(N1, type_as=C1) - np.log(N1) + nu = nx.zeros(N2, type_as=C2) - np.log(N2) + if log: log = {'err': []} @@ -124,7 +183,17 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, tens = gwggrad(constC, hC1, hC2, T, nx) else: tens = 0.5 * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx)) - T = sinkhorn(p, q, tens, epsilon, method='sinkhorn') + + if solver == 'PPA': + tens = tens - epsilon * nx.log(T) + + if warmstart: + T, loginn = sinkhorn(p, q, tens, epsilon, method='sinkhorn', log=True, warmstart=(mu, nu), **kwargs) + mu = epsilon * nx.log(loginn['u']) + nu = epsilon * nx.log(loginn['v']) + + else: + T = sinkhorn(p, q, tens, epsilon, method='sinkhorn', **kwargs) if cpt % 10 == 0: # we can speed up the process by checking for the error only all @@ -142,6 +211,9 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, cpt += 1 + if abs(nx.sum(T) - 1) > 1e-5: + warnings.warn("Solver failed to produce a transport plan. You might " + "want to increase the regularization parameter `epsilon`.") if log: log['gw_dist'] = gwloss(constC, hC1, hC2, T, nx) return T, log @@ -149,17 +221,36 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, return T -def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None, - max_iter=1000, tol=1e-9, verbose=False, log=False): +def entropic_gromov_wasserstein2( + C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000, + tol=1e-9, solver='PGD', warmstart=False, verbose=False, log=False, **kwargs): r""" - Returns the entropic Gromov-Wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + Returns the Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + estimated using Sinkhorn projections. - The function solves the following optimization problem: + If `solver="PGD"`, the function solves the following entropic-regularized + Gromov-Wasserstein optimization problem using Projected Gradient Descent [12]: + + .. math:: + \mathbf{GW} = \mathop{\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon H(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Else if `solver="PPA"`, the function solves the following Gromov-Wasserstein + optimization problem using Proximal Point Algorithm [51]: .. math:: - GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) - \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T})) + \mathbf{GW} = \mathop{\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 Where : - :math:`\mathbf{C_1}`: Metric cost matrix in the source space @@ -183,13 +274,15 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : str + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + loss_fun : string, optional Loss function used for the solver either 'square_loss' or 'kl_loss' - epsilon : float + epsilon : float, optional Regularization term >0 symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. @@ -197,16 +290,28 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). G0: array-like, shape (ns,nt), optional If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + Otherwise G0 will be used as initial transport of the solver. G0 is not + required to satisfy marginal constraints but we strongly recommand it + to correcly estimate the GW distance. max_iter : int, optional Max number of iterations tol : float, optional Stop threshold on error (>0) + solver: string, optional + Solver to use either 'PGD' for Projected Gradient Descent or 'PPA' + for Proximal Point Algorithm. + Default value is 'PGD'. + warmstart: bool, optional + Either to perform warmstart of dual potentials in the successive + Sinkhorn projections. verbose : bool, optional Print information along iterations log : bool, optional Record log if True. - + **kwargs: dict + parameters can be directly passed to the ot.sinkhorn solver. + Such as `numItermax` and `stopThr` to control its estimation precision, + e.g [51] suggests to use `numItermax=1`. Returns ------- gw_dist : float @@ -218,11 +323,15 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + .. [51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). Gromov-wasserstein + learning for graph matching and node embedding. In International + Conference on Machine Learning (ICML), 2019. """ - gw, logv = entropic_gromov_wasserstein( - C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter, tol, verbose, log=True) + T, logv = entropic_gromov_wasserstein( + C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter, + tol, solver, warmstart, verbose, log=True, **kwargs) - logv['T'] = gw + logv['T'] = T if log: return logv['gw_dist'], logv @@ -230,10 +339,13 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None return logv['gw_dist'] -def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmetric=True, - max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None): +def entropic_gromov_barycenters( + N, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss', + epsilon=0.1, symmetric=True, max_iter=1000, tol=1e-9, warmstartT=False, + verbose=False, log=False, init_C=None, random_state=None, **kwargs): r""" - Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` + Returns the Gromov-Wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` + estimated using Gromov-Wasserstein transports from Sinkhorn projections. The function solves the following optimization problem: @@ -252,19 +364,18 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet Size of the targeted barycenter Cs : list of S array-like of shape (ns,ns) Metric cost matrices - ps : list of S array-like of shape (ns,) - Sample weights in the `S` spaces - p : array-like, shape(N,) - Weights in the targeted barycenter - lambdas : list of float + ps : list of S array-like of shape (ns,), optional + Sample weights in the `S` spaces. + If let to its default value None, uniform distributions are taken. + p : array-like, shape (N,), optional + Weights in the targeted barycenter. + If let to its default value None, uniform distribution is taken. + lambdas : list of float, optional List of the `S` spaces' weights. - loss_fun : callable - Tensor-matrix multiplication function based on specific loss function. - update : callable - function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates - :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings - calculated at each iteration - epsilon : float + If let to its default value None, uniform weights are taken. + loss_fun : callable, optional + tensor-matrix multiplication function based on specific loss function + epsilon : float, optional Regularization term >0 symmetric : bool, optional. Either structures are to be assumed symmetric or not. Default value is True. @@ -273,6 +384,9 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet Max number of iterations tol : float, optional Stop threshold on error (>0) + warmstartT: bool, optional + Either to perform warmstart of transport plans in the successive + gromov-wasserstein transport problems. verbose : bool, optional Print information along iterations. log : bool, optional @@ -281,6 +395,8 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet Random initial value for the :math:`\mathbf{C}` matrix provided by user. random_state : int or RandomState instance, optional Fix the seed for reproducibility + **kwargs: dict + parameters can be directly passed to the `ot.entropic_gromov_wasserstein` solver. Returns ------- @@ -296,11 +412,21 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet International Conference on Machine Learning (ICML). 2016. """ Cs = list_to_array(*Cs) - ps = list_to_array(*ps) - p = list_to_array(p) - nx = get_backend(*Cs, *ps, p) + arr = [*Cs] + if ps is not None: + arr += list_to_array(*ps) + else: + ps = [unif(C.shape[0], type_as=C) for C in Cs] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(N, type_as=Cs[0]) + + nx = get_backend(*arr) S = len(Cs) + if lambdas is None: + lambdas = [1. / S] * S # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: @@ -317,11 +443,20 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet error = [] + if warmstartT: + T = [None] * S + while (err > tol) and (cpt < max_iter): Cprev = C + if warmstartT: + T = [entropic_gromov_wasserstein( + Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, T[s], + max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)] + else: + T = [entropic_gromov_wasserstein( + Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, None, + max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)] - T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, None, - max_iter, 1e-4, verbose, log=False) for s in range(S)] if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) @@ -346,3 +481,536 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet return C, {"err": error} else: return C + + +def entropic_fused_gromov_wasserstein( + M, C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, + symmetric=None, alpha=0.5, G0=None, max_iter=1000, tol=1e-9, + solver='PGD', warmstart=False, verbose=False, log=False, **kwargs): + r""" + Returns the Fused Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})` + with pairwise distance matrix :math:`\mathbf{M}` between node feature matrices :math:`\mathbf{Y_1}` and :math:`\mathbf{Y_2}`, + estimated using Sinkhorn projections. + + If `solver="PGD"`, the function solves the following entropic-regularized + Fused Gromov-Wasserstein optimization problem using Projected Gradient Descent [12]: + + .. math:: + \mathbf{T}^* \in \mathop{\arg\min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon H(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Else if `solver="PPA"`, the function solves the following Fused Gromov-Wasserstein + optimization problem using Proximal Point Algorithm [51]: + + .. math:: + \mathbf{T}^* \in\mathop{\arg\min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + Where : + + - :math:`\mathbf{M}`: metric cost matrix between features across domains + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity and feature matrices + - `H`: entropy + - :math:`\alpha`: trade-off parameter + + .. note:: If the inner solver `ot.sinkhorn` did not convergence, the + optimal coupling :math:`\mathbf{T}` returned by this function does not + necessarily satisfy the marginal constraints + :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and + :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned + Fused Gromov-Wasserstein loss does not necessarily satisfy distance + properties and may be negative. + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + loss_fun : string, optional + Loss function used for the solver either 'square_loss' or 'kl_loss' + epsilon : float, optional + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 will be used as initial transport of the solver. G0 is not + required to satisfy marginal constraints but we strongly recommand it + to correcly estimate the GW distance. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + solver: string, optional + Solver to use either 'PGD' for Projected Gradient Descent or 'PPA' + for Proximal Point Algorithm. + Default value is 'PGD'. + warmstart: bool, optional + Either to perform warmstart of dual potentials in the successive + Sinkhorn projections. + verbose : bool, optional + Print information along iterations + log : bool, optional + Record log if True. + **kwargs: dict + parameters can be directly passed to the ot.sinkhorn solver. + Such as `numItermax` and `stopThr` to control its estimation precision, + e.g [51] suggests to use `numItermax=1`. + Returns + ------- + T : array-like, shape (`ns`, `nt`) + Optimal coupling between the two joint spaces + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein + distance between networks and stable network invariants. + Information and Inference: A Journal of the IMA, 8(4), 757-787. + + .. [51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). Gromov-wasserstein + learning for graph matching and node embedding. In International + Conference on Machine Learning (ICML), 2019. + + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + """ + if solver not in ['PGD', 'PPA']: + raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver) + + M, C1, C2 = list_to_array(M, C1, C2) + arr = [M, C1, C2] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + if q is not None: + arr.append(list_to_array(q)) + else: + q = unif(C2.shape[0], type_as=C2) + + if G0 is not None: + arr.append(G0) + + nx = get_backend(*arr) + + if G0 is None: + G0 = nx.outer(p, q) + + T = G0 + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx) + if symmetric is None: + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if not symmetric: + constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx) + cpt = 0 + err = 1 + + if warmstart: + # initialize potentials to cope with ot.sinkhorn initialization + N1, N2 = C1.shape[0], C2.shape[0] + mu = nx.zeros(N1, type_as=C1) - np.log(N1) + nu = nx.zeros(N2, type_as=C2) - np.log(N2) + + if log: + log = {'err': []} + + while (err > tol and cpt < max_iter): + + Tprev = T + + # compute the gradient + if symmetric: + tens = alpha * gwggrad(constC, hC1, hC2, T, nx) + (1 - alpha) * M + else: + tens = (alpha * 0.5) * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx)) + (1 - alpha) * M + + if solver == 'PPA': + tens = tens - epsilon * nx.log(T) + + if warmstart: + T, loginn = sinkhorn(p, q, tens, epsilon, method='sinkhorn', log=True, warmstart=(mu, nu), **kwargs) + mu = epsilon * nx.log(loginn['u']) + nu = epsilon * nx.log(loginn['v']) + + else: + T = sinkhorn(p, q, tens, epsilon, method='sinkhorn', **kwargs) + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.norm(T - Tprev) + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + if abs(nx.sum(T) - 1) > 1e-5: + warnings.warn("Solver failed to produce a transport plan. You might " + "want to increase the regularization parameter `epsilon`.") + if log: + log['fgw_dist'] = (1 - alpha) * nx.sum(M * T) + alpha * gwloss(constC, hC1, hC2, T, nx) + return T, log + else: + return T + + +def entropic_fused_gromov_wasserstein2( + M, C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, + symmetric=None, alpha=0.5, G0=None, max_iter=1000, tol=1e-9, + solver='PGD', warmstart=False, verbose=False, log=False, **kwargs): + r""" + Returns the Fused Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})` + with pairwise distance matrix :math:`\mathbf{M}` between node feature matrices :math:`\mathbf{Y_1}` and :math:`\mathbf{Y_2}`, + estimated using Sinkhorn projections. + + If `solver="PGD"`, the function solves the following entropic-regularized + Fused Gromov-Wasserstein optimization problem using Projected Gradient Descent [12]: + + .. math:: + \mathbf{FGW} = \mathop{\min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon H(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + + Else if `solver="PPA"`, the function solves the following Fused Gromov-Wasserstein + optimization problem using Proximal Point Algorithm [51]: + + .. math:: + \mathbf{FGW} = \mathop{\min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T}^T \mathbf{1} &= \mathbf{q} + + \mathbf{T} &\geq 0 + Where : + + - :math:`\mathbf{M}`: metric cost matrix between features across domains + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - :math:`\mathbf{q}`: distribution in the target space + - `L`: loss function to account for the misfit between the similarity and feature matrices + - `H`: entropy + - :math:`\alpha`: trade-off parameter + + .. note:: If the inner solver `ot.sinkhorn` did not convergence, the + optimal coupling :math:`\mathbf{T}` returned by this function does not + necessarily satisfy the marginal constraints + :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and + :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned + Fused Gromov-Wasserstein loss does not necessarily satisfy distance + properties and may be negative. + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + loss_fun : string, optional + Loss function used for the solver either 'square_loss' or 'kl_loss' + epsilon : float, optional + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 will be used as initial transport of the solver. G0 is not + required to satisfy marginal constraints but we strongly recommand it + to correcly estimate the GW distance. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + Record log if True. + + Returns + ------- + fgw_dist : float + Fused Gromov-Wasserstein distance + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). Gromov-wasserstein + learning for graph matching and node embedding. In International + Conference on Machine Learning (ICML), 2019. + + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas "Optimal Transport for structured data with + application on graphs", International Conference on Machine Learning + (ICML). 2019. + + """ + T, logv = entropic_fused_gromov_wasserstein( + M, C1, C2, p, q, loss_fun, epsilon, symmetric, alpha, G0, max_iter, + tol, solver, warmstart, verbose, log=True, **kwargs) + + logv['T'] = T + + if log: + return logv['fgw_dist'], logv + else: + return logv['fgw_dist'] + + +def entropic_fused_gromov_barycenters( + N, Ys, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss', + epsilon=0.1, symmetric=True, alpha=0.5, max_iter=1000, tol=1e-9, + warmstartT=False, verbose=False, log=False, init_C=None, init_Y=None, + random_state=None, **kwargs): + r""" + Returns the Fused Gromov-Wasserstein barycenters of `S` measurable networks with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}` + estimated using Fused Gromov-Wasserstein transports from Sinkhorn projections. + + The function solves the following optimization problem: + + .. math:: + + \mathbf{C}, \mathbf{Y} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}, \mathbf{Y}\in \mathbb{Y}^{N \times d}} \quad \sum_s \lambda_s \mathrm{FGW}_{\alpha}(\mathbf{C}, \mathbf{C}_s, \mathbf{Y}, \mathbf{Y}_s, \mathbf{p}, \mathbf{p}_s) + + Where : + + - :math:`\mathbf{Y}_s`: feature matrix + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution + + Parameters + ---------- + N : int + Size of the targeted barycenter + Ys: list of array-like, each element has shape (ns,d) + Features of all samples + Cs : list of S array-like of shape (ns,ns) + Metric cost matrices + ps : list of S array-like of shape (ns,), optional + Sample weights in the `S` spaces. + If let to its default value None, uniform distributions are taken. + p : array-like, shape (N,), optional + Weights in the targeted barycenter. + If let to its default value None, uniform distribution is taken. + lambdas : list of float, optional + List of the `S` spaces' weights. + If let to its default value None, uniform weights are taken. + loss_fun : callable, optional + tensor-matrix multiplication function based on specific loss function + epsilon : float, optional + Regularization term >0 + symmetric : bool, optional. + Either structures are to be assumed symmetric or not. Default value is True. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0) + warmstartT: bool, optional + Either to perform warmstart of transport plans in the successive + fused gromov-wasserstein transport problems. + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. + init_C : bool | array-like, shape (N, N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + init_Y : array-like, shape (N,d), optional + Initialization for the barycenters' features. If not set a + random init is used. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + **kwargs: dict + parameters can be directly passed to the `ot.entropic_fused_gromov_wasserstein` solver. + + Returns + ------- + Y : array-like, shape (`N`, `d`) + Feature matrix in the barycenter space (permutated arbitrarily) + C : array-like, shape (`N`, `N`) + Similarity matrix in the barycenter space (permutated as Y's rows) + log : dict + Log dictionary of error during iterations. Return only if `log=True` in parameters. + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + Cs = list_to_array(*Cs) + Ys = list_to_array(*Ys) + arr = [*Cs, *Ys] + if ps is not None: + arr += list_to_array(*ps) + else: + ps = [unif(C.shape[0], type_as=C) for C in Cs] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(N, type_as=Cs[0]) + + nx = get_backend(*arr) + S = len(Cs) + if lambdas is None: + lambdas = [1. / S] * S + + d = Ys[0].shape[1] # dimension on the node features + + # Initialization of C : random SPD matrix (if not provided by user) + if init_C is None: + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) + C = dist(xalea, xalea) + C /= C.max() + C = nx.from_numpy(C, type_as=p) + else: + C = init_C + + # Initialization of Y + if init_Y is None: + Y = nx.zeros((N, d), type_as=ps[0]) + else: + Y = init_Y + + T = [nx.outer(p_, p) for p_ in ps] + + Ms = [dist(Ys[s], Y) for s in range(len(Ys))] + + cpt = 0 + err = 1 + + err_feature = 1 + err_structure = 1 + + if warmstartT: + T = [None] * S + + if log: + log_ = {} + log_['err_feature'] = [] + log_['err_structure'] = [] + log_['Ts_iter'] = [] + + while (err > tol) and (cpt < max_iter): + Cprev = C + Yprev = Y + + if warmstartT: + T = [entropic_fused_gromov_wasserstein( + Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha, + None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)] + + else: + T = [entropic_fused_gromov_wasserstein( + Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha, + None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)] + + if loss_fun == 'square_loss': + C = update_square_loss(p, lambdas, T, Cs) + + elif loss_fun == 'kl_loss': + C = update_kl_loss(p, lambdas, T, Cs) + + Ys_temp = [y.T for y in Ys] + T_temp = [Ts.T for Ts in T] + Y = update_feature_matrix(lambdas, Ys_temp, T_temp, p) + Ms = [dist(Ys[s], Y) for s in range(len(Ys))] + + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err_feature = nx.norm(Y - nx.reshape(Yprev, (N, d))) + err_structure = nx.norm(C - Cprev) + if log: + log_['err_feature'].append(err_feature) + log_['err_structure'].append(err_structure) + log_['Ts_iter'].append(T) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err_structure)) + print('{:5d}|{:8e}|'.format(cpt, err_feature)) + + cpt += 1 + print('Y type:', type(Y)) + if log: + log_['T'] = T # from target to Ys + log_['p'] = p + log_['Ms'] = Ms + + if log: + return Y, C, log_ + else: + return Y, C diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index cdfa9a3..adf6b82 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -16,14 +16,14 @@ import numpy as np from ..utils import dist, UndefinedParameter, list_to_array from ..optim import cg, line_search_armijo, solve_1d_linesearch_quad -from ..utils import check_random_state +from ..utils import check_random_state, unif from ..backend import get_backend, NumpyBackend from ._utils import init_matrix, gwloss, gwggrad -from ._utils import update_square_loss, update_kl_loss +from ._utils import update_square_loss, update_kl_loss, update_feature_matrix -def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, +def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Returns the Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -31,7 +31,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log The function solves the following optimization problem: .. math:: - \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} @@ -60,11 +60,13 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space - loss_fun : str + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. + loss_fun : str, optional loss function used for the solver either 'square_loss' or 'kl_loss' symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. @@ -112,15 +114,24 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log distance between networks and stable network invariants. Information and Inference: A Journal of the IMA, 8(4), 757-787. """ - p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - if G0 is None: - nx = get_backend(p0, q0, C10, C20) + arr = [C1, C2] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + if q is not None: + arr.append(list_to_array(q)) else: + q = unif(C2.shape[0], type_as=C2) + if G0 is not None: G0_ = G0 - nx = get_backend(p0, q0, C10, C20, G0_) - p = nx.to_numpy(p) - q = nx.to_numpy(q) + arr.append(G0) + + nx = get_backend(*arr) + p0, q0, C10, C20 = p, q, C1, C2 + + p = nx.to_numpy(p0) + q = nx.to_numpy(q0) C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) if symmetric is None: @@ -168,7 +179,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log return nx.from_numpy(cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) -def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, +def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Returns the Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -176,7 +187,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo The function solves the following optimization problem: .. math:: - GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} + \mathbf{GW} = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} @@ -209,10 +220,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : array-like, shape (ns,) + p : array-like, shape (ns,), optional Distribution in the source space. - q : array-like, shape (nt,) + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional Distribution in the target space. + If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss' symmetric : bool, optional @@ -266,6 +279,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo # simple get_backend as the full one will be handled in gromov_wasserstein nx = get_backend(C1, C2) + # init marginals if set as None + if p is None: + p = unif(C1.shape[0], type_as=C1) + if q is None: + q = unif(C2.shape[0], type_as=C2) + T, log_gw = gromov_wasserstein( C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0, max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) @@ -286,20 +305,20 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo return gw -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5, +def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5, armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Computes the FGW transport between two graphs (see :ref:`[24] `) .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + + \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + \mathbf{T}^T \mathbf{1} &= \mathbf{q} - \mathbf{\gamma} &\geq 0 + \mathbf{T} &\geq 0 where : @@ -323,10 +342,12 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= Metric cost matrix representative of the structure in the source space C2 : array-like, shape (nt, nt) Metric cost matrix representative of the structure in the target space - p : array-like, shape (ns,) - Distribution in the source space - q : array-like, shape (nt,) - Distribution in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional + Distribution in the target space. + If let to its default value None, uniform distribution is taken. loss_fun : str, optional Loss function used for the solver symmetric : bool, optional @@ -354,7 +375,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= Returns ------- - gamma : array-like, shape (`ns`, `nt`) + T : array-like, shape (`ns`, `nt`) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. @@ -372,16 +393,24 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= distance between networks and stable network invariants. Information and Inference: A Journal of the IMA, 8(4), 757-787. """ - p, q = list_to_array(p, q) - p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha - if G0 is None: - nx = get_backend(p0, q0, C10, C20, M0) + arr = [C1, C2, M] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + if q is not None: + arr.append(list_to_array(q)) else: + q = unif(C2.shape[0], type_as=C2) + if G0 is not None: G0_ = G0 - nx = get_backend(p0, q0, C10, C20, M0, G0_) + arr.append(G0) - p = nx.to_numpy(p) - q = nx.to_numpy(q) + nx = get_backend(*arr) + p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha + + p = nx.to_numpy(p0) + q = nx.to_numpy(q0) C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) M = nx.to_numpy(M0) @@ -433,20 +462,20 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) -def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5, +def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5, armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Computes the FGW distance between two graphs see (see :ref:`[24] `) .. math:: - \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} + \mathbf{GW} = \min_\mathbf{T} \quad (1 - \alpha) \langle \mathbf(T), \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf(T)\mathbf{1} &= \mathbf{p} - \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q} + \mathbf(T)^T \mathbf{1} &= \mathbf{q} - \mathbf{\gamma} &\geq 0 + \mathbf(T) &\geq 0 where : @@ -474,10 +503,12 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric Metric cost matrix representative of the structure in the source space. C2 : array-like, shape (nt, nt) Metric cost matrix representative of the structure in the target space. - p : array-like, shape (ns,) + p : array-like, shape (ns,), optional Distribution in the source space. - q : array-like, shape (nt,) + If let to its default value None, uniform distribution is taken. + q : array-like, shape (nt,), optional Distribution in the target space. + If let to its default value None, uniform distribution is taken. loss_fun : str, optional Loss function used for the solver. symmetric : bool, optional @@ -529,6 +560,12 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric """ nx = get_backend(C1, C2, M) + # init marginals if set as None + if p is None: + p = unif(C1.shape[0], type_as=C1) + if q is None: + q = unif(C2.shape[0], type_as=C2) + T, log_fgw = fused_gromov_wasserstein( M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True, max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) @@ -626,9 +663,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, return alpha, 1, cost_G -def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=False, - max_iter=1000, tol=1e-9, verbose=False, log=False, - init_C=None, random_state=None, **kwargs): +def gromov_barycenters( + N, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss', symmetric=True, armijo=False, + max_iter=1000, tol=1e-9, warmstartT=False, verbose=False, log=False, + init_C=None, random_state=None, **kwargs): r""" Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` @@ -649,13 +687,16 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F Size of the targeted barycenter Cs : list of S array-like of shape (ns, ns) Metric cost matrices - ps : list of S array-like of shape (ns,) - Sample weights in the `S` spaces - p : array-like, shape (N,) - Weights in the targeted barycenter - lambdas : list of float - List of the `S` spaces' weights - loss_fun : callable + ps : list of S array-like of shape (ns,), optional + Sample weights in the `S` spaces. + If let to its default value None, uniform distributions are taken. + p : array-like, shape (N,), optional + Weights in the targeted barycenter. + If let to its default value None, uniform distribution is taken. + lambdas : list of float, optional + List of the `S` spaces' weights. + If let to its default value None, uniform weights are taken. + loss_fun : callable, optional tensor-matrix multiplication function based on specific loss function symmetric : bool, optional. Either structures are to be assumed symmetric or not. Default value is True. @@ -668,6 +709,9 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F Max number of iterations tol : float, optional Stop threshold on relative error (>0) + warmstartT: bool, optional + Either to perform warmstart of transport plans in the successive + fused gromov-wasserstein transport problems.s verbose : bool, optional Print information along iterations. log : bool, optional @@ -692,11 +736,21 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F """ Cs = list_to_array(*Cs) - ps = list_to_array(*ps) - p = list_to_array(p) - nx = get_backend(*Cs, *ps, p) + arr = [*Cs] + if ps is not None: + arr += list_to_array(*ps) + else: + ps = [unif(C.shape[0], type_as=C) for C in Cs] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(N, type_as=Cs[0]) + + nx = get_backend(*arr) S = len(Cs) + if lambdas is None: + lambdas = [1. / S] * S # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: @@ -714,13 +768,19 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F cpt = 0 err = 1 + if warmstartT: + T = [None] * S error = [] while (err > tol and cpt < max_iter): Cprev = C - T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, - max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] + if warmstartT: + T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=T[s], + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] + else: + T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=None, + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) @@ -747,9 +807,11 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F return C -def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, - p=None, loss_fun='square_loss', armijo=False, symmetric=True, max_iter=100, tol=1e-9, - verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs): +def fgw_barycenters( + N, Ys, Cs, ps=None, lambdas=None, alpha=0.5, fixed_structure=False, + fixed_features=False, p=None, loss_fun='square_loss', armijo=False, + symmetric=True, max_iter=100, tol=1e-9, warmstartT=False, verbose=False, + log=False, init_C=None, init_X=None, random_state=None, **kwargs): r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` Parameters @@ -760,16 +822,21 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Features of all samples Cs : list of array-like, each element has shape (ns,ns) Structure matrices of all samples - ps : list of array-like, each element has shape (ns,) + ps : list of array-like, each element has shape (ns,), optional Masses of all samples. - lambdas : list of float - List of the `S` spaces' weights - alpha : float - Alpha parameter for the fgw distance + If let to its default value None, uniform distributions are taken. + lambdas : list of float, optional + List of the `S` spaces' weights. + If let to its default value None, uniform weights are taken. + alpha : float, optional + Alpha parameter for the fgw distance. fixed_structure : bool Whether to fix the structure of the barycenter during the updates fixed_features : bool Whether to fix the feature of the barycenter during the updates + p : array-like, shape (N,), optional + Weights in the targeted barycenter. + If let to its default value None, uniform distribution is taken. loss_fun : str Loss function used for the solver either 'square_loss' or 'kl_loss' symmetric : bool, optional @@ -779,6 +846,9 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Max number of iterations tol : float, optional Stop threshold on relative error (>0) + warmstartT: bool, optional + Either to perform warmstart of transport plans in the successive + fused gromov-wasserstein transport problems. verbose : bool, optional Print information along iterations. log : bool, optional @@ -814,15 +884,24 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ International Conference on Machine Learning (ICML). 2019. """ Cs = list_to_array(*Cs) - ps = list_to_array(*ps) Ys = list_to_array(*Ys) - p = list_to_array(p) - nx = get_backend(*Cs, *Ys, *ps) + arr = [*Cs, *Ys] + if ps is not None: + arr += list_to_array(*ps) + else: + ps = [unif(C.shape[0], type_as=C) for C in Cs] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(N, type_as=Cs[0]) + + nx = get_backend(*arr) S = len(Cs) + if lambdas is None: + lambdas = [1. / S] * S + d = Ys[0].shape[1] # dimension on the node features - if p is None: - p = nx.ones(N, type_as=Cs[0]) / N if fixed_structure: if init_C is None: @@ -877,13 +956,21 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Ms = [dist(X, Ys[s]) for s in range(len(Ys))] if not fixed_structure: + T_temp = [t.T for t in T] if loss_fun == 'square_loss': - T_temp = [t.T for t in T] - C = update_structure_matrix(p, lambdas, T_temp, Cs) + C = update_square_loss(p, lambdas, T_temp, Cs) - T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, - max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] + elif loss_fun == 'kl_loss': + C = update_kl_loss(p, lambdas, T_temp, Cs) + if warmstartT: + T = [fused_gromov_wasserstein( + Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, + G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] + else: + T = [fused_gromov_wasserstein( + Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, + G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] # T is N,ns err_feature = nx.norm(X - nx.reshape(Xprev, (N, d))) err_structure = nx.norm(C - Cprev) @@ -910,82 +997,3 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ return X, C, log_ else: return X, C - - -def update_structure_matrix(p, lambdas, T, Cs): - r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. - - It is calculated at each iteration - - Parameters - ---------- - p : array-like, shape (N,) - Masses in the targeted barycenter. - lambdas : list of float - List of the `S` spaces' weights. - T : list of S array-like of shape (ns, N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. - Cs : list of S array-like, shape (ns, ns) - Metric cost matrices. - - Returns - ------- - C : array-like, shape (`nt`, `nt`) - Updated :math:`\mathbf{C}` matrix. - """ - p = list_to_array(p) - T = list_to_array(*T) - Cs = list_to_array(*Cs) - nx = get_backend(*Cs, *T, p) - - tmpsum = sum([ - lambdas[s] * nx.dot( - nx.dot(T[s].T, Cs[s]), - T[s] - ) for s in range(len(T)) - ]) - ppt = nx.outer(p, p) - return tmpsum / ppt - - -def update_feature_matrix(lambdas, Ys, Ts, p): - r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. - - - See "Solving the barycenter problem with Block Coordinate Descent (BCD)" - in :ref:`[24] ` calculated at each iteration - - Parameters - ---------- - p : array-like, shape (N,) - masses in the targeted barycenter - lambdas : list of float - List of the `S` spaces' weights - Ts : list of S array-like, shape (ns,N) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration - Ys : list of S array-like, shape (d,ns) - The features. - - Returns - ------- - X : array-like, shape (`d`, `N`) - - - .. _references-update-feature-matrix: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - p = list_to_array(p) - Ts = list_to_array(*Ts) - Ys = list_to_array(*Ys) - nx = get_backend(*Ys, *Ts, p) - - p = 1. / p - tmpsum = sum([ - lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] - for s in range(len(Ts)) - ]) - return tmpsum diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 94dc975..206329d 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -18,7 +18,7 @@ from ..backend import get_backend from ._utils import init_matrix_semirelaxed, gwloss, gwggrad -def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None, +def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Returns the semi-relaxed Gromov-Wasserstein divergence transport from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` @@ -26,12 +26,12 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric= The function solves the following optimization problem: .. math:: - \mathbf{srGW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + \mathbf{T}^^* \in \mathop{\arg \min}_{\mathbf{T}} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - \mathbf{\gamma} &\geq 0 + \mathbf{T} &\geq 0 Where : @@ -51,8 +51,9 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric= Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : array-like, shape (ns,) - Distribution in the source space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. @@ -93,11 +94,16 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric= """ if loss_fun == 'kl_loss': raise NotImplementedError() - p = list_to_array(p) - if G0 is None: - nx = get_backend(p, C1, C2) + arr = [C1, C2] + if p is not None: + arr.append(list_to_array(p)) else: - nx = get_backend(p, C1, C2, G0) + p = unif(C1.shape[0], type_as=C1) + + if G0 is not None: + arr.append(G0) + + nx = get_backend(*arr) if symmetric is None: symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) @@ -143,7 +149,7 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric= return semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) -def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None, +def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Returns the semi-relaxed gromov-wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` @@ -151,12 +157,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric The function solves the following optimization problem: .. math:: - srGW = \min_\mathbf{T} \quad \sum_{i,j,k,l} + \text{srGW} = \min_{\mathbf{T}} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - \mathbf{\gamma} &\geq 0 + \mathbf{T} &\geq 0 Where : @@ -179,8 +185,9 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric Metric cost matrix in the source space C2 : array-like, shape (nt, nt) Metric cost matrix in the target space - p : array-like, shape (ns,) + p : array-like, shape (ns,), optional Distribution in the source space. + If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. @@ -218,7 +225,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - nx = get_backend(p, C1, C2) + # partial get_backend as the full one will be handled in gromov_wasserstein + nx = get_backend(C1, C2) + + # init marginals if set as None + if p is None: + p = unif(C1.shape[0], type_as=C1) T, log_srgw = semirelaxed_gromov_wasserstein( C1, C2, p, loss_fun, symmetric, log=True, G0=G0, @@ -239,18 +251,19 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric return srgw -def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, - max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): +def semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, + G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Computes the semi-relaxed FGW transport between two graphs (see :ref:`[48] `) .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \mathbf{T}^* \in \mathop{\arg \min}_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - \mathbf{\gamma} &\geq 0 + \mathbf{T} &\geq 0 where : @@ -273,8 +286,9 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s Metric cost matrix representative of the structure in the source space C2 : array-like, shape (nt, nt) Metric cost matrix representative of the structure in the target space - p : array-like, shape (ns,) - Distribution in the source space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. @@ -321,11 +335,16 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s if loss_fun == 'kl_loss': raise NotImplementedError() - p = list_to_array(p) - if G0 is None: - nx = get_backend(p, C1, C2, M) + arr = [M, C1, C2] + if p is not None: + arr.append(list_to_array(p)) else: - nx = get_backend(p, C1, C2, M, G0) + p = unif(C1.shape[0], type_as=C1) + + if G0 is not None: + arr.append(G0) + + nx = get_backend(*arr) if symmetric is None: symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) @@ -373,18 +392,18 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s return semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) -def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, +def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): r""" Computes the semi-relaxed FGW divergence between two graphs (see :ref:`[48] `) .. math:: - \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} - L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + \mathbf{srFGW} = \min_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l} - s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p} + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} - \mathbf{\gamma} &\geq 0 + \mathbf{T} &\geq 0 where : @@ -412,6 +431,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', Metric cost matrix representative of the structure in the target space. p : array-like, shape (ns,) Distribution in the source space. + If let to its default value None, uniform distribution is taken. loss_fun : str, optional loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. @@ -455,7 +475,12 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ - nx = get_backend(p, C1, C2, M) + # partial get_backend as the full one will be handled in gromov_wasserstein + nx = get_backend(C1, C2) + + # init marginals if set as None + if p is None: + p = unif(C1.shape[0], type_as=C1) T, log_fgw = semirelaxed_fused_gromov_wasserstein( M, C1, C2, p, loss_fun, symmetric, alpha, G0, log=True, @@ -551,3 +576,501 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, cost_G = cost_G + a * (alpha ** 2) + b * alpha return alpha, 1, cost_G + + +def entropic_semirelaxed_gromov_wasserstein( + C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, + G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + r""" + Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence + transport plan from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` + estimated using a Mirror Descent algorithm following the KL geometry. + + The function solves the following optimization problem: + + .. math:: + \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + + - `L`: loss function to account for the misfit between the similarity matrices + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error computed on transport plans + log : bool, optional + record log if True + verbose : bool, optional + Print information along iterations + Returns + ------- + G : array-like, shape (`ns`, `nt`) + Coupling between the two spaces that minimizes: + + :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}` + log : dict + Convergence information and loss. + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if loss_fun == 'kl_loss': + raise NotImplementedError() + arr = [C1, C2] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + + if G0 is not None: + arr.append(G0) + + nx = get_backend(*arr) + + if symmetric is None: + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: + q = unif(C2.shape[0], type_as=p) + G0 = nx.outer(p, q) + else: + q = nx.sum(G0, 0) + # Check first marginal of G0 + np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + + constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) + + ones_p = nx.ones(p.shape[0], type_as=p) + + if symmetric: + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return gwggrad(constC + marginal_product, hC1, hC2, G, nx) + else: + constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx) + + def df(G): + qG = nx.sum(G, 0) + marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) + marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) + return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + + cpt = 0 + err = 1e15 + G = G0 + + if log: + log = {'err': []} + + while (err > tol and cpt < max_iter): + + Gprev = G + # compute the kernel + K = G * nx.exp(- df(G) / epsilon) + scaling = p / nx.sum(K, 1) + G = nx.reshape(scaling, (-1, 1)) * K + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.norm(G - Gprev) + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + if log: + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + log['srgw_dist'] = gwloss(constC + marginal_product, hC1, hC2, G, nx) + return G, log + else: + return G + + +def entropic_semirelaxed_gromov_wasserstein2( + C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, + G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + r""" + Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence + from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` + estimated using a Mirror Descent algorithm following the KL geometry. + + The function solves the following optimization problem: + + .. math:: + \mathbf{srGW} = \min_{\mathbf{T}} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T} &\geq 0 + Where : + + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}`: distribution in the source space + - `L`: loss function to account for the misfit between the similarity + matrices + + Note that when using backends, this loss function is differentiable wrt the + matrices (C1, C2) but not yet for the weights p. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + verbose : bool, optional + Print information along iterations + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error computed on transport plans + log : bool, optional + record log if True + verbose : bool, optional + Print information along iterations + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + srgw : float + Semi-relaxed Gromov-Wasserstein divergence + log : dict + convergence information and Coupling matrix + + References + ---------- + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + T, log_srgw = entropic_semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun, epsilon, symmetric, G0, + max_iter, tol, log=True, verbose=verbose, **kwargs) + + log_srgw['T'] = T + + if log: + return log_srgw['srgw_dist'], log_srgw + else: + return log_srgw['srgw_dist'] + + +def entropic_semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1, + alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + r""" + Computes the entropic-regularized semi-relaxed FGW transport between two graphs (see :ref:`[48] `) + + .. math:: + \mathbf{T}^* \in \mathop{\arg \min}_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix between features + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}` source weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] ` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + loss_fun : str + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error computed on transport plans + log : bool, optional + record log if True + verbose : bool, optional + Print information along iterations + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + G : array-like, shape (`ns`, `nt`) + Optimal transportation matrix for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-semirelaxed-fused-gromov-wasserstein: + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if loss_fun == 'kl_loss': + raise NotImplementedError() + arr = [M, C1, C2] + if p is not None: + arr.append(list_to_array(p)) + else: + p = unif(C1.shape[0], type_as=C1) + + if G0 is not None: + arr.append(G0) + + nx = get_backend(*arr) + + if symmetric is None: + symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: + q = unif(C2.shape[0], type_as=p) + G0 = nx.outer(p, q) + else: + q = nx.sum(G0, 0) + # Check first marginal of G0 + np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08) + + constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx) + + ones_p = nx.ones(p.shape[0], type_as=p) + dM = (1 - alpha) * M + if symmetric: + def df(G): + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + return alpha * gwggrad(constC + marginal_product, hC1, hC2, G, nx) + dM + else: + constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx) + + def df(G): + qG = nx.sum(G, 0) + marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t)) + marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2)) + return 0.5 * alpha * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + dM + + cpt = 0 + err = 1e15 + G = G0 + + if log: + log = {'err': []} + + while (err > tol and cpt < max_iter): + + Gprev = G + # compute the kernel + K = G * nx.exp(- df(G) / epsilon) + scaling = p / nx.sum(K, 1) + G = nx.reshape(scaling, (-1, 1)) * K + if cpt % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.norm(G - Gprev) + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + cpt += 1 + + if log: + qG = nx.sum(G, 0) + marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) + log['srfgw_dist'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) + (1 - alpha) * nx.sum(M * G) + return G, log + else: + return G + + +def entropic_semirelaxed_fused_gromov_wasserstein2( + M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1, + alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + r""" + Computes the entropic-regularized semi-relaxed FGW transport between two graphs (see :ref:`[48] `) + + .. math:: + \mathbf{srFGW} = \min_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p} + + \mathbf{T} &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix between features + - :math:`\mathbf{C_1}`: Metric cost matrix in the source space + - :math:`\mathbf{C_2}`: Metric cost matrix in the target space + - :math:`\mathbf{p}` source weights (sum to 1) + - `L` is a loss function to account for the misfit between the similarity matrices + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. However all the steps in the conditional + gradient are not differentiable. + + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] ` + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + C1 : array-like, shape (ns, ns) + Metric cost matrix representative of the structure in the source space. + C2 : array-like, shape (nt, nt) + Metric cost matrix representative of the structure in the target space. + p : array-like, shape (ns,), optional + Distribution in the source space. + If let to its default value None, uniform distribution is taken. + loss_fun : str, optional + loss function used for the solver either 'square_loss' or 'kl_loss'. + 'kl_loss' is not implemented yet and will raise an error. + epsilon : float + Regularization term >0 + symmetric : bool, optional + Either C1 and C2 are to be assumed symmetric or not. + If let to its default None value, a symmetry test will be conducted. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error computed on transport plans + log : bool, optional + record log if True + verbose : bool, optional + Print information along iterations + **kwargs : dict + Parameters can be directly passed to the ot.optim.cg solver. + + Returns + ------- + srfgw-divergence : float + Semi-relaxed Fused gromov wasserstein divergence for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + + + .. _references-semirelaxed-fused-gromov-wasserstein2: + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + T, log_srfgw = entropic_semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p, loss_fun, symmetric, epsilon, alpha, G0, + max_iter, tol, log=True, verbose=verbose, **kwargs) + + log_srfgw['T'] = T + + if log: + return log_srfgw['srfgw_dist'], log_srfgw + else: + return log_srfgw['srfgw_dist'] diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index ef8cd88..0b8bb00 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -324,6 +324,49 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt) +def update_feature_matrix(lambdas, Ys, Ts, p): + r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. + + + See "Solving the barycenter problem with Block Coordinate Descent (BCD)" + in :ref:`[24] ` calculated at each iteration + + Parameters + ---------- + p : array-like, shape (N,) + masses in the targeted barycenter + lambdas : list of float + List of the `S` spaces' weights + Ts : list of S array-like, shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + Ys : list of S array-like, shape (d,ns) + The features. + + Returns + ------- + X : array-like, shape (`d`, `N`) + + + .. _references-update-feature-matrix: + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + p = list_to_array(p) + Ts = list_to_array(*Ts) + Ys = list_to_array(*Ys) + nx = get_backend(*Ys, *Ts, p) + + p = 1. / p + tmpsum = sum([ + lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] + for s in range(len(Ts)) + ]) + return tmpsum + + def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation diff --git a/test/test_gromov.py b/test/test_gromov.py index 1beb818..13ff3fe 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -34,8 +34,11 @@ def test_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=G0b, verbose=True)) + G = ot.gromov.gromov_wasserstein( + C1, C2, None, q, 'square_loss', G0=G0, verbose=True, + alpha_min=0., alpha_max=1.) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein( + C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True)) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -48,8 +51,8 @@ def test_gromov(nx): np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) - gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, log=True) - gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=True) + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=True, log=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=True, log=True) gwb = nx.to_numpy(gwb) gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False) @@ -312,11 +315,11 @@ def test_entropic_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-2, verbose=True, log=True) + C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-2, max_iter=10, verbose=True, log=True) Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, - epsilon=1e-2, verbose=True, log=False + C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None, + epsilon=1e-2, max_iter=10, verbose=True, log=False )) # check constraints @@ -327,10 +330,10 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw, log = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, + C1, C2, p, None, 'kl_loss', symmetric=True, G0=None, max_iter=10, epsilon=1e-2, log=True) gwb, logb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b, max_iter=10, epsilon=1e-2, log=True) gwb = nx.to_numpy(gwb) @@ -348,6 +351,65 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_entropic_proximal_gromov(nx): + n_samples = 10 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + G, log = ot.gromov.entropic_gromov_wasserstein( + C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=50, solver='PPA', verbose=True, log=True, numItermax=1) + Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None, + epsilon=1e-1, max_iter=50, solver='PPA', verbose=True, log=False, numItermax=1 + )) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + gw, log = ot.gromov.entropic_gromov_wasserstein2( + C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, + max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + gwb, logb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + gwb = nx.to_numpy(gwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") def test_asymmetric_entropic_gromov(nx): n_samples = 10 # nb samples np.random.seed(0) @@ -363,10 +425,10 @@ def test_asymmetric_entropic_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, verbose=True, log=False) + epsilon=1e-1, max_iter=5, verbose=True, log=False) Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, - epsilon=1e-1, verbose=True, log=False + epsilon=1e-1, max_iter=5, verbose=True, log=False )) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -376,11 +438,11 @@ def test_asymmetric_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, - max_iter=10, epsilon=1e-1, log=False) + C1, C2, None, None, 'kl_loss', symmetric=False, G0=None, + max_iter=5, epsilon=1e-1, log=False) gwb = ot.gromov.entropic_gromov_wasserstein2( C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=10, epsilon=1e-1, log=False) + max_iter=5, epsilon=1e-1, log=False) gwb = nx.to_numpy(gwb) np.testing.assert_allclose(gw, gwb, atol=1e-06) @@ -414,15 +476,300 @@ def test_entropic_gromov_dtype_device(nx): C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp) - Gb = ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True - ) - gw_valb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True - ) + for solver in ['PGD', 'PPA']: + Gb = ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', epsilon=1e-1, max_iter=5, + solver=solver, verbose=True + ) + gw_valb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'square_loss', epsilon=1e-1, max_iter=5, + solver=solver, verbose=True + ) - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, gw_valb) + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_fgw(nx): + n_samples = 10 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + G, log = ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, None, None, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, verbose=True, log=True) + Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, + epsilon=1e-1, max_iter=10, verbose=True, log=False + )) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2( + M, C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, + max_iter=10, epsilon=1e-1, log=True) + fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=10, epsilon=1e-1, log=True) + fgwb = nx.to_numpy(fgwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(fgw, fgwb, atol=1e-06) + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +def test_entropic_proximal_fgw(nx): + n_samples = 10 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + G, log = ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1) + Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, + epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=False, numItermax=1 + )) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2( + M, C1, C2, p, None, 'kl_loss', symmetric=True, G0=None, + max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + fgwb = nx.to_numpy(fgwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(fgw, fgwb, atol=1e-06) + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +def test_asymmetric_entropic_fgw(nx): + n_samples = 10 # nb samples + np.random.seed(0) + C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + np.random.shuffle(idx) + C2 = C1[idx, :][:, idx] + + ys = np.random.randn(n_samples, 2) + yt = ys[idx, :] + M = ot.dist(ys, yt) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + G = ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + max_iter=5, epsilon=1e-1, verbose=True, log=False) + Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, + max_iter=5, epsilon=1e-1, verbose=True, log=False + )) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + fgw = ot.gromov.entropic_fused_gromov_wasserstein2( + M, C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, + max_iter=5, epsilon=1e-1, log=False) + fgwb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=5, epsilon=1e-1, log=False) + fgwb = nx.to_numpy(fgwb) + + np.testing.assert_allclose(fgw, fgwb, atol=1e-06) + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_fgw_dtype_device(nx): + # setup + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + Mb, C1b, C2b, pb, qb = nx.from_numpy(M, C1, C2, p, q, type_as=tp) + + for solver in ['PGD', 'PPA']: + Gb = ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', epsilon=0.1, max_iter=5, + solver=solver, verbose=True + ) + fgw_valb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, 'square_loss', epsilon=0.1, max_iter=5, + solver=solver, verbose=True + ) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, fgw_valb) + + +def test_entropic_fgw_barycenter(nx): + ns = 5 + nt = 10 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + ys = np.random.randn(Xs.shape[0], 2) + yt = np.random.randn(Xt.shape[0], 2) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + p1 = ot.unif(ns) + p2 = ot.unif(nt) + n_samples = 2 + p = ot.unif(n_samples) + + ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) + + X, C, log = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], 'square_loss', 0.1, + max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42, + solver='PPA', numItermax=1, log=True + ) + Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', 0.1, + max_iter=10, tol=1e-3, verbose=False, warmstartT=True, random_state=42, + solver='PPA', numItermax=1, log=False) + Xb, Cb = nx.to_numpy(Xb, Cb) + + np.testing.assert_allclose(C, Cb, atol=1e-06) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X, Xb, atol=1e-06) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + + # test with 'kl_loss' and log=True + # providing init_C, init_Y + generator = ot.utils.check_random_state(42) + xalea = generator.randn(n_samples, 2) + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + + init_Y = np.zeros((n_samples, ys.shape[1]), dtype=ys.dtype) + init_Yb = nx.from_numpy(init_Y) + + X, C, log = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], p, None, 'kl_loss', 0.1, + max_iter=10, tol=1e-3, verbose=False, warmstartT=False, random_state=42, + solver='PPA', numItermax=1, init_C=init_C, init_Y=init_Y, log=True + ) + Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', 0.1, + max_iter=10, tol=1e-3, verbose=False, warmstartT=False, random_state=42, + solver='PPA', numItermax=1, init_C=init_Cb, init_Y=init_Yb, log=True) + Xb, Cb = nx.to_numpy(Xb, Cb) + + np.testing.assert_allclose(C, Cb, atol=1e-06) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X, Xb, atol=1e-06) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + np.testing.assert_array_almost_equal(log['err_feature'], nx.to_numpy(*logb['err_feature'])) + np.testing.assert_array_almost_equal(log['err_structure'], nx.to_numpy(*logb['err_structure'])) def test_pointwise_gromov(nx): @@ -539,11 +886,11 @@ def test_gromov_barycenter(nx): C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) Cb = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], + n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42 ) Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42 )) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) @@ -551,12 +898,12 @@ def test_gromov_barycenter(nx): # test of gromov_barycenters with `log` on Cb_, err_ = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True + n_samples, [C1, C2], [p1, p2], p, None, 'square_loss', max_iter=100, + tol=1e-3, verbose=False, warmstartT=True, random_state=42, log=True ) Cbb_, errb_ = ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'square_loss', max_iter=100, + tol=1e-3, verbose=False, warmstartT=True, random_state=42, log=True ) Cbb_ = nx.to_numpy(Cbb_) np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) @@ -565,23 +912,31 @@ def test_gromov_barycenter(nx): Cb2 = ot.gromov.gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + 'kl_loss', max_iter=100, tol=1e-3, warmstartT=True, random_state=42 ) Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + 'kl_loss', max_iter=100, tol=1e-3, warmstartT=True, random_state=42 )) np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) # test of gromov_barycenters with `log` on + # providing init_C + generator = ot.utils.check_random_state(42) + xalea = generator.randn(n_samples, 2) + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + Cb2_, err2_ = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True + n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', max_iter=100, + tol=1e-3, verbose=False, random_state=42, log=True, init_C=init_C ) Cb2b_, err2b_ = ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', + max_iter=100, tol=1e-3, verbose=True, random_state=42, + init_C=init_Cb, log=True ) Cb2b_ = nx.to_numpy(Cb2b_) np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) @@ -607,24 +962,24 @@ def test_gromov_entropic_barycenter(nx): C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) Cb = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', 1e-3, + max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 ) Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', 1e-3, + max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 )) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) # test of entropic_gromov_barycenters with `log` on Cb_, err_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + n_samples, [C1, C2], [p1, p2], p, None, + 'square_loss', 1e-3, max_iter=10, tol=1e-3, verbose=True, random_state=42, log=True ) Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'square_loss', 1e-3, max_iter=10, tol=1e-3, verbose=True, random_state=42, log=True ) Cbb_ = nx.to_numpy(Cbb_) np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) @@ -633,23 +988,32 @@ def test_gromov_entropic_barycenter(nx): Cb2 = ot.gromov.entropic_gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42 ) Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42 )) np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) # test of entropic_gromov_barycenters with `log` on + # providing init_C + generator = ot.utils.check_random_state(42) + xalea = generator.randn(n_samples, 2) + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', 1e-3, + max_iter=10, tol=1e-3, warmstartT=True, verbose=True, random_state=42, + init_C=init_C, log=True ) Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'kl_loss', 1e-3, max_iter=10, tol=1e-3, warmstartT=True, verbose=True, + random_state=42, init_Cb=init_Cb, log=True ) Cb2b_ = nx.to_numpy(Cb2b_) np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) @@ -685,8 +1049,8 @@ def test_fgw(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, None, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, None, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True) Gb = nx.to_numpy(Gb) # check constraints @@ -701,8 +1065,8 @@ def test_fgw(nx): np.testing.assert_allclose( Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True) + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, None, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, None, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True) fgwb = nx.to_numpy(fgwb) G = log['T'] @@ -923,6 +1287,9 @@ def test_fgw_barycenter(nx): C1 = ot.dist(Xs) C2 = ot.dist(Xt) + C1 /= C1.max() + C2 /= C2.max() + p1, p2 = ot.unif(ns), ot.unif(nt) n_samples = 3 p = ot.unif(n_samples) @@ -930,18 +1297,19 @@ def test_fgw_barycenter(nx): ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, + n_samples, [ysb, ytb], [C1b, C2b], None, [.5, .5], 0.5, fixed_structure=False, fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345 ) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) + init_C /= init_C.max() init_Cb = nx.from_numpy(init_C) Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5], + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, - p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3 + p=None, loss_fun='square_loss', max_iter=100, tol=1e-3 ) Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) @@ -953,11 +1321,21 @@ def test_fgw_barycenter(nx): Xb, Cb, logb = ot.gromov.fgw_barycenters( n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, fixed_features=True, init_X=init_Xb, - p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765 + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True ) - Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) - np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + # add test with 'kl_loss' + X, C = ot.gromov.fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', + max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True, random_state=12345 + ) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) def test_gromov_wasserstein_linear_unmixing(nx): @@ -1501,8 +1879,11 @@ def test_semirelaxed_gromov(nx): # asymmetric C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) - Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=False, log=True, G0=None) + G, log = ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein( + C1b, C2b, None, loss_fun='square_loss', symmetric=False, log=True, + G0=None, alpha_min=0., alpha_max=1.) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -1510,8 +1891,10 @@ def test_semirelaxed_gromov(nx): np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1, C2, None, loss_fun='square_loss', symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) G = log2['T'] Gb = nx.to_numpy(logb2['T']) @@ -1527,16 +1910,20 @@ def test_semirelaxed_gromov(nx): C1 = 0.5 * (C1 + C1.T) C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None) - Gb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b) + G, log = ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None) + Gb = ot.gromov.semirelaxed_gromov_wasserstein( + C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0) @@ -1661,7 +2048,7 @@ def test_semirelaxed_fgw(nx): # asymmetric Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b) # check constraints @@ -1670,7 +2057,7 @@ def test_semirelaxed_fgw(nx): np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) G = log2['T'] Gb = nx.to_numpy(logb2['T']) @@ -1819,3 +2206,276 @@ def test_srfgw_helper_backend(nx): res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) # check constraints np.testing.assert_allclose(res, Gb, atol=1e-06) + + +def test_entropic_semirelaxed_gromov(nx): + np.random.seed(0) + # unbalanced proportions + list_n = [30, 15] + nt = 2 + ns = np.sum(list_n) + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns), dtype=np.float64) + C2 = np.array([[0.8, 0.05], + [0.05, 1.]], dtype=np.float64) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + p = ot.unif(ns, type_as=C1) + q0 = ot.unif(C2.shape[0], type_as=C1) + G0 = p[:, None] * q0[None, :] + # asymmetric + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + epsilon = 0.1 + G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=False, log=True, G0=None) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, None, loss_fun='square_loss', epsilon=epsilon, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + + # symmetric + C1 = 0.5 * (C1 + C1.T) + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + + G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0) + + G = log2['T'] + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_semirelaxed_gromov_dtype_device(nx): + # setup + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + C1b, C2b, pb = nx.from_numpy(C1, C2, p, type_as=tp) + + Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True + ) + gw_valb = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True + ) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + +def test_entropic_semirelaxed_fgw(nx): + np.random.seed(0) + list_n = [16, 8] + nt = 2 + ns = 24 + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns)) + C2 = np.array([[0.7, 0.05], + [0.05, 0.9]]) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + F1 = np.zeros((ns, 1)) + F1[:16] = np.random.normal(loc=0., scale=0.01, size=(16, 1)) + F1[16:] = np.random.normal(loc=1., scale=0.01, size=(8, 1)) + F2 = np.zeros((2, 1)) + F2[1, :] = 1. + M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T) + + p = ot.unif(ns) + q0 = ot.unif(C2.shape[0]) + G0 = p[:, None] * q0[None, :] + + # asymmetric + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + + G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + Gb, logb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + + # symmetric + C1 = 0.5 * (C1 + C1.T) + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + + G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_semirelaxed_fgw_dtype_device(nx): + # setup + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + Mb, C1b, C2b, pb = nx.from_numpy(M, C1, C2, p, type_as=tp) + + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True + ) + fgw_valb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True + ) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, fgw_valb) + + +def test_not_implemented_solver(): + # test sinkhorn + n_samples = 5 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + M = ot.dist(ys, yt) + + solver = 'not_implemented' + # entropic gw and fgw + with pytest.raises(ValueError): + ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) + with pytest.raises(ValueError): + ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) + + # exact and entropic srgw and srfgw loss functions + loss_fun = 'kl_loss' + with pytest.raises(NotImplementedError): + ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun, armijo=False) + with pytest.raises(NotImplementedError): + ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun, epsilon=0.1) + with pytest.raises(NotImplementedError): + ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun) + with pytest.raises(NotImplementedError): + ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p, loss_fun, epsilon=0.1) -- cgit v1.2.3