diff options
author | Vivien Seguy <vivienseguy@h40.57.229.10.1016746.vlan.kuins.net> | 2018-07-09 16:58:00 +0900 |
---|---|---|
committer | Vivien Seguy <vivienseguy@h40.57.229.10.1016746.vlan.kuins.net> | 2018-07-09 16:58:00 +0900 |
commit | 08e5c0a91bc9b4afcd375109c08a14bcb0b1bd51 (patch) | |
tree | fb9f81a6d8f0b8c6134157a5833cd3db3ea044ce | |
parent | 46712790c1276f1ecb3496362a8117e153782ede (diff) |
add test free support barycenter algorithm + cleaning
-rw-r--r-- | examples/plot_free_support_barycenter.py | 6 | ||||
-rw-r--r-- | test/test_ot.py | 6 |
2 files changed, 6 insertions, 6 deletions
diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py index b2e62c8..b6efc59 100644 --- a/examples/plot_free_support_barycenter.py +++ b/examples/plot_free_support_barycenter.py @@ -48,9 +48,9 @@ for i in range(N): # Compute free support barycenter # ------------- -k = 10 # number of Diracs of the barycenter -X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations -b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) +k = 10 # number of Diracs of the barycenter +X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations +b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) diff --git a/test/test_ot.py b/test/test_ot.py index dafc03f..45e777a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -137,13 +137,13 @@ def test_lp_barycenter(): def test_free_support_barycenter(): - measures_locations = [np.array([-1.]).reshape((1,1)), np.array([1.]).reshape((1,1))] + measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] measures_weights = [np.array([1.]), np.array([1.])] - X_init = np.array([-12.]).reshape((1,1)) + X_init = np.array([-12.]).reshape((1, 1)) # obvious barycenter location between two diracs - bar_locations = np.array([0.]).reshape((1,1)) + bar_locations = np.array([0.]).reshape((1, 1)) X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) |