summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 17:05:38 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 17:05:48 +0200
commite1bd94bb7e85a0d2fd0fcd7642b06da12c1db6db (patch)
tree1e85920b878ab715d211db56f99e25bfa2482fd3
parentd4320382fa8873d15dcaec7adca3a4723c142515 (diff)
code review1
-rw-r--r--examples/plot_barycenter_fgw.py30
-rw-r--r--examples/plot_fgw.py32
-rw-r--r--ot/gromov.py108
-rw-r--r--ot/optim.py31
-rw-r--r--test/test_gromov.py57
5 files changed, 204 insertions, 54 deletions
diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py
index 9eea036..e4be447 100644
--- a/examples/plot_barycenter_fgw.py
+++ b/examples/plot_barycenter_fgw.py
@@ -125,7 +125,11 @@ def graph_colors(nx_graph, vmin=0, vmax=7):
colors.append(val_map[node])
return colors
-#%% create dataset
+##############################################################################
+# Generate data
+# -------------
+
+#%% circular dataset
# We build a dataset of noisy circular graphs.
# Noise is added on the structures by random connections and on the features by gaussian noise.
@@ -135,7 +139,11 @@ X0 = []
for k in range(9):
X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
-#%% Plot dataset
+##############################################################################
+# Plot data
+# ---------
+
+#%% Plot graphs
plt.figure(figsize=(8, 10))
for i in range(len(X0)):
@@ -146,9 +154,11 @@ for i in range(len(X0)):
plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
plt.show()
+##############################################################################
+# Barycenter computation
+# ----------------------
-#%%
-# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
+#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
# Features distances are the euclidean distances
Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
@@ -156,14 +166,16 @@ Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()])
lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
sizebary = 15 # we choose a barycenter with 15 nodes
-#%%
-
A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95)
-#%%
+##############################################################################
+# Plot Barycenter
+# -------------------------
+
+#%% Create the barycenter
bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
-for i in range(len(A.ravel())):
- bary.add_node(i, attr_name=float(A.ravel()[i]))
+for i, v in enumerate(A.ravel()):
+ bary.add_node(i, attr_name=v)
#%%
pos = nx.kamada_kawai_layout(bary)
diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py
index ae3c487..43efc94 100644
--- a/examples/plot_fgw.py
+++ b/examples/plot_fgw.py
@@ -22,12 +22,16 @@ import numpy as np
import ot
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+##############################################################################
+# Generate data
+# ---------
+
#%% parameters
# We create two 1D random measures
-n = 20
-n2 = 30
-sig = 1
-sig2 = 0.1
+n = 20 # number of points in the first distribution
+n2 = 30 # number of points in the second distribution
+sig = 1 # std of first distribution
+sig2 = 0.1 # std of second distribution
np.random.seed(0)
@@ -43,6 +47,10 @@ yt = yt[::-1, :]
p = ot.unif(n)
q = ot.unif(n2)
+##############################################################################
+# Plot data
+# ---------
+
#%% plot the distributions
pl.close(10)
@@ -64,15 +72,22 @@ pl.yticks(())
pl.tight_layout()
pl.show()
+##############################################################################
+# Create structure matrices and across-feature distance matrix
+# ---------
#%% Structure matrices and across-features distance matrix
C1 = ot.dist(xs)
-C2 = ot.dist(xt).T
+C2 = ot.dist(xt)
M = ot.dist(ys, yt)
w1 = ot.unif(C1.shape[0])
w2 = ot.unif(C2.shape[0])
Got = ot.emd([], [], M)
+##############################################################################
+# Plot matrices
+# ---------
+
#%%
cmap = 'Reds'
pl.close(10)
@@ -112,6 +127,9 @@ pl.tight_layout()
ax3.set_aspect('auto')
pl.show()
+##############################################################################
+# Compute FGW/GW
+# ---------
#%% Computing FGW and GW
alpha = 1e-3
@@ -123,6 +141,10 @@ ot.toc()
#%reload_ext WGW
Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
+##############################################################################
+# Visualize transport matrices
+# ---------
+
#%% visu OT matrix
cmap = 'Blues'
fs = 15
diff --git a/ot/gromov.py b/ot/gromov.py
index 5a57dc8..53349b7 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -10,6 +10,7 @@ Gromov-Wasserstein transport method
# Nicolas Courty <ncourty@irisa.fr>
# RĂ©mi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+#
# License: MIT License
import numpy as np
@@ -351,9 +352,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, **kwargs):
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
"""
- Computes the FGW distance between two graphs see [3]
+ Computes the FGW transport between two graphs see [24]
.. math::
\gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
s.t. \gamma 1 = p
@@ -377,7 +378,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
distribution in the source space
q : ndarray, shape (nt,)
distribution in the target space
- loss_fun : string,optionnal
+ loss_fun : string,optional
loss function used for the solver
max_iter : int, optional
Max number of iterations
@@ -416,7 +417,86 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
def df(G):
return gwggrad(constC, hC1, hC2, G)
- return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ if log:
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ log['fgw_dist'] = log['loss'][::-1][0]
+ return res, log
+ else:
+ return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
+ """
+ Computes the FGW distance between two graphs see [24]
+ .. math::
+ \gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ s.t. \gamma 1 = p
+ \gamma^T 1= q
+ \gamma\geq 0
+ where :
+ - M is the (ns,nt) metric cost matrix
+ - :math:`f` is the regularization term ( and df is its gradient)
+ - a and b are source and target weights (sum to 1)
+ - L is a loss function to account for the misfit between the similarity matrices
+ The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ Parameters
+ ----------
+ M : ndarray, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : ndarray, shape (ns, ns)
+ Metric cost matrix respresentative of the structure in the source space
+ C2 : ndarray, shape (nt, nt)
+ Metric cost matrix espresentative of the structure in the target space
+ p : ndarray, shape (ns,)
+ distribution in the source space
+ q : ndarray, shape (nt,)
+ distribution in the target space
+ loss_fun : string,optional
+ loss function used for the solver
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ **kwargs : dict
+ parameters can be directly pased to the ot.optim.cg solver
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
+
+ G0 = p[:, None] * q[None, :]
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G)
+
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G)
+
+ res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ if log:
+ log['fgw_dist'] = log['loss'][::-1][0]
+ log['T'] = res
+ return log['fgw_dist'], log
+ else:
+ return log['fgw_dist']
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
@@ -889,7 +969,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
- verbose=False, log=True, init_C=None, init_X=None):
+ verbose=False, log=False, init_C=None, init_X=None):
"""
Compute the fgw barycenter as presented eq (5) in [24].
----------
@@ -919,7 +999,8 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Barycenters' features
C : ndarray, shape (N,N)
Barycenters' structure matrix
- log_:
+ log_: dictionary
+ Only returned when log=True
T : list of (N,ns) transport matrices
Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns)
References
@@ -1015,14 +1096,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
-
- log_['Ts_iter'].append(T)
err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
err_structure = np.linalg.norm(C - Cprev)
if log:
log_['err_feature'].append(err_feature)
log_['err_structure'].append(err_structure)
+ log_['Ts_iter'].append(T)
if verbose:
if cpt % 200 == 0:
@@ -1032,11 +1112,15 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
print('{:5d}|{:8e}|'.format(cpt, err_feature))
cpt += 1
- log_['T'] = T # from target to Ys
- log_['p'] = p
- log_['Ms'] = Ms # Ms are N,ns
+ if log:
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms # Ms are N,ns
- return X, C, log_
+ if log:
+ return X, C, log_
+ else:
+ return X, C
def update_sructure_matrix(p, lambdas, T, Cs):
diff --git a/ot/optim.py b/ot/optim.py
index 7d103e2..4d428d9 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -5,6 +5,7 @@ Optimization algorithms for OT
# Author: Remi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+#
# License: MIT License
import numpy as np
@@ -88,20 +89,20 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
f_val : float
Value of the cost at G
- armijo : bool, optionnal
+ armijo : bool, optional
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
- C1 : ndarray (ns,ns), optionnal
+ C1 : ndarray (ns,ns), optional
Structure matrix in the source domain. Only used when armijo=False
- C2 : ndarray (nt,nt), optionnal
+ C2 : ndarray (nt,nt), optional
Structure matrix in the target domain. Only used when armijo=False
- reg : float, optionnal
+ reg : float, optional
Regularization parameter. Only used when armijo=False
Gc : ndarray (ns,nt)
Optimal map found by linearization in the FW algorithm. Only used when armijo=False
constC : ndarray (ns,nt)
Constant for the gromov cost. See [24]. Only used when armijo=False
- M : ndarray (ns,nt), optionnal
+ M : ndarray (ns,nt), optional
Cost matrix between the features. Only used when armijo=False
Returns
-------
@@ -223,9 +224,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
it = 0
if verbose:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
while loop:
@@ -261,8 +262,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
if verbose:
if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
@@ -363,9 +364,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
it = 0
if verbose:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
- print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
while loop:
@@ -402,8 +403,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
if verbose:
if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
+ print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
+ 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
diff --git a/test/test_gromov.py b/test/test_gromov.py
index cd180d4..ec85abf 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -2,6 +2,7 @@
# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
#
# License: MIT License
@@ -10,6 +11,8 @@ import ot
def test_gromov():
+ np.random.seed(42)
+
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -36,6 +39,11 @@ def test_gromov():
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence gromov
+ Id = (1 / n_samples) * np.eye(n_samples, n_samples)
+
+ np.testing.assert_allclose(
+ G, np.flipud(Id), atol=1e-04)
+
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
G = log['T']
@@ -50,6 +58,8 @@ def test_gromov():
def test_entropic_gromov():
+ np.random.seed(42)
+
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -92,6 +102,7 @@ def test_entropic_gromov():
def test_gromov_barycenter():
+ np.random.seed(42)
ns = 50
nt = 60
@@ -120,7 +131,7 @@ def test_gromov_barycenter():
def test_gromov_entropic_barycenter():
-
+ np.random.seed(42)
ns = 50
nt = 60
@@ -148,6 +159,8 @@ def test_gromov_entropic_barycenter():
def test_fgw():
+ np.random.seed(42)
+
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -180,8 +193,26 @@ def test_fgw():
np.testing.assert_allclose(
q, G.sum(0), atol=1e-04) # cf convergence fgw
+ Id = (1 / n_samples) * np.eye(n_samples, n_samples)
+
+ np.testing.assert_allclose(
+ G, np.flipud(Id), atol=1e-04) # cf convergence gromov
+
+ fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+
+ G = log['T']
+
+ np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
def test_fgw_barycenter():
+ np.random.seed(42)
ns = 50
nt = 60
@@ -196,28 +227,28 @@ def test_fgw_barycenter():
C2 = ot.dist(Xt)
n_samples = 3
- X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
xalea = np.random.randn(n_samples, 2)
init_C = ot.dist(xalea, xalea)
- X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
- fixed_structure=True, init_C=init_C, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
+ fixed_structure=True, init_C=init_C, fixed_features=False,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
init_X = np.random.randn(n_samples, ys.shape[1])
- X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
+ X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples), loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))