summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md1
-rw-r--r--ot/gromov.py12
-rw-r--r--test/test_gromov.py75
-rw-r--r--test/test_optim.py5
4 files changed, 89 insertions, 4 deletions
diff --git a/README.md b/README.md
index 9951773..9692344 100644
--- a/README.md
+++ b/README.md
@@ -164,6 +164,7 @@ The contributors to this library are:
* Erwan Vautier (Gromov-Wasserstein)
* [Kilian Fatras](https://kilianfatras.github.io/)
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
+* [Vayer Titouan](https://tvayer.github.io/)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
diff --git a/ot/gromov.py b/ot/gromov.py
index ad68a1c..297b194 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -926,6 +926,10 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
+
+ class UndefinedParameter(Exception):
+ pass
+
S = len(Cs)
d = Ys[0].shape[1] #dimension on the node features
if p is None:
@@ -938,7 +942,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
if fixed_structure:
if init_C is None:
- C=Cs[0]
+ raise UndefinedParameter('If C is fixed it must be initialized')
else:
C=init_C
else:
@@ -950,7 +954,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
if fixed_features:
if init_X is None:
- X=Ys[0]
+ raise UndefinedParameter('If X is fixed it must be initialized')
else :
X= init_X
else:
@@ -1004,13 +1008,13 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
# Cs is ns,ns
# p is N,1
# ps is ns,1
-
+
T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
log_['Ts_iter'].append(T)
- err_feature = np.linalg.norm(X - Xprev.reshape(d,N))
+ err_feature = np.linalg.norm(X - Xprev.reshape(N,d))
err_structure = np.linalg.norm(C - Cprev)
if log:
diff --git a/test/test_gromov.py b/test/test_gromov.py
index fb86274..07cd874 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -143,3 +143,78 @@ def test_gromov_entropic_barycenter():
'kl_loss', 2e-3,
max_iter=100, tol=1e-3)
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+
+def test_fgw():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+
+ xt = xs[::-1].copy()
+
+ ys = np.random.randn(xs.shape[0],2)
+ yt= ys[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ M=ot.dist(ys,yt)
+ M/=M.max()
+
+ G = ot.gromov.fused_gromov_wasserstein(M,C1, C2, p, q, 'square_loss',alpha=0.5)
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence fgw
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence fgw
+
+
+def test_fgw_barycenter():
+
+ ns = 50
+ nt = 60
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
+
+ ys = np.random.randn(Xs.shape[0],2)
+ yt= np.random.randn(Xt.shape[0],2)
+
+ C1 = ot.dist(Xs)
+ C2 = ot.dist(Xt)
+
+ n_samples = 3
+ X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5,
+ fixed_structure=False,fixed_features=False,
+ p=ot.unif(n_samples),loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ xalea = np.random.randn(n_samples, 2)
+ init_C = ot.dist(xalea, xalea)
+
+ X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],ps=[ot.unif(ns), ot.unif(nt)],lambdas=[.5, .5],alpha=0.5,
+ fixed_structure=True,init_C=init_C,fixed_features=False,
+ p=ot.unif(n_samples),loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+
+ init_X=np.random.randn(n_samples,ys.shape[1])
+
+ X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5,
+ fixed_structure=False,fixed_features=True, init_X=init_X,
+ p=ot.unif(n_samples),loss_fun='square_loss',
+ max_iter=100, tol=1e-3)
+ np.testing.assert_allclose(C.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
diff --git a/test/test_optim.py b/test/test_optim.py
index dfefe59..1188ef6 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -65,3 +65,8 @@ def test_generalized_conditional_gradient():
np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
+
+def test_solve_1d_linesearch_quad_funct():
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1,-1,0),0.5)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,5,0),0)
+ np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,0.5,0),1)