From 9076f02903ba2fb9ea9fe704764a755cad8dcd63 Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Mon, 12 Jun 2023 12:01:48 +0200 Subject: [FEAT] Entropic gw/fgw/srgw/srfgw solvers (#455) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add entropic fgw + fgw bary + srgw + srfgw with tests * add exemples for entropic srgw - srfgw solvers * add PPA solvers for GW/FGW + complete previous commits * update readme * add tests * add examples + tests + warning in entropic solvers + releases * reduce testing runtimes for test_gromov * fix conflicts * optional marginals * improve coverage * gromov doc harmonization * fix pep8 * complete optional marginal for entropic srfgw --------- Co-authored-by: Rémi Flamary --- test/test_gromov.py | 796 +++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 728 insertions(+), 68 deletions(-) (limited to 'test') diff --git a/test/test_gromov.py b/test/test_gromov.py index 1beb818..13ff3fe 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -34,8 +34,11 @@ def test_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) - G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=G0b, verbose=True)) + G = ot.gromov.gromov_wasserstein( + C1, C2, None, q, 'square_loss', G0=G0, verbose=True, + alpha_min=0., alpha_max=1.) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein( + C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True)) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -48,8 +51,8 @@ def test_gromov(nx): np.testing.assert_allclose(Gb, np.flipud(Id), atol=1e-04) - gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, log=True) - gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', armijo=True, log=True) + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, None, q, 'kl_loss', armijo=True, log=True) + gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, None, 'kl_loss', armijo=True, log=True) gwb = nx.to_numpy(gwb) gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', armijo=True, G0=G0, log=False) @@ -312,11 +315,11 @@ def test_entropic_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G, log = ot.gromov.entropic_gromov_wasserstein( - C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-2, verbose=True, log=True) + C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-2, max_iter=10, verbose=True, log=True) Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, - epsilon=1e-2, verbose=True, log=False + C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None, + epsilon=1e-2, max_iter=10, verbose=True, log=False )) # check constraints @@ -327,10 +330,10 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw, log = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, + C1, C2, p, None, 'kl_loss', symmetric=True, G0=None, max_iter=10, epsilon=1e-2, log=True) gwb, logb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b, max_iter=10, epsilon=1e-2, log=True) gwb = nx.to_numpy(gwb) @@ -348,6 +351,65 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_entropic_proximal_gromov(nx): + n_samples = 10 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + G, log = ot.gromov.entropic_gromov_wasserstein( + C1, C2, None, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=50, solver='PPA', verbose=True, log=True, numItermax=1) + Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None, + epsilon=1e-1, max_iter=50, solver='PPA', verbose=True, log=False, numItermax=1 + )) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + gw, log = ot.gromov.entropic_gromov_wasserstein2( + C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, + max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + gwb, logb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=10, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + gwb = nx.to_numpy(gwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(gw, gwb, atol=1e-06) + np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") def test_asymmetric_entropic_gromov(nx): n_samples = 10 # nb samples np.random.seed(0) @@ -363,10 +425,10 @@ def test_asymmetric_entropic_gromov(nx): C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, - epsilon=1e-1, verbose=True, log=False) + epsilon=1e-1, max_iter=5, verbose=True, log=False) Gb = nx.to_numpy(ot.gromov.entropic_gromov_wasserstein( C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, - epsilon=1e-1, verbose=True, log=False + epsilon=1e-1, max_iter=5, verbose=True, log=False )) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -376,11 +438,11 @@ def test_asymmetric_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, - max_iter=10, epsilon=1e-1, log=False) + C1, C2, None, None, 'kl_loss', symmetric=False, G0=None, + max_iter=5, epsilon=1e-1, log=False) gwb = ot.gromov.entropic_gromov_wasserstein2( C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, - max_iter=10, epsilon=1e-1, log=False) + max_iter=5, epsilon=1e-1, log=False) gwb = nx.to_numpy(gwb) np.testing.assert_allclose(gw, gwb, atol=1e-06) @@ -414,15 +476,300 @@ def test_entropic_gromov_dtype_device(nx): C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp) - Gb = ot.gromov.entropic_gromov_wasserstein( - C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True - ) - gw_valb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True - ) + for solver in ['PGD', 'PPA']: + Gb = ot.gromov.entropic_gromov_wasserstein( + C1b, C2b, pb, qb, 'square_loss', epsilon=1e-1, max_iter=5, + solver=solver, verbose=True + ) + gw_valb = ot.gromov.entropic_gromov_wasserstein2( + C1b, C2b, pb, qb, 'square_loss', epsilon=1e-1, max_iter=5, + solver=solver, verbose=True + ) - nx.assert_same_dtype_device(C1b, Gb) - nx.assert_same_dtype_device(C1b, gw_valb) + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_fgw(nx): + n_samples = 10 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + G, log = ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, None, None, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, verbose=True, log=True) + Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, + epsilon=1e-1, max_iter=10, verbose=True, log=False + )) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2( + M, C1, C2, p, q, 'kl_loss', symmetric=True, G0=None, + max_iter=10, epsilon=1e-1, log=True) + fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=10, epsilon=1e-1, log=True) + fgwb = nx.to_numpy(fgwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(fgw, fgwb, atol=1e-06) + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +def test_entropic_proximal_fgw(nx): + n_samples = 10 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + G, log = ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=True, numItermax=1) + Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None, + epsilon=1e-1, max_iter=10, solver='PPA', verbose=True, log=False, numItermax=1 + )) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + fgw, log = ot.gromov.entropic_fused_gromov_wasserstein2( + M, C1, C2, p, None, 'kl_loss', symmetric=True, G0=None, + max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + fgwb, logb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=5, epsilon=1e-1, solver='PPA', warmstart=True, log=True) + fgwb = nx.to_numpy(fgwb) + + G = log['T'] + Gb = nx.to_numpy(logb['T']) + + np.testing.assert_allclose(fgw, fgwb, atol=1e-06) + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + +def test_asymmetric_entropic_fgw(nx): + n_samples = 10 # nb samples + np.random.seed(0) + C1 = np.random.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + np.random.shuffle(idx) + C2 = C1[idx, :][:, idx] + + ys = np.random.randn(n_samples, 2) + yt = ys[idx, :] + M = ot.dist(ys, yt) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + G = ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0, + max_iter=5, epsilon=1e-1, verbose=True, log=False) + Gb = nx.to_numpy(ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=False, G0=None, + max_iter=5, epsilon=1e-1, verbose=True, log=False + )) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose( + p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, Gb.sum(0), atol=1e-04) # cf convergence gromov + + fgw = ot.gromov.entropic_fused_gromov_wasserstein2( + M, C1, C2, p, q, 'kl_loss', symmetric=False, G0=None, + max_iter=5, epsilon=1e-1, log=False) + fgwb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b, + max_iter=5, epsilon=1e-1, log=False) + fgwb = nx.to_numpy(fgwb) + + np.testing.assert_allclose(fgw, fgwb, atol=1e-06) + np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_fgw_dtype_device(nx): + # setup + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + Mb, C1b, C2b, pb, qb = nx.from_numpy(M, C1, C2, p, q, type_as=tp) + + for solver in ['PGD', 'PPA']: + Gb = ot.gromov.entropic_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, qb, 'square_loss', epsilon=0.1, max_iter=5, + solver=solver, verbose=True + ) + fgw_valb = ot.gromov.entropic_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, qb, 'square_loss', epsilon=0.1, max_iter=5, + solver=solver, verbose=True + ) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, fgw_valb) + + +def test_entropic_fgw_barycenter(nx): + ns = 5 + nt = 10 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + ys = np.random.randn(Xs.shape[0], 2) + yt = np.random.randn(Xt.shape[0], 2) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + p1 = ot.unif(ns) + p2 = ot.unif(nt) + n_samples = 2 + p = ot.unif(n_samples) + + ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) + + X, C, log = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], 'square_loss', 0.1, + max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42, + solver='PPA', numItermax=1, log=True + ) + Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', 0.1, + max_iter=10, tol=1e-3, verbose=False, warmstartT=True, random_state=42, + solver='PPA', numItermax=1, log=False) + Xb, Cb = nx.to_numpy(Xb, Cb) + + np.testing.assert_allclose(C, Cb, atol=1e-06) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X, Xb, atol=1e-06) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + + # test with 'kl_loss' and log=True + # providing init_C, init_Y + generator = ot.utils.check_random_state(42) + xalea = generator.randn(n_samples, 2) + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + + init_Y = np.zeros((n_samples, ys.shape[1]), dtype=ys.dtype) + init_Yb = nx.from_numpy(init_Y) + + X, C, log = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], p, None, 'kl_loss', 0.1, + max_iter=10, tol=1e-3, verbose=False, warmstartT=False, random_state=42, + solver='PPA', numItermax=1, init_C=init_C, init_Y=init_Y, log=True + ) + Xb, Cb, logb = ot.gromov.entropic_fused_gromov_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', 0.1, + max_iter=10, tol=1e-3, verbose=False, warmstartT=False, random_state=42, + solver='PPA', numItermax=1, init_C=init_Cb, init_Y=init_Yb, log=True) + Xb, Cb = nx.to_numpy(Xb, Cb) + + np.testing.assert_allclose(C, Cb, atol=1e-06) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X, Xb, atol=1e-06) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + np.testing.assert_array_almost_equal(log['err_feature'], nx.to_numpy(*logb['err_feature'])) + np.testing.assert_array_almost_equal(log['err_structure'], nx.to_numpy(*logb['err_structure'])) def test_pointwise_gromov(nx): @@ -539,11 +886,11 @@ def test_gromov_barycenter(nx): C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) Cb = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], + n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42 ) Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42 )) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) @@ -551,12 +898,12 @@ def test_gromov_barycenter(nx): # test of gromov_barycenters with `log` on Cb_, err_ = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True + n_samples, [C1, C2], [p1, p2], p, None, 'square_loss', max_iter=100, + tol=1e-3, verbose=False, warmstartT=True, random_state=42, log=True ) Cbb_, errb_ = ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'square_loss', max_iter=100, + tol=1e-3, verbose=False, warmstartT=True, random_state=42, log=True ) Cbb_ = nx.to_numpy(Cbb_) np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) @@ -565,23 +912,31 @@ def test_gromov_barycenter(nx): Cb2 = ot.gromov.gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + 'kl_loss', max_iter=100, tol=1e-3, warmstartT=True, random_state=42 ) Cb2b = nx.to_numpy(ot.gromov.gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, random_state=42 + 'kl_loss', max_iter=100, tol=1e-3, warmstartT=True, random_state=42 )) np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) # test of gromov_barycenters with `log` on + # providing init_C + generator = ot.utils.check_random_state(42) + xalea = generator.randn(n_samples, 2) + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + Cb2_, err2_ = ot.gromov.gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True + n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', max_iter=100, + tol=1e-3, verbose=False, random_state=42, log=True, init_C=init_C ) Cb2b_, err2b_ = ot.gromov.gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], 'kl_loss', + max_iter=100, tol=1e-3, verbose=True, random_state=42, + init_C=init_Cb, log=True ) Cb2b_ = nx.to_numpy(Cb2b_) np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) @@ -607,24 +962,24 @@ def test_gromov_entropic_barycenter(nx): C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) Cb = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', 1e-3, + max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 ) Cbb = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'square_loss', 1e-3, max_iter=50, tol=1e-3, verbose=True, random_state=42 + n_samples, [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', 1e-3, + max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42 )) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) # test of entropic_gromov_barycenters with `log` on Cb_, err_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + n_samples, [C1, C2], [p1, p2], p, None, + 'square_loss', 1e-3, max_iter=10, tol=1e-3, verbose=True, random_state=42, log=True ) Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'square_loss', 1e-3, max_iter=10, tol=1e-3, verbose=True, random_state=42, log=True ) Cbb_ = nx.to_numpy(Cbb_) np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) @@ -633,23 +988,32 @@ def test_gromov_entropic_barycenter(nx): Cb2 = ot.gromov.entropic_gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42 ) Cb2b = nx.to_numpy(ot.gromov.entropic_gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 + 'kl_loss', 1e-3, max_iter=10, tol=1e-3, random_state=42 )) np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) # test of entropic_gromov_barycenters with `log` on + # providing init_C + generator = ot.utils.check_random_state(42) + xalea = generator.randn(n_samples, 2) + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', 1e-3, + max_iter=10, tol=1e-3, warmstartT=True, verbose=True, random_state=42, + init_C=init_C, log=True ) Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'kl_loss', 1e-3, max_iter=10, tol=1e-3, warmstartT=True, verbose=True, + random_state=42, init_Cb=init_Cb, log=True ) Cb2b_ = nx.to_numpy(Cb2b_) np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) @@ -685,8 +1049,8 @@ def test_fgw(nx): Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, None, q, 'square_loss', alpha=0.5, armijo=True, symmetric=None, G0=G0, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, None, 'square_loss', alpha=0.5, armijo=True, symmetric=True, G0=G0b, log=True) Gb = nx.to_numpy(Gb) # check constraints @@ -701,8 +1065,8 @@ def test_fgw(nx): np.testing.assert_allclose( Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True) + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, None, 'square_loss', armijo=True, symmetric=True, G0=None, alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, None, qb, 'square_loss', armijo=True, symmetric=None, G0=G0b, alpha=0.5, log=True) fgwb = nx.to_numpy(fgwb) G = log['T'] @@ -923,6 +1287,9 @@ def test_fgw_barycenter(nx): C1 = ot.dist(Xs) C2 = ot.dist(Xt) + C1 /= C1.max() + C2 /= C2.max() + p1, p2 = ot.unif(ns), ot.unif(nt) n_samples = 3 p = ot.unif(n_samples) @@ -930,18 +1297,19 @@ def test_fgw_barycenter(nx): ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, + n_samples, [ysb, ytb], [C1b, C2b], None, [.5, .5], 0.5, fixed_structure=False, fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345 ) xalea = np.random.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) + init_C /= init_C.max() init_Cb = nx.from_numpy(init_C) Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=[.5, .5], + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, - p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3 + p=None, loss_fun='square_loss', max_iter=100, tol=1e-3 ) Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) @@ -953,11 +1321,21 @@ def test_fgw_barycenter(nx): Xb, Cb, logb = ot.gromov.fgw_barycenters( n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, fixed_features=True, init_X=init_Xb, - p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, log=True, random_state=98765 + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True ) - Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) - np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) - np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + # add test with 'kl_loss' + X, C = ot.gromov.fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', + max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True, random_state=12345 + ) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) def test_gromov_wasserstein_linear_unmixing(nx): @@ -1501,8 +1879,11 @@ def test_semirelaxed_gromov(nx): # asymmetric C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) - Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=False, log=True, G0=None) + G, log = ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.semirelaxed_gromov_wasserstein( + C1b, C2b, None, loss_fun='square_loss', symmetric=False, log=True, + G0=None, alpha_min=0., alpha_max=1.) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -1510,8 +1891,10 @@ def test_semirelaxed_gromov(nx): np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) - srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1, C2, None, loss_fun='square_loss', symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) G = log2['T'] Gb = nx.to_numpy(logb2['T']) @@ -1527,16 +1910,20 @@ def test_semirelaxed_gromov(nx): C1 = 0.5 * (C1 + C1.T) C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None) - Gb = ot.gromov.semirelaxed_gromov_wasserstein(C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b) + G, log = ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None) + Gb = ot.gromov.semirelaxed_gromov_wasserstein( + C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun='square_loss', symmetric=None, log=True, G0=None) srgw_ = ot.gromov.semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=True, log=False, G0=G0) @@ -1661,7 +2048,7 @@ def test_semirelaxed_fgw(nx): # asymmetric Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b) # check constraints @@ -1670,7 +2057,7 @@ def test_semirelaxed_fgw(nx): np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) G = log2['T'] Gb = nx.to_numpy(logb2['T']) @@ -1819,3 +2206,276 @@ def test_srfgw_helper_backend(nx): res, log = ot.optim.semirelaxed_cg(pb, qb, (1 - alpha) * Mb, alpha, f, df, Gb, line_search, log=True, numItermax=1e4, stopThr=1e-9, stopThr2=1e-9) # check constraints np.testing.assert_allclose(res, Gb, atol=1e-06) + + +def test_entropic_semirelaxed_gromov(nx): + np.random.seed(0) + # unbalanced proportions + list_n = [30, 15] + nt = 2 + ns = np.sum(list_n) + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns), dtype=np.float64) + C2 = np.array([[0.8, 0.05], + [0.05, 1.]], dtype=np.float64) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + p = ot.unif(ns, type_as=C1) + q0 = ot.unif(C2.shape[0], type_as=C1) + G0 = p[:, None] * q0[None, :] + # asymmetric + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + epsilon = 0.1 + G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=G0) + Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=False, log=True, G0=None) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, None, loss_fun='square_loss', epsilon=epsilon, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + + # symmetric + C1 = 0.5 * (C1 + C1.T) + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + + G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0) + + G = log2['T'] + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + + np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_semirelaxed_gromov_dtype_device(nx): + # setup + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + C1b, C2b, pb = nx.from_numpy(C1, C2, p, type_as=tp) + + Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True + ) + gw_valb = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True + ) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, gw_valb) + + +def test_entropic_semirelaxed_fgw(nx): + np.random.seed(0) + list_n = [16, 8] + nt = 2 + ns = 24 + # create directed sbm with C2 as connectivity matrix + C1 = np.zeros((ns, ns)) + C2 = np.array([[0.7, 0.05], + [0.05, 0.9]]) + for i in range(nt): + for j in range(nt): + ni, nj = list_n[i], list_n[j] + xij = np.random.binomial(size=(ni, nj), n=1, p=C2[i, j]) + C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + F1 = np.zeros((ns, 1)) + F1[:16] = np.random.normal(loc=0., scale=0.01, size=(16, 1)) + F1[16:] = np.random.normal(loc=1., scale=0.01, size=(8, 1)) + F2 = np.zeros((2, 1)) + F2[1, :] = 1. + M = (F1 ** 2).dot(np.ones((1, nt))) + np.ones((ns, 1)).dot((F2 ** 2).T) - 2 * F1.dot(F2.T) + + p = ot.unif(ns) + q0 = ot.unif(C2.shape[0]) + G0 = p[:, None] * q0[None, :] + + # asymmetric + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + + G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + Gb, logb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + + # symmetric + C1 = 0.5 * (C1 + C1.T) + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + + G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + + srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0) + + G = log2['T'] + Gb = nx.to_numpy(logb2['T']) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + + np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) + np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + + +@pytest.skip_backend("tf", reason="test very slow with tf backend") +def test_entropic_semirelaxed_fgw_dtype_device(nx): + # setup + n_samples = 5 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + M = ot.dist(ys, yt) + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + Mb, C1b, C2b, pb = nx.from_numpy(M, C1, C2, p, type_as=tp) + + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True + ) + fgw_valb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, 'square_loss', epsilon=0.1, verbose=True + ) + + nx.assert_same_dtype_device(C1b, Gb) + nx.assert_same_dtype_device(C1b, fgw_valb) + + +def test_not_implemented_solver(): + # test sinkhorn + n_samples = 5 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + xt = xs[::-1].copy() + + ys = np.random.randn(xs.shape[0], 2) + yt = ys[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + M = ot.dist(ys, yt) + + solver = 'not_implemented' + # entropic gw and fgw + with pytest.raises(ValueError): + ot.gromov.entropic_gromov_wasserstein( + C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) + with pytest.raises(ValueError): + ot.gromov.entropic_fused_gromov_wasserstein( + M, C1, C2, p, q, 'square_loss', epsilon=1e-1, solver=solver) + + # exact and entropic srgw and srfgw loss functions + loss_fun = 'kl_loss' + with pytest.raises(NotImplementedError): + ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun, armijo=False) + with pytest.raises(NotImplementedError): + ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun, epsilon=0.1) + with pytest.raises(NotImplementedError): + ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun) + with pytest.raises(NotImplementedError): + ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p, loss_fun, epsilon=0.1) -- cgit v1.2.3