summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2022-08-17 17:15:36 +0200
committerRémi Flamary <remi.flamary@gmail.com>2022-08-17 17:15:36 +0200
commitf2aaf401192dd1a9a14ee273d58466b5468f30a8 (patch)
tree61edef0be917d1835585a5504aa7229c8ac3df55
parent0138dcf636c3f3f0e63110b08a8249f065e1fa73 (diff)
debug sinkhorn divergence gradients
-rw-r--r--ot/bregman.py11
-rw-r--r--test/test_bregman.py28
2 files changed, 35 insertions, 4 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index b1321a4..4e1a25c 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -3173,8 +3173,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
return loss
else:
- M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
- M = nx.from_numpy(M, type_as=a)
+ M = dist(X_s, X_t, metric=metric)
if log:
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
@@ -3287,6 +3286,10 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
International Conference on Artficial Intelligence and Statistics,
(AISTATS) 21, 2018
'''
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+
if log:
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
numIterMax=numIterMax,
@@ -3313,7 +3316,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
log['log_sinkhorn_a'] = log_a
log['log_sinkhorn_b'] = log_b
- return max(0, sinkhorn_div), log
+ return nx.maximum(0, sinkhorn_div), log
else:
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
@@ -3332,7 +3335,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
warn=warn, **kwargs)
sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
- return max(0, sinkhorn_div)
+ return nx.maximum(0, sinkhorn_div)
def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
diff --git a/test/test_bregman.py b/test/test_bregman.py
index e128ea2..c674da6 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -879,6 +879,34 @@ def test_empirical_sinkhorn_divergence(nx):
ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
+def test_empirical_sinkhorn_divergence_gradient():
+ # Test sinkhorn divergence
+ n = 10
+ a = np.linspace(1, n, n)
+ a /= a.sum()
+ b = ot.unif(n)
+ X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1))
+
+ nx = ot.backend.TorchBackend()
+
+ ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t)
+
+ ab.requires_grad = True
+ bb.requires_grad = True
+ X_sb.requires_grad = True
+ X_tb.requires_grad = True
+
+ emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)
+
+ emp_sinkhorn_div.backward()
+
+ assert ab.grad is not None
+ assert bb.grad is not None
+ assert X_sb.grad is not None
+ assert X_tb.grad is not None
+
+
def test_stabilized_vs_sinkhorn_multidim(nx):
# test if stable version matches sinkhorn
# for multidimensional inputs