summaryrefslogtreecommitdiff
path: root/src/python/test/test_wasserstein_with_tensors.py
diff options
context:
space:
mode:
authorROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-09-29 13:23:56 +0200
committerROUVREAU Vincent <vincent.rouvreau@inria.fr>2020-09-29 13:23:56 +0200
commite7b7947adf13ec1dcb8c126a4373fa29baaecb63 (patch)
treedcc3eea7fd2de15ba2eb6e6da91f115c38a8519b /src/python/test/test_wasserstein_with_tensors.py
parenta304049bdcfb03aa848d8049923ab796e0761b56 (diff)
Added tests for wasserstein distance with tensorflow
Diffstat (limited to 'src/python/test/test_wasserstein_with_tensors.py')
-rwxr-xr-xsrc/python/test/test_wasserstein_with_tensors.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/src/python/test/test_wasserstein_with_tensors.py b/src/python/test/test_wasserstein_with_tensors.py
new file mode 100755
index 00000000..8957705d
--- /dev/null
+++ b/src/python/test/test_wasserstein_with_tensors.py
@@ -0,0 +1,25 @@
+""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
+ See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
+ Author(s): Mathieu Carriere
+
+ Copyright (C) 2020 Inria
+
+ Modification(s):
+ - YYYY/MM Author: Description of the modification
+"""
+
+from gudhi.wasserstein import wasserstein_distance as pot
+import numpy as np
+
+def test_wasserstein_distance_grad_tensorflow():
+ import tensorflow as tf
+
+ with tf.GradientTape() as tape:
+ diag4 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[0., 10.]]), trainable=True))
+ diag5 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[1., 11.], [3., 4.]]), trainable=True))
+ dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True)
+ assert dist45 == 3.
+
+ grads = tape.gradient(dist45, [diag4, diag5])
+ assert np.array_equal(grads[0].values, [[-1., -1.]])
+ assert np.array_equal(grads[1].values, [[1., 1.], [-1., 1.]]) \ No newline at end of file