From f2aaf401192dd1a9a14ee273d58466b5468f30a8 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Wed, 17 Aug 2022 17:15:36 +0200 Subject: debug sinkhorn divergence gradients --- ot/bregman.py | 11 +++++++---- test/test_bregman.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index b1321a4..4e1a25c 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -3173,8 +3173,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', return loss else: - M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) - M = nx.from_numpy(M, type_as=a) + M = dist(X_s, X_t, metric=metric) if log: sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, @@ -3287,6 +3286,10 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 ''' + X_s, X_t = list_to_array(X_s, X_t) + + nx = get_backend(X_s, X_t) + if log: sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, @@ -3313,7 +3316,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli log['log_sinkhorn_a'] = log_a log['log_sinkhorn_b'] = log_b - return max(0, sinkhorn_div), log + return nx.maximum(0, sinkhorn_div), log else: sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, @@ -3332,7 +3335,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) - return max(0, sinkhorn_div) + return nx.maximum(0, sinkhorn_div) def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, diff --git a/test/test_bregman.py b/test/test_bregman.py index e128ea2..c674da6 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -879,6 +879,34 @@ def test_empirical_sinkhorn_divergence(nx): ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True) +def test_empirical_sinkhorn_divergence_gradient(): + # Test sinkhorn divergence + n = 10 + a = np.linspace(1, n, n) + a /= a.sum() + b = ot.unif(n) + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1)) + + nx = ot.backend.TorchBackend() + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + ab.requires_grad = True + bb.requires_grad = True + X_sb.requires_grad = True + X_tb.requires_grad = True + + emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb) + + emp_sinkhorn_div.backward() + + assert ab.grad is not None + assert bb.grad is not None + assert X_sb.grad is not None + assert X_tb.grad is not None + + def test_stabilized_vs_sinkhorn_multidim(nx): # test if stable version matches sinkhorn # for multidimensional inputs -- cgit v1.2.3