summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_distance.py
diff options
context:
space:
mode:
authorMarc Glisse <marc.glisse@inria.fr>2020-04-18 23:52:12 +0200
committerMarc Glisse <marc.glisse@inria.fr>2020-04-18 23:52:12 +0200
commitf93c403b81b4ccb98bfad8e4ef30cdf0e7333f6c (patch)
tree7dd0112b0088151f2a3b3dee82d973e6ce3a9396 /src/python/test/test_wasserstein_distance.py
parent17aaa979e4cdfe5faed9b2750d452171de4b67e1 (diff)
enable_autodiff for POT wasserstein_distance
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_distance.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py
index 7e0d0f5f..5bec5bd3 100755
--- a/src/python/test/test_wasserstein_distance.py
+++ b/src/python/test/test_wasserstein_distance.py
@@ -73,14 +73,20 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat
-def hera_wrap(delta):
+def hera_wrap(**extra):
def fun(*kargs,**kwargs):
- return hera(*kargs,**kwargs,delta=delta)
+ return hera(*kargs,**kwargs,**extra)
+ return fun
+
+def pot_wrap(**extra):
+ def fun(*kargs,**kwargs):
+ return pot(*kargs,**kwargs,**extra)
return fun
def test_wasserstein_distance_pot():
_basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True)
+ _basic_wasserstein(pot_wrap(enable_autodiff=True), 1e-15, test_infinity=False, test_matching=False)
def test_wasserstein_distance_hera():
- _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False)
- _basic_wasserstein(hera_wrap(.1), .1, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False)
+ _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False)