From b3dc68feac355fa94c4237f4ecad65edc9f7a7e8 Mon Sep 17 00:00:00 2001 From: Hongwei Jin Date: Mon, 6 Dec 2021 08:14:30 -0600 Subject: [MRG] Fix issue 317 (#318) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix issue 317 * Update with docs and tests Co-authored-by: RĂ©mi Flamary --- ot/gromov.py | 24 +++++++++++++---------- test/test_gromov.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index 2a70070..dc95c74 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -1368,6 +1368,8 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, ------- C : array-like, shape (`N`, `N`) Similarity matrix in the barycenter space (permutated arbitrarily) + log : dict + Log dictionary of error during iterations. Return only if `log=True` in parameters. References ---------- @@ -1401,7 +1403,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Cprev = C T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, - max_iter, 1e-4, verbose, log) for s in range(S)] + max_iter, 1e-4, verbose, log=False) for s in range(S)] if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) @@ -1414,9 +1416,6 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, err = nx.norm(C - Cprev) error.append(err) - if log: - log['err'].append(err) - if verbose: if cpt % 200 == 0: print('{:5s}|{:12s}'.format( @@ -1425,7 +1424,10 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, cpt += 1 - return C + if log: + return C, {"err": error} + else: + return C def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, @@ -1479,6 +1481,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, ------- C : array-like, shape (`N`, `N`) Similarity matrix in the barycenter space (permutated arbitrarily) + log : dict + Log dictionary of error during iterations. Return only if `log=True` in parameters. References ---------- @@ -1513,7 +1517,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, Cprev = C T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, - numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)] + numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=False) for s in range(S)] if loss_fun == 'square_loss': C = update_square_loss(p, lambdas, T, Cs) @@ -1526,9 +1530,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, err = nx.norm(C - Cprev) error.append(err) - if log: - log['err'].append(err) - if verbose: if cpt % 200 == 0: print('{:5s}|{:12s}'.format( @@ -1537,7 +1538,10 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, cpt += 1 - return C + if log: + return C, {"err": error} + else: + return C def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, diff --git a/test/test_gromov.py b/test/test_gromov.py index 5c181f2..38a7fd7 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -385,6 +385,20 @@ def test_gromov_barycenter(nx): np.testing.assert_allclose(Cb, Cbb, atol=1e-06) np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + # test of gromov_barycenters with `log` on + Cb_, err_ = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + ) + Cbb_, errb_ = ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + ) + Cbb_ = nx.to_numpy(Cbb_) + np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) + np.testing.assert_array_almost_equal(err_['err'], errb_['err']) + np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) + Cb2 = ot.gromov.gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', max_iter=100, tol=1e-3, random_state=42 @@ -396,6 +410,20 @@ def test_gromov_barycenter(nx): np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) + # test of gromov_barycenters with `log` on + Cb2_, err2_ = ot.gromov.gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + ) + Cb2b_, err2b_ = ot.gromov.gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + ) + Cb2b_ = nx.to_numpy(Cb2b_) + np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) + np.testing.assert_array_almost_equal(err2_['err'], err2_['err']) + np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) + @pytest.mark.filterwarnings("ignore:divide") def test_gromov_entropic_barycenter(nx): @@ -429,6 +457,20 @@ def test_gromov_entropic_barycenter(nx): np.testing.assert_allclose(Cb, Cbb, atol=1e-06) np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + # test of entropic_gromov_barycenters with `log` on + Cb_, err_ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + ) + Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + ) + Cbb_ = nx.to_numpy(Cbb_) + np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) + np.testing.assert_array_almost_equal(err_['err'], errb_['err']) + np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) + Cb2 = ot.gromov.entropic_gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], 'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42 @@ -440,6 +482,20 @@ def test_gromov_entropic_barycenter(nx): np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) + # test of entropic_gromov_barycenters with `log` on + Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], p, [.5, .5], + 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + ) + Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], + 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + ) + Cb2b_ = nx.to_numpy(Cb2b_) + np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) + np.testing.assert_array_almost_equal(err2_['err'], err2_['err']) + np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) + def test_fgw(nx): n_samples = 50 # nb samples -- cgit v1.2.3