diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-04-18 23:52:12 +0200 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2020-04-18 23:52:12 +0200 |
commit | f93c403b81b4ccb98bfad8e4ef30cdf0e7333f6c (patch) | |
tree | 7dd0112b0088151f2a3b3dee82d973e6ce3a9396 /src/python/test/test_wasserstein_distance.py | |
parent | 17aaa979e4cdfe5faed9b2750d452171de4b67e1 (diff) |
enable_autodiff for POT wasserstein_distance
Diffstat (limited to 'src/python/test/test_wasserstein_distance.py')
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 7e0d0f5f..5bec5bd3 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -73,14 +73,20 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat -def hera_wrap(delta): +def hera_wrap(**extra): def fun(*kargs,**kwargs): - return hera(*kargs,**kwargs,delta=delta) + return hera(*kargs,**kwargs,**extra) + return fun + +def pot_wrap(**extra): + def fun(*kargs,**kwargs): + return pot(*kargs,**kwargs,**extra) return fun def test_wasserstein_distance_pot(): _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) + _basic_wasserstein(pot_wrap(enable_autodiff=True), 1e-15, test_infinity=False, test_matching=False) def test_wasserstein_distance_hera(): - _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False) - _basic_wasserstein(hera_wrap(.1), .1, test_matching=False) + _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) + _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) |