summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCédric Vincent-Cuaz <cedvincentcuaz@gmail.com>2023-06-12 12:01:48 +0200
committerGitHub <noreply@github.com>2023-06-12 12:01:48 +0200
commit9076f02903ba2fb9ea9fe704764a755cad8dcd63 (patch)
treeb7fda84880c5dabd1c441a1655741493e0683342
parentf0dab2f684f4fc768fd50e0b70918e075dcdd0f3 (diff)
[FEAT] Entropic gw/fgw/srgw/srfgw solvers (#455)upstream/latest
* add entropic fgw + fgw bary + srgw + srfgw with tests * add exemples for entropic srgw - srfgw solvers * add PPA solvers for GW/FGW + complete previous commits * update readme * add tests * add examples + tests + warning in entropic solvers + releases * reduce testing runtimes for test_gromov * fix conflicts * optional marginals * improve coverage * gromov doc harmonization * fix pep8 * complete optional marginal for entropic srfgw --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--CONTRIBUTORS.md2
-rw-r--r--README.md8
-rw-r--r--RELEASES.md10
-rw-r--r--examples/gromov/plot_entropic_semirelaxed_fgw.py304
-rw-r--r--examples/gromov/plot_fgw.py32
-rw-r--r--examples/gromov/plot_fgw_solvers.py288
-rw-r--r--examples/gromov/plot_gromov.py112
-rw-r--r--ot/gromov/__init__.py37
-rw-r--r--ot/gromov/_bregman.py782
-rw-r--r--ot/gromov/_gw.py318
-rw-r--r--ot/gromov/_semirelaxed.py591
-rw-r--r--ot/gromov/_utils.py43
-rw-r--r--test/test_gromov.py796
13 files changed, 2948 insertions, 375 deletions
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 6b35653..2d25f3e 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -36,7 +36,7 @@ The contributors to this library are:
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
-* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, semi-relaxed FGW)
+* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
diff --git a/README.md b/README.md
index c16b328..25c0401 100644
--- a/README.md
+++ b/README.md
@@ -27,8 +27,8 @@ POT provides the following generic OT solvers (links to examples):
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
* Weak OT solver between empirical distributions [39]
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) with LP solver (only small scale).
-* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from Graph Dictionary Learning [38]
- * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
+* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12,51]), differentiable using gradients from Graph Dictionary Learning [38]
+ * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) (exact [24] and regularized [12,51]).
* [Stochastic
solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and
[differentiable losses](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html) for
@@ -42,7 +42,7 @@ POT provides the following generic OT solvers (links to examples):
* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
-* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) [48].
+* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
POT provides the following Machine Learning related solvers:
@@ -310,3 +310,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
+
+[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019.
diff --git a/RELEASES.md b/RELEASES.md
index 61ad2ca..2f47f40 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -5,13 +5,18 @@
#### New features
- Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483)
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
-- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
+- Added the sparsity-constrained OT solver to `ot.smooth` and added `projection_sparse_simplex` to `ot.utils` (PR #459)
- Add tests on GPU for master branch and approved PR (PR #473)
- Add `median` method to all inherited classes of `backend.Backend` (PR #472)
- Update tests for macOS and Windows, speedup documentation (PR #484)
+- Added Proximal Point algorithm to solve GW problems via a new parameter `solver="PPA"` in `ot.gromov.entropic_gromov_wasserstein` + examples (PR #455)
+- Added features `warmstart` and `kwargs` in `ot.gromov.entropic_gromov_wasserstein` to respectively perform warmstart on dual potentials and pass parameters to `ot.sinkhorn` (PR #455)
+- Added sinkhorn projection based solvers for FGW `ot.gromov.entropic_fused_gromov_wasserstein` and entropic FGW barycenters + examples (PR #455)
+- Added features `warmstartT` and `kwargs` to all CG and entropic (F)GW barycenter solvers (PR #455)
+- Added entropic semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #455)
+- Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455)
#### Closed issues
-
- Fix circleci-redirector action and codecov (PR #460)
- Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457)
- Major documentation cleanup (PR #462, #467, #475)
@@ -22,6 +27,7 @@
- Fix `utils.cost_normalization` function issue to work with multiple backends (PR #472)
## 0.9.0
+*April 2023*
This new release contains so many new features and bug fixes since 0.8.2 that we
decided to make it a new minor release at 0.9.0.
diff --git a/examples/gromov/plot_entropic_semirelaxed_fgw.py b/examples/gromov/plot_entropic_semirelaxed_fgw.py
new file mode 100644
index 0000000..642baea
--- /dev/null
+++ b/examples/gromov/plot_entropic_semirelaxed_fgw.py
@@ -0,0 +1,304 @@
+# -*- coding: utf-8 -*-
+"""
+==========================
+Entropic-regularized semi-relaxed (Fused) Gromov-Wasserstein example
+==========================
+
+This example is designed to show how to use the entropic semi-relaxed Gromov-Wasserstein
+and the entropic semi-relaxed Fused Gromov-Wasserstein divergences.
+
+Entropic-regularized sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of
+G2 at a minimal entropic-regularized (F)GW distance from G1.
+
+First, we generate two graphs following Stochastic Block Models, then show
+how to compute their srGW matchings and illustrate them. These graphs are then
+endowed with node features and we follow the same process with srFGW.
+
+[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+International Conference on Learning Representations (ICLR), 2021.
+"""
+
+# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 1
+
+import numpy as np
+import matplotlib.pylab as pl
+from ot.gromov import entropic_semirelaxed_gromov_wasserstein, entropic_semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein
+import networkx
+from networkx.generators.community import stochastic_block_model as sbm
+
+#############################################################################
+#
+# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
+# ---------------------------------------------
+
+
+N2 = 20 # 2 communities
+N3 = 30 # 3 communities
+p2 = [[1., 0.1],
+ [0.1, 0.9]]
+p3 = [[1., 0.1, 0.],
+ [0.1, 0.95, 0.1],
+ [0., 0.1, 0.9]]
+G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)
+G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)
+
+
+C2 = networkx.to_numpy_array(G2)
+C3 = networkx.to_numpy_array(G3)
+
+h2 = np.ones(C2.shape[0]) / C2.shape[0]
+h3 = np.ones(C3.shape[0]) / C3.shape[0]
+
+# Add weights on the edges for visualization later on
+weight_intra_G2 = 5
+weight_inter_G2 = 0.5
+weight_intra_G3 = 1.
+weight_inter_G3 = 1.5
+
+weightedG2 = networkx.Graph()
+part_G2 = [G2.nodes[i]['block'] for i in range(N2)]
+
+for node in G2.nodes():
+ weightedG2.add_node(node)
+for i, j in G2.edges():
+ if part_G2[i] == part_G2[j]:
+ weightedG2.add_edge(i, j, weight=weight_intra_G2)
+ else:
+ weightedG2.add_edge(i, j, weight=weight_inter_G2)
+
+weightedG3 = networkx.Graph()
+part_G3 = [G3.nodes[i]['block'] for i in range(N3)]
+
+for node in G3.nodes():
+ weightedG3.add_node(node)
+for i, j in G3.edges():
+ if part_G3[i] == part_G3[j]:
+ weightedG3.add_edge(i, j, weight=weight_intra_G3)
+ else:
+ weightedG3.add_edge(i, j, weight=weight_inter_G3)
+
+#############################################################################
+#
+# Compute their entropic-regularized semi-relaxed Gromov-Wasserstein divergences
+# ---------------------------------------------
+
+# 0) GW(C2, h2, C3, h3) for reference
+OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)
+gw = log['gw_dist']
+
+# 1) srGW_e(C2, h2, C3)
+OT_23, log_23 = entropic_semirelaxed_gromov_wasserstein(
+ C2, C3, h2, symmetric=True, epsilon=1., G0=None, log=True)
+srgw_23 = log_23['srgw_dist']
+
+# 2) srGW_e(C3, h3, C2)
+
+OT_32, log_32 = entropic_semirelaxed_gromov_wasserstein(
+ C3, C2, h3, symmetric=None, epsilon=1., G0=None, log=True)
+srgw_32 = log_32['srgw_dist']
+
+print('GW(C2, C3) = ', gw)
+print('srGW_e(C2, h2, C3) = ', srgw_23)
+print('srGW_e(C3, h3, C2) = ', srgw_32)
+
+
+#############################################################################
+#
+# Visualization of the entropic-regularized semi-relaxed Gromov-Wasserstein matchings
+# ---------------------------------------------
+#
+# We color nodes of the graph on the right - then project its node colors
+# based on the optimal transport plan from the entropic srGW matching.
+# We adjust the intensity of links across domains proportionaly to the mass
+# sent, adding a minimal intensity of 0.1 if mass sent is not zero.
+
+
+def draw_graph(G, C, nodes_color_part, Gweights=None,
+ pos=None, edge_color='black', node_size=None,
+ shiftx=0, seed=0):
+
+ if (pos is None):
+ pos = networkx.spring_layout(G, scale=1., seed=seed)
+
+ if shiftx != 0:
+ for k, v in pos.items():
+ v[0] = v[0] + shiftx
+
+ alpha_edge = 0.7
+ width_edge = 1.8
+ if Gweights is None:
+ networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color)
+ else:
+ # We make more visible connections between activated nodes
+ n = len(Gweights)
+ edgelist_activated = []
+ edgelist_deactivated = []
+ for i in range(n):
+ for j in range(n):
+ if Gweights[i] * Gweights[j] * C[i, j] > 0:
+ edgelist_activated.append((i, j))
+ elif C[i, j] > 0:
+ edgelist_deactivated.append((i, j))
+
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated,
+ width=width_edge, alpha=alpha_edge,
+ edge_color=edge_color)
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated,
+ width=width_edge, alpha=0.1,
+ edge_color=edge_color)
+
+ if Gweights is None:
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=node_size, alpha=1,
+ node_color=node_color)
+ else:
+ scaled_Gweights = Gweights / (0.5 * Gweights.max())
+ nodes_size = node_size * scaled_Gweights
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=nodes_size[node], alpha=1,
+ node_color=node_color)
+ return pos
+
+
+def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1,
+ p1, p2, T, pos1=None, pos2=None,
+ shiftx=4, switchx=False, node_size=70,
+ seed_G1=0, seed_G2=0):
+ starting_color = 0
+ # get graphs partition and their coloring
+ part1 = part_G1.copy()
+ unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)]
+ nodes_color_part1 = []
+ for cluster in part1:
+ nodes_color_part1.append(unique_colors[cluster])
+
+ nodes_color_part2 = []
+ # T: getting colors assignment from argmin of columns
+ for i in range(len(G2.nodes())):
+ j = np.argmax(T[:, i])
+ nodes_color_part2.append(nodes_color_part1[j])
+ pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1,
+ pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1)
+ pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2,
+ node_size=node_size, shiftx=shiftx, seed=seed_G2)
+ for k1, v1 in pos1.items():
+ max_Tk1 = np.max(T[k1, :])
+ for k2, v2 in pos2.items():
+ if (T[k1, k2] > 0):
+ pl.plot([pos1[k1][0], pos2[k2][0]],
+ [pos1[k1][1], pos2[k2][1]],
+ '-', lw=0.6, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.),
+ color=nodes_color_part1[k1])
+ return pos1, pos2
+
+
+node_size = 40
+fontsize = 10
+seed_G2 = 0
+seed_G3 = 4
+
+pl.figure(1, figsize=(8, 2.5))
+pl.clf()
+pl.subplot(121)
+pl.axis('off')
+pl.axis
+pl.title(r'$srGW_e(\mathbf{C_2},\mathbf{h_2},\mathbf{C_3}) =%s$' % (np.round(srgw_23, 3)), fontsize=fontsize)
+
+hbar2 = OT_23.sum(axis=0)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
+ shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
+pl.subplot(122)
+pl.axis('off')
+hbar3 = OT_32.sum(axis=0)
+pl.title(r'$srGW_e(\mathbf{C_3}, \mathbf{h_3},\mathbf{C_2}) =%s$' % (np.round(srgw_32, 3)), fontsize=fontsize)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
+ pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
+pl.tight_layout()
+
+pl.show()
+
+#############################################################################
+#
+# Add node features
+# ---------------------------------------------
+
+# We add node features with given mean - by clusters
+# and inversely proportional to clusters' intra-connectivity
+
+F2 = np.zeros((N2, 1))
+for i, c in enumerate(part_G2):
+ F2[i, 0] = np.random.normal(loc=c, scale=0.01)
+
+F3 = np.zeros((N3, 1))
+for i, c in enumerate(part_G3):
+ F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01)
+
+#############################################################################
+#
+# Compute their semi-relaxed Fused Gromov-Wasserstein divergences
+# ---------------------------------------------
+
+alpha = 0.5
+# Compute pairwise euclidean distance between node features
+M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T)
+
+# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference
+
+OT, log = fused_gromov_wasserstein(
+ M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True)
+fgw = log['fgw_dist']
+
+# 1) srFGW_e(C2, F2, h2, C3, F3)
+OT_23, log_23 = entropic_semirelaxed_fused_gromov_wasserstein(
+ M, C2, C3, h2, symmetric=True, epsilon=1., alpha=0.5, log=True, G0=None)
+srfgw_23 = log_23['srfgw_dist']
+
+# 2) srFGW(C3, F3, h3, C2, F2)
+
+OT_32, log_32 = entropic_semirelaxed_fused_gromov_wasserstein(
+ M.T, C3, C2, h3, symmetric=None, epsilon=1., alpha=alpha, log=True, G0=None)
+srfgw_32 = log_32['srfgw_dist']
+
+print('FGW(C2, F2, C3, F3) = ', fgw)
+print(r'$srGW_e$(C2, F2, h2, C3, F3) = ', srfgw_23)
+print(r'$srGW_e$(C3, F3, h3, C2, F2) = ', srfgw_32)
+
+#############################################################################
+#
+# Visualization of the entropic semi-relaxed Fused Gromov-Wasserstein matchings
+# ---------------------------------------------
+#
+# We color nodes of the graph on the right - then project its node colors
+# based on the optimal transport plan from the srFGW matching
+# NB: colors refer to clusters - not to node features
+
+pl.figure(2, figsize=(8, 2.5))
+pl.clf()
+pl.subplot(121)
+pl.axis('off')
+pl.axis
+pl.title(r'$srFGW_e(\mathbf{C_2},\mathbf{F_2},\mathbf{h_2},\mathbf{C_3},\mathbf{F_3}) =%s$' % (np.round(srfgw_23, 3)), fontsize=fontsize)
+
+hbar2 = OT_23.sum(axis=0)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=None, p2=hbar2, T=OT_23,
+ shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
+pl.subplot(122)
+pl.axis('off')
+hbar3 = OT_32.sum(axis=0)
+pl.title(r'$srFGW_e(\mathbf{C_3}, \mathbf{F_3}, \mathbf{h_3}, \mathbf{C_2}, \mathbf{F_2}) =%s$' % (np.round(srfgw_32, 3)), fontsize=fontsize)
+pos1, pos2 = draw_transp_colored_srGW(
+ weightedG3, C3, weightedG2, C2, part_G3, p1=None, p2=hbar3, T=OT_32,
+ pos1=pos2, pos2=pos1, shiftx=3., node_size=node_size, seed_G1=0, seed_G2=0)
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/gromov/plot_fgw.py b/examples/gromov/plot_fgw.py
index bf10de6..68ecb13 100644
--- a/examples/gromov/plot_fgw.py
+++ b/examples/gromov/plot_fgw.py
@@ -4,13 +4,13 @@
Plot Fused-Gromov-Wasserstein
==============================
-This example illustrates the computation of FGW for 1D measures [18].
+This example first illustrates the computation of FGW for 1D measures estimated
+using a Conditional Gradient solver [24].
-[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
-
"""
# Author: Titouan Vayer <titouan.vayer@irisa.fr>
@@ -24,11 +24,13 @@ import numpy as np
import ot
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+
##############################################################################
# Generate data
# -------------
-#%% parameters
+# parameters
+
# We create two 1D random measures
n = 20 # number of points in the first distribution
n2 = 30 # number of points in the second distribution
@@ -53,10 +55,9 @@ q = ot.unif(n2)
# Plot data
# ---------
-#%% plot the distributions
+# plot the distributions
-pl.close(10)
-pl.figure(10, (7, 7))
+pl.figure(1, (7, 7))
pl.subplot(2, 1, 1)
@@ -78,7 +79,7 @@ pl.show()
# Create structure matrices and across-feature distance matrix
# ------------------------------------------------------------
-#%% Structure matrices and across-features distance matrix
+# Structure matrices and across-features distance matrix
C1 = ot.dist(xs)
C2 = ot.dist(xt)
M = ot.dist(ys, yt)
@@ -90,10 +91,9 @@ Got = ot.emd([], [], M)
# Plot matrices
# -------------
-#%%
cmap = 'Reds'
-pl.close(10)
-pl.figure(10, (5, 5))
+
+pl.figure(2, (5, 5))
fs = 15
l_x = [0, 5, 10, 15]
l_y = [0, 5, 10, 15, 20, 25]
@@ -113,7 +113,6 @@ ax2 = pl.subplot(gs[:3, 2:])
pl.imshow(C2, cmap=cmap, interpolation='nearest')
pl.title("$C_2$", fontsize=fs)
pl.ylabel("$l$", fontsize=fs)
-#pl.ylabel("$l$",fontsize=fs)
pl.xticks(())
pl.yticks(l_y)
ax2.set_aspect('auto')
@@ -133,28 +132,27 @@ pl.show()
# Compute FGW/GW
# --------------
-#%% Computing FGW and GW
+# Computing FGW and GW
alpha = 1e-3
ot.tic()
Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
ot.toc()
-#%reload_ext WGW
+# reload_ext WGW
Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
##############################################################################
# Visualize transport matrices
# ----------------------------
-#%% visu OT matrix
+# visu OT matrix
cmap = 'Blues'
fs = 15
-pl.figure(2, (13, 5))
+pl.figure(3, (13, 5))
pl.clf()
pl.subplot(1, 3, 1)
pl.imshow(Got, cmap=cmap, interpolation='nearest')
-#pl.xlabel("$y$",fontsize=fs)
pl.ylabel("$i$", fontsize=fs)
pl.xticks(())
diff --git a/examples/gromov/plot_fgw_solvers.py b/examples/gromov/plot_fgw_solvers.py
new file mode 100644
index 0000000..5f8a885
--- /dev/null
+++ b/examples/gromov/plot_fgw_solvers.py
@@ -0,0 +1,288 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+Comparison of Fused Gromov-Wasserstein solvers
+==============================
+
+This example illustrates the computation of FGW for attributed graphs
+using 3 different solvers to estimate the distance based on Conditional
+Gradient [24] or Sinkhorn projections [12, 51].
+
+We generate two graphs following Stochastic Block Models further endowed with
+node features and compute their FGW matchings.
+
+[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
+"Gromov-Wasserstein averaging of kernel and distance matrices".
+International Conference on Machine Learning (ICML).
+
+[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+and Courty Nicolas
+"Optimal Transport for structured data with application on graphs"
+International Conference on Machine Learning (ICML). 2019.
+
+[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019).
+"Gromov-wasserstein learning for graph matching and node embedding".
+In International Conference on Machine Learning (ICML), 2019.
+"""
+
+# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 1
+
+import numpy as np
+import matplotlib.pylab as pl
+from ot.gromov import fused_gromov_wasserstein, entropic_fused_gromov_wasserstein
+import networkx
+from networkx.generators.community import stochastic_block_model as sbm
+
+#############################################################################
+#
+# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
+# ---------------------------------------------
+np.random.seed(0)
+
+N2 = 20 # 2 communities
+N3 = 30 # 3 communities
+p2 = [[1., 0.1],
+ [0.1, 0.9]]
+p3 = [[1., 0.1, 0.],
+ [0.1, 0.95, 0.1],
+ [0., 0.1, 0.9]]
+G2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)
+G3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)
+part_G2 = [G2.nodes[i]['block'] for i in range(N2)]
+part_G3 = [G3.nodes[i]['block'] for i in range(N3)]
+
+C2 = networkx.to_numpy_array(G2)
+C3 = networkx.to_numpy_array(G3)
+
+
+# We add node features with given mean - by clusters
+# and inversely proportional to clusters' intra-connectivity
+
+F2 = np.zeros((N2, 1))
+for i, c in enumerate(part_G2):
+ F2[i, 0] = np.random.normal(loc=c, scale=0.01)
+
+F3 = np.zeros((N3, 1))
+for i, c in enumerate(part_G3):
+ F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01)
+
+# Compute pairwise euclidean distance between node features
+M = (F2 ** 2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3 ** 2).T) - 2 * F2.dot(F3.T)
+
+h2 = np.ones(C2.shape[0]) / C2.shape[0]
+h3 = np.ones(C3.shape[0]) / C3.shape[0]
+
+#############################################################################
+#
+# Compute their Fused Gromov-Wasserstein distances
+# ---------------------------------------------
+
+alpha = 0.5
+
+
+# Conditional Gradient algorithm
+fgw0, log0 = fused_gromov_wasserstein(
+ M, C2, C3, h2, h3, 'square_loss', alpha=alpha, verbose=True, log=True)
+
+# Proximal Point algorithm with Kullback-Leibler as proximal operator
+fgw, log = entropic_fused_gromov_wasserstein(
+ M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1., solver='PPA',
+ log=True, verbose=True, warmstart=False, numItermax=10)
+
+# Projected Gradient algorithm with entropic regularization
+fgwe, loge = entropic_fused_gromov_wasserstein(
+ M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=0.01, solver='PGD',
+ log=True, verbose=True, warmstart=False, numItermax=10)
+
+print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['fgw_dist']))
+print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['fgw_dist']))
+print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['fgw_dist']))
+
+# compute OT sparsity level
+fgw0_sparsity = 100 * (fgw0 == 0.).astype(np.float64).sum() / (N2 * N3)
+fgw_sparsity = 100 * (fgw == 0.).astype(np.float64).sum() / (N2 * N3)
+fgwe_sparsity = 100 * (fgwe == 0.).astype(np.float64).sum() / (N2 * N3)
+
+# Methods using Sinkhorn projections tend to produce feasibility errors on the
+# marginal constraints
+
+err0 = np.linalg.norm(fgw0.sum(1) - h2) + np.linalg.norm(fgw0.sum(0) - h3)
+err = np.linalg.norm(fgw.sum(1) - h2) + np.linalg.norm(fgw.sum(0) - h3)
+erre = np.linalg.norm(fgwe.sum(1) - h2) + np.linalg.norm(fgwe.sum(0) - h3)
+
+#############################################################################
+#
+# Visualization of the Fused Gromov-Wasserstein matchings
+# ---------------------------------------------
+#
+# We color nodes of the graph on the right - then project its node colors
+# based on the optimal transport plan from the FGW matchings
+# We adjust the intensity of links across domains proportionaly to the mass
+# sent, adding a minimal intensity of 0.1 if mass sent is not zero.
+# For each matching, all node sizes are proportionnal to their mass computed
+# from marginals of the OT plan to illustrate potential feasibility errors.
+# NB: colors refer to clusters - not to node features
+
+# Add weights on the edges for visualization later on
+weight_intra_G2 = 5
+weight_inter_G2 = 0.5
+weight_intra_G3 = 1.
+weight_inter_G3 = 1.5
+
+weightedG2 = networkx.Graph()
+part_G2 = [G2.nodes[i]['block'] for i in range(N2)]
+
+for node in G2.nodes():
+ weightedG2.add_node(node)
+for i, j in G2.edges():
+ if part_G2[i] == part_G2[j]:
+ weightedG2.add_edge(i, j, weight=weight_intra_G2)
+ else:
+ weightedG2.add_edge(i, j, weight=weight_inter_G2)
+
+weightedG3 = networkx.Graph()
+part_G3 = [G3.nodes[i]['block'] for i in range(N3)]
+
+for node in G3.nodes():
+ weightedG3.add_node(node)
+for i, j in G3.edges():
+ if part_G3[i] == part_G3[j]:
+ weightedG3.add_edge(i, j, weight=weight_intra_G3)
+ else:
+ weightedG3.add_edge(i, j, weight=weight_inter_G3)
+
+
+def draw_graph(G, C, nodes_color_part, Gweights=None,
+ pos=None, edge_color='black', node_size=None,
+ shiftx=0, seed=0):
+
+ if (pos is None):
+ pos = networkx.spring_layout(G, scale=1., seed=seed)
+
+ if shiftx != 0:
+ for k, v in pos.items():
+ v[0] = v[0] + shiftx
+
+ alpha_edge = 0.7
+ width_edge = 1.8
+ if Gweights is None:
+ networkx.draw_networkx_edges(G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color)
+ else:
+ # We make more visible connections between activated nodes
+ n = len(Gweights)
+ edgelist_activated = []
+ edgelist_deactivated = []
+ for i in range(n):
+ for j in range(n):
+ if Gweights[i] * Gweights[j] * C[i, j] > 0:
+ edgelist_activated.append((i, j))
+ elif C[i, j] > 0:
+ edgelist_deactivated.append((i, j))
+
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_activated,
+ width=width_edge, alpha=alpha_edge,
+ edge_color=edge_color)
+ networkx.draw_networkx_edges(G, pos, edgelist=edgelist_deactivated,
+ width=width_edge, alpha=0.1,
+ edge_color=edge_color)
+
+ if Gweights is None:
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=node_size, alpha=1,
+ node_color=node_color)
+ else:
+ scaled_Gweights = Gweights / (0.5 * Gweights.max())
+ nodes_size = node_size * scaled_Gweights
+ for node, node_color in enumerate(nodes_color_part):
+ networkx.draw_networkx_nodes(G, pos, nodelist=[node],
+ node_size=nodes_size[node], alpha=1,
+ node_color=node_color)
+ return pos
+
+
+def draw_transp_colored_GW(G1, C1, G2, C2, part_G1, p1, p2, T,
+ pos1=None, pos2=None, shiftx=4, switchx=False,
+ node_size=70, seed_G1=0, seed_G2=0):
+ starting_color = 0
+ # get graphs partition and their coloring
+ part1 = part_G1.copy()
+ unique_colors = ['C%s' % (starting_color + i) for i in np.unique(part1)]
+ nodes_color_part1 = []
+ for cluster in part1:
+ nodes_color_part1.append(unique_colors[cluster])
+
+ nodes_color_part2 = []
+ # T: getting colors assignment from argmin of columns
+ for i in range(len(G2.nodes())):
+ j = np.argmax(T[:, i])
+ nodes_color_part2.append(nodes_color_part1[j])
+ pos1 = draw_graph(G1, C1, nodes_color_part1, Gweights=p1,
+ pos=pos1, node_size=node_size, shiftx=0, seed=seed_G1)
+ pos2 = draw_graph(G2, C2, nodes_color_part2, Gweights=p2, pos=pos2,
+ node_size=node_size, shiftx=shiftx, seed=seed_G2)
+
+ for k1, v1 in pos1.items():
+ max_Tk1 = np.max(T[k1, :])
+ for k2, v2 in pos2.items():
+ if (T[k1, k2] > 0):
+ pl.plot([pos1[k1][0], pos2[k2][0]],
+ [pos1[k1][1], pos2[k2][1]],
+ '-', lw=0.7, alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.),
+ color=nodes_color_part1[k1])
+ return pos1, pos2
+
+
+node_size = 40
+fontsize = 13
+seed_G2 = 0
+seed_G3 = 4
+
+pl.figure(2, figsize=(12, 3.5))
+pl.clf()
+pl.subplot(131)
+pl.axis('off')
+pl.axis
+pl.title('(CG algo) FGW=%s \n \n OT sparsity = %s \n feasibility error = %s' % (
+ np.round(log0['fgw_dist'], 3), str(np.round(fgw0_sparsity, 2)) + ' %',
+ np.round(err0, 4)), fontsize=fontsize)
+
+p0, q0 = fgw0.sum(1), fgw0.sum(0) # check marginals
+
+pos1, pos2 = draw_transp_colored_GW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=p0, p2=q0, T=fgw0,
+ shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
+
+pl.subplot(132)
+pl.axis('off')
+
+p, q = fgw.sum(1), fgw.sum(0) # check marginals
+
+pl.title('(PP algo) FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % (
+ np.round(log['fgw_dist'], 3), str(np.round(fgw_sparsity, 2)) + ' %',
+ np.round(err, 4)), fontsize=fontsize)
+
+pos1, pos2 = draw_transp_colored_GW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=p, p2=q, T=fgw,
+ pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
+
+pl.subplot(133)
+pl.axis('off')
+
+pe, qe = fgwe.sum(1), fgwe.sum(0) # check marginals
+
+pl.title('Entropic FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % (
+ np.round(loge['fgw_dist'], 3), str(np.round(fgwe_sparsity, 2)) + ' %',
+ np.round(erre, 4)), fontsize=fontsize)
+
+pos1, pos2 = draw_transp_colored_GW(
+ weightedG2, C2, weightedG3, C3, part_G2, p1=pe, p2=qe, T=fgwe,
+ pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
+
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py
index afb5bdc..252267f 100644
--- a/examples/gromov/plot_gromov.py
+++ b/examples/gromov/plot_gromov.py
@@ -5,13 +5,38 @@ Gromov-Wasserstein example
==========================
This example is designed to show how to use the Gromov-Wasserstein distance
computation in POT.
+We first compare 3 solvers to estimate the distance based on
+Conditional Gradient [24] or Sinkhorn projections [12, 51].
+Then we compare 2 stochastic solvers to estimate the distance with a lower
+numerical cost [33].
+
+[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
+"Gromov-Wasserstein averaging of kernel and distance matrices".
+International Conference on Machine Learning (ICML).
+
+[24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+and Courty Nicolas
+"Optimal Transport for structured data with application on graphs"
+International Conference on Machine Learning (ICML). 2019.
+
+[33] Kerdoncuff T., Emonet R., Marc S. "Sampled Gromov Wasserstein",
+Machine Learning Journal (MJL), 2021.
+
+[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019).
+"Gromov-wasserstein learning for graph matching and node embedding".
+In International Conference on Machine Learning (ICML), 2019.
+
"""
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+# Tanguy Kerdoncuff <tanguy.kerdoncuff@laposte.net>
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 1
+
import scipy as sp
import numpy as np
import matplotlib.pylab as pl
@@ -36,7 +61,7 @@ cov_s = np.array([[1, 0], [0, 1]])
mu_t = np.array([4, 4, 4])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
-
+np.random.seed(0)
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
@@ -47,7 +72,7 @@ xt = np.random.randn(n_samples, 3).dot(P) + mu_t
# --------------------------
-fig = pl.figure()
+fig = pl.figure(1)
ax1 = fig.add_subplot(121)
ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
ax2 = fig.add_subplot(122, projection='3d')
@@ -66,11 +91,15 @@ C2 = sp.spatial.distance.cdist(xt, xt)
C1 /= C1.max()
C2 /= C2.max()
-pl.figure()
+pl.figure(2)
pl.subplot(121)
pl.imshow(C1)
+pl.title('C1')
+
pl.subplot(122)
pl.imshow(C2)
+pl.title('C2')
+
pl.show()
#############################################################################
@@ -81,32 +110,63 @@ pl.show()
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+# Conditional Gradient algorithm
gw0, log0 = ot.gromov.gromov_wasserstein(
C1, C2, p, q, 'square_loss', verbose=True, log=True)
+# Proximal Point algorithm with Kullback-Leibler as proximal operator
gw, log = ot.gromov.entropic_gromov_wasserstein(
- C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
-
-
-print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
-print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
-
-
-pl.figure(1, (10, 5))
-
-pl.subplot(1, 2, 1)
-pl.imshow(gw0, cmap='jet')
-pl.title('Gromov Wasserstein')
-
-pl.subplot(1, 2, 2)
-pl.imshow(gw, cmap='jet')
-pl.title('Entropic Gromov Wasserstein')
-
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PPA',
+ log=True, verbose=True)
+
+# Projected Gradient algorithm with entropic regularization
+gwe, loge = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, solver='PGD',
+ log=True, verbose=True)
+
+print('Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['gw_dist']))
+print('Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['gw_dist']))
+print('Entropic Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['gw_dist']))
+
+# compute OT sparsity level
+gw0_sparsity = 100 * (gw0 == 0.).astype(np.float64).sum() / (n_samples ** 2)
+gw_sparsity = 100 * (gw == 0.).astype(np.float64).sum() / (n_samples ** 2)
+gwe_sparsity = 100 * (gwe == 0.).astype(np.float64).sum() / (n_samples ** 2)
+
+# Methods using Sinkhorn projections tend to produce feasibility errors on the
+# marginal constraints
+
+err0 = np.linalg.norm(gw0.sum(1) - p) + np.linalg.norm(gw0.sum(0) - q)
+err = np.linalg.norm(gw.sum(1) - p) + np.linalg.norm(gw.sum(0) - q)
+erre = np.linalg.norm(gwe.sum(1) - p) + np.linalg.norm(gwe.sum(0) - q)
+
+pl.figure(3, (10, 6))
+cmap = 'Blues'
+fontsize = 12
+pl.subplot(131)
+pl.imshow(gw0, cmap=cmap)
+pl.title('(CG algo) GW=%s \n \n OT sparsity=%s \n feasibility error=%s' % (
+ np.round(log0['gw_dist'], 4), str(np.round(gw0_sparsity, 2)) + ' %', np.round(np.round(err0, 4))),
+ fontsize=fontsize)
+
+pl.subplot(132)
+pl.imshow(gw, cmap=cmap)
+pl.title('(PP algo) GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % (
+ np.round(log['gw_dist'], 4), str(np.round(gw_sparsity, 2)) + ' %', np.round(err, 4)),
+ fontsize=fontsize)
+
+pl.subplot(133)
+pl.imshow(gwe, cmap=cmap)
+pl.title('Entropic GW=%s \n \n OT sparsity=%s \nfeasibility error=%s' % (
+ np.round(loge['gw_dist'], 4), str(np.round(gwe_sparsity, 2)) + ' %', np.round(erre, 4)),
+ fontsize=fontsize)
+
+pl.tight_layout()
pl.show()
#############################################################################
#
-# Compute GW with a scalable stochastic method with any loss function
+# Compute GW with scalable stochastic methods with any loss function
# ----------------------------------------------------------------------
@@ -126,14 +186,14 @@ print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated']))
print('Variance estimated: ' + str(slog['gw_dist_std']))
-pl.figure(1, (10, 5))
+pl.figure(4, (10, 5))
-pl.subplot(1, 2, 1)
-pl.imshow(pgw.toarray(), cmap='jet')
+pl.subplot(121)
+pl.imshow(pgw.toarray(), cmap=cmap)
pl.title('Pointwise Gromov Wasserstein')
-pl.subplot(1, 2, 2)
-pl.imshow(sgw, cmap='jet')
+pl.subplot(122)
+pl.imshow(sgw, cmap=cmap)
pl.title('Sampled Gromov Wasserstein')
pl.show()
diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py
index 6184edf..e39d906 100644
--- a/ot/gromov/__init__.py
+++ b/ot/gromov/__init__.py
@@ -11,38 +11,51 @@ Solvers related to Gromov-Wasserstein problems.
# All submodules and packages
from ._utils import (init_matrix, tensor_product, gwloss, gwggrad,
- update_square_loss, update_kl_loss,
+ update_square_loss, update_kl_loss, update_feature_matrix,
init_matrix_semirelaxed)
+
from ._gw import (gromov_wasserstein, gromov_wasserstein2,
fused_gromov_wasserstein, fused_gromov_wasserstein2,
- solve_gromov_linesearch, gromov_barycenters, fgw_barycenters,
- update_structure_matrix, update_feature_matrix)
+ solve_gromov_linesearch, gromov_barycenters, fgw_barycenters)
+
from ._bregman import (entropic_gromov_wasserstein,
entropic_gromov_wasserstein2,
- entropic_gromov_barycenters)
+ entropic_gromov_barycenters,
+ entropic_fused_gromov_wasserstein,
+ entropic_fused_gromov_wasserstein2,
+ entropic_fused_gromov_barycenters)
+
from ._estimators import (GW_distance_estimation, pointwise_gromov_wasserstein,
sampled_gromov_wasserstein)
+
from ._semirelaxed import (semirelaxed_gromov_wasserstein,
semirelaxed_gromov_wasserstein2,
semirelaxed_fused_gromov_wasserstein,
semirelaxed_fused_gromov_wasserstein2,
- solve_semirelaxed_gromov_linesearch)
+ solve_semirelaxed_gromov_linesearch,
+ entropic_semirelaxed_gromov_wasserstein,
+ entropic_semirelaxed_gromov_wasserstein2,
+ entropic_semirelaxed_fused_gromov_wasserstein,
+ entropic_semirelaxed_fused_gromov_wasserstein2)
+
from ._dictionary import (gromov_wasserstein_dictionary_learning,
gromov_wasserstein_linear_unmixing,
fused_gromov_wasserstein_dictionary_learning,
fused_gromov_wasserstein_linear_unmixing)
-__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad',
- 'update_square_loss', 'update_kl_loss', 'init_matrix_semirelaxed',
+__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss',
+ 'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed',
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
- 'fgw_barycenters', 'update_structure_matrix', 'update_feature_matrix',
- 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
- 'entropic_gromov_barycenters', 'GW_distance_estimation',
- 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein',
+ 'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
+ 'entropic_gromov_barycenters', 'entropic_fused_gromov_wasserstein',
+ 'entropic_fused_gromov_wasserstein2', 'entropic_fused_gromov_barycenters',
+ 'GW_distance_estimation', 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein',
'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2',
'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2',
- 'solve_semirelaxed_gromov_linesearch', 'gromov_wasserstein_dictionary_learning',
+ 'solve_semirelaxed_gromov_linesearch', 'entropic_semirelaxed_gromov_wasserstein',
+ 'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein',
+ 'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning',
'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning',
'fused_gromov_wasserstein_linear_unmixing']
diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py
index aa25f1f..18cef56 100644
--- a/ot/gromov/_bregman.py
+++ b/ot/gromov/_bregman.py
@@ -11,23 +11,29 @@ Bregman projections solvers for entropic Gromov-Wasserstein
#
# License: MIT License
+import numpy as np
+import warnings
+
from ..bregman import sinkhorn
-from ..utils import dist, list_to_array, check_random_state
+from ..utils import dist, list_to_array, check_random_state, unif
from ..backend import get_backend
from ._utils import init_matrix, gwloss, gwggrad
-from ._utils import update_square_loss, update_kl_loss
+from ._utils import update_square_loss, update_kl_loss, update_feature_matrix
-def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None,
- max_iter=1000, tol=1e-9, verbose=False, log=False):
+def entropic_gromov_wasserstein(
+ C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000,
+ tol=1e-9, solver='PGD', warmstart=False, verbose=False, log=False, **kwargs):
r"""
- Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ Returns the Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ estimated using Sinkhorn projections.
- The function solves the following optimization problem:
+ If `solver="PGD"`, the function solves the following entropic-regularized
+ Gromov-Wasserstein optimization problem using Projected Gradient Descent [12]:
.. math::
- \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
+ \mathbf{T}^* \in \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon H(\mathbf{T})
s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
@@ -35,6 +41,17 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None,
\mathbf{T} &\geq 0
+ Else if `solver="PPA"`, the function solves the following Gromov-Wasserstein
+ optimization problem using Proximal Point Algorithm [51]:
+
+ .. math::
+ \mathbf{T}^* \in \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
Where :
- :math:`\mathbf{C_1}`: Metric cost matrix in the source space
@@ -58,13 +75,15 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None,
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : string
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
+ Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : string, optional
Loss function used for the solver either 'square_loss' or 'kl_loss'
- epsilon : float
+ epsilon : float, optional
Regularization term >0
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
@@ -72,16 +91,28 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None,
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
G0: array-like, shape (ns,nt), optional
If None the initial transport plan of the solver is pq^T.
- Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ Otherwise G0 will be used as initial transport of the solver. G0 is not
+ required to satisfy marginal constraints but we strongly recommand it
+ to correcly estimate the GW distance.
max_iter : int, optional
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
+ solver: string, optional
+ Solver to use either 'PGD' for Projected Gradient Descent or 'PPA'
+ for Proximal Point Algorithm.
+ Default value is 'PGD'.
+ warmstart: bool, optional
+ Either to perform warmstart of dual potentials in the successive
+ Sinkhorn projections.
verbose : bool, optional
Print information along iterations
log : bool, optional
Record log if True.
-
+ **kwargs: dict
+ parameters can be directly passed to the ot.sinkhorn solver.
+ Such as `numItermax` and `stopThr` to control its estimation precision,
+ e.g [51] suggests to use `numItermax=1`.
Returns
-------
T : array-like, shape (`ns`, `nt`)
@@ -96,22 +127,50 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None,
.. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
distance between networks and stable network invariants.
Information and Inference: A Journal of the IMA, 8(4), 757-787.
+
+ .. [51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). Gromov-wasserstein
+ learning for graph matching and node embedding. In International
+ Conference on Machine Learning (ICML), 2019.
"""
- C1, C2, p, q = list_to_array(C1, C2, p, q)
+ if solver not in ['PGD', 'PPA']:
+ raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)
+
+ C1, C2 = list_to_array(C1, C2)
+ arr = [C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(C1.shape[0], type_as=C1)
+ if q is not None:
+ arr.append(list_to_array(q))
+ else:
+ q = unif(C2.shape[0], type_as=C2)
+
+ if G0 is not None:
+ arr.append(G0)
+
+ nx = get_backend(*arr)
+
if G0 is None:
- nx = get_backend(p, q, C1, C2)
G0 = nx.outer(p, q)
- else:
- nx = get_backend(p, q, C1, C2, G0)
+
T = G0
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx)
+
if symmetric is None:
symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
if not symmetric:
constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx)
+
cpt = 0
err = 1
+ if warmstart:
+ # initialize potentials to cope with ot.sinkhorn initialization
+ N1, N2 = C1.shape[0], C2.shape[0]
+ mu = nx.zeros(N1, type_as=C1) - np.log(N1)
+ nu = nx.zeros(N2, type_as=C2) - np.log(N2)
+
if log:
log = {'err': []}
@@ -124,7 +183,17 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None,
tens = gwggrad(constC, hC1, hC2, T, nx)
else:
tens = 0.5 * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx))
- T = sinkhorn(p, q, tens, epsilon, method='sinkhorn')
+
+ if solver == 'PPA':
+ tens = tens - epsilon * nx.log(T)
+
+ if warmstart:
+ T, loginn = sinkhorn(p, q, tens, epsilon, method='sinkhorn', log=True, warmstart=(mu, nu), **kwargs)
+ mu = epsilon * nx.log(loginn['u'])
+ nu = epsilon * nx.log(loginn['v'])
+
+ else:
+ T = sinkhorn(p, q, tens, epsilon, method='sinkhorn', **kwargs)
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
@@ -142,6 +211,9 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None,
cpt += 1
+ if abs(nx.sum(T) - 1) > 1e-5:
+ warnings.warn("Solver failed to produce a transport plan. You might "
+ "want to increase the regularization parameter `epsilon`.")
if log:
log['gw_dist'] = gwloss(constC, hC1, hC2, T, nx)
return T, log
@@ -149,17 +221,36 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None,
return T
-def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None,
- max_iter=1000, tol=1e-9, verbose=False, log=False):
+def entropic_gromov_wasserstein2(
+ C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, G0=None, max_iter=1000,
+ tol=1e-9, solver='PGD', warmstart=False, verbose=False, log=False, **kwargs):
r"""
- Returns the entropic Gromov-Wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ Returns the Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ estimated using Sinkhorn projections.
- The function solves the following optimization problem:
+ If `solver="PGD"`, the function solves the following entropic-regularized
+ Gromov-Wasserstein optimization problem using Projected Gradient Descent [12]:
+
+ .. math::
+ \mathbf{GW} = \mathop{\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon H(\mathbf{T})
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Else if `solver="PPA"`, the function solves the following Gromov-Wasserstein
+ optimization problem using Proximal Point Algorithm [51]:
.. math::
- GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})
- \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
+ \mathbf{GW} = \mathop{\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
Where :
- :math:`\mathbf{C_1}`: Metric cost matrix in the source space
@@ -183,13 +274,15 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : str
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
+ Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : string, optional
Loss function used for the solver either 'square_loss' or 'kl_loss'
- epsilon : float
+ epsilon : float, optional
Regularization term >0
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
@@ -197,16 +290,28 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
G0: array-like, shape (ns,nt), optional
If None the initial transport plan of the solver is pq^T.
- Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ Otherwise G0 will be used as initial transport of the solver. G0 is not
+ required to satisfy marginal constraints but we strongly recommand it
+ to correcly estimate the GW distance.
max_iter : int, optional
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
+ solver: string, optional
+ Solver to use either 'PGD' for Projected Gradient Descent or 'PPA'
+ for Proximal Point Algorithm.
+ Default value is 'PGD'.
+ warmstart: bool, optional
+ Either to perform warmstart of dual potentials in the successive
+ Sinkhorn projections.
verbose : bool, optional
Print information along iterations
log : bool, optional
Record log if True.
-
+ **kwargs: dict
+ parameters can be directly passed to the ot.sinkhorn solver.
+ Such as `numItermax` and `stopThr` to control its estimation precision,
+ e.g [51] suggests to use `numItermax=1`.
Returns
-------
gw_dist : float
@@ -218,11 +323,15 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
+ .. [51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). Gromov-wasserstein
+ learning for graph matching and node embedding. In International
+ Conference on Machine Learning (ICML), 2019.
"""
- gw, logv = entropic_gromov_wasserstein(
- C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter, tol, verbose, log=True)
+ T, logv = entropic_gromov_wasserstein(
+ C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter,
+ tol, solver, warmstart, verbose, log=True, **kwargs)
- logv['T'] = gw
+ logv['T'] = T
if log:
return logv['gw_dist'], logv
@@ -230,10 +339,13 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None
return logv['gw_dist']
-def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmetric=True,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
+def entropic_gromov_barycenters(
+ N, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss',
+ epsilon=0.1, symmetric=True, max_iter=1000, tol=1e-9, warmstartT=False,
+ verbose=False, log=False, init_C=None, random_state=None, **kwargs):
r"""
- Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
+ Returns the Gromov-Wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
+ estimated using Gromov-Wasserstein transports from Sinkhorn projections.
The function solves the following optimization problem:
@@ -252,19 +364,18 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet
Size of the targeted barycenter
Cs : list of S array-like of shape (ns,ns)
Metric cost matrices
- ps : list of S array-like of shape (ns,)
- Sample weights in the `S` spaces
- p : array-like, shape(N,)
- Weights in the targeted barycenter
- lambdas : list of float
+ ps : list of S array-like of shape (ns,), optional
+ Sample weights in the `S` spaces.
+ If let to its default value None, uniform distributions are taken.
+ p : array-like, shape (N,), optional
+ Weights in the targeted barycenter.
+ If let to its default value None, uniform distribution is taken.
+ lambdas : list of float, optional
List of the `S` spaces' weights.
- loss_fun : callable
- Tensor-matrix multiplication function based on specific loss function.
- update : callable
- function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
- :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
- calculated at each iteration
- epsilon : float
+ If let to its default value None, uniform weights are taken.
+ loss_fun : callable, optional
+ tensor-matrix multiplication function based on specific loss function
+ epsilon : float, optional
Regularization term >0
symmetric : bool, optional.
Either structures are to be assumed symmetric or not. Default value is True.
@@ -273,6 +384,9 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
+ warmstartT: bool, optional
+ Either to perform warmstart of transport plans in the successive
+ gromov-wasserstein transport problems.
verbose : bool, optional
Print information along iterations.
log : bool, optional
@@ -281,6 +395,8 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet
Random initial value for the :math:`\mathbf{C}` matrix provided by user.
random_state : int or RandomState instance, optional
Fix the seed for reproducibility
+ **kwargs: dict
+ parameters can be directly passed to the `ot.entropic_gromov_wasserstein` solver.
Returns
-------
@@ -296,11 +412,21 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet
International Conference on Machine Learning (ICML). 2016.
"""
Cs = list_to_array(*Cs)
- ps = list_to_array(*ps)
- p = list_to_array(p)
- nx = get_backend(*Cs, *ps, p)
+ arr = [*Cs]
+ if ps is not None:
+ arr += list_to_array(*ps)
+ else:
+ ps = [unif(C.shape[0], type_as=C) for C in Cs]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(N, type_as=Cs[0])
+
+ nx = get_backend(*arr)
S = len(Cs)
+ if lambdas is None:
+ lambdas = [1. / S] * S
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
@@ -317,11 +443,20 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet
error = []
+ if warmstartT:
+ T = [None] * S
+
while (err > tol) and (cpt < max_iter):
Cprev = C
+ if warmstartT:
+ T = [entropic_gromov_wasserstein(
+ Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, T[s],
+ max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
+ else:
+ T = [entropic_gromov_wasserstein(
+ Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, None,
+ max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
- T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, None,
- max_iter, 1e-4, verbose, log=False) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -346,3 +481,536 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmet
return C, {"err": error}
else:
return C
+
+
+def entropic_fused_gromov_wasserstein(
+ M, C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1,
+ symmetric=None, alpha=0.5, G0=None, max_iter=1000, tol=1e-9,
+ solver='PGD', warmstart=False, verbose=False, log=False, **kwargs):
+ r"""
+ Returns the Fused Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})`
+ with pairwise distance matrix :math:`\mathbf{M}` between node feature matrices :math:`\mathbf{Y_1}` and :math:`\mathbf{Y_2}`,
+ estimated using Sinkhorn projections.
+
+ If `solver="PGD"`, the function solves the following entropic-regularized
+ Fused Gromov-Wasserstein optimization problem using Projected Gradient Descent [12]:
+
+ .. math::
+ \mathbf{T}^* \in \mathop{\arg\min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon H(\mathbf{T})
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Else if `solver="PPA"`, the function solves the following Fused Gromov-Wasserstein
+ optimization problem using Proximal Point Algorithm [51]:
+
+ .. math::
+ \mathbf{T}^* \in\mathop{\arg\min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+ Where :
+
+ - :math:`\mathbf{M}`: metric cost matrix between features across domains
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity and feature matrices
+ - `H`: entropy
+ - :math:`\alpha`: trade-off parameter
+
+ .. note:: If the inner solver `ot.sinkhorn` did not convergence, the
+ optimal coupling :math:`\mathbf{T}` returned by this function does not
+ necessarily satisfy the marginal constraints
+ :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and
+ :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned
+ Fused Gromov-Wasserstein loss does not necessarily satisfy distance
+ properties and may be negative.
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
+ Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : string, optional
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float, optional
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 will be used as initial transport of the solver. G0 is not
+ required to satisfy marginal constraints but we strongly recommand it
+ to correcly estimate the GW distance.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ solver: string, optional
+ Solver to use either 'PGD' for Projected Gradient Descent or 'PPA'
+ for Proximal Point Algorithm.
+ Default value is 'PGD'.
+ warmstart: bool, optional
+ Either to perform warmstart of dual potentials in the successive
+ Sinkhorn projections.
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+ **kwargs: dict
+ parameters can be directly passed to the ot.sinkhorn solver.
+ Such as `numItermax` and `stopThr` to control its estimation precision,
+ e.g [51] suggests to use `numItermax=1`.
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two joint spaces
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+
+ .. [51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). Gromov-wasserstein
+ learning for graph matching and node embedding. In International
+ Conference on Machine Learning (ICML), 2019.
+
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas "Optimal Transport for structured data with
+ application on graphs", International Conference on Machine Learning
+ (ICML). 2019.
+ """
+ if solver not in ['PGD', 'PPA']:
+ raise ValueError("Unknown solver '%s'. Pick one in ['PGD', 'PPA']." % solver)
+
+ M, C1, C2 = list_to_array(M, C1, C2)
+ arr = [M, C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(C1.shape[0], type_as=C1)
+ if q is not None:
+ arr.append(list_to_array(q))
+ else:
+ q = unif(C2.shape[0], type_as=C2)
+
+ if G0 is not None:
+ arr.append(G0)
+
+ nx = get_backend(*arr)
+
+ if G0 is None:
+ G0 = nx.outer(p, q)
+
+ T = G0
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx)
+ if symmetric is None:
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
+ if not symmetric:
+ constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx)
+ cpt = 0
+ err = 1
+
+ if warmstart:
+ # initialize potentials to cope with ot.sinkhorn initialization
+ N1, N2 = C1.shape[0], C2.shape[0]
+ mu = nx.zeros(N1, type_as=C1) - np.log(N1)
+ nu = nx.zeros(N2, type_as=C2) - np.log(N2)
+
+ if log:
+ log = {'err': []}
+
+ while (err > tol and cpt < max_iter):
+
+ Tprev = T
+
+ # compute the gradient
+ if symmetric:
+ tens = alpha * gwggrad(constC, hC1, hC2, T, nx) + (1 - alpha) * M
+ else:
+ tens = (alpha * 0.5) * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx)) + (1 - alpha) * M
+
+ if solver == 'PPA':
+ tens = tens - epsilon * nx.log(T)
+
+ if warmstart:
+ T, loginn = sinkhorn(p, q, tens, epsilon, method='sinkhorn', log=True, warmstart=(mu, nu), **kwargs)
+ mu = epsilon * nx.log(loginn['u'])
+ nu = epsilon * nx.log(loginn['v'])
+
+ else:
+ T = sinkhorn(p, q, tens, epsilon, method='sinkhorn', **kwargs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(T - Tprev)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if abs(nx.sum(T) - 1) > 1e-5:
+ warnings.warn("Solver failed to produce a transport plan. You might "
+ "want to increase the regularization parameter `epsilon`.")
+ if log:
+ log['fgw_dist'] = (1 - alpha) * nx.sum(M * T) + alpha * gwloss(constC, hC1, hC2, T, nx)
+ return T, log
+ else:
+ return T
+
+
+def entropic_fused_gromov_wasserstein2(
+ M, C1, C2, p=None, q=None, loss_fun='square_loss', epsilon=0.1,
+ symmetric=None, alpha=0.5, G0=None, max_iter=1000, tol=1e-9,
+ solver='PGD', warmstart=False, verbose=False, log=False, **kwargs):
+ r"""
+ Returns the Fused Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{Y_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{Y_2}, \mathbf{q})`
+ with pairwise distance matrix :math:`\mathbf{M}` between node feature matrices :math:`\mathbf{Y_1}` and :math:`\mathbf{Y_2}`,
+ estimated using Sinkhorn projections.
+
+ If `solver="PGD"`, the function solves the following entropic-regularized
+ Fused Gromov-Wasserstein optimization problem using Projected Gradient Descent [12]:
+
+ .. math::
+ \mathbf{FGW} = \mathop{\min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon H(\mathbf{T})
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Else if `solver="PPA"`, the function solves the following Fused Gromov-Wasserstein
+ optimization problem using Proximal Point Algorithm [51]:
+
+ .. math::
+ \mathbf{FGW} = \mathop{\min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+ Where :
+
+ - :math:`\mathbf{M}`: metric cost matrix between features across domains
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity and feature matrices
+ - `H`: entropy
+ - :math:`\alpha`: trade-off parameter
+
+ .. note:: If the inner solver `ot.sinkhorn` did not convergence, the
+ optimal coupling :math:`\mathbf{T}` returned by this function does not
+ necessarily satisfy the marginal constraints
+ :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and
+ :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned
+ Fused Gromov-Wasserstein loss does not necessarily satisfy distance
+ properties and may be negative.
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
+ Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : string, optional
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float, optional
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 will be used as initial transport of the solver. G0 is not
+ required to satisfy marginal constraints but we strongly recommand it
+ to correcly estimate the GW distance.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+
+ Returns
+ -------
+ fgw_dist : float
+ Fused Gromov-Wasserstein distance
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). Gromov-wasserstein
+ learning for graph matching and node embedding. In International
+ Conference on Machine Learning (ICML), 2019.
+
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas "Optimal Transport for structured data with
+ application on graphs", International Conference on Machine Learning
+ (ICML). 2019.
+
+ """
+ T, logv = entropic_fused_gromov_wasserstein(
+ M, C1, C2, p, q, loss_fun, epsilon, symmetric, alpha, G0, max_iter,
+ tol, solver, warmstart, verbose, log=True, **kwargs)
+
+ logv['T'] = T
+
+ if log:
+ return logv['fgw_dist'], logv
+ else:
+ return logv['fgw_dist']
+
+
+def entropic_fused_gromov_barycenters(
+ N, Ys, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss',
+ epsilon=0.1, symmetric=True, alpha=0.5, max_iter=1000, tol=1e-9,
+ warmstartT=False, verbose=False, log=False, init_C=None, init_Y=None,
+ random_state=None, **kwargs):
+ r"""
+ Returns the Fused Gromov-Wasserstein barycenters of `S` measurable networks with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}`
+ estimated using Fused Gromov-Wasserstein transports from Sinkhorn projections.
+
+ The function solves the following optimization problem:
+
+ .. math::
+
+ \mathbf{C}, \mathbf{Y} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}, \mathbf{Y}\in \mathbb{Y}^{N \times d}} \quad \sum_s \lambda_s \mathrm{FGW}_{\alpha}(\mathbf{C}, \mathbf{C}_s, \mathbf{Y}, \mathbf{Y}_s, \mathbf{p}, \mathbf{p}_s)
+
+ Where :
+
+ - :math:`\mathbf{Y}_s`: feature matrix
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
+
+ Parameters
+ ----------
+ N : int
+ Size of the targeted barycenter
+ Ys: list of array-like, each element has shape (ns,d)
+ Features of all samples
+ Cs : list of S array-like of shape (ns,ns)
+ Metric cost matrices
+ ps : list of S array-like of shape (ns,), optional
+ Sample weights in the `S` spaces.
+ If let to its default value None, uniform distributions are taken.
+ p : array-like, shape (N,), optional
+ Weights in the targeted barycenter.
+ If let to its default value None, uniform distribution is taken.
+ lambdas : list of float, optional
+ List of the `S` spaces' weights.
+ If let to its default value None, uniform weights are taken.
+ loss_fun : callable, optional
+ tensor-matrix multiplication function based on specific loss function
+ epsilon : float, optional
+ Regularization term >0
+ symmetric : bool, optional.
+ Either structures are to be assumed symmetric or not. Default value is True.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ warmstartT: bool, optional
+ Either to perform warmstart of transport plans in the successive
+ fused gromov-wasserstein transport problems.
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : bool | array-like, shape (N, N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ init_Y : array-like, shape (N,d), optional
+ Initialization for the barycenters' features. If not set a
+ random init is used.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+ **kwargs: dict
+ parameters can be directly passed to the `ot.entropic_fused_gromov_wasserstein` solver.
+
+ Returns
+ -------
+ Y : array-like, shape (`N`, `d`)
+ Feature matrix in the barycenter space (permutated arbitrarily)
+ C : array-like, shape (`N`, `N`)
+ Similarity matrix in the barycenter space (permutated as Y's rows)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ Cs = list_to_array(*Cs)
+ Ys = list_to_array(*Ys)
+ arr = [*Cs, *Ys]
+ if ps is not None:
+ arr += list_to_array(*ps)
+ else:
+ ps = [unif(C.shape[0], type_as=C) for C in Cs]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(N, type_as=Cs[0])
+
+ nx = get_backend(*arr)
+ S = len(Cs)
+ if lambdas is None:
+ lambdas = [1. / S] * S
+
+ d = Ys[0].shape[1] # dimension on the node features
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
+ else:
+ C = init_C
+
+ # Initialization of Y
+ if init_Y is None:
+ Y = nx.zeros((N, d), type_as=ps[0])
+ else:
+ Y = init_Y
+
+ T = [nx.outer(p_, p) for p_ in ps]
+
+ Ms = [dist(Ys[s], Y) for s in range(len(Ys))]
+
+ cpt = 0
+ err = 1
+
+ err_feature = 1
+ err_structure = 1
+
+ if warmstartT:
+ T = [None] * S
+
+ if log:
+ log_ = {}
+ log_['err_feature'] = []
+ log_['err_structure'] = []
+ log_['Ts_iter'] = []
+
+ while (err > tol) and (cpt < max_iter):
+ Cprev = C
+ Yprev = Y
+
+ if warmstartT:
+ T = [entropic_fused_gromov_wasserstein(
+ Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha,
+ None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
+
+ else:
+ T = [entropic_fused_gromov_wasserstein(
+ Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha,
+ None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
+
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ Ys_temp = [y.T for y in Ys]
+ T_temp = [Ts.T for Ts in T]
+ Y = update_feature_matrix(lambdas, Ys_temp, T_temp, p)
+ Ms = [dist(Ys[s], Y) for s in range(len(Ys))]
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err_feature = nx.norm(Y - nx.reshape(Yprev, (N, d)))
+ err_structure = nx.norm(C - Cprev)
+ if log:
+ log_['err_feature'].append(err_feature)
+ log_['err_structure'].append(err_structure)
+ log_['Ts_iter'].append(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err_structure))
+ print('{:5d}|{:8e}|'.format(cpt, err_feature))
+
+ cpt += 1
+ print('Y type:', type(Y))
+ if log:
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms
+
+ if log:
+ return Y, C, log_
+ else:
+ return Y, C
diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py
index cdfa9a3..adf6b82 100644
--- a/ot/gromov/_gw.py
+++ b/ot/gromov/_gw.py
@@ -16,14 +16,14 @@ import numpy as np
from ..utils import dist, UndefinedParameter, list_to_array
from ..optim import cg, line_search_armijo, solve_1d_linesearch_quad
-from ..utils import check_random_state
+from ..utils import check_random_state, unif
from ..backend import get_backend, NumpyBackend
from ._utils import init_matrix, gwloss, gwggrad
-from ._utils import update_square_loss, update_kl_loss
+from ._utils import update_square_loss, update_kl_loss, update_feature_matrix
-def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
+def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Returns the Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -31,7 +31,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log
The function solves the following optimization problem:
.. math::
- \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
@@ -60,11 +60,13 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : str
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
+ Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : str, optional
loss function used for the solver either 'square_loss' or 'kl_loss'
symmetric : bool, optional
Either C1 and C2 are to be assumed symmetric or not.
@@ -112,15 +114,24 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log
distance between networks and stable network invariants.
Information and Inference: A Journal of the IMA, 8(4), 757-787.
"""
- p, q = list_to_array(p, q)
- p0, q0, C10, C20 = p, q, C1, C2
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20)
+ arr = [C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(C1.shape[0], type_as=C1)
+ if q is not None:
+ arr.append(list_to_array(q))
else:
+ q = unif(C2.shape[0], type_as=C2)
+ if G0 is not None:
G0_ = G0
- nx = get_backend(p0, q0, C10, C20, G0_)
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
+ arr.append(G0)
+
+ nx = get_backend(*arr)
+ p0, q0, C10, C20 = p, q, C1, C2
+
+ p = nx.to_numpy(p0)
+ q = nx.to_numpy(q0)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
if symmetric is None:
@@ -168,7 +179,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log
return nx.from_numpy(cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
+def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Returns the Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -176,7 +187,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo
The function solves the following optimization problem:
.. math::
- GW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ \mathbf{GW} = \min_\mathbf{T} \quad \sum_{i,j,k,l}
L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
@@ -209,10 +220,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : array-like, shape (ns,)
+ p : array-like, shape (ns,), optional
Distribution in the source space.
- q : array-like, shape (nt,)
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
symmetric : bool, optional
@@ -266,6 +279,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo
# simple get_backend as the full one will be handled in gromov_wasserstein
nx = get_backend(C1, C2)
+ # init marginals if set as None
+ if p is None:
+ p = unif(C1.shape[0], type_as=C1)
+ if q is None:
+ q = unif(C2.shape[0], type_as=C2)
+
T, log_gw = gromov_wasserstein(
C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0,
max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
@@ -286,20 +305,20 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, lo
return gw
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5,
+def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5,
armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
.. math::
- \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
+ \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
- \mathbf{\gamma} &\geq 0
+ \mathbf{T} &\geq 0
where :
@@ -323,10 +342,12 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
Metric cost matrix representative of the structure in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix representative of the structure in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
+ Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str, optional
Loss function used for the solver
symmetric : bool, optional
@@ -354,7 +375,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
Returns
-------
- gamma : array-like, shape (`ns`, `nt`)
+ T : array-like, shape (`ns`, `nt`)
Optimal transportation matrix for the given parameters.
log : dict
Log dictionary return only if log==True in parameters.
@@ -372,16 +393,24 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
distance between networks and stable network invariants.
Information and Inference: A Journal of the IMA, 8(4), 757-787.
"""
- p, q = list_to_array(p, q)
- p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20, M0)
+ arr = [C1, C2, M]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(C1.shape[0], type_as=C1)
+ if q is not None:
+ arr.append(list_to_array(q))
else:
+ q = unif(C2.shape[0], type_as=C2)
+ if G0 is not None:
G0_ = G0
- nx = get_backend(p0, q0, C10, C20, M0, G0_)
+ arr.append(G0)
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
+ nx = get_backend(*arr)
+ p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha
+
+ p = nx.to_numpy(p0)
+ q = nx.to_numpy(q0)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
M = nx.to_numpy(M0)
@@ -433,20 +462,20 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
-def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5,
+def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5,
armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
.. math::
- \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
+ \mathbf{GW} = \min_\mathbf{T} \quad (1 - \alpha) \langle \mathbf(T), \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf(T)\mathbf{1} &= \mathbf{p}
- \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+ \mathbf(T)^T \mathbf{1} &= \mathbf{q}
- \mathbf{\gamma} &\geq 0
+ \mathbf(T) &\geq 0
where :
@@ -474,10 +503,12 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric
Metric cost matrix representative of the structure in the source space.
C2 : array-like, shape (nt, nt)
Metric cost matrix representative of the structure in the target space.
- p : array-like, shape (ns,)
+ p : array-like, shape (ns,), optional
Distribution in the source space.
- q : array-like, shape (nt,)
+ If let to its default value None, uniform distribution is taken.
+ q : array-like, shape (nt,), optional
Distribution in the target space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str, optional
Loss function used for the solver.
symmetric : bool, optional
@@ -529,6 +560,12 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric
"""
nx = get_backend(C1, C2, M)
+ # init marginals if set as None
+ if p is None:
+ p = unif(C1.shape[0], type_as=C1)
+ if q is None:
+ q = unif(C2.shape[0], type_as=C2)
+
T, log_fgw = fused_gromov_wasserstein(
M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True,
max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
@@ -626,9 +663,10 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
return alpha, 1, cost_G
-def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=False,
- max_iter=1000, tol=1e-9, verbose=False, log=False,
- init_C=None, random_state=None, **kwargs):
+def gromov_barycenters(
+ N, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss', symmetric=True, armijo=False,
+ max_iter=1000, tol=1e-9, warmstartT=False, verbose=False, log=False,
+ init_C=None, random_state=None, **kwargs):
r"""
Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
@@ -649,13 +687,16 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F
Size of the targeted barycenter
Cs : list of S array-like of shape (ns, ns)
Metric cost matrices
- ps : list of S array-like of shape (ns,)
- Sample weights in the `S` spaces
- p : array-like, shape (N,)
- Weights in the targeted barycenter
- lambdas : list of float
- List of the `S` spaces' weights
- loss_fun : callable
+ ps : list of S array-like of shape (ns,), optional
+ Sample weights in the `S` spaces.
+ If let to its default value None, uniform distributions are taken.
+ p : array-like, shape (N,), optional
+ Weights in the targeted barycenter.
+ If let to its default value None, uniform distribution is taken.
+ lambdas : list of float, optional
+ List of the `S` spaces' weights.
+ If let to its default value None, uniform weights are taken.
+ loss_fun : callable, optional
tensor-matrix multiplication function based on specific loss function
symmetric : bool, optional.
Either structures are to be assumed symmetric or not. Default value is True.
@@ -668,6 +709,9 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F
Max number of iterations
tol : float, optional
Stop threshold on relative error (>0)
+ warmstartT: bool, optional
+ Either to perform warmstart of transport plans in the successive
+ fused gromov-wasserstein transport problems.s
verbose : bool, optional
Print information along iterations.
log : bool, optional
@@ -692,11 +736,21 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F
"""
Cs = list_to_array(*Cs)
- ps = list_to_array(*ps)
- p = list_to_array(p)
- nx = get_backend(*Cs, *ps, p)
+ arr = [*Cs]
+ if ps is not None:
+ arr += list_to_array(*ps)
+ else:
+ ps = [unif(C.shape[0], type_as=C) for C in Cs]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(N, type_as=Cs[0])
+
+ nx = get_backend(*arr)
S = len(Cs)
+ if lambdas is None:
+ lambdas = [1. / S] * S
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
@@ -714,13 +768,19 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F
cpt = 0
err = 1
+ if warmstartT:
+ T = [None] * S
error = []
while (err > tol and cpt < max_iter):
Cprev = C
- T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo,
- max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
+ if warmstartT:
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=T[s],
+ max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
+ else:
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=None,
+ max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -747,9 +807,11 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=F
return C
-def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
- p=None, loss_fun='square_loss', armijo=False, symmetric=True, max_iter=100, tol=1e-9,
- verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs):
+def fgw_barycenters(
+ N, Ys, Cs, ps=None, lambdas=None, alpha=0.5, fixed_structure=False,
+ fixed_features=False, p=None, loss_fun='square_loss', armijo=False,
+ symmetric=True, max_iter=100, tol=1e-9, warmstartT=False, verbose=False,
+ log=False, init_C=None, init_X=None, random_state=None, **kwargs):
r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
Parameters
@@ -760,16 +822,21 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Features of all samples
Cs : list of array-like, each element has shape (ns,ns)
Structure matrices of all samples
- ps : list of array-like, each element has shape (ns,)
+ ps : list of array-like, each element has shape (ns,), optional
Masses of all samples.
- lambdas : list of float
- List of the `S` spaces' weights
- alpha : float
- Alpha parameter for the fgw distance
+ If let to its default value None, uniform distributions are taken.
+ lambdas : list of float, optional
+ List of the `S` spaces' weights.
+ If let to its default value None, uniform weights are taken.
+ alpha : float, optional
+ Alpha parameter for the fgw distance.
fixed_structure : bool
Whether to fix the structure of the barycenter during the updates
fixed_features : bool
Whether to fix the feature of the barycenter during the updates
+ p : array-like, shape (N,), optional
+ Weights in the targeted barycenter.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str
Loss function used for the solver either 'square_loss' or 'kl_loss'
symmetric : bool, optional
@@ -779,6 +846,9 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Max number of iterations
tol : float, optional
Stop threshold on relative error (>0)
+ warmstartT: bool, optional
+ Either to perform warmstart of transport plans in the successive
+ fused gromov-wasserstein transport problems.
verbose : bool, optional
Print information along iterations.
log : bool, optional
@@ -814,15 +884,24 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
International Conference on Machine Learning (ICML). 2019.
"""
Cs = list_to_array(*Cs)
- ps = list_to_array(*ps)
Ys = list_to_array(*Ys)
- p = list_to_array(p)
- nx = get_backend(*Cs, *Ys, *ps)
+ arr = [*Cs, *Ys]
+ if ps is not None:
+ arr += list_to_array(*ps)
+ else:
+ ps = [unif(C.shape[0], type_as=C) for C in Cs]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(N, type_as=Cs[0])
+
+ nx = get_backend(*arr)
S = len(Cs)
+ if lambdas is None:
+ lambdas = [1. / S] * S
+
d = Ys[0].shape[1] # dimension on the node features
- if p is None:
- p = nx.ones(N, type_as=Cs[0]) / N
if fixed_structure:
if init_C is None:
@@ -877,13 +956,21 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
if not fixed_structure:
+ T_temp = [t.T for t in T]
if loss_fun == 'square_loss':
- T_temp = [t.T for t in T]
- C = update_structure_matrix(p, lambdas, T_temp, Cs)
+ C = update_square_loss(p, lambdas, T_temp, Cs)
- T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
- max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T_temp, Cs)
+ if warmstartT:
+ T = [fused_gromov_wasserstein(
+ Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
+ G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
+ else:
+ T = [fused_gromov_wasserstein(
+ Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
+ G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
# T is N,ns
err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
err_structure = nx.norm(C - Cprev)
@@ -910,82 +997,3 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
return X, C, log_
else:
return X, C
-
-
-def update_structure_matrix(p, lambdas, T, Cs):
- r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
-
- It is calculated at each iteration
-
- Parameters
- ----------
- p : array-like, shape (N,)
- Masses in the targeted barycenter.
- lambdas : list of float
- List of the `S` spaces' weights.
- T : list of S array-like of shape (ns, N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
- Cs : list of S array-like, shape (ns, ns)
- Metric cost matrices.
-
- Returns
- -------
- C : array-like, shape (`nt`, `nt`)
- Updated :math:`\mathbf{C}` matrix.
- """
- p = list_to_array(p)
- T = list_to_array(*T)
- Cs = list_to_array(*Cs)
- nx = get_backend(*Cs, *T, p)
-
- tmpsum = sum([
- lambdas[s] * nx.dot(
- nx.dot(T[s].T, Cs[s]),
- T[s]
- ) for s in range(len(T))
- ])
- ppt = nx.outer(p, p)
- return tmpsum / ppt
-
-
-def update_feature_matrix(lambdas, Ys, Ts, p):
- r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
-
-
- See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
- in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
-
- Parameters
- ----------
- p : array-like, shape (N,)
- masses in the targeted barycenter
- lambdas : list of float
- List of the `S` spaces' weights
- Ts : list of S array-like, shape (ns,N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
- Ys : list of S array-like, shape (d,ns)
- The features.
-
- Returns
- -------
- X : array-like, shape (`d`, `N`)
-
-
- .. _references-update-feature-matrix:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
- """
- p = list_to_array(p)
- Ts = list_to_array(*Ts)
- Ys = list_to_array(*Ys)
- nx = get_backend(*Ys, *Ts, p)
-
- p = 1. / p
- tmpsum = sum([
- lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
- for s in range(len(Ts))
- ])
- return tmpsum
diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py
index 94dc975..206329d 100644
--- a/ot/gromov/_semirelaxed.py
+++ b/ot/gromov/_semirelaxed.py
@@ -18,7 +18,7 @@ from ..backend import get_backend
from ._utils import init_matrix_semirelaxed, gwloss, gwggrad
-def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None,
+def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None,
max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Returns the semi-relaxed Gromov-Wasserstein divergence transport from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
@@ -26,12 +26,12 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=
The function solves the following optimization problem:
.. math::
- \mathbf{srGW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ \mathbf{T}^^* \in \mathop{\arg \min}_{\mathbf{T}} \quad \sum_{i,j,k,l}
L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- \mathbf{\gamma} &\geq 0
+ \mathbf{T} &\geq 0
Where :
@@ -51,8 +51,9 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
@@ -93,11 +94,16 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=
"""
if loss_fun == 'kl_loss':
raise NotImplementedError()
- p = list_to_array(p)
- if G0 is None:
- nx = get_backend(p, C1, C2)
+ arr = [C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
else:
- nx = get_backend(p, C1, C2, G0)
+ p = unif(C1.shape[0], type_as=C1)
+
+ if G0 is not None:
+ arr.append(G0)
+
+ nx = get_backend(*arr)
if symmetric is None:
symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
@@ -143,7 +149,7 @@ def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=
return semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
-def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None,
+def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None,
max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Returns the semi-relaxed gromov-wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
@@ -151,12 +157,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric
The function solves the following optimization problem:
.. math::
- srGW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ \text{srGW} = \min_{\mathbf{T}} \quad \sum_{i,j,k,l}
L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- \mathbf{\gamma} &\geq 0
+ \mathbf{T} &\geq 0
Where :
@@ -179,8 +185,9 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : array-like, shape (ns,)
+ p : array-like, shape (ns,), optional
Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
@@ -218,7 +225,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
- nx = get_backend(p, C1, C2)
+ # partial get_backend as the full one will be handled in gromov_wasserstein
+ nx = get_backend(C1, C2)
+
+ # init marginals if set as None
+ if p is None:
+ p = unif(C1.shape[0], type_as=C1)
T, log_srgw = semirelaxed_gromov_wasserstein(
C1, C2, p, loss_fun, symmetric, log=True, G0=G0,
@@ -239,18 +251,19 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric
return srgw
-def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False,
- max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+def semirelaxed_fused_gromov_wasserstein(
+ M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5,
+ G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Computes the semi-relaxed FGW transport between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`)
.. math::
- \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
- \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+ \mathbf{T}^* \in \mathop{\arg \min}_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- \mathbf{\gamma} &\geq 0
+ \mathbf{T} &\geq 0
where :
@@ -273,8 +286,9 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s
Metric cost matrix representative of the structure in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix representative of the structure in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
@@ -321,11 +335,16 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s
if loss_fun == 'kl_loss':
raise NotImplementedError()
- p = list_to_array(p)
- if G0 is None:
- nx = get_backend(p, C1, C2, M)
+ arr = [M, C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
else:
- nx = get_backend(p, C1, C2, M, G0)
+ p = unif(C1.shape[0], type_as=C1)
+
+ if G0 is not None:
+ arr.append(G0)
+
+ nx = get_backend(*arr)
if symmetric is None:
symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
@@ -373,18 +392,18 @@ def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', s
return semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
-def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False,
+def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False,
max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
r"""
Computes the semi-relaxed FGW divergence between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein2>`)
.. math::
- \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+ \mathbf{srFGW} = \min_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) T_{i,j} T_{k,l}
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- \mathbf{\gamma} &\geq 0
+ \mathbf{T} &\geq 0
where :
@@ -412,6 +431,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss',
Metric cost matrix representative of the structure in the target space.
p : array-like, shape (ns,)
Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
loss_fun : str, optional
loss function used for the solver either 'square_loss' or 'kl_loss'.
'kl_loss' is not implemented yet and will raise an error.
@@ -455,7 +475,12 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss',
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
- nx = get_backend(p, C1, C2, M)
+ # partial get_backend as the full one will be handled in gromov_wasserstein
+ nx = get_backend(C1, C2)
+
+ # init marginals if set as None
+ if p is None:
+ p = unif(C1.shape[0], type_as=C1)
T, log_fgw = semirelaxed_fused_gromov_wasserstein(
M, C1, C2, p, loss_fun, symmetric, alpha, G0, log=True,
@@ -551,3 +576,501 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
cost_G = cost_G + a * (alpha ** 2) + b * alpha
return alpha, 1, cost_G
+
+
+def entropic_semirelaxed_gromov_wasserstein(
+ C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None,
+ G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs):
+ r"""
+ Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence
+ transport plan from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
+ estimated using a Mirror Descent algorithm following the KL geometry.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{T}^* \in \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error computed on transport plans
+ log : bool, optional
+ record log if True
+ verbose : bool, optional
+ Print information along iterations
+ Returns
+ -------
+ G : array-like, shape (`ns`, `nt`)
+ Coupling between the two spaces that minimizes:
+
+ :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
+ log : dict
+ Convergence information and loss.
+
+ References
+ ----------
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ if loss_fun == 'kl_loss':
+ raise NotImplementedError()
+ arr = [C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(C1.shape[0], type_as=C1)
+
+ if G0 is not None:
+ arr.append(G0)
+
+ nx = get_backend(*arr)
+
+ if symmetric is None:
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
+ if G0 is None:
+ q = unif(C2.shape[0], type_as=p)
+ G0 = nx.outer(p, q)
+ else:
+ q = nx.sum(G0, 0)
+ # Check first marginal of G0
+ np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
+
+ constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
+
+ ones_p = nx.ones(p.shape[0], type_as=p)
+
+ if symmetric:
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return gwggrad(constC + marginal_product, hC1, hC2, G, nx)
+ else:
+ constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t))
+ marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2))
+ return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx))
+
+ cpt = 0
+ err = 1e15
+ G = G0
+
+ if log:
+ log = {'err': []}
+
+ while (err > tol and cpt < max_iter):
+
+ Gprev = G
+ # compute the kernel
+ K = G * nx.exp(- df(G) / epsilon)
+ scaling = p / nx.sum(K, 1)
+ G = nx.reshape(scaling, (-1, 1)) * K
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(G - Gprev)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ log['srgw_dist'] = gwloss(constC + marginal_product, hC1, hC2, G, nx)
+ return G, log
+ else:
+ return G
+
+
+def entropic_semirelaxed_gromov_wasserstein2(
+ C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None,
+ G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs):
+ r"""
+ Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence
+ from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
+ estimated using a Mirror Descent algorithm following the KL geometry.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{srGW} = \min_{\mathbf{T}} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2) but not yet for the weights p.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error computed on transport plans
+ log : bool, optional
+ record log if True
+ verbose : bool, optional
+ Print information along iterations
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ srgw : float
+ Semi-relaxed Gromov-Wasserstein divergence
+ log : dict
+ convergence information and Coupling matrix
+
+ References
+ ----------
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ T, log_srgw = entropic_semirelaxed_gromov_wasserstein(
+ C1, C2, p, loss_fun, epsilon, symmetric, G0,
+ max_iter, tol, log=True, verbose=verbose, **kwargs)
+
+ log_srgw['T'] = T
+
+ if log:
+ return log_srgw['srgw_dist'], log_srgw
+ else:
+ return log_srgw['srgw_dist']
+
+
+def entropic_semirelaxed_fused_gromov_wasserstein(
+ M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1,
+ alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs):
+ r"""
+ Computes the entropic-regularized semi-relaxed FGW transport between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`)
+
+ .. math::
+ \mathbf{T}^* \in \mathop{\arg \min}_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix between features
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}` source weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error computed on transport plans
+ log : bool, optional
+ record log if True
+ verbose : bool, optional
+ Print information along iterations
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ G : array-like, shape (`ns`, `nt`)
+ Optimal transportation matrix for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-semirelaxed-fused-gromov-wasserstein:
+ References
+ ----------
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ if loss_fun == 'kl_loss':
+ raise NotImplementedError()
+ arr = [M, C1, C2]
+ if p is not None:
+ arr.append(list_to_array(p))
+ else:
+ p = unif(C1.shape[0], type_as=C1)
+
+ if G0 is not None:
+ arr.append(G0)
+
+ nx = get_backend(*arr)
+
+ if symmetric is None:
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
+ if G0 is None:
+ q = unif(C2.shape[0], type_as=p)
+ G0 = nx.outer(p, q)
+ else:
+ q = nx.sum(G0, 0)
+ # Check first marginal of G0
+ np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
+
+ constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
+
+ ones_p = nx.ones(p.shape[0], type_as=p)
+ dM = (1 - alpha) * M
+ if symmetric:
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return alpha * gwggrad(constC + marginal_product, hC1, hC2, G, nx) + dM
+ else:
+ constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t))
+ marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2))
+ return 0.5 * alpha * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx)) + dM
+
+ cpt = 0
+ err = 1e15
+ G = G0
+
+ if log:
+ log = {'err': []}
+
+ while (err > tol and cpt < max_iter):
+
+ Gprev = G
+ # compute the kernel
+ K = G * nx.exp(- df(G) / epsilon)
+ scaling = p / nx.sum(K, 1)
+ G = nx.reshape(scaling, (-1, 1)) * K
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(G - Gprev)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ log['srfgw_dist'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) + (1 - alpha) * nx.sum(M * G)
+ return G, log
+ else:
+ return G
+
+
+def entropic_semirelaxed_fused_gromov_wasserstein2(
+ M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1,
+ alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs):
+ r"""
+ Computes the entropic-regularized semi-relaxed FGW transport between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`)
+
+ .. math::
+ \mathbf{srFGW} = \min_{\mathbf{T}} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix between features
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}` source weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space.
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space.
+ If let to its default value None, uniform distribution is taken.
+ loss_fun : str, optional
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error computed on transport plans
+ log : bool, optional
+ record log if True
+ verbose : bool, optional
+ Print information along iterations
+ **kwargs : dict
+ Parameters can be directly passed to the ot.optim.cg solver.
+
+ Returns
+ -------
+ srfgw-divergence : float
+ Semi-relaxed Fused gromov wasserstein divergence for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-semirelaxed-fused-gromov-wasserstein2:
+ References
+ ----------
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ T, log_srfgw = entropic_semirelaxed_fused_gromov_wasserstein(
+ M, C1, C2, p, loss_fun, symmetric, epsilon, alpha, G0,
+ max_iter, tol, log=True, verbose=verbose, **kwargs)
+
+ log_srfgw['T'] = T
+
+ if log:
+ return log_srfgw['srfgw_dist'], log_srfgw
+ else:
+ return log_srfgw['srfgw_dist']
diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py
index ef8cd88..0b8bb00 100644
--- a/ot/gromov/_utils.py
+++ b/ot/gromov/_utils.py
@@ -324,6 +324,49 @@ def update_kl_loss(p, lambdas, T, Cs):
return nx.exp(tmpsum / ppt)
+def update_feature_matrix(lambdas, Ys, Ts, p):
+ r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
+
+
+ See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
+ in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ List of the `S` spaces' weights
+ Ts : list of S array-like, shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
+ Ys : list of S array-like, shape (d,ns)
+ The features.
+
+ Returns
+ -------
+ X : array-like, shape (`d`, `N`)
+
+
+ .. _references-update-feature-matrix:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ p = list_to_array(p)
+ Ts = list_to_array(*Ts)
+ Ys = list_to_array(*Ys)
+ nx = get_backend(*Ys, *Ts, p)
+
+ p = 1. / p
+ tmpsum = sum([
+ lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
+ for s in range(len(Ts))
+ ])
+ return tmpsum
+
+
def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None):
r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 1beb818..13ff3fe 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -34,8 +34,11 @@ def test_gromov(nx):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
- G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True)
- Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=G0b, verbose=True))
+ G = ot.gromov.gromov_wasserstein(
+ C1, C2, None, q, 'square_loss', G0=G0, verbose=True,
+ alpha_min=0., alpha_max=1.)
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(
+ C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True))
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
@@ -48,8 +51,8 @@ def test_gromov(nx):
np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
- gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, log=True)
- gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=True)
+ gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=True, log=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=True, log=True)
gwb = nx.to_numpy(gwb)
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False)
@@ -312,11 +315,11 @@ def test_entropic_gromov(nx):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
G, log = ot.gromov.entropic_gromov_wasserstein(
- C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
- epsilon=1e-2, verbose=True, log=True)
+ C1, C2, None, q, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-2, max_iter=10, verbose=True, log=True)
Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
- C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None,
- epsilon=1e-2, verbose=True, log=False
+ C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None,
+ epsilon=1e-2, max_iter=10, verbose=True, log=False
))
# check constraints
@@ -327,10 +330,10 @@ def test_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
- C1, C2, p, q, 'kl_loss', symmetric=True, G0=None,
+ C1, C2, p, None, 'kl_loss', symmetric=True, G0=None,
max_iter=10, epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
- C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b,
max_iter=10, epsilon=1e-2, log=True)
gwb = nx.to_numpy(gwb)
@@ -348,6 +351,65 @@ def test_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+def test_entropic_proximal_gromov(nx):
+ n_samples = 10 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
+
+ G, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, None, q, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-1, max_iter=50, solver='PPA', verbose=True, log=True, numItermax=1)
+ Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None,
+ epsilon=1e-1, max_iter=50, solver='PPA', verbose=True, log=False, numItermax=1
+ ))
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ gw, log = ot.gromov.entropic_gromov_wasserstein2(
+ C1, C2, p, q, 'kl_loss', symmetric=True, G0=None,
+ max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True)
+ gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True)
+ gwb = nx.to_numpy(gwb)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
+ np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
def test_asymmetric_entropic_gromov(nx):
n_samples = 10 # nb samples
np.random.seed(0)
@@ -363,10 +425,10 @@ def test_asymmetric_entropic_gromov(nx):
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
G = ot.gromov.entropic_gromov_wasserstein(
C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
- epsilon=1e-1, verbose=True, log=False)
+ epsilon=1e-1, max_iter=5, verbose=True, log=False)
Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None,
- epsilon=1e-1, verbose=True, log=False
+ epsilon=1e-1, max_iter=5, verbose=True, log=False
))
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
@@ -376,11 +438,11 @@ def test_asymmetric_entropic_gromov(nx):
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw = ot.gromov.entropic_gromov_wasserstein2(
- C1, C2, p, q, 'kl_loss', symmetric=False, G0=None,
- max_iter=10, epsilon=1e-1, log=False)
+ C1, C2, None, None, 'kl_loss', symmetric=False, G0=None,
+ max_iter=5, epsilon=1e-1, log=False)
gwb = ot.gromov.entropic_gromov_wasserstein2(
C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
- max_iter=10, epsilon=1e-1, log=False)
+ max_iter=5, epsilon=1e-1, log=False)
gwb = nx.to_numpy(gwb)
np.testing.assert_allclose(gw, gwb, atol=1e-06)
@@ -414,15 +476,300 @@ def test_entropic_gromov_dtype_device(nx):
C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp)
- Gb = ot.gromov.entropic_gromov_wasserstein(
- C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
- )
- gw_valb = ot.gromov.entropic_gromov_wasserstein2(
- C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
- )
+ for solver in ['PGD', 'PPA']:
+ Gb = ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=1e-1, max_iter=5,
+ solver=solver, verbose=True
+ )
+ gw_valb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=1e-1, max_iter=5,
+ solver=solver, verbose=True
+ )
- nx.assert_same_dtype_device(C1b, Gb)
- nx.assert_same_dtype_device(C1b, gw_valb)
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
+def test_entropic_fgw(nx):
+ n_samples = 10 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0], 2)
+ yt = ys[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+
+ G, log = ot.gromov.entropic_fused_gromov_wasserstein(
+ M, C1, C2, None, None, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-1, max_iter=10, verbose=True, log=True)
+ Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein(
+ Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None,
+ epsilon=1e-1, max_iter=10, verbose=True, log=False
+ ))
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2(
+ M, C1, C2, p, q, 'kl_loss', symmetric=True, G0=None,
+ max_iter=10, epsilon=1e-1, log=True)
+ fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2(
+ Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=10, epsilon=1e-1, log=True)
+ fgwb = nx.to_numpy(fgwb)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+
+ np.testing.assert_allclose(fgw, fgwb, atol=1e-06)
+ np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_entropic_proximal_fgw(nx):
+ n_samples = 10 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0], 2)
+ yt = ys[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+
+ G, log = ot.gromov.entropic_fused_gromov_wasserstein(
+ M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
+ epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1)
+ Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein(
+ Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None,
+ epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=False, numItermax=1
+ ))
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2(
+ M, C1, C2, p, None, 'kl_loss', symmetric=True, G0=None,
+ max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True)
+ fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2(
+ Mb, C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True)
+ fgwb = nx.to_numpy(fgwb)
+
+ G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+
+ np.testing.assert_allclose(fgw, fgwb, atol=1e-06)
+ np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_asymmetric_entropic_fgw(nx):
+ n_samples = 10 # nb samples
+ np.random.seed(0)
+ C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples))
+ idx = np.arange(n_samples)
+ np.random.shuffle(idx)
+ C2 = C1[idx, :][:, idx]
+
+ ys = np.random.randn(n_samples, 2)
+ yt = ys[idx, :]
+ M = ot.dist(ys, yt)
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
+
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
+ G = ot.gromov.entropic_fused_gromov_wasserstein(
+ M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
+ max_iter=5, epsilon=1e-1, verbose=True, log=False)
+ Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein(
+ Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None,
+ max_iter=5, epsilon=1e-1, verbose=True, log=False
+ ))
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ fgw = ot.gromov.entropic_fused_gromov_wasserstein2(
+ M, C1, C2, p, q, 'kl_loss', symmetric=False, G0=None,
+ max_iter=5, epsilon=1e-1, log=False)
+ fgwb = ot.gromov.entropic_fused_gromov_wasserstein2(
+ Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
+ max_iter=5, epsilon=1e-1, log=False)
+ fgwb = nx.to_numpy(fgwb)
+
+ np.testing.assert_allclose(fgw, fgwb, atol=1e-06)
+ np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+
+
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
+def test_entropic_fgw_dtype_device(nx):
+ # setup
+ n_samples = 5 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0], 2)
+ yt = ys[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ Mb, C1b, C2b, pb, qb = nx.from_numpy(M, C1, C2, p, q, type_as=tp)
+
+ for solver in ['PGD', 'PPA']:
+ Gb = ot.gromov.entropic_fused_gromov_wasserstein(
+ Mb, C1b, C2b, pb, qb, 'square_loss', epsilon=0.1, max_iter=5,
+ solver=solver, verbose=True
+ )
+ fgw_valb = ot.gromov.entropic_fused_gromov_wasserstein2(
+ Mb, C1b, C2b, pb, qb, 'square_loss', epsilon=0.1, max_iter=5,
+ solver=solver, verbose=True
+ )
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, fgw_valb)
+
+
+def test_entropic_fgw_barycenter(nx):
+ ns = 5
+ nt = 10
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
+
+ ys = np.random.randn(Xs.shape[0], 2)
+ yt = np.random.randn(Xt.shape[0], 2)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+ p1 = ot.unif(ns)
+ p2 = ot.unif(nt)
+ n_samples = 2
+ p = ot.unif(n_samples)
+
+ ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p)
+
+ X, C, log = ot.gromov.entropic_fused_gromov_barycenters(
+ n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], 'square_loss', 0.1,
+ max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42,
+ solver='PPA', numItermax=1, log=True
+ )
+ Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', 0.1,
+ max_iter=10, tol=1e-3, verbose=False, warmstartT=True, random_state=42,
+ solver='PPA', numItermax=1, log=False)
+ Xb, Cb = nx.to_numpy(Xb, Cb)
+
+ np.testing.assert_allclose(C, Cb, atol=1e-06)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X, Xb, atol=1e-06)
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
+
+ # test with 'kl_loss' and log=True
+ # providing init_C, init_Y
+ generator = ot.utils.check_random_state(42)
+ xalea = generator.randn(n_samples, 2)
+ init_C = ot.utils.dist(xalea, xalea)
+ init_C /= init_C.max()
+ init_Cb = nx.from_numpy(init_C)
+
+ init_Y = np.zeros((n_samples, ys.shape[1]), dtype=ys.dtype)
+ init_Yb = nx.from_numpy(init_Y)
+
+ X, C, log = ot.gromov.entropic_fused_gromov_barycenters(
+ n_samples, [ys, yt], [C1, C2], [p1, p2], p, None, 'kl_loss', 0.1,
+ max_iter=10, tol=1e-3, verbose=False, warmstartT=False, random_state=42,
+ solver='PPA', numItermax=1, init_C=init_C, init_Y=init_Y, log=True
+ )
+ Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', 0.1,
+ max_iter=10, tol=1e-3, verbose=False, warmstartT=False, random_state=42,
+ solver='PPA', numItermax=1, init_C=init_Cb, init_Y=init_Yb, log=True)
+ Xb, Cb = nx.to_numpy(Xb, Cb)
+
+ np.testing.assert_allclose(C, Cb, atol=1e-06)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X, Xb, atol=1e-06)
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
+ np.testing.assert_array_almost_equal(log['err_feature'], nx.to_numpy(*logb['err_feature']))
+ np.testing.assert_array_almost_equal(log['err_structure'], nx.to_numpy(*logb['err_structure']))
def test_pointwise_gromov(nx):
@@ -539,11 +886,11 @@ def test_gromov_barycenter(nx):
C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p)
Cb = ot.gromov.gromov_barycenters(
- n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ n_samples, [C1, C2], None, p, [.5, .5],
'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42
)
Cbb = nx.to_numpy(ot.gromov.gromov_barycenters(
- n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5],
'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42
))
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
@@ -551,12 +898,12 @@ def test_gromov_barycenter(nx):
# test of gromov_barycenters with `log` on
Cb_, err_ = ot.gromov.gromov_barycenters(
- n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
+ n_samples, [C1, C2], [p1, p2], p, None, 'square_loss', max_iter=100,
+ tol=1e-3, verbose=False, warmstartT=True, random_state=42, log=True
)
Cbb_, errb_ = ot.gromov.gromov_barycenters(
- n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'square_loss', max_iter=100,
+ tol=1e-3, verbose=False, warmstartT=True, random_state=42, log=True
)
Cbb_ = nx.to_numpy(Cbb_)
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
@@ -565,23 +912,31 @@ def test_gromov_barycenter(nx):
Cb2 = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ 'kl_loss', max_iter=100, tol=1e-3, warmstartT=True, random_state=42
)
Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ 'kl_loss', max_iter=100, tol=1e-3, warmstartT=True, random_state=42
))
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
# test of gromov_barycenters with `log` on
+ # providing init_C
+ generator = ot.utils.check_random_state(42)
+ xalea = generator.randn(n_samples, 2)
+ init_C = ot.utils.dist(xalea, xalea)
+ init_C /= init_C.max()
+ init_Cb = nx.from_numpy(init_C)
+
Cb2_, err2_ = ot.gromov.gromov_barycenters(
- n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', max_iter=100,
+ tol=1e-3, verbose=False, random_state=42, log=True, init_C=init_C
)
Cb2b_, err2b_ = ot.gromov.gromov_barycenters(
- n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss',
+ max_iter=100, tol=1e-3, verbose=True, random_state=42,
+ init_C=init_Cb, log=True
)
Cb2b_ = nx.to_numpy(Cb2b_)
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
@@ -607,24 +962,24 @@ def test_gromov_entropic_barycenter(nx):
C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p)
Cb = ot.gromov.entropic_gromov_barycenters(
- n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', 1e-3,
+ max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42
)
Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
- n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', 1e-3,
+ max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42
))
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
# test of entropic_gromov_barycenters with `log` on
Cb_, err_ = ot.gromov.entropic_gromov_barycenters(
- n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ n_samples, [C1, C2], [p1, p2], p, None,
+ 'square_loss', 1e-3, max_iter=10, tol=1e-3, verbose=True, random_state=42, log=True
)
Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'square_loss', 1e-3, max_iter=10, tol=1e-3, verbose=True, random_state=42, log=True
)
Cbb_ = nx.to_numpy(Cbb_)
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
@@ -633,23 +988,32 @@ def test_gromov_entropic_barycenter(nx):
Cb2 = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42
)
Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42
))
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
# test of entropic_gromov_barycenters with `log` on
+ # providing init_C
+ generator = ot.utils.check_random_state(42)
+ xalea = generator.randn(n_samples, 2)
+ init_C = ot.utils.dist(xalea, xalea)
+ init_C /= init_C.max()
+ init_Cb = nx.from_numpy(init_C)
+
Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters(
- n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', 1e-3,
+ max_iter=10, tol=1e-3, warmstartT=True, verbose=True, random_state=42,
+ init_C=init_C, log=True
)
Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'kl_loss', 1e-3, max_iter=10, tol=1e-3, warmstartT=True, verbose=True,
+ random_state=42, init_Cb=init_Cb, log=True
)
Cb2b_ = nx.to_numpy(Cb2b_)
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
@@ -685,8 +1049,8 @@ def test_fgw(nx):
Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
- G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True)
- Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, None, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, None, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True)
Gb = nx.to_numpy(Gb)
# check constraints
@@ -701,8 +1065,8 @@ def test_fgw(nx):
np.testing.assert_allclose(
Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
- fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True)
- fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True)
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, None, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, None, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True)
fgwb = nx.to_numpy(fgwb)
G = log['T']
@@ -923,6 +1287,9 @@ def test_fgw_barycenter(nx):
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
+ C1 /= C1.max()
+ C2 /= C2.max()
+
p1, p2 = ot.unif(ns), ot.unif(nt)
n_samples = 3
p = ot.unif(n_samples)
@@ -930,18 +1297,19 @@ def test_fgw_barycenter(nx):
ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p)
Xb, Cb = ot.gromov.fgw_barycenters(
- n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False,
+ n_samples, [ysb, ytb], [C1b, C2b], None, [.5, .5], 0.5, fixed_structure=False,
fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345
)
xalea = np.random.randn(n_samples, 2)
init_C = ot.dist(xalea, xalea)
+ init_C /= init_C.max()
init_Cb = nx.from_numpy(init_C)
Xb, Cb = ot.gromov.fgw_barycenters(
- n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5],
+ n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
- p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3
+ p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
)
Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
@@ -953,11 +1321,21 @@ def test_fgw_barycenter(nx):
Xb, Cb, logb = ot.gromov.fgw_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
fixed_structure=False, fixed_features=True, init_X=init_Xb,
- p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765
+ p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
+ warmstartT=True, log=True, random_state=98765, verbose=True
)
- Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
- np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
- np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
+ X, C = nx.to_numpy(Xb), nx.to_numpy(Cb)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ # add test with 'kl_loss'
+ X, C = ot.gromov.fgw_barycenters(
+ n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss',
+ max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True, random_state=12345
+ )
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
def test_gromov_wasserstein_linear_unmixing(nx):
@@ -1501,8 +1879,11 @@ def test_semirelaxed_gromov(nx):
# asymmetric
C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
- G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0)
- Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=False, log=True, G0=None)
+ G, log = ot.gromov.semirelaxed_gromov_wasserstein(
+ C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0)
+ Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(
+ C1b, C2b, None, loss_fun='square_loss', symmetric=False, log=True,
+ G0=None, alpha_min=0., alpha_max=1.)
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
@@ -1510,8 +1891,10 @@ def test_semirelaxed_gromov(nx):
np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)
- srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=False, log=True, G0=G0)
- srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+ srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(
+ C1, C2, None, loss_fun='square_loss', symmetric=False, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(
+ C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
G = log2['T']
Gb = nx.to_numpy(logb2['T'])
@@ -1527,16 +1910,20 @@ def test_semirelaxed_gromov(nx):
C1 = 0.5 * (C1 + C1.T)
C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
- G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None)
- Gb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b)
+ G, log = ot.gromov.semirelaxed_gromov_wasserstein(
+ C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+ Gb = ot.gromov.semirelaxed_gromov_wasserstein(
+ C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b)
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
- srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0)
- srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
+ srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(
+ C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(
+ C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None)
srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0)
@@ -1661,7 +2048,7 @@ def test_semirelaxed_fgw(nx):
# asymmetric
Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
- G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+ G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b)
# check constraints
@@ -1670,7 +2057,7 @@ def test_semirelaxed_fgw(nx):
np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0)
- srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
+ srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None)
G = log2['T']
Gb = nx.to_numpy(logb2['T'])
@@ -1819,3 +2206,276 @@ def test_srfgw_helper_backend(nx):
res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9)
# check constraints
np.testing.assert_allclose(res, Gb, atol=1e-06)
+
+
+def test_entropic_semirelaxed_gromov(nx):
+ np.random.seed(0)
+ # unbalanced proportions
+ list_n = [30, 15]
+ nt = 2
+ ns = np.sum(list_n)
+ # create directed sbm with C2 as connectivity matrix
+ C1 = np.zeros((ns, ns), dtype=np.float64)
+ C2 = np.array([[0.8, 0.05],
+ [0.05, 1.]], dtype=np.float64)
+ for i in range(nt):
+ for j in range(nt):
+ ni, nj = list_n[i], list_n[j]
+ xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j])
+ C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
+ p = ot.unif(ns, type_as=C1)
+ q0 = ot.unif(C2.shape[0], type_as=C1)
+ G0 = p[:, None] * q0[None, :]
+ # asymmetric
+ C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
+ epsilon = 0.1
+ G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=G0)
+ Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=False, log=True, G0=None)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04)
+ np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)
+
+ srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, None, loss_fun='square_loss', epsilon=epsilon, symmetric=False, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+
+ # symmetric
+ C1 = 0.5 * (C1 + C1.T)
+ C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None)
+ Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None)
+
+ srgw_ = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0)
+
+ G = log2['T']
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01)
+ np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01)
+
+ np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(srgw, srgw_, atol=1e-07)
+
+
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
+def test_entropic_semirelaxed_gromov_dtype_device(nx):
+ # setup
+ n_samples = 5 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ C1b, C2b, pb = nx.from_numpy(C1, C2, p, type_as=tp)
+
+ Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(
+ C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True
+ )
+ gw_valb = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(
+ C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True
+ )
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, gw_valb)
+
+
+def test_entropic_semirelaxed_fgw(nx):
+ np.random.seed(0)
+ list_n = [16, 8]
+ nt = 2
+ ns = 24
+ # create directed sbm with C2 as connectivity matrix
+ C1 = np.zeros((ns, ns))
+ C2 = np.array([[0.7, 0.05],
+ [0.05, 0.9]])
+ for i in range(nt):
+ for j in range(nt):
+ ni, nj = list_n[i], list_n[j]
+ xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j])
+ C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij
+ F1 = np.zeros((ns, 1))
+ F1[:16] = np.random.normal(loc=0., scale=0.01, size=(16, 1))
+ F1[16:] = np.random.normal(loc=1., scale=0.01, size=(8, 1))
+ F2 = np.zeros((2, 1))
+ F2[1, :] = 1.
+ M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T)
+
+ p = ot.unif(ns)
+ q0 = ot.unif(C2.shape[0])
+ G0 = p[:, None] * q0[None, :]
+
+ # asymmetric
+ Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None)
+ Gb, logb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+
+ # symmetric
+ C1 = 0.5 * (C1 + C1.T)
+ Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0)
+
+ G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None)
+ Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b)
+
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov
+
+ srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0)
+ srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None)
+
+ srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0)
+
+ G = log2['T']
+ Gb = nx.to_numpy(logb2['T'])
+ # check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov
+
+ np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07)
+ np.testing.assert_allclose(srgw, srgw_, atol=1e-07)
+
+
+@pytest.skip_backend("tf", reason="test very slow with tf backend")
+def test_entropic_semirelaxed_fgw_dtype_device(nx):
+ # setup
+ n_samples = 5 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0], 2)
+ yt = ys[::-1].copy()
+
+ p = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ for tp in nx.__type_list__:
+ print(nx.dtype_device(tp))
+
+ Mb, C1b, C2b, pb = nx.from_numpy(M, C1, C2, p, type_as=tp)
+
+ Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(
+ Mb, C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True
+ )
+ fgw_valb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(
+ Mb, C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True
+ )
+
+ nx.assert_same_dtype_device(C1b, Gb)
+ nx.assert_same_dtype_device(C1b, fgw_valb)
+
+
+def test_not_implemented_solver():
+ # test sinkhorn
+ n_samples = 5 # nb samples
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+ xt = xs[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0], 2)
+ yt = ys[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+ M = ot.dist(ys, yt)
+
+ solver = 'not_implemented'
+ # entropic gw and fgw
+ with pytest.raises(ValueError):
+ ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver)
+ with pytest.raises(ValueError):
+ ot.gromov.entropic_fused_gromov_wasserstein(
+ M, C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver)
+
+ # exact and entropic srgw and srfgw loss functions
+ loss_fun = 'kl_loss'
+ with pytest.raises(NotImplementedError):
+ ot.gromov.semirelaxed_gromov_wasserstein(
+ C1, C2, p, loss_fun, armijo=False)
+ with pytest.raises(NotImplementedError):
+ ot.gromov.entropic_semirelaxed_gromov_wasserstein(
+ C1, C2, p, loss_fun, epsilon=0.1)
+ with pytest.raises(NotImplementedError):
+ ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun)
+ with pytest.raises(NotImplementedError):
+ ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(
+ M, C1, C2, p, loss_fun, epsilon=0.1)