summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorHongwei Jin <jinhw1989@gmail.com>2021-12-06 08:14:30 -0600
committerGitHub <noreply@github.com>2021-12-06 15:14:30 +0100
commitb3dc68feac355fa94c4237f4ecad65edc9f7a7e8 (patch)
tree5a14cdd17b89aef8b2a1461bfc1c01f8cc95c3b1 /test
parentca69658400dc2ef6a7d3e531acffcd107400085f (diff)
[MRG] Fix issue 317 (#318)
* Fix issue 317 * Update with docs and tests Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test')
-rw-r--r--test/test_gromov.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 5c181f2..38a7fd7 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -385,6 +385,20 @@ def test_gromov_barycenter(nx):
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+ # test of gromov_barycenters with `log` on
+ Cb_, err_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_, errb_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_ = nx.to_numpy(Cbb_)
+ np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
+
Cb2 = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'kl_loss', max_iter=100, tol=1e-3, random_state=42
@@ -396,6 +410,20 @@ def test_gromov_barycenter(nx):
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+ # test of gromov_barycenters with `log` on
+ Cb2_, err2_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_, err2b_ = ot.gromov.gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_ = nx.to_numpy(Cb2b_)
+ np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
+
@pytest.mark.filterwarnings("ignore:divide")
def test_gromov_entropic_barycenter(nx):
@@ -429,6 +457,20 @@ def test_gromov_entropic_barycenter(nx):
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
+ # test of entropic_gromov_barycenters with `log` on
+ Cb_, err_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cbb_ = nx.to_numpy(Cbb_)
+ np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
+
Cb2 = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
@@ -440,6 +482,20 @@ def test_gromov_entropic_barycenter(nx):
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))
+ # test of entropic_gromov_barycenters with `log` on
+ Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1, C2], [p1, p2], p, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters(
+ n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
+ 'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ )
+ Cb2b_ = nx.to_numpy(Cb2b_)
+ np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
+ np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
+
def test_fgw(nx):
n_samples = 50 # nb samples