diff options
author | Francisco Muñoz <femunoz@dim.uchile.cl> | 2023-05-10 07:41:15 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-10 13:41:15 +0200 |
commit | 8cc8dd2e8e13022b03bcd013becc08e7e18c404a (patch) | |
tree | 8305bdebda3479c1717c69e6e7133e07b4b444bd | |
parent | 03341c6953c06608ba17d6bf7cd35666bc069988 (diff) |
[FIX] Refactor the function `utils.cost_normalization` to work with multiple backends (#472)
* [FEAT] Add the 'median' method to the backend base class and implements this method in the Numpy, Pytorch, Jax and Cupy backends
* [TEST] Modify the 'cost_normalization' test to multiple backends
* [REFACTOR] Refactor the 'utils.cost_normalization' function for multiple backends
* [TEST] Update backend tests for method 'median'
* [DEBUG] Fix the error in the test in the 'median' method with PyTorch backend
* [TEST] Add the edge case where the 'median' method is not yet implemented in the Tensorflow backend.
* [FEAT] Implement the 'median' method in the Tensorflow backend using Numpy
* [DEBUG] For compatibility reasons, the median method in the Pytorch backend change using numpy
* [DEBUG] The 'median' method checks the Pytorch version to decide whether to use torch.quantile or numpy
* Add changes to RELEASES.md
---------
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
-rw-r--r-- | RELEASES.md | 2 | ||||
-rw-r--r-- | ot/backend.py | 42 | ||||
-rw-r--r-- | ot/utils.py | 10 | ||||
-rw-r--r-- | test/test_backend.py | 6 | ||||
-rw-r--r-- | test/test_utils.py | 28 |
5 files changed, 73 insertions, 15 deletions
diff --git a/RELEASES.md b/RELEASES.md index 97f4c44..0870b34 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,7 @@ - Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463) - Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459) - Add tests on GPU for master branch and approved PR (PR #473) +- Add `median` method to all inherited classes of `backend.Backend` (PR #472) #### Closed issues @@ -16,6 +17,7 @@ - Faster Bures-Wasserstein distance with NumPy backend (PR #468) - Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471) - Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (RP #474) +- Fix `utils.cost_normalization` function issue to work with multiple backends (PR #472) ## 0.9.0 diff --git a/ot/backend.py b/ot/backend.py index d661c74..9aa14e6 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -574,6 +574,16 @@ class Backend(): """ raise NotImplementedError() + def median(self, a, axis=None): + r""" + Computes the median of a tensor along given dimensions. + + This function follows the api from :any:`numpy.median` + + See: https://numpy.org/doc/stable/reference/generated/numpy.median.html + """ + raise NotImplementedError() + def std(self, a, axis=None): r""" Computes the standard deviation of a tensor along given dimensions. @@ -1123,6 +1133,9 @@ class NumpyBackend(Backend): def mean(self, a, axis=None): return np.mean(a, axis=axis) + def median(self, a, axis=None): + return np.median(a, axis=axis) + def std(self, a, axis=None): return np.std(a, axis=axis) @@ -1482,6 +1495,9 @@ class JaxBackend(Backend): def mean(self, a, axis=None): return jnp.mean(a, axis=axis) + def median(self, a, axis=None): + return jnp.median(a, axis=axis) + def std(self, a, axis=None): return jnp.std(a, axis=axis) @@ -1899,6 +1915,22 @@ class TorchBackend(Backend): else: return torch.mean(a) + def median(self, a, axis=None): + from packaging import version + # Since version 1.11.0, interpolation is available + if version.parse(torch.__version__) >= version.parse("1.11.0"): + if axis is not None: + return torch.quantile(a, 0.5, interpolation="midpoint", dim=axis) + else: + return torch.quantile(a, 0.5, interpolation="midpoint") + + # Else, use numpy + warnings.warn("The median is being computed using numpy and the array has been detached " + "in the Pytorch backend.") + a_ = self.to_numpy(a) + a_median = np.median(a_, axis=axis) + return self.from_numpy(a_median, type_as=a) + def std(self, a, axis=None): if axis is not None: return torch.std(a, dim=axis, unbiased=False) @@ -2289,6 +2321,9 @@ class CupyBackend(Backend): # pragma: no cover def mean(self, a, axis=None): return cp.mean(a, axis=axis) + def median(self, a, axis=None): + return cp.median(a, axis=axis) + def std(self, a, axis=None): return cp.std(a, axis=axis) @@ -2678,6 +2713,13 @@ class TensorflowBackend(Backend): def mean(self, a, axis=None): return tnp.mean(a, axis=axis) + def median(self, a, axis=None): + warnings.warn("The median is being computed using numpy and the array has been detached " + "in the Tensorflow backend.") + a_ = self.to_numpy(a) + a_median = np.median(a_, axis=axis) + return self.from_numpy(a_median, type_as=a) + def std(self, a, axis=None): return tnp.std(a, axis=axis) diff --git a/ot/utils.py b/ot/utils.py index 3343028..091b268 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -359,16 +359,18 @@ def cost_normalization(C, norm=None): The input cost matrix normalized according to given norm. """ + nx = get_backend(C) + if norm is None: pass elif norm == "median": - C /= float(np.median(C)) + C /= float(nx.median(C)) elif norm == "max": - C /= float(np.max(C)) + C /= float(nx.max(C)) elif norm == "log": - C = np.log(1 + C) + C = nx.log(1 + C) elif norm == "loglog": - C = np.log1p(np.log1p(C)) + C = nx.log(1 + nx.log(1 + C)) else: raise ValueError('Norm %s is not a valid option.\n' 'Valid options are:\n' diff --git a/test/test_backend.py b/test/test_backend.py index fedc62f..799ac54 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -222,6 +222,8 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.mean(M) with pytest.raises(NotImplementedError): + nx.median(M) + with pytest.raises(NotImplementedError): nx.std(M) with pytest.raises(NotImplementedError): nx.linspace(0, 1, 50) @@ -519,6 +521,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('mean') + A = nx.median(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('median') + A = nx.std(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('std') diff --git a/test/test_utils.py b/test/test_utils.py index 658214d..87f4dc4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -270,25 +270,31 @@ def test_clean_zeros(): assert len(b) == n - nz2 -def test_cost_normalization(): +def test_cost_normalization(nx): C = np.random.rand(10, 10) + C1 = nx.from_numpy(C) # does nothing - M0 = ot.utils.cost_normalization(C) - np.testing.assert_allclose(C, M0) + M0 = ot.utils.cost_normalization(C1) + M1 = nx.to_numpy(M0) + np.testing.assert_allclose(C, M1) - M = ot.utils.cost_normalization(C, 'median') - np.testing.assert_allclose(np.median(M), 1) + M = ot.utils.cost_normalization(C1, 'median') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(np.median(M1), 1) - M = ot.utils.cost_normalization(C, 'max') - np.testing.assert_allclose(M.max(), 1) + M = ot.utils.cost_normalization(C1, 'max') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), 1) - M = ot.utils.cost_normalization(C, 'log') - np.testing.assert_allclose(M.max(), np.log(1 + C).max()) + M = ot.utils.cost_normalization(C1, 'log') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), np.log(1 + C).max()) - M = ot.utils.cost_normalization(C, 'loglog') - np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max()) + M = ot.utils.cost_normalization(C1, 'loglog') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max()) def test_check_params(): |