summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-08-17 17:36:20 +0200
committerGitHub <noreply@github.com>2022-08-17 17:36:20 +0200
commita24ec2554d003778de2c3e0550a214f83b984fa9 (patch)
tree22fb2ead9d78719274ac4539f69a87b7f1f9c300
parentf2aaf401192dd1a9a14ee273d58466b5468f30a8 (diff)
[MRG] Debug sinkhorn divergence bug and add proper test (#394)
* skip tets if not torch installed * update release.md
-rw-r--r--RELEASES.md1
-rw-r--r--test/test_bregman.py1
2 files changed, 2 insertions, 0 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 571cd74..e6e0ba4 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -9,6 +9,7 @@
#### Closed issues
+- Fixed an issue where sinkhorn divergence did not have a gradients (Issue #393, PR #394)
- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU
(Issue #371, PR #373)
- Fixed an issue where Sinkhorn solver assumed a symmetric cost matrix (Issue #374, PR #375)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index c674da6..0f47c3f 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -879,6 +879,7 @@ def test_empirical_sinkhorn_divergence(nx):
ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
+@pytest.mark.skipif(not torch, reason="No torch available")
def test_empirical_sinkhorn_divergence_gradient():
# Test sinkhorn divergence
n = 10