summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCédric Vincent-Cuaz <cedvincentcuaz@gmail.com>2022-02-11 10:53:38 +0100
committerGitHub <noreply@github.com>2022-02-11 10:53:38 +0100
commit50c0f17d00e3492c4d56a356af30cf00d6d07913 (patch)
tree57abfe9510fdba64f6e9c1c4b4716e7b0ba28ed0
parenta5e0f0d40d5046a6639924347ef97e2ac80ad0c9 (diff)
[MRG] GW dictionary learning (#319)
* add fgw dictionary learning feature * add fgw dictionary learning feature * plot gromov wasserstein dictionary learning * Update __init__.py * fix pep8 errors exact E501 line too long * fix last pep8 issues * add unitary tests for (F)GW dictionary learning without using autodifferentiable functions * correct tests for (F)GW dictionary learning without using autodiff * correct tests for (F)GW dictionary learning without using autodiff * fix docs and notations * answer to review: improve tests, docs, examples + make node weights optional * fix pep8 and examples * improve docs + tests + thumbnail * make example faster * improve ex * update README.md * make GDL tests faster Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--README.md2
-rw-r--r--RELEASES.md2
-rwxr-xr-xexamples/gromov/plot_gromov_wasserstein_dictionary_learning.py357
-rw-r--r--ot/__init__.py4
-rw-r--r--ot/gromov.py1074
-rw-r--r--test/test_gromov.py554
6 files changed, 1954 insertions, 39 deletions
diff --git a/README.md b/README.md
index a7627df..c6bfd9c 100644
--- a/README.md
+++ b/README.md
@@ -36,6 +36,7 @@ POT provides the following generic OT solvers (links to examples):
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
+* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
POT provides the following Machine Learning related solvers:
@@ -198,6 +199,7 @@ The contributors to this library are
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
+* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
diff --git a/RELEASES.md b/RELEASES.md
index 4d05582..925920a 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -10,7 +10,7 @@
of the regularization parameter (PR #336).
- Backend implementation for `ot.lp.free_support_barycenter` (PR #340).
- Add weak OT solver + example (PR #341).
-
+- Add (F)GW linear dictionary learning solvers + example (PR #319)
#### Closed issues
diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
new file mode 100755
index 0000000..1fdc3b9
--- /dev/null
+++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py
@@ -0,0 +1,357 @@
+# -*- coding: utf-8 -*-
+
+r"""
+=================================
+(Fused) Gromov-Wasserstein Linear Dictionary Learning
+=================================
+
+In this exemple, we illustrate how to learn a Gromov-Wasserstein dictionary on
+a dataset of structured data such as graphs, denoted
+:math:`\{ \mathbf{C_s} \}_{s \in [S]}` where every nodes have uniform weights.
+Given a dictionary :math:`\mathbf{C_{dict}}` composed of D structures of a fixed
+size nt, each graph :math:`(\mathbf{C_s}, \mathbf{p_s})`
+is modeled as a convex combination :math:`\mathbf{w_s} \in \Sigma_D` of these
+dictionary atoms as :math:`\sum_d w_{s,d} \mathbf{C_{dict}[d]}`.
+
+
+First, we consider a dataset composed of graphs generated by Stochastic Block models
+with variable sizes taken in :math:`\{30, ... , 50\}` and quantities of clusters
+varying in :math:`\{ 1, 2, 3\}`. We learn a dictionary of 3 atoms, by minimizing
+the Gromov-Wasserstein distance from all samples to its model in the dictionary
+with respect to the dictionary atoms.
+
+Second, we illustrate the extension of this dictionary learning framework to
+structured data endowed with node features by using the Fused Gromov-Wasserstein
+distance. Starting from the aforementioned dataset of unattributed graphs, we
+add discrete labels uniformly depending on the number of clusters. Then we learn
+and visualize attributed graph atoms where each sample is modeled as a joint convex
+combination between atom structures and features.
+
+
+[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph
+Dictionary Learning, International Conference on Machine Learning (ICML), 2021.
+
+"""
+# Author: Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 4
+
+import numpy as np
+import matplotlib.pylab as pl
+from sklearn.manifold import MDS
+from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning
+import ot
+import networkx
+from networkx.generators.community import stochastic_block_model as sbm
+# %%
+# =============================================================================
+# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
+# =============================================================================
+
+np.random.seed(42)
+
+N = 60 # number of graphs in the dataset
+# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability.
+clusters = [1, 2, 3]
+Nc = N // len(clusters) # number of graphs by cluster
+nlabels = len(clusters)
+dataset = []
+labels = []
+
+p_inter = 0.1
+p_intra = 0.9
+for n_cluster in clusters:
+ for i in range(Nc):
+ n_nodes = int(np.random.uniform(low=30, high=50))
+
+ if n_cluster > 1:
+ P = p_inter * np.ones((n_cluster, n_cluster))
+ np.fill_diagonal(P, p_intra)
+ else:
+ P = p_intra * np.eye(1)
+ sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32)
+ G = sbm(sizes, P, seed=i, directed=False)
+ C = networkx.to_numpy_array(G)
+ dataset.append(C)
+ labels.append(n_cluster)
+
+
+# Visualize samples
+
+def plot_graph(x, C, binary=True, color='C0', s=None):
+ for j in range(C.shape[0]):
+ for i in range(j):
+ if binary:
+ if C[i, j] > 0:
+ pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k')
+ else: # connection intensity proportional to C[i,j]
+ pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k')
+
+ pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9)
+
+
+pl.figure(1, (12, 8))
+pl.clf()
+for idx_c, c in enumerate(clusters):
+ C = dataset[(c - 1) * Nc] # sample with c clusters
+ # get 2d position for nodes
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
+ pl.subplot(2, nlabels, c)
+ pl.title('(graph) sample from label ' + str(c), fontsize=14)
+ plot_graph(x, C, binary=True, color='C0', s=50.)
+ pl.axis("off")
+ pl.subplot(2, nlabels, nlabels + c)
+ pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
+ pl.imshow(C, interpolation='nearest')
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Estimate the gromov-wasserstein dictionary from the dataset
+# =============================================================================
+
+
+np.random.seed(0)
+ps = [ot.unif(C.shape[0]) for C in dataset]
+
+D = 3 # 3 atoms in the dictionary
+nt = 6 # of 6 nodes each
+
+q = ot.unif(nt)
+reg = 0. # regularization coefficient to promote sparsity of unmixings {w_s}
+
+Cdict_GW, log = gromov_wasserstein_dictionary_learning(
+ Cs=dataset, D=D, nt=nt, ps=ps, q=q, epochs=10, batch_size=16,
+ learning_rate=0.1, reg=reg, projection='nonnegative_symmetric',
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300,
+ use_log=True, use_adam_optimizer=True, verbose=True
+)
+# visualize loss evolution over epochs
+pl.figure(2, (4, 3))
+pl.clf()
+pl.title('loss evolution by epoch', fontsize=14)
+pl.plot(log['loss_epochs'])
+pl.xlabel('epochs', fontsize=12)
+pl.ylabel('loss', fontsize=12)
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Visualization of the estimated dictionary atoms
+# =============================================================================
+
+
+# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white)
+
+pl.figure(3, (12, 8))
+pl.clf()
+for idx_atom, atom in enumerate(Cdict_GW):
+ scaled_atom = (atom - atom.min()) / (atom.max() - atom.min())
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom)
+ pl.subplot(2, D, idx_atom + 1)
+ pl.title('(graph) atom ' + str(idx_atom + 1), fontsize=14)
+ plot_graph(x, atom / atom.max(), binary=False, color='C0', s=100.)
+ pl.axis("off")
+ pl.subplot(2, D, D + idx_atom + 1)
+ pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14)
+ pl.imshow(scaled_atom, interpolation='nearest')
+ pl.colorbar()
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+#%%
+# =============================================================================
+# Visualization of the embedding space
+# =============================================================================
+
+unmixings = []
+reconstruction_errors = []
+for C in dataset:
+ p = ot.unif(C.shape[0])
+ unmixing, Cembedded, OT, reconstruction_error = gromov_wasserstein_linear_unmixing(
+ C, Cdict_GW, p=p, q=q, reg=reg,
+ tol_outer=10**(-5), tol_inner=10**(-5),
+ max_iter_outer=30, max_iter_inner=300
+ )
+ unmixings.append(unmixing)
+ reconstruction_errors.append(reconstruction_error)
+unmixings = np.array(unmixings)
+print('cumulated reconstruction error:', np.array(reconstruction_errors).sum())
+
+
+# Compute the 2D representation of the unmixing living in the 2-simplex of probability
+unmixings2D = np.zeros(shape=(N, 2))
+for i, w in enumerate(unmixings):
+ unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
+ unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
+x = [0., 0.]
+y = [1., 0.]
+z = [0.5, np.sqrt(3) / 2.]
+extremities = np.stack([x, y, z])
+
+pl.figure(4, (4, 4))
+pl.clf()
+pl.title('Embedding space', fontsize=14)
+for cluster in range(nlabels):
+ start, end = Nc * cluster, Nc * (cluster + 1)
+ if cluster == 0:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster')
+ else:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1))
+pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms')
+pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
+pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
+pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
+pl.axis('off')
+pl.legend(fontsize=11)
+pl.tight_layout()
+pl.show()
+# %%
+# =============================================================================
+# Endow the dataset with node features
+# =============================================================================
+
+# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters
+# 1 cluster --> 0 as nodes feature
+# 2 clusters --> 1 as nodes feature
+# 3 clusters --> 2 as nodes feature
+# features are one-hot encoded following these assignments
+dataset_features = []
+for i in range(len(dataset)):
+ n = dataset[i].shape[0]
+ F = np.zeros((n, 3))
+ if i < Nc: # graph with 1 cluster
+ F[:, 0] = 1.
+ elif i < 2 * Nc: # graph with 2 clusters
+ F[:, 1] = 1.
+ else: # graph with 3 clusters
+ F[:, 2] = 1.
+ dataset_features.append(F)
+
+pl.figure(5, (12, 8))
+pl.clf()
+for idx_c, c in enumerate(clusters):
+ C = dataset[(c - 1) * Nc] # sample with c clusters
+ F = dataset_features[(c - 1) * Nc]
+ colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])]
+ # get 2d position for nodes
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
+ pl.subplot(2, nlabels, c)
+ pl.title('(graph) sample from label ' + str(c), fontsize=14)
+ plot_graph(x, C, binary=True, color=colors, s=50)
+ pl.axis("off")
+ pl.subplot(2, nlabels, nlabels + c)
+ pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
+ pl.imshow(C, interpolation='nearest')
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+# %%
+# =============================================================================
+# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs
+# =============================================================================
+np.random.seed(0)
+ps = [ot.unif(C.shape[0]) for C in dataset]
+D = 3 # 6 atoms instead of 3
+nt = 6
+q = ot.unif(nt)
+reg = 0.001
+alpha = 0.5 # trade-off parameter between structure and feature information of Fused Gromov-Wasserstein
+
+
+Cdict_FGW, Ydict_FGW, log = fused_gromov_wasserstein_dictionary_learning(
+ Cs=dataset, Ys=dataset_features, D=D, nt=nt, ps=ps, q=q, alpha=alpha,
+ epochs=10, batch_size=16, learning_rate_C=0.1, learning_rate_Y=0.1, reg=reg,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300,
+ projection='nonnegative_symmetric', use_log=True, use_adam_optimizer=True, verbose=True
+)
+# visualize loss evolution
+pl.figure(6, (4, 3))
+pl.clf()
+pl.title('loss evolution by epoch', fontsize=14)
+pl.plot(log['loss_epochs'])
+pl.xlabel('epochs', fontsize=12)
+pl.ylabel('loss', fontsize=12)
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Visualization of the estimated dictionary atoms
+# =============================================================================
+
+pl.figure(7, (12, 8))
+pl.clf()
+max_features = Ydict_FGW.max()
+min_features = Ydict_FGW.min()
+
+for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)):
+ scaled_atom = (Catom - Catom.min()) / (Catom.max() - Catom.min())
+ #scaled_F = 2 * (Fatom - min_features) / (max_features - min_features)
+ colors = ['C%s' % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])]
+ x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom)
+ pl.subplot(2, D, idx_atom + 1)
+ pl.title('(attributed graph) atom ' + str(idx_atom + 1), fontsize=14)
+ plot_graph(x, Catom / Catom.max(), binary=False, color=colors, s=100)
+ pl.axis("off")
+ pl.subplot(2, D, D + idx_atom + 1)
+ pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14)
+ pl.imshow(scaled_atom, interpolation='nearest')
+ pl.colorbar()
+ pl.axis("off")
+pl.tight_layout()
+pl.show()
+
+# %%
+# =============================================================================
+# Visualization of the embedding space
+# =============================================================================
+
+unmixings = []
+reconstruction_errors = []
+for i in range(len(dataset)):
+ C = dataset[i]
+ Y = dataset_features[i]
+ p = ot.unif(C.shape[0])
+ unmixing, Cembedded, Yembedded, OT, reconstruction_error = fused_gromov_wasserstein_linear_unmixing(
+ C, Y, Cdict_FGW, Ydict_FGW, p=p, q=q, alpha=alpha,
+ reg=reg, tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=30, max_iter_inner=300
+ )
+ unmixings.append(unmixing)
+ reconstruction_errors.append(reconstruction_error)
+unmixings = np.array(unmixings)
+print('cumulated reconstruction error:', np.array(reconstruction_errors).sum())
+
+# Visualize unmixings in the 2-simplex of probability
+unmixings2D = np.zeros(shape=(N, 2))
+for i, w in enumerate(unmixings):
+ unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
+ unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
+x = [0., 0.]
+y = [1., 0.]
+z = [0.5, np.sqrt(3) / 2.]
+extremities = np.stack([x, y, z])
+
+pl.figure(8, (4, 4))
+pl.clf()
+pl.title('Embedding space', fontsize=14)
+for cluster in range(nlabels):
+ start, end = Nc * cluster, Nc * (cluster + 1)
+ if cluster == 0:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster')
+ else:
+ pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1))
+
+pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms')
+pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
+pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
+pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
+pl.axis('off')
+pl.legend(fontsize=11)
+pl.tight_layout()
+pl.show()
diff --git a/ot/__init__.py b/ot/__init__.py
index 7253318..bda7a35 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -1,5 +1,4 @@
"""
-
.. warning::
The list of automatically imported sub-modules is as follows:
:py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
@@ -7,13 +6,10 @@
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
:py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath`
, :py:mod:`ot.unbalanced`.
-
The following sub-modules are not imported due to additional dependencies:
-
- :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
- :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU.
- :any:`ot.plot` : depends on :code:`matplotlib`
-
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
diff --git a/ot/gromov.py b/ot/gromov.py
index b7e7949..f5a1f91 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -7,6 +7,7 @@ Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers
# Nicolas Courty <ncourty@irisa.fr>
# Rémi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
#
# License: MIT License
@@ -17,7 +18,7 @@ from .bregman import sinkhorn
from .utils import dist, UndefinedParameter, list_to_array
from .optim import cg
from .lp import emd_1d, emd
-from .utils import check_random_state
+from .utils import check_random_state, unif
from .backend import get_backend
@@ -320,7 +321,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return nx.exp(tmpsum / ppt)
-def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
r"""
Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -365,6 +366,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver
@@ -389,18 +393,26 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
"""
p, q = list_to_array(p, q)
-
p0, q0, C10, C20 = p, q, C1, C2
- nx = get_backend(p0, q0, C10, C20)
-
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
- G0 = p[:, None] * q[None, :]
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -418,7 +430,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
r"""
Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
@@ -467,6 +479,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
Returns
-------
@@ -491,9 +506,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
"""
p, q = list_to_array(p, q)
-
p0, q0, C10, C20 = p, q, C1, C2
- nx = get_backend(p0, q0, C10, C20)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
@@ -502,7 +520,13 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -533,7 +557,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
return gw
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
r"""
Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
@@ -578,6 +602,9 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
armijo : bool, optional
If True the step of the line-search is found via an armijo research. Else closed form is used.
If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
log : bool, optional
record log if True
**kwargs : dict
@@ -600,20 +627,28 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
(ICML). 2019.
"""
p, q = list_to_array(p, q)
-
p0, q0, C10, C20, M0 = p, q, C1, C2, M
- nx = get_backend(p0, q0, C10, C20, M0)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20, M0)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, M0, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
M = nx.to_numpy(M0)
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
-
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -622,19 +657,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
-
fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
-
log['fgw_dist'] = fgw_dist
log['u'] = nx.from_numpy(log['u'], type_as=C10)
log['v'] = nx.from_numpy(log['v'], type_as=C10)
return nx.from_numpy(res, type_as=C10), log
-
else:
return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10)
-def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
r"""
Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
@@ -683,6 +715,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
armijo : bool, optional
If True the step of the line-search is found via an armijo research.
Else closed form is used. If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
log : bool, optional
Record log if True.
**kwargs : dict
@@ -711,7 +746,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
p, q = list_to_array(p, q)
p0, q0, C10, C20, M0 = p, q, C1, C2, M
- nx = get_backend(p0, q0, C10, C20, M0)
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20, M0)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, M0, G0_)
p = nx.to_numpy(p)
q = nx.to_numpy(q)
@@ -721,7 +760,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
- G0 = p[:, None] * q[None, :]
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
def f(G):
return gwloss(constC, hC1, hC2, G)
@@ -1796,3 +1841,988 @@ def update_feature_matrix(lambdas, Ys, Ts, p):
for s in range(len(Ts))
])
return tmpsum
+
+
+def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - reg is the regularization coefficient.
+
+ The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38]
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns, ns).
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate: float, optional
+ Learning rate used for the stochastic gradient descent. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ projection: str , optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epochs.
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ # Handle backend of non-optional arguments
+ Cs0 = Cs
+ nx = get_backend(*Cs0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ dataset_size = len(Cs)
+ # Handle backend of optional arguments
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0
+ if use_adam_optimizer:
+ adam_moments = _initialize_adam_optimizer(Cdict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ Cdict_best_state = Cdict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+ # batch sampling
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner,
+ max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ if use_adam_optimizer:
+ Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments)
+ else:
+ Cdict -= learning_rate * grad_Cdict
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ if verbose:
+ print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), log
+
+
+def _initialize_adam_optimizer(variable):
+
+ # Initialization for our numpy implementation of adam optimizer
+ atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor
+ atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor
+ atoms_adam_count = 1
+
+ return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count}
+
+
+def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09):
+
+ adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad
+ adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2)
+ unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count'])
+ unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count'])
+ variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps)
+ adam_moments['count'] += 1
+
+ return variable, adam_moments
+
+
+def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
+ r"""
+ Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`.
+
+ .. math::
+ \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 1.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ T: array-like (ns, nt)
+ Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})`
+ current_loss: float
+ reconstruction error
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ C0, Cdict0 = C, Cdict
+ nx = get_backend(C0, Cdict0)
+ C = nx.to_numpy(C0)
+ Cdict = nx.to_numpy(Cdict0)
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+
+ w = unif(D) # Initialize uniformly the unmixing w
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+
+ const_q = q[:, None] * q[None, :]
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ T, log = gromov_wasserstein(C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, log=True, armijo=False, **kwargs)
+ current_loss = log['gw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing(
+ C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T,
+ starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs
+ )
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ w: array-like, shape (D,)
+ Linear unmixing of the input structure onto the dictionary
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between the input structure and its representation in the dictionary.
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+
+ previous_loss = current_loss
+ # 1) Compute gradient at current point w
+ grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ grad_w -= 2 * reg * w
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0: # not that the loss can be negative if reg >0
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ count += 1
+
+ return w, Cembedded, current_loss
+
+
+def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`.
+ """
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ a = trace_diffx - trace_diffw
+ b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+
+ if a > 0:
+ gamma = min(1, max(0, - b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff
+
+
+def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
+ Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2
+
+
+ Such that :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+
+ The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38]
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns,ns).
+ Ys : list of S array-like, shape (ns, d)
+ List of feature matrix of variable size (ns,d) with d fixed.
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ alpha : float
+ Trade-off parameter of Fused Gromov-Wasserstein
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate_C: float, optional
+ Learning rate used for the stochastic gradient descent on Cdict. Default is 1.
+ learning_rate_Y: float, optional
+ Learning rate used for the stochastic gradient descent on Ydict. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary structures Cdict.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ Ydict_init: list of D array-like with shape (nt, d), optional
+ Used to initialize the dictionary features Ydict.
+ If set to None, the dictionary features will be initialized randomly.
+ Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features.
+ projection: str, optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ Ydict_best_state : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epoches.
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ Cs0, Ys0 = Cs, Ys
+ nx = get_backend(*Cs0, *Ys0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ Ys = [nx.to_numpy(Y) for Y in Ys0]
+
+ d = Ys[0].shape[-1]
+ dataset_size = len(Cs)
+
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+ if Ydict_init is None:
+ # Initialize randomly features of dictionary atoms based on samples distribution by feature component
+ dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
+ Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
+ else:
+ Ydict = nx.to_numpy(Ydict_init).copy()
+ assert Ydict.shape == (D, nt, d)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_adam_optimizer:
+ adam_moments_C = _initialize_adam_optimizer(Cdict)
+ adam_moments_Y = _initialize_adam_optimizer(Ydict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+
+ # Batch iterations
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ys_embedded = np.zeros((batch_size, nt, d))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q,
+ tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ grad_Ydict = np.zeros_like(Ydict)
+
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx])
+ grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ grad_Ydict *= 2 / batch_size
+
+ if use_adam_optimizer:
+ Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C)
+ Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y)
+ else:
+ Cdict -= learning_rate_C * grad_Cdict
+ Ydict -= learning_rate_Y * grad_Ydict
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ if verbose:
+ print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log
+
+
+def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
+ r"""
+ Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`
+
+ .. math::
+ \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 6.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Ydict : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ Yembedded: array-like, shape (nt,d)
+ embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`.
+ T: array-like (ns,nt)
+ Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`.
+ current_loss: float
+ reconstruction error
+ References
+ -------
+
+ ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
+ "Online Graph Dictionary Learning"
+ International Conference on Machine Learning (ICML). 2021.
+ """
+ C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict
+ nx = get_backend(C0, Y0, Cdict0, Ydict0)
+ C = nx.to_numpy(C0)
+ Y = nx.to_numpy(Y0)
+ Cdict = nx.to_numpy(Cdict0)
+ Ydict = nx.to_numpy(Ydict0)
+
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+ d = Y.shape[-1]
+ w = unif(D) # Initialize with uniform weights
+ ns = C.shape[-1]
+ nt = Cdict.shape[-1]
+
+ # modeling (C,Y)
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+ Yembedded = np.sum(w[:, None, None] * Ydict, axis=0)
+
+ # constants depending on q
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+ Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T)
+ M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features
+ T, log = fused_gromov_wasserstein(M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, armijo=False, G0=T, log=True)
+ current_loss = log['fgw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w,
+ T, p, q, const_q, diag_q, current_loss, alpha, reg,
+ tol=tol_inner, max_iter=max_iter_inner, **kwargs)
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2
+
+ Such that :
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}`
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38], algorithm 7.
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt,d)
+ Embedded features of (C,Y) onto the dictionary
+ w: array-like, shape (n_D,)
+ Linear unmixing of (C,Y) onto (Cdict,Ydict)
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution.
+ diag_q: array-like, shape (nt,nt)
+ diagonal matrix with values of q on the diagonal.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between (C,Y) and its model
+ p : array-like, shape (ns,)
+ Distribution in the source space (C,Y).
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+ ones_ns_d = np.ones(Y.shape)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+ previous_loss = current_loss
+
+ # 1) Compute gradient at current point w
+ # structure
+ grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ # feature
+ grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2))
+ grad_w -= reg * w
+ grad_w *= 2
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ Yembedded += gamma * Yembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ count += 1
+
+ return w, Cembedded, Yembedded, current_loss
+
+
+def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that :
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Y: arrat-like, shape (ns,d)
+ Feature matrix of the input space
+ Cdict : list of D array-like, shape (nt, nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt, d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt, nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt, d)
+ Embedded features of (C,Y) onto the dictionary
+ T: array-like, shape (ns, nt)
+ Fixed transport plan between (C,Y) and its current model.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ ones_ns_d: array-like, shape (ns, d)
+ :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ Yembedded_diff: numpy array, shape (nt, nt)
+ Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ """
+ # polynomial coefficients from quadratic objective (with respect to w) on structures
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_gw = trace_diffx - trace_diffw
+ b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+
+ # polynomial coefficient from quadratic objective (with respect to w) on features
+ Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0)
+ Yembedded_diff = Yembedded_x - Yembedded
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T)
+ b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T)))
+
+ a = alpha * a_gw + (1 - alpha) * a_w
+ b = alpha * b_gw + (1 - alpha) * b_w
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+ if a > 0:
+ gamma = min(1, max(0, -b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff, Yembedded_diff
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 4b995d5..329f99c 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -3,6 +3,7 @@
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
#
# License: MIT License
@@ -26,6 +27,7 @@ def test_gromov(nx):
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
@@ -37,9 +39,10 @@ def test_gromov(nx):
C2b = nx.from_numpy(C2)
pb = nx.from_numpy(p)
qb = nx.from_numpy(q)
+ G0b = nx.from_numpy(G0)
- G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True)
- Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True))
+ G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True)
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True))
# check constraints
np.testing.assert_allclose(G, Gb, atol=1e-06)
@@ -56,9 +59,9 @@ def test_gromov(nx):
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
gwb = nx.to_numpy(gwb)
- gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
+ gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False)
gw_valb = nx.to_numpy(
- ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
)
G = log['T']
@@ -91,6 +94,7 @@ def test_gromov_dtype_device(nx):
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
@@ -105,9 +109,10 @@ def test_gromov_dtype_device(nx):
C2b = nx.from_numpy(C2, type_as=tp)
pb = nx.from_numpy(p, type_as=tp)
qb = nx.from_numpy(q, type_as=tp)
+ G0b = nx.from_numpy(G0, type_as=tp)
- Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -123,6 +128,7 @@ def test_gromov_device_tf():
xt = xs[::-1].copy()
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
C1 /= C1.max()
@@ -134,8 +140,9 @@ def test_gromov_device_tf():
C2b = nx.from_numpy(C2)
pb = nx.from_numpy(p)
qb = nx.from_numpy(q)
- Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ G0b = nx.from_numpy(G0)
+ Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
nx.assert_same_dtype_device(C1b, gw_valb)
@@ -145,6 +152,7 @@ def test_gromov_device_tf():
C2b = nx.from_numpy(C2)
pb = nx.from_numpy(p)
qb = nx.from_numpy(q)
+ G0b = nx.from_numpy(G0b)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
nx.assert_same_dtype_device(C1b, Gb)
@@ -554,6 +562,7 @@ def test_fgw(nx):
p = ot.unif(n_samples)
q = ot.unif(n_samples)
+ G0 = p[:, None] * q[None, :]
C1 = ot.dist(xs, xs)
C2 = ot.dist(xt, xt)
@@ -569,9 +578,10 @@ def test_fgw(nx):
C2b = nx.from_numpy(C2)
pb = nx.from_numpy(p)
qb = nx.from_numpy(q)
+ G0b = nx.from_numpy(G0)
- G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
- Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True)
Gb = nx.to_numpy(Gb)
# check constraints
@@ -586,8 +596,8 @@ def test_fgw(nx):
np.testing.assert_allclose(
Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
- fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
- fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True)
fgwb = nx.to_numpy(fgwb)
G = log['T']
@@ -698,3 +708,523 @@ def test_fgw_barycenter(nx):
Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
+
+
+def test_gromov_wasserstein_linear_unmixing(nx):
+ n = 10
+
+ X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42)
+
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cdict = np.stack([C1, C2])
+ p = ot.unif(n)
+
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ Cdictb = nx.from_numpy(Cdict)
+ pb = nx.from_numpy(p)
+ tol = 10**(-5)
+ # Tests without regularization
+ reg = 0.
+ unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1, Cdict, reg=reg, p=p, q=p,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1b, Cdictb, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2, Cdict, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2b, Cdictb, reg=reg, p=pb, q=pb,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06)
+ np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+ # Tests with regularization
+
+ reg = 0.001
+ unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1, Cdict, reg=reg, p=p, q=p,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C1b, Cdictb, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2, Cdict, reg=reg, p=None, q=None,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing(
+ C2b, Cdictb, reg=reg, p=pb, q=pb,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06)
+ np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+
+def test_gromov_wasserstein_dictionary_learning(nx):
+
+ # create dataset composed from 2 structures which are repeated 5 times
+ shape = 10
+ n_samples = 2
+ n_atoms = 2
+ projection = 'nonnegative_symmetric'
+ X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42)
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)]
+ ps = [ot.unif(shape) for _ in range(n_samples)]
+ q = ot.unif(shape)
+
+ # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape)
+ # following the same procedure than implemented in gromov_wasserstein_dictionary_learning.
+ dataset_means = [C.mean() for C in Cs]
+ np.random.seed(0)
+ Cdict_init = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape))
+ if projection == 'nonnegative_symmetric':
+ Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1)))
+ Cdict_init[Cdict_init < 0.] = 0.
+ Csb = [nx.from_numpy(C) for C in Cs]
+ psb = [nx.from_numpy(p) for p in ps]
+ qb = nx.from_numpy(q)
+ Cdict_initb = nx.from_numpy(Cdict_init)
+
+ # Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization
+ # > Compute initial reconstruction of samples on this random dictionary without backend
+ use_adam_optimizer = True
+ verbose = False
+ tol = 10**(-5)
+ epochs = 1
+
+ initial_total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict_init, p=ps[i], q=q, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ initial_total_reconstruction += reconstruction
+
+ # > Learn the dictionary using this init
+ Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init,
+ epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary without backend
+ total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict, p=None, q=None, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction += reconstruction
+
+ np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction)
+
+ # Test: Perform same experiments after going through backend
+
+ Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb,
+ epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Csb[i], Cdictb, p=psb[i], q=qb, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction_b += reconstruction
+
+ np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction)
+ np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
+ np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
+ np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03)
+
+ # Test: Perform same comparison without providing the initial dictionary being an optional input
+ # knowing than the initialization scheme is the same than implemented to set the benchmarked initialization.
+ np.random.seed(0)
+ Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict_bis, p=ps[i], q=q, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction_bis += reconstruction
+
+ np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05)
+
+ # Test: Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Csb[i], Cdictb_bis, p=None, q=None, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction_b_bis += reconstruction
+
+ np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05)
+ np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03)
+
+ # Test: Perform same comparison without providing the initial dictionary being an optional input
+ # and testing other optimization settings untested until now.
+ # We pass previously estimated dictionaries to speed up the process.
+ use_adam_optimizer = False
+ verbose = True
+ use_log = True
+
+ np.random.seed(0)
+ Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict,
+ epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction_bis2 += reconstruction
+
+ np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction)
+
+ # Test: Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning(
+ Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb,
+ epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
+ Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction_b_bis2 += reconstruction
+
+ np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05)
+
+
+def test_fused_gromov_wasserstein_linear_unmixing(nx):
+
+ n = 10
+ X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42)
+ F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42)
+
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cdict = np.stack([C1, C2])
+ Ydict = np.stack([F, F])
+ p = ot.unif(n)
+
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ Fb = nx.from_numpy(F)
+ Cdictb = nx.from_numpy(Cdict)
+ Ydictb = nx.from_numpy(Ydict)
+ pb = nx.from_numpy(p)
+ # Tests without regularization
+ reg = 0.
+
+ unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03)
+ np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+ # Tests with regularization
+ reg = 0.001
+
+ unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200
+ )
+
+ unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg,
+ tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200
+ )
+
+ np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06)
+ np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01)
+ np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06)
+ np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01)
+ np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03)
+ np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03)
+ np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03)
+ np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06)
+ np.testing.assert_allclose(C1b_emb.shape, (n, n))
+ np.testing.assert_allclose(C2b_emb.shape, (n, n))
+
+
+def test_fused_gromov_wasserstein_dictionary_learning(nx):
+
+ # create dataset composed from 2 structures which are repeated 5 times
+ shape = 10
+ n_samples = 2
+ n_atoms = 2
+ projection = 'nonnegative_symmetric'
+ X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42)
+ X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42)
+ F, y = ot.datasets.make_data_classif('3gauss', shape, random_state=42)
+
+ C1 = ot.dist(X1)
+ C2 = ot.dist(X2)
+ Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)]
+ Ys = [F.copy() for _ in range(n_samples)]
+ ps = [ot.unif(shape) for _ in range(n_samples)]
+ q = ot.unif(shape)
+
+ # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape)
+ # following the same procedure than implemented in gromov_wasserstein_dictionary_learning.
+ dataset_structure_means = [C.mean() for C in Cs]
+ np.random.seed(0)
+ Cdict_init = np.random.normal(loc=np.mean(dataset_structure_means), scale=np.std(dataset_structure_means), size=(n_atoms, shape, shape))
+ if projection == 'nonnegative_symmetric':
+ Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1)))
+ Cdict_init[Cdict_init < 0.] = 0.
+ dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys])
+ Ydict_init = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2))
+
+ Csb = [nx.from_numpy(C) for C in Cs]
+ Ysb = [nx.from_numpy(Y) for Y in Ys]
+ psb = [nx.from_numpy(p) for p in ps]
+ qb = nx.from_numpy(q)
+ Cdict_initb = nx.from_numpy(Cdict_init)
+ Ydict_initb = nx.from_numpy(Ydict_init)
+
+ # Test: Compute initial reconstruction of samples on this random dictionary
+ alpha = 0.5
+ use_adam_optimizer = True
+ verbose = False
+ tol = 1e-05
+ epochs = 1
+
+ initial_total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q,
+ alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ initial_total_reconstruction += reconstruction
+
+ # > Learn a dictionary using this given initialization and check that the reconstruction loss
+ # on the learned dictionary is lower than the one using its initialization.
+ Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction += reconstruction
+ # Compare both
+ np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction)
+
+ # Test: Perform same experiments after going through backend
+
+ Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb,
+ epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction_b += reconstruction
+
+ np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction)
+ np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
+ np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03)
+ np.testing.assert_allclose(Ydict, nx.to_numpy(Ydictb), atol=1e-03)
+
+ # Test: Perform similar experiment without providing the initial dictionary being an optional input
+ np.random.seed(0)
+ Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction_bis += reconstruction
+
+ np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05)
+
+ # > Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction_b_bis += reconstruction
+ np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05)
+
+ # Test: without using adam optimizer, with log and verbose set to True
+ use_adam_optimizer = False
+ verbose = True
+ use_log = True
+
+ # > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init.
+ np.random.seed(0)
+ Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction_bis2 += reconstruction
+
+ np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction)
+
+ # > Same after going through backend
+ np.random.seed(0)
+ Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
+ Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb,
+ epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
+ projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
+ )
+
+ # > Compute reconstruction of samples on learned dictionary
+ total_reconstruction_b_bis2 = 0
+ for i in range(n_samples):
+ _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
+ Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0.,
+ tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
+ )
+ total_reconstruction_b_bis2 += reconstruction
+
+ # > Compare results with/without backend
+ np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05)