summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/gromov.py2
-rw-r--r--test/test_gromov.py31
-rw-r--r--test/test_plot.py2
3 files changed, 34 insertions, 1 deletions
diff --git a/ot/gromov.py b/ot/gromov.py
index b1e9ee0..2a23873 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -307,6 +307,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, **kwargs):
Print information along iterations
log : bool, optional
record log if True
+ **kwargs : dict
+ parameters can be directly pased to the ot.optim.cg solver
Returns
-------
diff --git a/test/test_gromov.py b/test/test_gromov.py
index e808292..625e62a 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -28,7 +28,36 @@ def test_gromov():
C1 /= C1.max()
C2 /= C2.max()
- G = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+ G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss')
+
+ # check constratints
+ np.testing.assert_allclose(
+ p, G.sum(1), atol=1e-04) # cf convergence gromov
+ np.testing.assert_allclose(
+ q, G.sum(0), atol=1e-04) # cf convergence gromov
+
+
+def test_entropic_gromov():
+ n_samples = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
+
+ xt = xs[::-1].copy()
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ C1 = ot.dist(xs, xs)
+ C2 = ot.dist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ G = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4)
# check constratints
np.testing.assert_allclose(
diff --git a/test/test_plot.py b/test/test_plot.py
index f7debee..a50ed14 100644
--- a/test/test_plot.py
+++ b/test/test_plot.py
@@ -12,6 +12,7 @@ matplotlib.use('Agg')
def test_plot1D_mat():
import ot
+ import ot.plot
n_bins = 100 # nb bins
@@ -32,6 +33,7 @@ def test_plot1D_mat():
def test_plot2D_samples_mat():
import ot
+ import ot.plot
n_bins = 50 # nb samples