summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2019-06-04 11:57:18 +0200
committerGitHub <noreply@github.com>2019-06-04 11:57:18 +0200
commit5a6b226de20624b51c2ff98bc30e5611a7a788c7 (patch)
tree69b019aaa43ec7d69d97a48717eed27c01890c6e
parentf66ab58c7c895011fd37bafd3e848828399c56c4 (diff)
parent788a6506c9bf3b862a9652d74f65f8d07851e653 (diff)
Merge pull request #86 from tvayer/master
[MRG] Gromov-Wasserstein closed form for linesearch and integration of Fused Gromov-Wasserstein This PR closes #82 Thank you @tvayer for all the work.
-rw-r--r--README.md3
-rw-r--r--examples/plot_barycenter_fgw.py184
-rw-r--r--examples/plot_fgw.py173
-rw-r--r--ot/bregman.py1
-rw-r--r--ot/gromov.py388
-rw-r--r--ot/optim.py145
-rw-r--r--ot/utils.py8
-rw-r--r--test/test_gromov.py115
-rw-r--r--test/test_optim.py6
9 files changed, 972 insertions, 51 deletions
diff --git a/README.md b/README.md
index dd34a97..b6b215c 100644
--- a/README.md
+++ b/README.md
@@ -164,6 +164,7 @@ The contributors to this library are:
* Erwan Vautier (Gromov-Wasserstein)
* [Kilian Fatras](https://kilianfatras.github.io/)
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
+* [Vayer Titouan](https://tvayer.github.io/)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
@@ -233,3 +234,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31
[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
+
+[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). [Optimal Transport for structured data with application on graphs](http://proceedings.mlr.press/v97/titouan19a.html) Proceedings of the 36th International Conference on Machine Learning (ICML).
diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py
new file mode 100644
index 0000000..e4be447
--- /dev/null
+++ b/examples/plot_barycenter_fgw.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Plot graphs' barycenter using FGW
+=================================
+
+This example illustrates the computation barycenter of labeled graphs using FGW
+
+Requires networkx >=2
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+"""
+
+# Author: Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+#%% load libraries
+import numpy as np
+import matplotlib.pyplot as plt
+import networkx as nx
+import math
+from scipy.sparse.csgraph import shortest_path
+import matplotlib.colors as mcol
+from matplotlib import cm
+from ot.gromov import fgw_barycenters
+#%% Graph functions
+
+
+def find_thresh(C, inf=0.5, sup=3, step=10):
+ """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
+ Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
+ The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
+ and the original matrix.
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ inf : float
+ The beginning of the linesearch
+ sup : float
+ The end of the linesearch
+ step : integer
+ Number of thresholds tested
+ """
+ dist = []
+ search = np.linspace(inf, sup, step)
+ for thresh in search:
+ Cprime = sp_to_adjency(C, 0, thresh)
+ SC = shortest_path(Cprime, method='D')
+ SC[SC == float('inf')] = 100
+ dist.append(np.linalg.norm(SC - C))
+ return search[np.argmin(dist)], dist
+
+
+def sp_to_adjency(C, threshinf=0.2, threshsup=1.8):
+ """ Thresholds the structure matrix in order to compute an adjency matrix.
+ All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ threshinf : float
+ The minimum value of distance from which the new value is set to 1
+ threshsup : float
+ The maximum value of distance from which the new value is set to 1
+ Returns
+ -------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The threshold matrix. Each element is in {0,1}
+ """
+ H = np.zeros_like(C)
+ np.fill_diagonal(H, np.diagonal(C))
+ C = C - H
+ C = np.minimum(np.maximum(C, threshinf), threshsup)
+ C[C == threshsup] = 0
+ C[C != 0] = 1
+
+ return C
+
+
+def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):
+ """ Create a noisy circular graph
+ """
+ g = nx.Graph()
+ g.add_nodes_from(list(range(N)))
+ for i in range(N):
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)
+ else:
+ g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))
+ g.add_edge(i, i + 1)
+ if structure_noise:
+ randomint = np.random.randint(0, p)
+ if randomint == 0:
+ if i <= N - 3:
+ g.add_edge(i, i + 2)
+ if i == N - 2:
+ g.add_edge(i, 0)
+ if i == N - 1:
+ g.add_edge(i, 1)
+ g.add_edge(N, 0)
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)
+ else:
+ g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))
+ return g
+
+
+def graph_colors(nx_graph, vmin=0, vmax=7):
+ cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)
+ cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')
+ cpick.set_array([])
+ val_map = {}
+ for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():
+ val_map[k] = cpick.to_rgba(v)
+ colors = []
+ for node in nx_graph.nodes():
+ colors.append(val_map[node])
+ return colors
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% circular dataset
+# We build a dataset of noisy circular graphs.
+# Noise is added on the structures by random connections and on the features by gaussian noise.
+
+
+np.random.seed(30)
+X0 = []
+for k in range(9):
+ X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% Plot graphs
+
+plt.figure(figsize=(8, 10))
+for i in range(len(X0)):
+ plt.subplot(3, 3, i + 1)
+ g = X0[i]
+ pos = nx.kamada_kawai_layout(g)
+ nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)
+plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
+plt.show()
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
+# Features distances are the euclidean distances
+Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
+ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
+Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]
+lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
+sizebary = 15 # we choose a barycenter with 15 nodes
+
+A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95)
+
+##############################################################################
+# Plot Barycenter
+# -------------------------
+
+#%% Create the barycenter
+bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
+for i, v in enumerate(A.ravel()):
+ bary.add_node(i, attr_name=v)
+
+#%%
+pos = nx.kamada_kawai_layout(bary)
+nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)
+plt.suptitle('Barycenter', fontsize=20)
+plt.show()
diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py
new file mode 100644
index 0000000..43efc94
--- /dev/null
+++ b/examples/plot_fgw.py
@@ -0,0 +1,173 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+Plot Fused-gromov-Wasserstein
+==============================
+
+This example illustrates the computation of FGW for 1D measures[18].
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+"""
+
+# Author: Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+import matplotlib.pyplot as pl
+import numpy as np
+import ot
+from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+
+##############################################################################
+# Generate data
+# ---------
+
+#%% parameters
+# We create two 1D random measures
+n = 20 # number of points in the first distribution
+n2 = 30 # number of points in the second distribution
+sig = 1 # std of first distribution
+sig2 = 0.1 # std of second distribution
+
+np.random.seed(0)
+
+phi = np.arange(n)[:, None]
+xs = phi + sig * np.random.randn(n, 1)
+ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1)
+
+phi2 = np.arange(n2)[:, None]
+xt = phi2 + sig * np.random.randn(n2, 1)
+yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1)
+yt = yt[::-1, :]
+
+p = ot.unif(n)
+q = ot.unif(n2)
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.close(10)
+pl.figure(10, (7, 7))
+
+pl.subplot(2, 1, 1)
+
+pl.scatter(ys, xs, c=phi, s=70)
+pl.ylabel('Feature value a', fontsize=20)
+pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)
+pl.xticks(())
+pl.yticks(())
+pl.subplot(2, 1, 2)
+pl.scatter(yt, xt, c=phi2, s=70)
+pl.xlabel('coordinates x/y', fontsize=25)
+pl.ylabel('Feature value b', fontsize=20)
+pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)
+pl.yticks(())
+pl.tight_layout()
+pl.show()
+
+##############################################################################
+# Create structure matrices and across-feature distance matrix
+# ---------
+
+#%% Structure matrices and across-features distance matrix
+C1 = ot.dist(xs)
+C2 = ot.dist(xt)
+M = ot.dist(ys, yt)
+w1 = ot.unif(C1.shape[0])
+w2 = ot.unif(C2.shape[0])
+Got = ot.emd([], [], M)
+
+##############################################################################
+# Plot matrices
+# ---------
+
+#%%
+cmap = 'Reds'
+pl.close(10)
+pl.figure(10, (5, 5))
+fs = 15
+l_x = [0, 5, 10, 15]
+l_y = [0, 5, 10, 15, 20, 25]
+gs = pl.GridSpec(5, 5)
+
+ax1 = pl.subplot(gs[3:, :2])
+
+pl.imshow(C1, cmap=cmap, interpolation='nearest')
+pl.title("$C_1$", fontsize=fs)
+pl.xlabel("$k$", fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+pl.xticks(l_x)
+pl.yticks(l_x)
+
+ax2 = pl.subplot(gs[:3, 2:])
+
+pl.imshow(C2, cmap=cmap, interpolation='nearest')
+pl.title("$C_2$", fontsize=fs)
+pl.ylabel("$l$", fontsize=fs)
+#pl.ylabel("$l$",fontsize=fs)
+pl.xticks(())
+pl.yticks(l_y)
+ax2.set_aspect('auto')
+
+ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1)
+pl.imshow(M, cmap=cmap, interpolation='nearest')
+pl.yticks(l_x)
+pl.xticks(l_y)
+pl.ylabel("$i$", fontsize=fs)
+pl.title("$M_{AB}$", fontsize=fs)
+pl.xlabel("$j$", fontsize=fs)
+pl.tight_layout()
+ax3.set_aspect('auto')
+pl.show()
+
+##############################################################################
+# Compute FGW/GW
+# ---------
+
+#%% Computing FGW and GW
+alpha = 1e-3
+
+ot.tic()
+Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
+ot.toc()
+
+#%reload_ext WGW
+Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
+
+##############################################################################
+# Visualize transport matrices
+# ---------
+
+#%% visu OT matrix
+cmap = 'Blues'
+fs = 15
+pl.figure(2, (13, 5))
+pl.clf()
+pl.subplot(1, 3, 1)
+pl.imshow(Got, cmap=cmap, interpolation='nearest')
+#pl.xlabel("$y$",fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+pl.xticks(())
+
+pl.title('Wasserstein ($M$ only)')
+
+pl.subplot(1, 3, 2)
+pl.imshow(Gg, cmap=cmap, interpolation='nearest')
+pl.title('Gromov ($C_1,C_2$ only)')
+pl.xticks(())
+pl.subplot(1, 3, 3)
+pl.imshow(Gwg, cmap=cmap, interpolation='nearest')
+pl.title('FGW ($M+C_1,C_2$)')
+
+pl.xlabel("$j$", fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+
+pl.tight_layout()
+pl.show()
diff --git a/ot/bregman.py b/ot/bregman.py
index dc43834..321712b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -6,6 +6,7 @@ Bregman projections for regularized OT
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
#
# License: MIT License
diff --git a/ot/gromov.py b/ot/gromov.py
index 7974546..ca96b31 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -9,17 +9,19 @@ Gromov-Wasserstein transport method
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
# Rémi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
#
# License: MIT License
import numpy as np
+
from .bregman import sinkhorn
-from .utils import dist
+from .utils import dist, UndefinedParameter
from .optim import cg
-def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'):
+def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
""" Return loss matrices and tensors for Gromov-Wasserstein fast computation
Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
@@ -32,12 +34,12 @@ def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'):
* C2 : Metric cost matrix in the target space
* T : A coupling between those two spaces
- The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
+ The square-loss function L(a,b)=|a-b|^2 is read as :
L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
- * f1(a)=(a^2)/2
- * f2(b)=(b^2)/2
+ * f1(a)=(a^2)
+ * f2(b)=(b^2)
* h1(a)=a
- * h2(b)=b
+ * h2(b)=2*b
The kl-loss function L(a,b)=a*log(a/b)-a+b is read as :
L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
@@ -77,16 +79,16 @@ def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'):
if loss_fun == 'square_loss':
def f1(a):
- return (a**2) / 2
+ return (a**2)
def f2(b):
- return (b**2) / 2
+ return (b**2)
def h1(a):
return a
def h2(b):
- return b
+ return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
return a * np.log(a + 1e-15) - a
@@ -268,7 +270,7 @@ def update_kl_loss(p, lambdas, T, Cs):
return np.exp(np.divide(tmpsum, ppt))
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
"""
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
@@ -306,6 +308,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
Print information along iterations
log : bool, optional
record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
**kwargs : dict
parameters can be directly pased to the ot.optim.cg solver
@@ -329,9 +334,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
"""
- T = np.eye(len(p), len(q))
-
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -342,14 +345,161 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
return res, log
else:
- return cg(p, q, 0, 1, f, df, G0, **kwargs)
+ return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs):
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+ """
+ Computes the FGW transport between two graphs see [24]
+ .. math::
+ \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix respresentative of the structure in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix espresentative of the structure in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string,optional
+ loss function used for the solver
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly pased to the ot.optim.cg solver
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ if log:
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ log['fgw_dist'] = log['loss'][::-1][0]
+ return res, log
+ else:
+ return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+ """
+ Computes the FGW distance between two graphs see [24]
+ .. math::
+ \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix respresentative of the structure in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix espresentative of the structure in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string,optional
+ loss function used for the solver
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly pased to the ot.optim.cg solver
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ if log:
+ log['fgw_dist'] = log['loss'][::-1][0]
+ log['T'] = res
+ return log['fgw_dist'], log
+ else:
+ return log['fgw_dist']
+
+
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
"""
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
@@ -387,7 +537,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs):
Print information along iterations
log : bool, optional
record log if True
-
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
Returns
-------
gw_dist : float
@@ -407,9 +559,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs):
"""
- T = np.eye(len(p), len(q))
-
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -418,7 +568,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs):
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
log['T'] = res
if log:
@@ -495,7 +645,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
T = np.outer(p, q) # Initialization
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
cpt = 0
err = 1
@@ -815,3 +965,197 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
cpt += 1
return C
+
+
+def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
+ p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
+ verbose=False, log=False, init_C=None, init_X=None):
+ """
+ Compute the fgw barycenter as presented eq (5) in [24].
+ ----------
+ N : integer
+ Desired number of samples of the target barycenter
+ Ys: list of ndarray, each element has shape (ns,d)
+ Features of all samples
+ Cs : list of ndarray, each element has shape (ns,ns)
+ Structure matrices of all samples
+ ps : list of ndarray, each element has shape (ns,)
+ masses of all samples
+ lambdas : list of float
+ list of the S spaces' weights
+ alpha : float
+ Alpha parameter for the fgw distance
+ fixed_structure : bool
+ Wether to fix the structure of the barycenter during the updates
+ fixed_features : bool
+ Wether to fix the feature of the barycenter during the updates
+ init_C : ndarray, shape (N,N), optional
+ initialization for the barycenters' structure matrix. If not set random init
+ init_X : ndarray, shape (N,d), optional
+ initialization for the barycenters' features. If not set random init
+ Returns
+ ----------
+ X : ndarray, shape (N,d)
+ Barycenters' features
+ C : ndarray, shape (N,N)
+ Barycenters' structure matrix
+ log_: dictionary
+ Only returned when log=True
+ T : list of (N,ns) transport matrices
+ Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns)
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ S = len(Cs)
+ d = Ys[0].shape[1] # dimension on the node features
+ if p is None:
+ p = np.ones(N) / N
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
+
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ if fixed_structure:
+ if init_C is None:
+ raise UndefinedParameter('If C is fixed it must be initialized')
+ else:
+ C = init_C
+ else:
+ if init_C is None:
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ else:
+ C = init_C
+
+ if fixed_features:
+ if init_X is None:
+ raise UndefinedParameter('If X is fixed it must be initialized')
+ else:
+ X = init_X
+ else:
+ if init_X is None:
+ X = np.zeros((N, d))
+ else:
+ X = init_X
+
+ T = [np.outer(p, q) for q in ps]
+
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns
+
+ cpt = 0
+ err_feature = 1
+ err_structure = 1
+
+ if log:
+ log_ = {}
+ log_['err_feature'] = []
+ log_['err_structure'] = []
+ log_['Ts_iter'] = []
+
+ while((err_feature > tol or err_structure > tol) and cpt < max_iter):
+ Cprev = C
+ Xprev = X
+
+ if not fixed_features:
+ Ys_temp = [y.T for y in Ys]
+ X = update_feature_matrix(lambdas, Ys_temp, T, p).T
+
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+
+ if not fixed_structure:
+ if loss_fun == 'square_loss':
+ T_temp = [t.T for t in T]
+ C = update_sructure_matrix(p, lambdas, T_temp, Cs)
+
+ T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
+
+ # T is N,ns
+ err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
+ err_structure = np.linalg.norm(C - Cprev)
+
+ if log:
+ log_['err_feature'].append(err_feature)
+ log_['err_structure'].append(err_structure)
+ log_['Ts_iter'].append(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err_structure))
+ print('{:5d}|{:8e}|'.format(cpt, err_feature))
+
+ cpt += 1
+ if log:
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms
+
+ if log:
+ return X, C, log_
+ else:
+ return X, C
+
+
+def update_sructure_matrix(p, lambdas, T, Cs):
+ """
+ Updates C according to the L2 Loss kernel with the S Ts couplings
+ calculated at each iteration
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ T : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Cs : list of S ndarray, shape(ns,ns)
+ Metric cost matrices
+ Returns
+ ----------
+ C : ndarray, shape (nt,nt)
+ updated C matrix
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.divide(tmpsum, ppt)
+
+
+def update_feature_matrix(lambdas, Ys, Ts, p):
+ """
+ Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [24]
+ calculated at each iteration
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ Ts : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Ys : list of S ndarray, shape(d,ns)
+ The features
+ Returns
+ ----------
+ X : ndarray, shape (d,N)
+
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ p = np.diag(np.array(1 / p).reshape(-1,))
+
+ tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T).dot(p) for s in range(len(Ts))])
+
+ return tmpsum
diff --git a/ot/optim.py b/ot/optim.py
index f31fae2..f94aceb 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -4,6 +4,7 @@ Optimization algorithms for OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
#
# License: MIT License
@@ -72,8 +73,70 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
return alpha, fc[0], phi1
+def solve_linesearch(cost, G, deltaG, Mi, f_val,
+ armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
+ """
+ Solve the linesearch in the FW iterations
+ Parameters
+ ----------
+ cost : method
+ Cost in the FW for the linesearch
+ G : ndarray, shape(ns,nt)
+ The transport map at a given iteration of the FW
+ deltaG : ndarray (ns,nt)
+ Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
+ Mi : ndarray (ns,nt)
+ Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
+ f_val : float
+ Value of the cost at G
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ C1 : ndarray (ns,ns), optional
+ Structure matrix in the source domain. Only used and necessary when armijo=False
+ C2 : ndarray (nt,nt), optional
+ Structure matrix in the target domain. Only used and necessary when armijo=False
+ reg : float, optional
+ Regularization parameter. Only used and necessary when armijo=False
+ Gc : ndarray (ns,nt)
+ Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
+ constC : ndarray (ns,nt)
+ Constant for the gromov cost. See [24]. Only used and necessary when armijo=False
+ M : ndarray (ns,nt), optional
+ Cost matrix between the features. Only used and necessary when armijo=False
+ Returns
+ -------
+ alpha : float
+ The optimal step size of the FW
+ fc : int
+ nb of function call. Useless here
+ f_val : float
+ The value of the cost for the next iteration
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ if armijo:
+ alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
+ else: # requires symetric matrices
+ dot1 = np.dot(C1, deltaG)
+ dot12 = dot1.dot(C2)
+ a = -2 * reg * np.sum(dot12 * deltaG)
+ b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
+ c = cost(G)
+
+ alpha = solve_1d_linesearch_quad(a, b, c)
+ fc = None
+ f_val = cost(G + alpha * deltaG)
+
+ return alpha, fc, f_val
+
+
def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
- stopThr=1e-9, verbose=False, log=False):
+ stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
"""
Solve the general regularized OT problem with conditional gradient
@@ -111,11 +174,15 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshol on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ **kwargs : dict
+ Parameters for linesearch
Returns
-------
@@ -157,9 +224,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
it = 0
if verbose:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
while loop:
@@ -177,7 +244,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
deltaG = Gc - G
# line search
- alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
+ alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
G = G + alpha * deltaG
@@ -185,8 +252,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
if it >= numItermax:
loop = 0
- delta_fval = (f_val - old_fval) / abs(f_val)
- if abs(delta_fval) < stopThr:
+ abs_delta_fval = abs(f_val - old_fval)
+ relative_delta_fval = abs_delta_fval / abs(f_val)
+ if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
loop = 0
if log:
@@ -194,9 +262,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
if verbose:
if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval))
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
return G, log
@@ -205,7 +273,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
- numInnerItermax=200, stopThr=1e-9, verbose=False, log=False):
+ numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
"""
Solve the general regularized OT problem with the generalized conditional gradient
@@ -248,7 +316,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
numInnerItermax : int, optional
Max number of iterations of Sinkhorn
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshol on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -294,9 +364,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
it = 0
if verbose:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
while loop:
@@ -322,8 +392,10 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
if it >= numItermax:
loop = 0
- delta_fval = (f_val - old_fval) / abs(f_val)
- if abs(delta_fval) < stopThr:
+ abs_delta_fval = abs(f_val - old_fval)
+ relative_delta_fval = abs_delta_fval / abs(f_val)
+
+ if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
loop = 0
if log:
@@ -331,11 +403,42 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
if verbose:
if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval))
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
return G, log
else:
return G
+
+
+def solve_1d_linesearch_quad(a, b, c):
+ """
+ For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem:
+ .. math::
+ \argmin f(x)=a*x^{2}+b*x+c
+
+ Parameters
+ ----------
+ a,b,c : float
+ The coefficients of the quadratic function
+
+ Returns
+ -------
+ x : float
+ The optimal value which leads to the minimal cost
+
+ """
+ f0 = c
+ df0 = b
+ f1 = a + f0 + df0
+
+ if a > 0: # convex
+ minimum = min(1, max(0, np.divide(-b, 2.0 * a)))
+ return minimum
+ else: # non convex
+ if f0 > f1:
+ return 1
+ else:
+ return 0
diff --git a/ot/utils.py b/ot/utils.py
index bb21b38..efd1288 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -487,3 +487,11 @@ class BaseEstimator(object):
(key, self.__class__.__name__))
setattr(self, key, value)
return self
+
+
+class UndefinedParameter(Exception):
+ """
+ Aim at raising an Exception when a undefined parameter is called
+
+ """
+ pass
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 305ae84..70fa83f 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -2,6 +2,7 @@
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
#
# License: MIT License
@@ -15,7 +16,7 @@ def test_gromov():
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
xt = xs[::-1].copy()
@@ -36,6 +37,11 @@ def test_gromov():
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+ Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
+
+ np.testing.assert_allclose(
+ G, np.flipud(Id), atol=1e-04)
+
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
G = log['T']
@@ -55,7 +61,7 @@ def test_entropic_gromov():
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
xt = xs[::-1].copy()
@@ -92,12 +98,11 @@ def test_entropic_gromov():
def test_gromov_barycenter():
-
ns = 50
nt = 60
- Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
- Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
@@ -120,12 +125,11 @@ def test_gromov_barycenter():
def test_gromov_entropic_barycenter():
-
ns = 50
nt = 60
- Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
- Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
@@ -145,3 +149,98 @@ def test_gromov_entropic_barycenter():
'kl_loss', 2e-3,
max_iter=100, tol=1e-3)
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+
+
+def test_fgw():
+
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
+
+ xt = xs[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0], 2)
+ yt = ys[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M = ot.dist(ys, yt)
+ M /= M.max()
+
+ G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence fgw
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence fgw
+
+ Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
+
+ np.testing.assert_allclose(
+ G, np.flipud(Id), atol=1e-04) # cf convergence gromov
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_fgw_barycenter():
+ np.random.seed(42)
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
+
+ ys = np.random.randn(Xs.shape[0], 2)
+ yt = np.random.randn(Xt.shape[0], 2)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ xalea = np.random.randn(n_samples, 2)
+ init_C = ot.dist(xalea, xalea)
+
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
+ fixed_structure=True, init_C=init_C, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ init_X = np.random.randn(n_samples, ys.shape[1])
+
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_optim.py b/test/test_optim.py
index dfefe59..ae31e1f 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -65,3 +65,9 @@ def test_generalized_conditional_gradient():
np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+
+
+def test_solve_1d_linesearch_quad_funct():
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)