summaryrefslogtreecommitdiff
path: root/test/test_partial.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_partial.py')
-rwxr-xr-xtest/test_partial.py26
1 files changed, 25 insertions, 1 deletions
diff --git a/test/test_partial.py b/test/test_partial.py
index 8b1ca89..5960e4e 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -9,6 +9,30 @@ import numpy as np
import scipy as sp
import ot
+def test_partial_wasserstein_lagrange():
+
+ n_samples = 20 # nb samples (gaussian)
+ n_noise = 20 # nb of samples (noise)
+
+ mu = np.array([0, 0])
+ cov = np.array([[1, 0], [0, 2]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
+ xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+ xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
+
+ M = ot.dist(xs, xt)
+
+ p = ot.unif(n_samples + n_noise)
+ q = ot.unif(n_samples + n_noise)
+
+ m = 0.5
+
+ w0, log0 = ot.partial.partial_wasserstein_lagrange(p, q, M, 1, log=True)
+
+
+
def test_partial_wasserstein():
@@ -32,7 +56,7 @@ def test_partial_wasserstein():
w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=m, log=True)
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
- log=True)
+ log=True, verbose=True)
# check constratints
np.testing.assert_equal(