summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati100@gmail.com>2022-05-09 14:30:04 +0200
committerGitHub <noreply@github.com>2022-05-09 14:30:04 +0200
commitc1ccfc45350f8db3fa78d91b84eb4286bcf36e69 (patch)
treeab1086116414350346cc678d8e9c3c68a31224cf
parent726e84e1e9f2832ea5ad156f62a5e3636c1fd3d3 (diff)
[MRG] Fix barycenter mass (#375)
* fix transpose in sinkhorn barycenters * add test for assymetric cost barycenters * fix pep8 Co-authored-by: Hicham Janati <hicham.janati@inria.fr>
-rw-r--r--ot/bregman.py2
-rw-r--r--test/test_bregman.py35
2 files changed, 36 insertions, 1 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index c06af2f..34dcadb 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1511,7 +1511,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
for ii in range(numItermax):
- UKv = u * nx.dot(K, A / nx.dot(K, u))
+ UKv = u * nx.dot(K.T, A / nx.dot(K, u))
u = (u.T * geometricBar(weights, UKv)).T / UKv
if ii % 10 == 1:
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6c37984..112bfca 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -491,6 +491,41 @@ def test_barycenter(nx, method, verbose, warn):
@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter_assymetric_cost(nx, method, verbose, warn):
+ n_bins = 20 # nb bins
+
+ # Gaussian distributions
+ A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+
+ # creating matrix A containing all distributions
+ A = A[:, None]
+
+ # assymetric loss matrix + normalization
+ rng = np.random.RandomState(42)
+ M = rng.randn(n_bins, n_bins) ** 2
+ M /= M.max()
+
+ A_nx, M_nx = nx.from_numpy(A, M)
+ reg = 1e-2
+
+ if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter(A_nx, M_nx, reg, method=method)
+ else:
+ # wasserstein
+ bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
+
+ ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
+
+
+@pytest.mark.parametrize("method, verbose, warn",
product(["sinkhorn", "sinkhorn_log"],
[True, False], [True, False]))
def test_barycenter_debiased(nx, method, verbose, warn):