summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 285b95c9..6701c7ba 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -31,7 +31,7 @@ def test_proj_on_diag():
def test_offdiag():
diag = np.array([[0, 1], [3, 5], [2, np.inf], [3, np.inf], [-np.inf, 8], [-np.inf, 12], [-np.inf, -np.inf],
[np.inf, np.inf], [-np.inf, np.inf], [-np.inf, np.inf]])
- assert np.array_equal(_offdiag(diag), [[0, 1], [3, 5]])
+ assert np.array_equal(_offdiag(diag, enable_autodiff=False), [[0, 1], [3, 5]])
def test_handle_essential_parts():