From f93c403b81b4ccb98bfad8e4ef30cdf0e7333f6c Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Sat, 18 Apr 2020 23:52:12 +0200 Subject: enable_autodiff for POT wasserstein_distance --- src/python/test/test_wasserstein_distance.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) (limited to 'src/python/test/test_wasserstein_distance.py') diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 7e0d0f5f..5bec5bd3 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -73,14 +73,20 @@ def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True, test_mat -def hera_wrap(delta): +def hera_wrap(**extra): def fun(*kargs,**kwargs): - return hera(*kargs,**kwargs,delta=delta) + return hera(*kargs,**kwargs,**extra) + return fun + +def pot_wrap(**extra): + def fun(*kargs,**kwargs): + return pot(*kargs,**kwargs,**extra) return fun def test_wasserstein_distance_pot(): _basic_wasserstein(pot, 1e-15, test_infinity=False, test_matching=True) + _basic_wasserstein(pot_wrap(enable_autodiff=True), 1e-15, test_infinity=False, test_matching=False) def test_wasserstein_distance_hera(): - _basic_wasserstein(hera_wrap(1e-12), 1e-12, test_matching=False) - _basic_wasserstein(hera_wrap(.1), .1, test_matching=False) + _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) + _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) -- cgit v1.2.3