diff options
author | Clément Bonet <32179275+clbonet@users.noreply.github.com> | 2023-02-23 08:31:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-23 08:31:01 +0100 |
commit | 80e3c23bc968f866fd20344ddc443a3c7fcb3b0d (patch) | |
tree | e4c2e938896243842e290d8fcf78879a8f6960bf /ot/backend.py | |
parent | 97feeb32b6c069d7bb44cd995531c2b820d59771 (diff) |
[WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein (#434)
* W circle + SSW
* Tests + Example SSW_1
* Example Wasserstein Circle + Tests
* Wasserstein on the circle wrt Unif
* Example SSW unif
* pep8
* np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests
* np qr
* rm test python 3.11
* update names, tests, backend transpose
* Comment error batchs
* semidiscrete_wasserstein2_unif_circle example
* torch permute method instead of torch.permute for previous versions
* update comments and doc
* doc wasserstein circle model as [0,1[
* Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 204 |
1 files changed, 192 insertions, 12 deletions
diff --git a/ot/backend.py b/ot/backend.py index 337e040..0779243 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -534,9 +534,9 @@ class Backend(): """ raise NotImplementedError() - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value=0): r""" - Pads a tensor. + Pads a tensor with a given value (0 by default). This function follows the api from :any:`numpy.pad` @@ -895,6 +895,62 @@ class Backend(): """ raise NotImplementedError() + def tile(self, a, reps): + r""" + Construct an array by repeating a the number of times given by reps + + See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html + """ + raise NotImplementedError() + + def floor(self, a): + r""" + Return the floor of the input element-wise + + See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html + """ + raise NotImplementedError() + + def prod(self, a, axis=None): + r""" + Return the product of all elements. + + See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html + """ + raise NotImplementedError() + + def sort2(self, a, axis=None): + r""" + Return the sorted array and the indices to sort the array + + See: https://pytorch.org/docs/stable/generated/torch.sort.html + """ + raise NotImplementedError() + + def qr(self, a): + r""" + Return the QR factorization + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html + """ + raise NotImplementedError() + + def atan2(self, a, b): + r""" + Element wise arctangent + + See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html + """ + raise NotImplementedError() + + def transpose(self, a, axes=None): + r""" + Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped. + + See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1039,8 +1095,8 @@ class NumpyBackend(Backend): def concatenate(self, arrays, axis=0): return np.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return np.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return np.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return np.argmax(a, axis=axis) @@ -1185,6 +1241,44 @@ class NumpyBackend(Backend): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return np.tile(a, reps) + + def floor(self, a): + return np.floor(a) + + def prod(self, a, axis=0): + return np.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + np_version = tuple([int(k) for k in np.__version__.split(".")]) + if np_version < (1, 22, 0): + M, N = a.shape[-2], a.shape[-1] + K = min(M, N) + + if len(a.shape) >= 3: + n = a.shape[0] + + qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N)) + + for i in range(a.shape[0]): + qs[i], rs[i] = np.linalg.qr(a[i]) + + else: + return np.linalg.qr(a) + + return qs, rs + return np.linalg.qr(a) + + def atan2(self, a, b): + return np.arctan2(a, b) + + def transpose(self, a, axes=None): + return np.transpose(a, axes) + class JaxBackend(Backend): """ @@ -1351,8 +1445,8 @@ class JaxBackend(Backend): def concatenate(self, arrays, axis=0): return jnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return jnp.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return jnp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return jnp.argmax(a, axis=axis) @@ -1511,6 +1605,27 @@ class JaxBackend(Backend): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return jnp.tile(a, reps) + + def floor(self, a): + return jnp.floor(a) + + def prod(self, a, axis=0): + return jnp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return jnp.linalg.qr(a) + + def atan2(self, a, b): + return jnp.arctan2(a, b) + + def transpose(self, a, axes=None): + return jnp.transpose(a, axes) + class TorchBackend(Backend): """ @@ -1729,13 +1844,13 @@ class TorchBackend(Backend): def concatenate(self, arrays, axis=0): return torch.cat(arrays, dim=axis) - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value=0): from torch.nn.functional import pad # pad_width is an array of ndim tuples indicating how many 0 before and after # we need to add. We first need to make it compliant with torch syntax, that # starts with the last dim, then second last, etc. how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl) - return pad(a, how_pad) + return pad(a, how_pad, value=value) def argmax(self, a, axis=None): return torch.argmax(a, dim=axis) @@ -1934,6 +2049,29 @@ class TorchBackend(Backend): def is_floating_point(self, a): return a.dtype.is_floating_point + def tile(self, a, reps): + return a.repeat(reps) + + def floor(self, a): + return torch.floor(a) + + def prod(self, a, axis=0): + return torch.prod(a, dim=axis) + + def sort2(self, a, axis=-1): + return torch.sort(a, axis) + + def qr(self, a): + return torch.linalg.qr(a) + + def atan2(self, a, b): + return torch.atan2(a, b) + + def transpose(self, a, axes=None): + if axes is None: + axes = tuple(range(a.ndim)[::-1]) + return a.permute(axes) + class CupyBackend(Backend): # pragma: no cover """ @@ -2096,8 +2234,8 @@ class CupyBackend(Backend): # pragma: no cover def concatenate(self, arrays, axis=0): return cp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return cp.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return cp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return cp.argmax(a, axis=axis) @@ -2284,6 +2422,27 @@ class CupyBackend(Backend): # pragma: no cover def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return cp.tile(a, reps) + + def floor(self, a): + return cp.floor(a) + + def prod(self, a, axis=0): + return cp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return cp.linalg.qr(a) + + def atan2(self, a, b): + return cp.arctan2(a, b) + + def transpose(self, a, axes=None): + return cp.transpose(a, axes) + class TensorflowBackend(Backend): @@ -2454,8 +2613,8 @@ class TensorflowBackend(Backend): def concatenate(self, arrays, axis=0): return tnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return tnp.pad(a, pad_width, mode="constant") + def zero_pad(self, a, pad_width, value=0): + return tnp.pad(a, pad_width, mode="constant", constant_values=value) def argmax(self, a, axis=None): return tnp.argmax(a, axis=axis) @@ -2646,3 +2805,24 @@ class TensorflowBackend(Backend): def is_floating_point(self, a): return a.dtype.is_floating + + def tile(self, a, reps): + return tnp.tile(a, reps) + + def floor(self, a): + return tf.floor(a) + + def prod(self, a, axis=0): + return tnp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return tf.linalg.qr(a) + + def atan2(self, a, b): + return tf.math.atan2(a, b) + + def transpose(self, a, axes=None): + return tf.transpose(a, perm=axes) |