summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-11-02 13:42:02 +0100
committerGitHub <noreply@github.com>2021-11-02 13:42:02 +0100
commita335324d008e8982be61d7ace937815a2bfa98f9 (patch)
tree83c7f637597f10f6f3d20b15532e53fc65b51f22 /test
parent0cb2b2efe901ed74c614046d250518769f870313 (diff)
[MRG] Backend for gromov (#294)
* bregman: small correction * gromov backend first draft * Removing decorators * Reworked casting method * Bug solve * Removing casting * Bug solve * toarray renamed todense ; expand_dims removed * Warning (jax not supporting sparse matrix) moved * Mistake corrected * test backend * Sparsity test for older versions of pytorch * Trying pytorch/1.10 * Attempt to correct torch sparse bug * Backend version of gromov tests * Random state introduced for remaining gromov functions * review changes * code coverage * Docs (first draft, to be continued) * Gromov docs * Prettified docs * mistake corrected in the docs * little change Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/test_backend.py56
-rw-r--r--test/test_bregman.py4
-rw-r--r--test/test_gromov.py297
3 files changed, 265 insertions, 92 deletions
diff --git a/test/test_backend.py b/test/test_backend.py
index 5853282..0f11ace 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -207,6 +207,22 @@ def test_empty_backend():
nx.stack([M, M])
with pytest.raises(NotImplementedError):
nx.reshape(M, (5, 3, 2))
+ with pytest.raises(NotImplementedError):
+ nx.coo_matrix(M, M, M)
+ with pytest.raises(NotImplementedError):
+ nx.issparse(M)
+ with pytest.raises(NotImplementedError):
+ nx.tocsr(M)
+ with pytest.raises(NotImplementedError):
+ nx.eliminate_zeros(M)
+ with pytest.raises(NotImplementedError):
+ nx.todense(M)
+ with pytest.raises(NotImplementedError):
+ nx.where(M, M, M)
+ with pytest.raises(NotImplementedError):
+ nx.copy(M)
+ with pytest.raises(NotImplementedError):
+ nx.allclose(M, M)
def test_func_backends(nx):
@@ -216,6 +232,11 @@ def test_func_backends(nx):
v = rnd.randn(3)
val = np.array([1.0])
+ # Sparse tensors test
+ sp_row = np.array([0, 3, 1, 0, 3])
+ sp_col = np.array([0, 3, 1, 2, 2])
+ sp_data = np.array([4, 5, 7, 9, 0])
+
lst_tot = []
for nx in [ot.backend.NumpyBackend(), nx]:
@@ -229,6 +250,10 @@ def test_func_backends(nx):
vb = nx.from_numpy(v)
val = nx.from_numpy(val)
+ sp_rowb = nx.from_numpy(sp_row)
+ sp_colb = nx.from_numpy(sp_col)
+ sp_datab = nx.from_numpy(sp_data)
+
A = nx.set_gradients(val, v, v)
lst_b.append(nx.to_numpy(A))
lst_name.append('set_gradients')
@@ -438,6 +463,37 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('reshape')
+ sp_Mb = nx.coo_matrix(sp_datab, sp_rowb, sp_colb, shape=(4, 4))
+ nx.todense(Mb)
+ lst_b.append(nx.to_numpy(nx.todense(sp_Mb)))
+ lst_name.append('coo_matrix')
+
+ assert not nx.issparse(Mb), 'Assert fail on: issparse (expected False)'
+ assert nx.issparse(sp_Mb) or nx.__name__ == "jax", 'Assert fail on: issparse (expected True)'
+
+ A = nx.tocsr(sp_Mb)
+ lst_b.append(nx.to_numpy(nx.todense(A)))
+ lst_name.append('tocsr')
+
+ A = nx.eliminate_zeros(nx.copy(sp_datab), threshold=5.)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('eliminate_zeros (dense)')
+
+ A = nx.eliminate_zeros(sp_Mb)
+ lst_b.append(nx.to_numpy(nx.todense(A)))
+ lst_name.append('eliminate_zeros (sparse)')
+
+ A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('where')
+
+ A = nx.copy(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('copy')
+
+ assert nx.allclose(Mb, Mb), 'Assert fail on: allclose (expected True)'
+ assert not nx.allclose(2 * Mb, Mb), 'Assert fail on: allclose (expected False)'
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
diff --git a/test/test_bregman.py b/test/test_bregman.py
index c1120ba..6923d31 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -477,8 +477,8 @@ def test_lazy_empirical_sinkhorn(nx):
b = ot.unif(n)
numIterMax = 1000
- X_s = np.reshape(np.arange(n), (n, 1))
- X_t = np.reshape(np.arange(0, n), (n, 1))
+ X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
+ X_t = np.reshape(np.arange(0, n, dtype=np.float64), (n, 1))
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='euclidean')
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 0242d72..509c54d 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -8,11 +8,12 @@
import numpy as np
import ot
+from ot.backend import NumpyBackend
import pytest
-def test_gromov():
+def test_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -31,37 +32,50 @@ def test_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True)
+ Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True))
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
- np.testing.assert_allclose(
- G, np.flipud(Id), atol=1e-04)
+ np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04)
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
+ gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
+ gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
- np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
+ np.testing.assert_allclose(gwb, 0, atol=1e-1, rtol=1e-1)
- np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
+ np.testing.assert_allclose(gw_val, gw_valb, atol=1e-06)
+ np.testing.assert_allclose(gwb, gw_valb, atol=1e-1, rtol=1e-1) # cf log=False
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
-def test_entropic_gromov():
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_entropic_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -80,30 +94,44 @@ def test_entropic_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G = ot.gromov.entropic_gromov_wasserstein(
C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
+ Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein(
+ C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
+ ))
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
gw, log = ot.gromov.entropic_gromov_wasserstein2(
C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
+ gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
+ C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
+ np.testing.assert_allclose(gw, gwb, atol=1e-06)
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
-def test_pointwise_gromov():
+def test_pointwise_gromov(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -122,33 +150,52 @@ def test_pointwise_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
def loss(x, y):
return np.abs(x - y)
+ def lossb(x, y):
+ return nx.abs(x - y)
+
G, log = ot.gromov.pointwise_gromov_wasserstein(
C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42)
+ G = NumpyBackend().todense(G)
+ Gb, logb = ot.gromov.pointwise_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(nx.todense(Gb))
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- assert log['gw_dist_estimated'] == 0.0
- assert log['gw_dist_std'] == 0.0
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08)
G, log = ot.gromov.pointwise_gromov_wasserstein(
C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+ G = NumpyBackend().todense(G)
+ Gb, logb = ot.gromov.pointwise_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(nx.todense(Gb))
- assert log['gw_dist_estimated'] == 0.10342276348494964
- assert log['gw_dist_std'] == 0.0015952535464736394
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8)
-def test_sampled_gromov():
+@pytest.skip_backend("jax", reason="test very slow with jax backend")
+def test_sampled_gromov(nx):
n_samples = 50 # nb samples
- mu_s = np.array([0, 0])
- cov_s = np.array([[1, 0], [0, 1]])
+ mu_s = np.array([0, 0], dtype=np.float64)
+ cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64)
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
@@ -163,23 +210,35 @@ def test_sampled_gromov():
C1 /= C1.max()
C2 /= C2.max()
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
def loss(x, y):
return np.abs(x - y)
+ def lossb(x, y):
+ return nx.abs(x - y)
+
G, log = ot.gromov.sampled_gromov_wasserstein(
C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ Gb, logb = ot.gromov.sampled_gromov_wasserstein(
+ C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42)
+ Gb = nx.to_numpy(Gb)
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- assert log['gw_dist_estimated'] == 0.05679474884977278
- assert log['gw_dist_std'] == 0.0005986592106971995
+ np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08)
+ np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08)
-def test_gromov_barycenter():
+def test_gromov_barycenter(nx):
ns = 10
nt = 20
@@ -188,26 +247,42 @@ def test_gromov_barycenter():
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1 = ot.unif(ns)
+ p2 = ot.unif(nt)
n_samples = 3
- Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'square_loss', # 5e-4,
- max_iter=100, tol=1e-3,
- verbose=True)
- np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ p = ot.unif(n_samples)
- Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'kl_loss', # 5e-4,
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Cb = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ )
+ Cbb = nx.to_numpy(ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ ))
+ np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
+ np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ )
+ Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, random_state=42
+ ))
+ np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
+ np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
@pytest.mark.filterwarnings("ignore:divide")
-def test_gromov_entropic_barycenter():
+def test_gromov_entropic_barycenter(nx):
ns = 10
nt = 20
@@ -216,26 +291,41 @@ def test_gromov_entropic_barycenter():
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1 = ot.unif(ns)
+ p2 = ot.unif(nt)
n_samples = 2
- Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'square_loss', 1e-3,
- max_iter=50, tol=1e-3,
- verbose=True)
- np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
-
- Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
- [ot.unif(ns), ot.unif(nt)
- ], ot.unif(n_samples), [.5, .5],
- 'kl_loss', 1e-3,
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
-
-
-def test_fgw():
+ p = ot.unif(n_samples)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Cb = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ )
+ Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42
+ ))
+ np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
+ np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+
+ Cb2 = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ )
+ Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
+ ))
+ np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
+ np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+
+
+def test_fgw(nx):
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
@@ -260,33 +350,46 @@ def test_fgw():
M = ot.dist(ys, yt)
M /= M.max()
+ Mb = nx.from_numpy(M)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ pb = nx.from_numpy(p)
+ qb = nx.from_numpy(q)
+
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+ Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ Gb = nx.to_numpy(Gb)
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence fgw
+ p, Gb.sum(1), atol=1e-04) # cf convergence fgw
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence fgw
+ q, Gb.sum(0), atol=1e-04) # cf convergence fgw
Id = (1 / (1.0 * n_samples)) * np.eye(n_samples, n_samples)
np.testing.assert_allclose(
- G, np.flipud(Id), atol=1e-04) # cf convergence gromov
+ Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov
fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
+ fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
G = log['T']
+ Gb = nx.to_numpy(logb['T'])
- np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
+ np.testing.assert_allclose(fgw, fgwb, atol=1e-08)
+ np.testing.assert_allclose(fgwb, 0, atol=1e-1, rtol=1e-1)
# check constraints
+ np.testing.assert_allclose(G, Gb, atol=1e-06)
np.testing.assert_allclose(
- p, G.sum(1), atol=1e-04) # cf convergence gromov
+ p, Gb.sum(1), atol=1e-04) # cf convergence gromov
np.testing.assert_allclose(
- q, G.sum(0), atol=1e-04) # cf convergence gromov
+ q, Gb.sum(0), atol=1e-04) # cf convergence gromov
-def test_fgw_barycenter():
+def test_fgw_barycenter(nx):
np.random.seed(42)
ns = 50
@@ -300,30 +403,44 @@ def test_fgw_barycenter():
C1 = ot.dist(Xs)
C2 = ot.dist(Xt)
-
+ p1, p2 = ot.unif(ns), ot.unif(nt)
n_samples = 3
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ p = ot.unif(n_samples)
+
+ ysb = nx.from_numpy(ys)
+ ytb = nx.from_numpy(yt)
+ C1b = nx.from_numpy(C1)
+ C2b = nx.from_numpy(C2)
+ p1b = nx.from_numpy(p1)
+ p2b = nx.from_numpy(p2)
+ pb = nx.from_numpy(p)
+
+ Xb, Cb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False,
+ fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345
+ )
xalea = np.random.randn(n_samples, 2)
init_C = ot.dist(xalea, xalea)
-
- X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], ps=[ot.unif(ns), ot.unif(nt)], lambdas=[.5, .5], alpha=0.5,
- fixed_structure=True, init_C=init_C, fixed_features=False,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ init_Cb = nx.from_numpy(init_C)
+
+ Xb, Cb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5],
+ alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
+ p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3
+ )
+ Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))
init_X = np.random.randn(n_samples, ys.shape[1])
-
- X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5,
- fixed_structure=False, fixed_features=True, init_X=init_X,
- p=ot.unif(n_samples), loss_fun='square_loss',
- max_iter=100, tol=1e-3, log=True)
- np.testing.assert_allclose(C.shape, (n_samples, n_samples))
- np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
+ init_Xb = nx.from_numpy(init_X)
+
+ Xb, Cb, logb = ot.gromov.fgw_barycenters(
+ n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
+ fixed_structure=False, fixed_features=True, init_X=init_Xb,
+ p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765
+ )
+ Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb)
+ np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
+ np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1]))