summaryrefslogtreecommitdiff
path: root/test/test_bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_bregman.py')
-rw-r--r--test/test_bregman.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/test/test_bregman.py b/test/test_bregman.py
index e128ea2..c674da6 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -879,6 +879,34 @@ def test_empirical_sinkhorn_divergence(nx):
ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
+def test_empirical_sinkhorn_divergence_gradient():
+ # Test sinkhorn divergence
+ n = 10
+ a = np.linspace(1, n, n)
+ a /= a.sum()
+ b = ot.unif(n)
+ X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1))
+
+ nx = ot.backend.TorchBackend()
+
+ ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t)
+
+ ab.requires_grad = True
+ bb.requires_grad = True
+ X_sb.requires_grad = True
+ X_tb.requires_grad = True
+
+ emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)
+
+ emp_sinkhorn_div.backward()
+
+ assert ab.grad is not None
+ assert bb.grad is not None
+ assert X_sb.grad is not None
+ assert X_tb.grad is not None
+
+
def test_stabilized_vs_sinkhorn_multidim(nx):
# test if stable version matches sinkhorn
# for multidimensional inputs