summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorFrancisco Muñoz <femunoz@dim.uchile.cl>2023-05-10 07:41:15 -0400
committerGitHub <noreply@github.com>2023-05-10 13:41:15 +0200
commit8cc8dd2e8e13022b03bcd013becc08e7e18c404a (patch)
tree8305bdebda3479c1717c69e6e7133e07b4b444bd /ot
parent03341c6953c06608ba17d6bf7cd35666bc069988 (diff)
[FIX] Refactor the function `utils.cost_normalization` to work with multiple backends (#472)
* [FEAT] Add the 'median' method to the backend base class and implements this method in the Numpy, Pytorch, Jax and Cupy backends * [TEST] Modify the 'cost_normalization' test to multiple backends * [REFACTOR] Refactor the 'utils.cost_normalization' function for multiple backends * [TEST] Update backend tests for method 'median' * [DEBUG] Fix the error in the test in the 'median' method with PyTorch backend * [TEST] Add the edge case where the 'median' method is not yet implemented in the Tensorflow backend. * [FEAT] Implement the 'median' method in the Tensorflow backend using Numpy * [DEBUG] For compatibility reasons, the median method in the Pytorch backend change using numpy * [DEBUG] The 'median' method checks the Pytorch version to decide whether to use torch.quantile or numpy * Add changes to RELEASES.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot')
-rw-r--r--ot/backend.py42
-rw-r--r--ot/utils.py10
2 files changed, 48 insertions, 4 deletions
diff --git a/ot/backend.py b/ot/backend.py
index d661c74..9aa14e6 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -574,6 +574,16 @@ class Backend():
"""
raise NotImplementedError()
+ def median(self, a, axis=None):
+ r"""
+ Computes the median of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.median`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.median.html
+ """
+ raise NotImplementedError()
+
def std(self, a, axis=None):
r"""
Computes the standard deviation of a tensor along given dimensions.
@@ -1123,6 +1133,9 @@ class NumpyBackend(Backend):
def mean(self, a, axis=None):
return np.mean(a, axis=axis)
+ def median(self, a, axis=None):
+ return np.median(a, axis=axis)
+
def std(self, a, axis=None):
return np.std(a, axis=axis)
@@ -1482,6 +1495,9 @@ class JaxBackend(Backend):
def mean(self, a, axis=None):
return jnp.mean(a, axis=axis)
+ def median(self, a, axis=None):
+ return jnp.median(a, axis=axis)
+
def std(self, a, axis=None):
return jnp.std(a, axis=axis)
@@ -1899,6 +1915,22 @@ class TorchBackend(Backend):
else:
return torch.mean(a)
+ def median(self, a, axis=None):
+ from packaging import version
+ # Since version 1.11.0, interpolation is available
+ if version.parse(torch.__version__) >= version.parse("1.11.0"):
+ if axis is not None:
+ return torch.quantile(a, 0.5, interpolation="midpoint", dim=axis)
+ else:
+ return torch.quantile(a, 0.5, interpolation="midpoint")
+
+ # Else, use numpy
+ warnings.warn("The median is being computed using numpy and the array has been detached "
+ "in the Pytorch backend.")
+ a_ = self.to_numpy(a)
+ a_median = np.median(a_, axis=axis)
+ return self.from_numpy(a_median, type_as=a)
+
def std(self, a, axis=None):
if axis is not None:
return torch.std(a, dim=axis, unbiased=False)
@@ -2289,6 +2321,9 @@ class CupyBackend(Backend): # pragma: no cover
def mean(self, a, axis=None):
return cp.mean(a, axis=axis)
+ def median(self, a, axis=None):
+ return cp.median(a, axis=axis)
+
def std(self, a, axis=None):
return cp.std(a, axis=axis)
@@ -2678,6 +2713,13 @@ class TensorflowBackend(Backend):
def mean(self, a, axis=None):
return tnp.mean(a, axis=axis)
+ def median(self, a, axis=None):
+ warnings.warn("The median is being computed using numpy and the array has been detached "
+ "in the Tensorflow backend.")
+ a_ = self.to_numpy(a)
+ a_median = np.median(a_, axis=axis)
+ return self.from_numpy(a_median, type_as=a)
+
def std(self, a, axis=None):
return tnp.std(a, axis=axis)
diff --git a/ot/utils.py b/ot/utils.py
index 3343028..091b268 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -359,16 +359,18 @@ def cost_normalization(C, norm=None):
The input cost matrix normalized according to given norm.
"""
+ nx = get_backend(C)
+
if norm is None:
pass
elif norm == "median":
- C /= float(np.median(C))
+ C /= float(nx.median(C))
elif norm == "max":
- C /= float(np.max(C))
+ C /= float(nx.max(C))
elif norm == "log":
- C = np.log(1 + C)
+ C = nx.log(1 + C)
elif norm == "loglog":
- C = np.log1p(np.log1p(C))
+ C = nx.log(1 + nx.log(1 + C))
else:
raise ValueError('Norm %s is not a valid option.\n'
'Valid options are:\n'