diff options
author | tvayer <titouan.vayer@gmail.com> | 2019-05-29 14:24:05 +0200 |
---|---|---|
committer | tvayer <titouan.vayer@gmail.com> | 2019-05-29 14:24:05 +0200 |
commit | 63bbeb34e48f02c97a762dab5232158d90a5cffc (patch) | |
tree | 853026b5854b6e4b01fdf750db139985b3dd596f /test/test_ot.py | |
parent | f70aabfcc11f92181e0dc987b341bad8ec030d75 (diff) | |
parent | f66ab58c7c895011fd37bafd3e848828399c56c4 (diff) |
Merge remote-tracking branch 'rflamary/master'
merge pot
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 399e549..7652394 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -70,7 +70,7 @@ def test_emd_empty(): def test_emd2_multi(): - n = 1000 # nb bins + n = 500 # nb bins # bin positions x = np.arange(n, dtype=np.float64) @@ -78,7 +78,7 @@ def test_emd2_multi(): # Gaussian distributions a = gauss(n, m=20, s=5) # m= mean, s= std - ls = np.arange(20, 1000, 20) + ls = np.arange(20, 500, 20) nb = len(ls) b = np.zeros((n, nb)) for i in range(nb): @@ -135,6 +135,21 @@ def test_lp_barycenter(): np.testing.assert_allclose(bary.sum(), 1) +def test_free_support_barycenter(): + + measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + measures_weights = [np.array([1.]), np.array([1.])] + + X_init = np.array([-12.]).reshape((1, 1)) + + # obvious barycenter location between two diracs + bar_locations = np.array([0.]).reshape((1, 1)) + + X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) + + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): @@ -192,11 +207,11 @@ def test_warnings(): def test_dual_variables(): - n = 5000 # nb bins - m = 6000 # nb bins + n = 500 # nb bins + m = 600 # nb bins - mean1 = 1000 - mean2 = 1100 + mean1 = 300 + mean2 = 400 # bin positions x = np.arange(n, dtype=np.float64) |