diff options
author | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2020-10-30 14:53:48 +0100 |
---|---|---|
committer | ROUVREAU Vincent <vincent.rouvreau@inria.fr> | 2020-10-30 14:53:48 +0100 |
commit | 218be4f4c88f94c2c281cce16d5a9cd41fdd2fd8 (patch) | |
tree | 28480df9197e6a8bdf407cc5dff6c5eda9a10835 /src | |
parent | cf30dde33d2463172af32de208909f4638343bec (diff) | |
parent | 6b995c03793096459a333c907b606770113b96d7 (diff) |
Merge branch 'master' into coxeter_integration
Diffstat (limited to 'src')
-rw-r--r-- | src/Simplex_tree/include/gudhi/Simplex_tree.h | 31 | ||||
-rw-r--r-- | src/Simplex_tree/test/simplex_tree_unit_test.cpp | 51 | ||||
-rw-r--r-- | src/Subsampling/include/gudhi/choose_n_farthest_points.h | 7 | ||||
-rw-r--r-- | src/Subsampling/test/test_choose_n_farthest_points.cpp | 16 | ||||
-rw-r--r-- | src/cmake/modules/GUDHI_third_party_libraries.cmake | 1 | ||||
-rw-r--r-- | src/python/CMakeLists.txt | 11 | ||||
-rw-r--r-- | src/python/doc/installation.rst | 5 | ||||
-rw-r--r-- | src/python/gudhi/simplex_tree.pxd | 1 | ||||
-rw-r--r-- | src/python/gudhi/simplex_tree.pyx | 22 | ||||
-rw-r--r-- | src/python/gudhi/subsampling.pyx | 21 | ||||
-rwxr-xr-x | src/python/test/test_simplex_tree.py | 26 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 24 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_with_tensors.py | 47 |
13 files changed, 216 insertions, 47 deletions
diff --git a/src/Simplex_tree/include/gudhi/Simplex_tree.h b/src/Simplex_tree/include/gudhi/Simplex_tree.h index 889dbd00..85d6c3b0 100644 --- a/src/Simplex_tree/include/gudhi/Simplex_tree.h +++ b/src/Simplex_tree/include/gudhi/Simplex_tree.h @@ -1667,6 +1667,37 @@ class Simplex_tree { return sh; // None of its faces has the same filtration. } + public: + /** \brief This function resets the filtration value of all the simplices of dimension at least min_dim. Resets all + * the Simplex_tree when `min_dim = 0`. + * `reset_filtration` may break the filtration property with `min_dim > 0`, and it is the user's responsibility to + * make it a valid filtration (using a large enough `filt_value`, or calling `make_filtration_non_decreasing` + * afterwards for instance). + * @param[in] filt_value The new filtration value. + * @param[in] min_dim The minimal dimension. Default value is 0. + */ + void reset_filtration(Filtration_value filt_value, int min_dim = 0) { + rec_reset_filtration(&root_, filt_value, min_dim); + clear_filtration(); // Drop the cache. + } + + private: + /** \brief Recursively resets filtration value when minimal depth <= 0. + * @param[in] sib Siblings to be parsed. + * @param[in] filt_value The new filtration value. + * @param[in] min_depth The minimal depth. + */ + void rec_reset_filtration(Siblings * sib, Filtration_value filt_value, int min_depth) { + for (auto sh = sib->members().begin(); sh != sib->members().end(); ++sh) { + if (min_depth <= 0) { + sh->second.assign_filtration(filt_value); + } + if (has_children(sh)) { + rec_reset_filtration(sh->second.children(), filt_value, min_depth - 1); + } + } + } + private: Vertex_handle null_vertex_; /** \brief Total number of simplices in the complex, without the empty simplex.*/ diff --git a/src/Simplex_tree/test/simplex_tree_unit_test.cpp b/src/Simplex_tree/test/simplex_tree_unit_test.cpp index 9b5fa8fe..bdd41d34 100644 --- a/src/Simplex_tree/test/simplex_tree_unit_test.cpp +++ b/src/Simplex_tree/test/simplex_tree_unit_test.cpp @@ -940,3 +940,54 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(generators, typeST, list_of_tested_variants) { BOOST_CHECK(st.edge_with_same_filtration(st.find({1,5}))==st.find({1,5})); } } + +BOOST_AUTO_TEST_CASE_TEMPLATE(simplex_tree_reset_filtration, typeST, list_of_tested_variants) { + std::clog << "********************************************************************" << std::endl; + std::clog << "TEST RESET FILTRATION" << std::endl; + typeST st; + + st.insert_simplex_and_subfaces({2, 1, 0}, 3.); + st.insert_simplex_and_subfaces({3, 0}, 2.); + st.insert_simplex_and_subfaces({3, 4, 5}, 3.); + st.insert_simplex_and_subfaces({0, 1, 6, 7}, 4.); + + /* Inserted simplex: */ + /* 1 6 */ + /* o---o */ + /* /X\7/ */ + /* o---o---o---o */ + /* 2 0 3\X/4 */ + /* o */ + /* 5 */ + + for (auto f_simplex : st.skeleton_simplex_range(3)) { + std::clog << "vertex = ("; + for (auto vertex : st.simplex_vertex_range(f_simplex)) { + std::clog << vertex << ","; + } + std::clog << ") - filtration = " << st.filtration(f_simplex); + std::clog << " - dimension = " << st.dimension(f_simplex) << std::endl; + // Guaranteed by construction + BOOST_CHECK(st.filtration(f_simplex) >= 2.); + } + + // dimension until 5 even if simplex tree is of dimension 3 to test the limits + for(int dimension = 5; dimension >= 0; dimension --) { + std::clog << "### reset_filtration - dimension = " << dimension << "\n"; + st.reset_filtration(0., dimension); + for (auto f_simplex : st.skeleton_simplex_range(3)) { + std::clog << "vertex = ("; + for (auto vertex : st.simplex_vertex_range(f_simplex)) { + std::clog << vertex << ","; + } + std::clog << ") - filtration = " << st.filtration(f_simplex); + std::clog << " - dimension = " << st.dimension(f_simplex) << std::endl; + if (st.dimension(f_simplex) < dimension) + BOOST_CHECK(st.filtration(f_simplex) >= 2.); + else + BOOST_CHECK(st.filtration(f_simplex) == 0.); + } + } + +} + diff --git a/src/Subsampling/include/gudhi/choose_n_farthest_points.h b/src/Subsampling/include/gudhi/choose_n_farthest_points.h index 66421a69..b70af8a0 100644 --- a/src/Subsampling/include/gudhi/choose_n_farthest_points.h +++ b/src/Subsampling/include/gudhi/choose_n_farthest_points.h @@ -48,7 +48,8 @@ enum : std::size_t { * \tparam PointOutputIterator Output iterator whose value type is Kernel::Point_d. * \tparam DistanceOutputIterator Output iterator for distances. * \details It chooses `final_size` points from a random access range - * `input_pts` and outputs them in the output iterator `output_it`. It also + * `input_pts` (or the number of distinct points if `final_size` is larger) + * and outputs them in the output iterator `output_it`. It also * outputs the distance from each of those points to the set of previous * points in `dist_it`. * @param[in] k A kernel object. @@ -99,7 +100,7 @@ void choose_n_farthest_points(Kernel const &k, *dist_it++ = dist_to_L[curr_max_w]; std::size_t i = 0; for (auto&& p : input_pts) { - double curr_dist = sqdist(p, *(std::begin(input_pts) + curr_max_w)); + double curr_dist = sqdist(p, input_pts[curr_max_w]); if (curr_dist < dist_to_L[i]) dist_to_L[i] = curr_dist; ++i; @@ -111,6 +112,8 @@ void choose_n_farthest_points(Kernel const &k, curr_max_dist = dist_to_L[i]; curr_max_w = i; } + // If all that remains are duplicates of points already taken, stop. + if (curr_max_dist == 0) break; } } diff --git a/src/Subsampling/test/test_choose_n_farthest_points.cpp b/src/Subsampling/test/test_choose_n_farthest_points.cpp index 5c4bd4cb..b318d58e 100644 --- a/src/Subsampling/test/test_choose_n_farthest_points.cpp +++ b/src/Subsampling/test/test_choose_n_farthest_points.cpp @@ -39,7 +39,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(test_choose_farthest_point, Kernel, list_of_tested for (FT k = 0; k < 5; k += 1.0) for (FT l = 0; l < 5; l += 1.0) { std::vector<FT> point({i, j, k, l}); - points.push_back(Point_d(point.begin(), point.end())); + points.emplace_back(point.begin(), point.end()); } landmarks.clear(); @@ -75,7 +75,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(test_choose_farthest_point_limits, Kernel, list_of landmarks.clear(); distances.clear(); std::vector<FT> point({0.0, 0.0, 0.0, 0.0}); - points.push_back(Point_d(point.begin(), point.end())); + points.emplace_back(point.begin(), point.end()); // Choose -1 farthest points in a one point cloud Gudhi::subsampling::choose_n_farthest_points(k, points, -1, -1, std::back_inserter(landmarks), std::back_inserter(distances)); BOOST_CHECK(landmarks.size() == 1 && distances.size() == 1); @@ -92,8 +92,16 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(test_choose_farthest_point_limits, Kernel, list_of landmarks.clear(); distances.clear(); std::vector<FT> point2({1.0, 0.0, 0.0, 0.0}); - points.push_back(Point_d(point2.begin(), point2.end())); - // Choose all farthest points in a one point cloud + points.emplace_back(point2.begin(), point2.end()); + // Choose all farthest points among 2 points + Gudhi::subsampling::choose_n_farthest_points(k, points, -1, -1, std::back_inserter(landmarks), std::back_inserter(distances)); + BOOST_CHECK(landmarks.size() == 2 && distances.size() == 2); + BOOST_CHECK(distances[0] == std::numeric_limits<FT>::infinity()); + BOOST_CHECK(distances[1] == 1); + landmarks.clear(); distances.clear(); + + // Ignore duplicated points + points.emplace_back(point.begin(), point.end()); Gudhi::subsampling::choose_n_farthest_points(k, points, -1, -1, std::back_inserter(landmarks), std::back_inserter(distances)); BOOST_CHECK(landmarks.size() == 2 && distances.size() == 2); BOOST_CHECK(distances[0] == std::numeric_limits<FT>::infinity()); diff --git a/src/cmake/modules/GUDHI_third_party_libraries.cmake b/src/cmake/modules/GUDHI_third_party_libraries.cmake index 1fbc4244..9dadac4f 100644 --- a/src/cmake/modules/GUDHI_third_party_libraries.cmake +++ b/src/cmake/modules/GUDHI_third_party_libraries.cmake @@ -155,6 +155,7 @@ if( PYTHONINTERP_FOUND ) find_python_module("pykeops") find_python_module("eagerpy") find_python_module_no_version("hnswlib") + find_python_module("tensorflow") endif() if(NOT GUDHI_PYTHON_PATH) diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 4f26481e..c09996fe 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -103,6 +103,9 @@ if(PYTHONINTERP_FOUND) if(EAGERPY_FOUND) add_gudhi_debug_info("EagerPy version ${EAGERPY_VERSION}") endif() + if(TENSORFLOW_FOUND) + add_gudhi_debug_info("TensorFlow version ${TENSORFLOW_VERSION}") + endif() set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_RESULT_OF_USE_DECLTYPE', ") set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_ALL_NO_LIB', ") @@ -496,11 +499,17 @@ if(PYTHONINTERP_FOUND) # Wasserstein if(OT_FOUND AND PYBIND11_FOUND) - if(TORCH_FOUND AND EAGERPY_FOUND) + # EagerPy dependency because of enable_autodiff=True + if(EAGERPY_FOUND) add_gudhi_py_test(test_wasserstein_distance) endif() add_gudhi_py_test(test_wasserstein_barycenter) endif() + if(OT_FOUND) + if(TORCH_FOUND AND TENSORFLOW_FOUND AND EAGERPY_FOUND) + add_gudhi_py_test(test_wasserstein_with_tensors) + endif() + endif() # Representations if(SKLEARN_FOUND AND MATPLOTLIB_FOUND) diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst index 2f161d66..66efe45a 100644 --- a/src/python/doc/installation.rst +++ b/src/python/doc/installation.rst @@ -394,6 +394,11 @@ mathematics, science, and engineering. :class:`~gudhi.point_cloud.knn.KNearestNeighbors` can use the Python package `SciPy <http://scipy.org>`_ as a backend if explicitly requested. +TensorFlow +---------- + +`TensorFlow <https://www.tensorflow.org>`_ is currently only used in some automatic differentiation tests. + Bug reports and contributions ***************************** diff --git a/src/python/gudhi/simplex_tree.pxd b/src/python/gudhi/simplex_tree.pxd index 75e94e0b..3b494ba3 100644 --- a/src/python/gudhi/simplex_tree.pxd +++ b/src/python/gudhi/simplex_tree.pxd @@ -58,6 +58,7 @@ cdef extern from "Simplex_tree_interface.h" namespace "Gudhi": void compute_extended_filtration() nogil vector[vector[pair[int, pair[double, double]]]] compute_extended_persistence_subdiagrams(vector[pair[int, pair[double, double]]] dgm, double min_persistence) nogil Simplex_tree_interface_full_featured* collapse_edges(int nb_collapse_iteration) nogil + void reset_filtration(double filtration, int dimension) nogil # Iterators over Simplex tree pair[vector[int], double] get_simplex_and_filtration(Simplex_tree_simplex_handle f_simplex) nogil Simplex_tree_simplices_iterator get_simplices_iterator_begin() nogil diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index dfb1d985..910711a9 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -328,7 +328,7 @@ cdef class SimplexTree: return self.get_ptr().prune_above_filtration(filtration) def expansion(self, max_dim): - """Expands the Simplex_tree containing only its one skeleton + """Expands the simplex tree containing only its one skeleton until dimension max_dim. The expanded simplicial complex until dimension :math:`d` @@ -338,7 +338,7 @@ cdef class SimplexTree: The filtration value assigned to a simplex is the maximal filtration value of one of its edges. - The Simplex_tree must contain no simplex of dimension bigger than + The simplex tree must contain no simplex of dimension bigger than 1 when calling the method. :param max_dim: The maximal dimension. @@ -358,6 +358,20 @@ cdef class SimplexTree: """ return self.get_ptr().make_filtration_non_decreasing() + def reset_filtration(self, filtration, min_dim = 0): + """This function resets the filtration value of all the simplices of dimension at least min_dim. Resets all the + simplex tree when `min_dim = 0`. + `reset_filtration` may break the filtration property with `min_dim > 0`, and it is the user's responsibility to + make it a valid filtration (using a large enough `filt_value`, or calling `make_filtration_non_decreasing` + afterwards for instance). + + :param filtration: New threshold value. + :type filtration: float. + :param min_dim: The minimal dimension. Default value is 0. + :type min_dim: int. + """ + self.get_ptr().reset_filtration(filtration, min_dim) + def extend_filtration(self): """ Extend filtration for computing extended persistence. This function only uses the filtration values at the 0-dimensional simplices, and computes the extended persistence @@ -366,14 +380,14 @@ cdef class SimplexTree: .. note:: Note that after calling this function, the filtration - values are actually modified within the Simplex_tree. + values are actually modified within the simplex tree. The function :func:`extended_persistence` retrieves the original values. .. note:: Note that this code creates an extra vertex internally, so you should make sure that - the Simplex_tree does not contain a vertex with the largest possible value (i.e., 4294967295). + the simplex tree does not contain a vertex with the largest possible value (i.e., 4294967295). """ self.get_ptr().compute_extended_filtration() diff --git a/src/python/gudhi/subsampling.pyx b/src/python/gudhi/subsampling.pyx index f77c6f75..b11d07e5 100644 --- a/src/python/gudhi/subsampling.pyx +++ b/src/python/gudhi/subsampling.pyx @@ -33,7 +33,7 @@ def choose_n_farthest_points(points=None, off_file='', nb_points=0, starting_poi The iteration starts with the landmark `starting point`. :param points: The input point set. - :type points: Iterable[Iterable[float]]. + :type points: Iterable[Iterable[float]] Or @@ -42,14 +42,15 @@ def choose_n_farthest_points(points=None, off_file='', nb_points=0, starting_poi And in both cases - :param nb_points: Number of points of the subsample. - :type nb_points: unsigned. + :param nb_points: Number of points of the subsample (the subsample may be \ + smaller if there are fewer than nb_points distinct input points) + :type nb_points: int :param starting_point: The iteration starts with the landmark `starting \ - point`,which is the index of the point to start with. If not set, this \ + point`, which is the index of the point to start with. If not set, this \ index is chosen randomly. - :type starting_point: unsigned. + :type starting_point: int :returns: The subsample point set. - :rtype: List[List[float]]. + :rtype: List[List[float]] """ if off_file: if os.path.isfile(off_file): @@ -76,7 +77,7 @@ def pick_n_random_points(points=None, off_file='', nb_points=0): """Subsample a point set by picking random vertices. :param points: The input point set. - :type points: Iterable[Iterable[float]]. + :type points: Iterable[Iterable[float]] Or @@ -86,7 +87,7 @@ def pick_n_random_points(points=None, off_file='', nb_points=0): And in both cases :param nb_points: Number of points of the subsample. - :type nb_points: unsigned. + :type nb_points: int :returns: The subsample point set. :rtype: List[List[float]] """ @@ -107,7 +108,7 @@ def sparsify_point_set(points=None, off_file='', min_squared_dist=0.0): between any two points is greater than or equal to min_squared_dist. :param points: The input point set. - :type points: Iterable[Iterable[float]]. + :type points: Iterable[Iterable[float]] Or @@ -118,7 +119,7 @@ def sparsify_point_set(points=None, off_file='', min_squared_dist=0.0): :param min_squared_dist: Minimum squared distance separating the output \ points. - :type min_squared_dist: float. + :type min_squared_dist: float :returns: The subsample point set. :rtype: List[List[float]] """ diff --git a/src/python/test/test_simplex_tree.py b/src/python/test/test_simplex_tree.py index 83be0602..ac2b59c7 100755 --- a/src/python/test/test_simplex_tree.py +++ b/src/python/test/test_simplex_tree.py @@ -356,5 +356,27 @@ def test_collapse_edges(): st.collapse_edges() assert st.num_simplices() == 9 assert st.find([1, 3]) == False - for simplex in st.get_skeleton(0): - assert simplex[1] == 1. + for simplex in st.get_skeleton(0): + assert simplex[1] == 1. + +def test_reset_filtration(): + st = SimplexTree() + + assert st.insert([0, 1, 2], 3.) == True + assert st.insert([0, 3], 2.) == True + assert st.insert([3, 4, 5], 3.) == True + assert st.insert([0, 1, 6, 7], 4.) == True + + # Guaranteed by construction + for simplex in st.get_simplices(): + assert st.filtration(simplex[0]) >= 2. + + # dimension until 5 even if simplex tree is of dimension 3 to test the limits + for dimension in range(5, -1, -1): + st.reset_filtration(0., dimension) + for simplex in st.get_skeleton(3): + print(simplex) + if len(simplex[0]) < (dimension) + 1: + assert st.filtration(simplex[0]) >= 2. + else: + assert st.filtration(simplex[0]) == 0. diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 90d26809..e3b521d6 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -97,27 +97,3 @@ def test_wasserstein_distance_pot(): def test_wasserstein_distance_hera(): _basic_wasserstein(hera_wrap(delta=1e-12), 1e-12, test_matching=False) _basic_wasserstein(hera_wrap(delta=.1), .1, test_matching=False) - -def test_wasserstein_distance_grad(): - import torch - - diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True) - diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) - diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) - assert diag1.grad is None and diag2.grad is None and diag3.grad is None - dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) - dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) - dist12.backward() - dist30.backward() - assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() - diag4 = torch.tensor([[0., 10.]], requires_grad=True) - diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True) - dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) - assert dist45 == 3. - dist45.backward() - assert np.array_equal(diag4.grad, [[-1., -1.]]) - assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]]) - diag6 = torch.tensor([[5., 10.]], requires_grad=True) - pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward() - # https://github.com/jonasrauber/eagerpy/issues/6 - # assert np.array_equal(diag6.grad, [[0., 0.]]) diff --git a/src/python/test/test_wasserstein_with_tensors.py b/src/python/test/test_wasserstein_with_tensors.py new file mode 100755 index 00000000..e3f1411a --- /dev/null +++ b/src/python/test/test_wasserstein_with_tensors.py @@ -0,0 +1,47 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Mathieu Carriere + + Copyright (C) 2020 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +from gudhi.wasserstein import wasserstein_distance as pot +import numpy as np +import torch +import tensorflow as tf + +def test_wasserstein_distance_grad(): + diag1 = torch.tensor([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]], requires_grad=True) + diag2 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) + diag3 = torch.tensor([[2.8, 4.45], [9.5, 14.1]], requires_grad=True) + assert diag1.grad is None and diag2.grad is None and diag3.grad is None + dist12 = pot(diag1, diag2, internal_p=2, order=2, enable_autodiff=True) + dist30 = pot(diag3, torch.tensor([]), internal_p=2, order=2, enable_autodiff=True) + dist12.backward() + dist30.backward() + assert not torch.isnan(diag1.grad).any() and not torch.isnan(diag2.grad).any() and not torch.isnan(diag3.grad).any() + diag4 = torch.tensor([[0., 10.]], requires_grad=True) + diag5 = torch.tensor([[1., 11.], [3., 4.]], requires_grad=True) + dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) + assert dist45 == 3. + dist45.backward() + assert np.array_equal(diag4.grad, [[-1., -1.]]) + assert np.array_equal(diag5.grad, [[1., 1.], [-1., 1.]]) + diag6 = torch.tensor([[5., 10.]], requires_grad=True) + pot(diag6, diag6, internal_p=2, order=2, enable_autodiff=True).backward() + # https://github.com/jonasrauber/eagerpy/issues/6 + # assert np.array_equal(diag6.grad, [[0., 0.]]) + +def test_wasserstein_distance_grad_tensorflow(): + with tf.GradientTape() as tape: + diag4 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[0., 10.]]), trainable=True)) + diag5 = tf.convert_to_tensor(tf.Variable(initial_value=np.array([[1., 11.], [3., 4.]]), trainable=True)) + dist45 = pot(diag4, diag5, internal_p=1, order=1, enable_autodiff=True) + assert dist45 == 3. + + grads = tape.gradient(dist45, [diag4, diag5]) + assert np.array_equal(grads[0].values, [[-1., -1.]]) + assert np.array_equal(grads[1].values, [[1., 1.], [-1., 1.]])
\ No newline at end of file |