diff options
-rw-r--r-- | ot/gromov.py | 2 | ||||
-rw-r--r-- | test/test_gromov.py | 31 | ||||
-rw-r--r-- | test/test_plot.py | 2 |
3 files changed, 34 insertions, 1 deletions
diff --git a/ot/gromov.py b/ot/gromov.py index b1e9ee0..2a23873 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -307,6 +307,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs): Print information along iterations
log : bool, optional
record log if True
+ **kwargs : dict
+ parameters can be directly pased to the ot.optim.cg solver
Returns
-------
diff --git a/test/test_gromov.py b/test/test_gromov.py index e808292..625e62a 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -28,7 +28,36 @@ def test_gromov(): C1 /= C1.max()
C2 /= C2.max()
- G = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+ G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss')
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_entropic_gromov():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ G = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4)
# check constratints
np.testing.assert_allclose(
diff --git a/test/test_plot.py b/test/test_plot.py index f7debee..a50ed14 100644 --- a/test/test_plot.py +++ b/test/test_plot.py @@ -12,6 +12,7 @@ matplotlib.use('Agg') def test_plot1D_mat(): import ot + import ot.plot n_bins = 100 # nb bins @@ -32,6 +33,7 @@ def test_plot1D_mat(): def test_plot2D_samples_mat(): import ot + import ot.plot n_bins = 50 # nb samples |