diff options
-rw-r--r-- | README.md | 13 | ||||
-rw-r--r-- | examples/plot_barycenter_fgw.py | 172 | ||||
-rw-r--r-- | examples/plot_fgw.py | 151 | ||||
-rw-r--r-- | ot/bregman.py | 1 | ||||
-rw-r--r-- | ot/gromov.py | 315 | ||||
-rw-r--r-- | ot/optim.py | 101 | ||||
-rw-r--r-- | test/test_gromov.py | 75 | ||||
-rw-r--r-- | test/test_optim.py | 5 |
8 files changed, 800 insertions, 33 deletions
@@ -164,6 +164,7 @@ The contributors to this library are: * Erwan Vautier (Gromov-Wasserstein) * [Kilian Fatras](https://kilianfatras.github.io/) * [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) +* [Vayer Titouan](https://tvayer.github.io/) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -221,15 +222,3 @@ You can also post bug reports and feature requests in Github issues. Make sure t [16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). [Smooth and Sparse Optimal Transport](https://arxiv.org/abs/1710.06276). Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). - -[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](https://arxiv.org/abs/1605.08527). Advances in Neural Information Processing Systems (2016). - -[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018) - -[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning - -[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66. - -[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31 - -[23] Aude, G., Peyré, G., Cuturi, M., [Learning Generative Models with Sinkhorn Divergences](https://arxiv.org/abs/1706.00292), Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py new file mode 100644 index 0000000..f416629 --- /dev/null +++ b/examples/plot_barycenter_fgw.py @@ -0,0 +1,172 @@ +# -*- coding: utf-8 -*- +""" +================================= +Plot graphs' barycenter using FGW +================================= + +This example illustrates the computation barycenter of labeled graphs using FGW + +Requires networkx >=2 + +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + +""" + +# Author: Titouan Vayer <titouan.vayer@irisa.fr> +# +# License: MIT License + +#%% load libraries +import numpy as np +import matplotlib.pyplot as plt +import networkx as nx +import math +from scipy.sparse.csgraph import shortest_path +import matplotlib.colors as mcol +from matplotlib import cm +from ot.gromov import fgw_barycenters +#%% Graph functions + +def find_thresh(C,inf=0.5,sup=3,step=10): + """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected + Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested. + The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix + and the original matrix. + Parameters + ---------- + C : ndarray, shape (n_nodes,n_nodes) + The structure matrix to threshold + inf : float + The beginning of the linesearch + sup : float + The end of the linesearch + step : integer + Number of thresholds tested + """ + dist=[] + search=np.linspace(inf,sup,step) + for thresh in search: + Cprime=sp_to_adjency(C,0,thresh) + SC=shortest_path(Cprime,method='D') + SC[SC==float('inf')]=100 + dist.append(np.linalg.norm(SC-C)) + return search[np.argmin(dist)],dist + +def sp_to_adjency(C,threshinf=0.2,threshsup=1.8): + """ Thresholds the structure matrix in order to compute an adjency matrix. + All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0 + Parameters + ---------- + C : ndarray, shape (n_nodes,n_nodes) + The structure matrix to threshold + threshinf : float + The minimum value of distance from which the new value is set to 1 + threshsup : float + The maximum value of distance from which the new value is set to 1 + Returns + ------- + C : ndarray, shape (n_nodes,n_nodes) + The threshold matrix. Each element is in {0,1} + """ + H=np.zeros_like(C) + np.fill_diagonal(H,np.diagonal(C)) + C=C-H + C=np.minimum(np.maximum(C,threshinf),threshsup) + C[C==threshsup]=0 + C[C!=0]=1 + + return C + +def build_noisy_circular_graph(N=20,mu=0,sigma=0.3,with_noise=False,structure_noise=False,p=None): + """ Create a noisy circular graph + """ + g=nx.Graph() + g.add_nodes_from(list(range(N))) + for i in range(N): + noise=float(np.random.normal(mu,sigma,1)) + if with_noise: + g.add_node(i,attr_name=math.sin((2*i*math.pi/N))+noise) + else: + g.add_node(i,attr_name=math.sin(2*i*math.pi/N)) + g.add_edge(i,i+1) + if structure_noise: + randomint=np.random.randint(0,p) + if randomint==0: + if i<=N-3: + g.add_edge(i,i+2) + if i==N-2: + g.add_edge(i,0) + if i==N-1: + g.add_edge(i,1) + g.add_edge(N,0) + noise=float(np.random.normal(mu,sigma,1)) + if with_noise: + g.add_node(N,attr_name=math.sin((2*N*math.pi/N))+noise) + else: + g.add_node(N,attr_name=math.sin(2*N*math.pi/N)) + return g + +def graph_colors(nx_graph,vmin=0,vmax=7): + cnorm = mcol.Normalize(vmin=vmin,vmax=vmax) + cpick = cm.ScalarMappable(norm=cnorm,cmap='viridis') + cpick.set_array([]) + val_map = {} + for k,v in nx.get_node_attributes(nx_graph,'attr_name').items(): + val_map[k]=cpick.to_rgba(v) + colors=[] + for node in nx_graph.nodes(): + colors.append(val_map[node]) + return colors + +#%% create dataset +# We build a dataset of noisy circular graphs. +# Noise is added on the structures by random connections and on the features by gaussian noise. + +np.random.seed(30) +X0=[] +for k in range(9): + X0.append(build_noisy_circular_graph(np.random.randint(15,25),with_noise=True,structure_noise=True,p=3)) + +#%% Plot dataset + +plt.figure(figsize=(8,10)) +for i in range(len(X0)): + plt.subplot(3,3,i+1) + g=X0[i] + pos=nx.kamada_kawai_layout(g) + nx.draw(g,pos=pos,node_color = graph_colors(g,vmin=-1,vmax=1),with_labels=False,node_size=100) +plt.suptitle('Dataset of noisy graphs. Color indicates the label',fontsize=20) +plt.show() + + + +#%% +# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph +# Features distances are the euclidean distances +Cs=[shortest_path(nx.adjacency_matrix(x)) for x in X0] +ps=[np.ones(len(x.nodes()))/len(x.nodes()) for x in X0] +Ys=[np.array([v for (k,v) in nx.get_node_attributes(x,'attr_name').items()]).reshape(-1,1) for x in X0] +lambdas=np.array([np.ones(len(Ys))/len(Ys)]).ravel() +sizebary=15 # we choose a barycenter with 15 nodes + +#%% + +A,C,log=fgw_barycenters(sizebary,Ys,Cs,ps,lambdas,alpha=0.95) + +#%% +bary=nx.from_numpy_matrix(sp_to_adjency(C,threshinf=0,threshsup=find_thresh(C,sup=100,step=100)[0])) +for i in range(len(A.ravel())): + bary.add_node(i,attr_name=float(A.ravel()[i])) + +#%% +pos = nx.kamada_kawai_layout(bary) +nx.draw(bary,pos=pos,node_color = graph_colors(bary,vmin=-1,vmax=1),with_labels=False) +plt.suptitle('Barycenter',fontsize=20) +plt.show() + + + + diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py new file mode 100644 index 0000000..bfa7fb4 --- /dev/null +++ b/examples/plot_fgw.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +""" +============================== +Plot Fused-gromov-Wasserstein +============================== + +This example illustrates the computation of FGW for 1D measures[18]. + +.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + +""" + +# Author: Titouan Vayer <titouan.vayer@irisa.fr> +# +# License: MIT License + +import matplotlib.pyplot as pl +import numpy as np +import ot +from ot.gromov import gromov_wasserstein,fused_gromov_wasserstein + +#%% parameters +# We create two 1D random measures +n=20 +n2=30 +sig=1 +sig2=0.1 + +np.random.seed(0) + +phi=np.arange(n)[:,None] +xs=phi+sig*np.random.randn(n,1) +ys=np.vstack((np.ones((n//2,1)),0*np.ones((n//2,1))))+sig2*np.random.randn(n,1) + +phi2=np.arange(n2)[:,None] +xt=phi2+sig*np.random.randn(n2,1) +yt=np.vstack((np.ones((n2//2,1)),0*np.ones((n2//2,1))))+sig2*np.random.randn(n2,1) +yt= yt[::-1,:] + +p=ot.unif(n) +q=ot.unif(n2) + +#%% plot the distributions + +pl.close(10) +pl.figure(10,(7,7)) + +pl.subplot(2,1,1) + +pl.scatter(ys,xs,c=phi,s=70) +pl.ylabel('Feature value a',fontsize=20) +pl.title('$\mu=\sum_i \delta_{x_i,a_i}$',fontsize=25, usetex=True, y=1) +pl.xticks(()) +pl.yticks(()) +pl.subplot(2,1,2) +pl.scatter(yt,xt,c=phi2,s=70) +pl.xlabel('coordinates x/y',fontsize=25) +pl.ylabel('Feature value b',fontsize=20) +pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$',fontsize=25, usetex=True, y=1) +pl.yticks(()) +pl.tight_layout() +pl.show() + + +#%% Structure matrices and across-features distance matrix +C1=ot.dist(xs) +C2=ot.dist(xt).T +M=ot.dist(ys,yt) +w1=ot.unif(C1.shape[0]) +w2=ot.unif(C2.shape[0]) +Got=ot.emd([],[],M) + +#%% +cmap='Reds' +pl.close(10) +pl.figure(10,(5,5)) +fs=15 +l_x=[0,5,10,15] +l_y=[0,5,10,15,20,25] +gs = pl.GridSpec(5, 5) + +ax1=pl.subplot(gs[3:,:2]) + +pl.imshow(C1,cmap=cmap,interpolation='nearest') +pl.title("$C_1$",fontsize=fs) +pl.xlabel("$k$",fontsize=fs) +pl.ylabel("$i$",fontsize=fs) +pl.xticks(l_x) +pl.yticks(l_x) + +ax2=pl.subplot(gs[:3,2:]) + +pl.imshow(C2,cmap=cmap,interpolation='nearest') +pl.title("$C_2$",fontsize=fs) +pl.ylabel("$l$",fontsize=fs) +#pl.ylabel("$l$",fontsize=fs) +pl.xticks(()) +pl.yticks(l_y) +ax2.set_aspect('auto') + +ax3=pl.subplot(gs[3:,2:],sharex=ax2,sharey=ax1) +pl.imshow(M,cmap=cmap,interpolation='nearest') +pl.yticks(l_x) +pl.xticks(l_y) +pl.ylabel("$i$",fontsize=fs) +pl.title("$M_{AB}$",fontsize=fs) +pl.xlabel("$j$",fontsize=fs) +pl.tight_layout() +ax3.set_aspect('auto') +pl.show() + + +#%% Computing FGW and GW +alpha=1e-3 + +ot.tic() +Gwg,logw=fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=alpha,verbose=True,log=True) +ot.toc() + +#%reload_ext WGW +Gg,log=gromov_wasserstein(C1,C2,p,q,loss_fun='square_loss',verbose=True,log=True) + +#%% visu OT matrix +cmap='Blues' +fs=15 +pl.figure(2,(13,5)) +pl.clf() +pl.subplot(1,3,1) +pl.imshow(Got,cmap=cmap,interpolation='nearest') +#pl.xlabel("$y$",fontsize=fs) +pl.ylabel("$i$",fontsize=fs) +pl.xticks(()) + +pl.title('Wasserstein ($M$ only)') + +pl.subplot(1,3,2) +pl.imshow(Gg,cmap=cmap,interpolation='nearest') +pl.title('Gromov ($C_1,C_2$ only)') +pl.xticks(()) +pl.subplot(1,3,3) +pl.imshow(Gwg,cmap=cmap,interpolation='nearest') +pl.title('FGW ($M+C_1,C_2$)') + +pl.xlabel("$j$",fontsize=fs) +pl.ylabel("$i$",fontsize=fs) + +pl.tight_layout() +pl.show()
\ No newline at end of file diff --git a/ot/bregman.py b/ot/bregman.py index dc43834..321712b 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -6,6 +6,7 @@ Bregman projections for regularized OT # Author: Remi Flamary <remi.flamary@unice.fr> # Nicolas Courty <ncourty@irisa.fr> # Kilian Fatras <kilian.fatras@irisa.fr> +# Titouan Vayer <titouan.vayer@irisa.fr> # # License: MIT License diff --git a/ot/gromov.py b/ot/gromov.py index 7974546..33134a2 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -9,17 +9,18 @@ Gromov-Wasserstein transport method # Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
# Rémi Flamary <remi.flamary@unice.fr>
-#
+# Titouan Vayer <titouan.vayer@irisa.fr>
# License: MIT License
import numpy as np
+
from .bregman import sinkhorn
from .utils import dist
from .optim import cg
-def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'):
+def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
""" Return loss matrices and tensors for Gromov-Wasserstein fast computation
Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
@@ -77,16 +78,16 @@ def init_matrix(C1, C2, T, p, q, loss_fun='square_loss'): if loss_fun == 'square_loss':
def f1(a):
- return (a**2) / 2
+ return (a**2)
def f2(b):
- return (b**2) / 2
+ return (b**2)
def h1(a):
return a
def h2(b):
- return b
+ return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
return a * np.log(a + 1e-15) - a
@@ -268,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs): return np.exp(np.divide(tmpsum, ppt))
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
+def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
"""
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
@@ -306,6 +307,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): Print information along iterations
log : bool, optional
record log if True
+ amijo : bool, optional
+ If True the steps of the line-search is found via an amijo research. Else closed form is used.
+ If there is convergence issues use False.
**kwargs : dict
parameters can be directly pased to the ot.optim.cg solver
@@ -329,9 +333,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): """
- T = np.eye(len(p), len(q))
-
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -342,14 +344,81 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): return gwggrad(constC, hC1, hC2, G)
if log:
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
return res, log
else:
- return cg(p, q, 0, 1, f, df, G0, **kwargs)
+ return cg(p, q, 0, 1, f, df, G0, amijo=amijo, **kwargs)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs):
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs):
+ """
+ Computes the FGW distance between two graphs see [3]
+ .. math::
+ \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix respresentative of the structure in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix espresentative of the structure in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string,optionnal
+ loss function used for the solver
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ amijo : bool, optional
+ If True the steps of the line-search is found via an amijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly pased to the ot.optim.cg solver
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ return cg(p, q, M, alpha, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+
+def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
"""
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
@@ -387,7 +456,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): Print information along iterations
log : bool, optional
record log if True
-
+ amijo : bool, optional
+ If True the steps of the line-search is found via an amijo research. Else closed form is used.
+ If there is convergence issues use False.
Returns
-------
gw_dist : float
@@ -407,9 +478,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): """
- T = np.eye(len(p), len(q))
-
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
G0 = p[:, None] * q[None, :]
@@ -418,7 +487,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, **kwargs): def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, **kwargs)
+ res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
log['T'] = res
if log:
@@ -495,7 +564,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, T = np.outer(p, q) # Initialization
- constC, hC1, hC2 = init_matrix(C1, C2, T, p, q, loss_fun)
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
cpt = 0
err = 1
@@ -815,3 +884,213 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, cpt += 1
return C
+
+
+def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
+ p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
+ verbose=False, log=True, init_C=None, init_X=None):
+ """
+ Compute the fgw barycenter as presented eq (5) in [3].
+ ----------
+ N : integer
+ Desired number of samples of the target barycenter
+ Ys: list of ndarray, each element has shape (ns,d)
+ Features of all samples
+ Cs : list of ndarray, each element has shape (ns,ns)
+ Structure matrices of all samples
+ ps : list of ndarray, each element has shape (ns,)
+ masses of all samples
+ lambdas : list of float
+ list of the S spaces' weights
+ alpha : float
+ Alpha parameter for the fgw distance
+ fixed_structure : bool
+ Wether to fix the structure of the barycenter during the updates
+ fixed_features : bool
+ Wether to fix the feature of the barycenter during the updates
+ init_C : ndarray, shape (N,N), optional
+ initialization for the barycenters' structure matrix. If not set random init
+ init_X : ndarray, shape (N,d), optional
+ initialization for the barycenters' features. If not set random init
+ Returns
+ ----------
+ X : ndarray, shape (N,d)
+ Barycenters' features
+ C : ndarray, shape (N,N)
+ Barycenters' structure matrix
+ log_:
+ T : list of (N,ns) transport matrices
+ Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns)
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ class UndefinedParameter(Exception):
+ pass
+
+ S = len(Cs)
+ d = Ys[0].shape[1] # dimension on the node features
+ if p is None:
+ p = np.ones(N) / N
+
+ Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
+ Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
+
+ lambdas = np.asarray(lambdas, dtype=np.float64)
+
+ if fixed_structure:
+ if init_C is None:
+ raise UndefinedParameter('If C is fixed it must be initialized')
+ else:
+ C = init_C
+ else:
+ if init_C is None:
+ xalea = np.random.randn(N, 2)
+ C = dist(xalea, xalea)
+ else:
+ C = init_C
+
+ if fixed_features:
+ if init_X is None:
+ raise UndefinedParameter('If X is fixed it must be initialized')
+ else:
+ X = init_X
+ else:
+ if init_X is None:
+ X = np.zeros((N, d))
+ else:
+ X = init_X
+
+ T = [np.outer(p, q) for q in ps]
+
+ # X is N,d
+ # Ys is ns,d
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+ # Ms is N,ns
+
+ cpt = 0
+ err_feature = 1
+ err_structure = 1
+
+ if log:
+ log_ = {}
+ log_['err_feature'] = []
+ log_['err_structure'] = []
+ log_['Ts_iter'] = []
+
+ while((err_feature > tol or err_structure > tol) and cpt < max_iter):
+ Cprev = C
+ Xprev = X
+
+ if not fixed_features:
+ Ys_temp = [y.T for y in Ys]
+ X = update_feature_matrix(lambdas, Ys_temp, T, p).T
+
+ # X must be N,d
+ # Ys must be ns,d
+ Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+
+ if not fixed_structure:
+ if loss_fun == 'square_loss':
+ # T must be ns,N
+ # Cs must be ns,ns
+ # p must be N,1
+ T_temp = [t.T for t in T]
+ C = update_sructure_matrix(p, lambdas, T_temp, Cs)
+
+ # Ys must be d,ns
+ # Ts must be N,ns
+ # p must be N,1
+ # Ms is N,ns
+ # C is N,N
+ # Cs is ns,ns
+ # p is N,1
+ # ps is ns,1
+
+ T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
+
+ # T is N,ns
+
+ log_['Ts_iter'].append(T)
+ err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
+ err_structure = np.linalg.norm(C - Cprev)
+
+ if log:
+ log_['err_feature'].append(err_feature)
+ log_['err_structure'].append(err_structure)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err_structure))
+ print('{:5d}|{:8e}|'.format(cpt, err_feature))
+
+ cpt += 1
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms # Ms are N,ns
+
+ return X, C, log_
+
+
+def update_sructure_matrix(p, lambdas, T, Cs):
+ """
+ Updates C according to the L2 Loss kernel with the S Ts couplings
+ calculated at each iteration
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ T : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Cs : list of S ndarray, shape(ns,ns)
+ Metric cost matrices
+ Returns
+ ----------
+ C : ndarray, shape (nt,nt)
+ updated C matrix
+ """
+ tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
+ ppt = np.outer(p, p)
+
+ return np.divide(tmpsum, ppt)
+
+
+def update_feature_matrix(lambdas, Ys, Ts, p):
+ """
+ Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3]
+ calculated at each iteration
+ Parameters
+ ----------
+ p : ndarray, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ list of the S spaces' weights
+ Ts : list of S np.ndarray(ns,N)
+ the S Ts couplings calculated at each iteration
+ Ys : list of S ndarray, shape(d,ns)
+ The features
+ Returns
+ ----------
+ X : ndarray, shape (d,N)
+
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ p = np.diag(np.array(1 / p).reshape(-1,))
+
+ tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T).dot(p) for s in range(len(Ts))])
+
+ return tmpsum
diff --git a/ot/optim.py b/ot/optim.py index f31fae2..cbfb187 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -4,7 +4,7 @@ Optimization algorithms for OT """ # Author: Remi Flamary <remi.flamary@unice.fr> -# +# Titouan Vayer <titouan.vayer@irisa.fr> # License: MIT License import numpy as np @@ -72,8 +72,70 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, return alpha, fc[0], phi1 +def do_linesearch(cost, G, deltaG, Mi, f_val, + amijo=False, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): + """ + Solve the linesearch in the FW iterations + Parameters + ---------- + cost : method + The FGW cost + G : ndarray, shape(ns,nt) + The transport map at a given iteration of the FW + deltaG : ndarray (ns,nt) + Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration + Mi : ndarray (ns,nt) + Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost + f_val : float + Value of the cost at G + amijo : bool, optionnal + If True the steps of the line-search is found via an amijo research. Else closed form is used. + If there is convergence issues use False. + C1 : ndarray (ns,ns), optionnal + Structure matrix in the source domain. Only used when amijo=False + C2 : ndarray (nt,nt), optionnal + Structure matrix in the target domain. Only used when amijo=False + reg : float, optionnal + Regularization parameter. Corresponds to the alpha parameter of FGW. Only used when amijo=False + Gc : ndarray (ns,nt) + Optimal map found by linearization in the FW algorithm. Only used when amijo=False + constC : ndarray (ns,nt) + Constant for the gromov cost. See [3]. Only used when amijo=False + M : ndarray (ns,nt), optionnal + Cost matrix between the features. Only used when amijo=False + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain + and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + """ + if amijo: + alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + else: # requires symetric matrices + dot1 = np.dot(C1, deltaG) + dot12 = dot1.dot(C2) + a = -2 * reg * np.sum(dot12 * deltaG) + b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + c = cost(G) + + alpha = solve_1d_linesearch_quad_funct(a, b, c) + fc = None + f_val = cost(G + alpha * deltaG) + + return alpha, fc, f_val + + def cg(a, b, M, reg, f, df, G0=None, numItermax=200, - stopThr=1e-9, verbose=False, log=False): + stopThr=1e-9, verbose=False, log=False, **kwargs): """ Solve the general regularized OT problem with conditional gradient @@ -116,6 +178,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, Print information along iterations log : bool, optional record log if True + kwargs : dict + Parameters for linesearch Returns ------- @@ -177,7 +241,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, deltaG = Gc - G # line search - alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) G = G + alpha * deltaG @@ -339,3 +403,34 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, return G, log else: return G + + +def solve_1d_linesearch_quad_funct(a, b, c): + """ + Solve on 0,1 the following problem: + .. math:: + \min f(x)=a*x^{2}+b*x+c + + Parameters + ---------- + a,b,c : float + The coefficients of the quadratic function + + Returns + ------- + x : float + The optimal value which leads to the minimal cost + + """ + f0 = c + df0 = b + f1 = a + f0 + df0 + + if a > 0: # convex + minimum = min(1, max(0, -b / (2 * a))) + return minimum + else: # non convex + if f0 > f1: + return 1 + else: + return 0 diff --git a/test/test_gromov.py b/test/test_gromov.py index 305ae84..43b63e1 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -145,3 +145,78 @@ def test_gromov_entropic_barycenter(): 'kl_loss', 2e-3,
max_iter=100, tol=1e-3)
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+
+def test_fgw():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+
+ xt = xs[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0],2)
+ yt= ys[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M=ot.dist(ys,yt)
+ M/=M.max()
+
+ G = ot.gromov.fused_gromov_wasserstein(M,C1, C2, p, q, 'square_loss',alpha=0.5)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence fgw
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence fgw
+
+
+def test_fgw_barycenter():
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
+
+ ys = np.random.randn(Xs.shape[0],2)
+ yt= np.random.randn(Xt.shape[0],2)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5,
+ fixed_structure=False,fixed_features=False,
+ p=ot.unif(n_samples),loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ xalea = np.random.randn(n_samples, 2)
+ init_C = ot.dist(xalea, xalea)
+
+ X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],ps=[ot.unif(ns), ot.unif(nt)],lambdas=[.5, .5],alpha=0.5,
+ fixed_structure=True,init_C=init_C,fixed_features=False,
+ p=ot.unif(n_samples),loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ init_X=np.random.randn(n_samples,ys.shape[1])
+
+ X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5,
+ fixed_structure=False,fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples),loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_optim.py b/test/test_optim.py index dfefe59..1188ef6 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -65,3 +65,8 @@ def test_generalized_conditional_gradient(): np.testing.assert_allclose(a, G.sum(1), atol=1e-05) np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + +def test_solve_1d_linesearch_quad_funct(): + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1,-1,0),0.5) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,5,0),0) + np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,0.5,0),1) |