From c1ccfc45350f8db3fa78d91b84eb4286bcf36e69 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Mon, 9 May 2022 14:30:04 +0200 Subject: [MRG] Fix barycenter mass (#375) * fix transpose in sinkhorn barycenters * add test for assymetric cost barycenters * fix pep8 Co-authored-by: Hicham Janati --- ot/bregman.py | 2 +- test/test_bregman.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/ot/bregman.py b/ot/bregman.py index c06af2f..34dcadb 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1511,7 +1511,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, for ii in range(numItermax): - UKv = u * nx.dot(K, A / nx.dot(K, u)) + UKv = u * nx.dot(K.T, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv if ii % 10 == 1: diff --git a/test/test_bregman.py b/test/test_bregman.py index 6c37984..112bfca 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -490,6 +490,41 @@ def test_barycenter(nx, method, verbose, warn): ot.bregman.barycenter(A_nx, M_nx, reg, log=True) +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter_assymetric_cost(nx, method, verbose, warn): + n_bins = 20 # nb bins + + # Gaussian distributions + A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + + # creating matrix A containing all distributions + A = A[:, None] + + # assymetric loss matrix + normalization + rng = np.random.RandomState(42) + M = rng.randn(n_bins, n_bins) ** 2 + M /= M.max() + + A_nx, M_nx = nx.from_numpy(A, M) + reg = 1e-2 + + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter(A_nx, M_nx, reg, method=method) + else: + # wasserstein + bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) + + ot.bregman.barycenter(A_nx, M_nx, reg, log=True) + + @pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_log"], [True, False], [True, False])) -- cgit v1.2.3