summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancisco Muñoz <femunoz@dim.uchile.cl>2023-05-10 07:41:15 -0400
committerGitHub <noreply@github.com>2023-05-10 13:41:15 +0200
commit8cc8dd2e8e13022b03bcd013becc08e7e18c404a (patch)
tree8305bdebda3479c1717c69e6e7133e07b4b444bd
parent03341c6953c06608ba17d6bf7cd35666bc069988 (diff)
[FIX] Refactor the function `utils.cost_normalization` to work with multiple backends (#472)
* [FEAT] Add the 'median' method to the backend base class and implements this method in the Numpy, Pytorch, Jax and Cupy backends * [TEST] Modify the 'cost_normalization' test to multiple backends * [REFACTOR] Refactor the 'utils.cost_normalization' function for multiple backends * [TEST] Update backend tests for method 'median' * [DEBUG] Fix the error in the test in the 'median' method with PyTorch backend * [TEST] Add the edge case where the 'median' method is not yet implemented in the Tensorflow backend. * [FEAT] Implement the 'median' method in the Tensorflow backend using Numpy * [DEBUG] For compatibility reasons, the median method in the Pytorch backend change using numpy * [DEBUG] The 'median' method checks the Pytorch version to decide whether to use torch.quantile or numpy * Add changes to RELEASES.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md2
-rw-r--r--ot/backend.py42
-rw-r--r--ot/utils.py10
-rw-r--r--test/test_backend.py6
-rw-r--r--test/test_utils.py28
5 files changed, 73 insertions, 15 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 97f4c44..0870b34 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -6,6 +6,7 @@
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
- Add tests on GPU for master branch and approved PR (PR #473)
+- Add `median` method to all inherited classes of `backend.Backend` (PR #472)
#### Closed issues
@@ -16,6 +17,7 @@
- Faster Bures-Wasserstein distance with NumPy backend (PR #468)
- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471)
- Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (RP #474)
+- Fix `utils.cost_normalization` function issue to work with multiple backends (PR #472)
## 0.9.0
diff --git a/ot/backend.py b/ot/backend.py
index d661c74..9aa14e6 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -574,6 +574,16 @@ class Backend():
"""
raise NotImplementedError()
+ def median(self, a, axis=None):
+ r"""
+ Computes the median of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.median`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.median.html
+ """
+ raise NotImplementedError()
+
def std(self, a, axis=None):
r"""
Computes the standard deviation of a tensor along given dimensions.
@@ -1123,6 +1133,9 @@ class NumpyBackend(Backend):
def mean(self, a, axis=None):
return np.mean(a, axis=axis)
+ def median(self, a, axis=None):
+ return np.median(a, axis=axis)
+
def std(self, a, axis=None):
return np.std(a, axis=axis)
@@ -1482,6 +1495,9 @@ class JaxBackend(Backend):
def mean(self, a, axis=None):
return jnp.mean(a, axis=axis)
+ def median(self, a, axis=None):
+ return jnp.median(a, axis=axis)
+
def std(self, a, axis=None):
return jnp.std(a, axis=axis)
@@ -1899,6 +1915,22 @@ class TorchBackend(Backend):
else:
return torch.mean(a)
+ def median(self, a, axis=None):
+ from packaging import version
+ # Since version 1.11.0, interpolation is available
+ if version.parse(torch.__version__) >= version.parse("1.11.0"):
+ if axis is not None:
+ return torch.quantile(a, 0.5, interpolation="midpoint", dim=axis)
+ else:
+ return torch.quantile(a, 0.5, interpolation="midpoint")
+
+ # Else, use numpy
+ warnings.warn("The median is being computed using numpy and the array has been detached "
+ "in the Pytorch backend.")
+ a_ = self.to_numpy(a)
+ a_median = np.median(a_, axis=axis)
+ return self.from_numpy(a_median, type_as=a)
+
def std(self, a, axis=None):
if axis is not None:
return torch.std(a, dim=axis, unbiased=False)
@@ -2289,6 +2321,9 @@ class CupyBackend(Backend): # pragma: no cover
def mean(self, a, axis=None):
return cp.mean(a, axis=axis)
+ def median(self, a, axis=None):
+ return cp.median(a, axis=axis)
+
def std(self, a, axis=None):
return cp.std(a, axis=axis)
@@ -2678,6 +2713,13 @@ class TensorflowBackend(Backend):
def mean(self, a, axis=None):
return tnp.mean(a, axis=axis)
+ def median(self, a, axis=None):
+ warnings.warn("The median is being computed using numpy and the array has been detached "
+ "in the Tensorflow backend.")
+ a_ = self.to_numpy(a)
+ a_median = np.median(a_, axis=axis)
+ return self.from_numpy(a_median, type_as=a)
+
def std(self, a, axis=None):
return tnp.std(a, axis=axis)
diff --git a/ot/utils.py b/ot/utils.py
index 3343028..091b268 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -359,16 +359,18 @@ def cost_normalization(C, norm=None):
The input cost matrix normalized according to given norm.
"""
+ nx = get_backend(C)
+
if norm is None:
pass
elif norm == "median":
- C /= float(np.median(C))
+ C /= float(nx.median(C))
elif norm == "max":
- C /= float(np.max(C))
+ C /= float(nx.max(C))
elif norm == "log":
- C = np.log(1 + C)
+ C = nx.log(1 + C)
elif norm == "loglog":
- C = np.log1p(np.log1p(C))
+ C = nx.log(1 + nx.log(1 + C))
else:
raise ValueError('Norm %s is not a valid option.\n'
'Valid options are:\n'
diff --git a/test/test_backend.py b/test/test_backend.py
index fedc62f..799ac54 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -222,6 +222,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.mean(M)
with pytest.raises(NotImplementedError):
+ nx.median(M)
+ with pytest.raises(NotImplementedError):
nx.std(M)
with pytest.raises(NotImplementedError):
nx.linspace(0, 1, 50)
@@ -519,6 +521,10 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('mean')
+ A = nx.median(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('median')
+
A = nx.std(Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('std')
diff --git a/test/test_utils.py b/test/test_utils.py
index 658214d..87f4dc4 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -270,25 +270,31 @@ def test_clean_zeros():
assert len(b) == n - nz2
-def test_cost_normalization():
+def test_cost_normalization(nx):
C = np.random.rand(10, 10)
+ C1 = nx.from_numpy(C)
# does nothing
- M0 = ot.utils.cost_normalization(C)
- np.testing.assert_allclose(C, M0)
+ M0 = ot.utils.cost_normalization(C1)
+ M1 = nx.to_numpy(M0)
+ np.testing.assert_allclose(C, M1)
- M = ot.utils.cost_normalization(C, 'median')
- np.testing.assert_allclose(np.median(M), 1)
+ M = ot.utils.cost_normalization(C1, 'median')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(np.median(M1), 1)
- M = ot.utils.cost_normalization(C, 'max')
- np.testing.assert_allclose(M.max(), 1)
+ M = ot.utils.cost_normalization(C1, 'max')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(M1.max(), 1)
- M = ot.utils.cost_normalization(C, 'log')
- np.testing.assert_allclose(M.max(), np.log(1 + C).max())
+ M = ot.utils.cost_normalization(C1, 'log')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(M1.max(), np.log(1 + C).max())
- M = ot.utils.cost_normalization(C, 'loglog')
- np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max())
+ M = ot.utils.cost_normalization(C1, 'loglog')
+ M1 = nx.to_numpy(M)
+ np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max())
def test_check_params():