diff options
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 252 |
1 files changed, 236 insertions, 16 deletions
diff --git a/ot/backend.py b/ot/backend.py index 361ffba..0779243 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -534,9 +534,9 @@ class Backend(): """ raise NotImplementedError() - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value=0): r""" - Pads a tensor. + Pads a tensor with a given value (0 by default). This function follows the api from :any:`numpy.pad` @@ -854,6 +854,21 @@ class Backend(): """ raise NotImplementedError() + def kl_div(self, p, q, eps=1e-16): + r""" + Computes the Kullback-Leibler divergence. + + This function follows the api from :any:`scipy.stats.entropy`. + + Parameter eps is used to avoid numerical errors and is added in the log. + + .. math:: + KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon) + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html + """ + raise NotImplementedError() + def isfinite(self, a): r""" Tests element-wise for finiteness (not infinity and not Not a Number). @@ -880,6 +895,62 @@ class Backend(): """ raise NotImplementedError() + def tile(self, a, reps): + r""" + Construct an array by repeating a the number of times given by reps + + See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html + """ + raise NotImplementedError() + + def floor(self, a): + r""" + Return the floor of the input element-wise + + See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html + """ + raise NotImplementedError() + + def prod(self, a, axis=None): + r""" + Return the product of all elements. + + See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html + """ + raise NotImplementedError() + + def sort2(self, a, axis=None): + r""" + Return the sorted array and the indices to sort the array + + See: https://pytorch.org/docs/stable/generated/torch.sort.html + """ + raise NotImplementedError() + + def qr(self, a): + r""" + Return the QR factorization + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html + """ + raise NotImplementedError() + + def atan2(self, a, b): + r""" + Element wise arctangent + + See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html + """ + raise NotImplementedError() + + def transpose(self, a, axes=None): + r""" + Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped. + + See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1024,8 +1095,8 @@ class NumpyBackend(Backend): def concatenate(self, arrays, axis=0): return np.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return np.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return np.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return np.argmax(a, axis=axis) @@ -1158,6 +1229,9 @@ class NumpyBackend(Backend): def sqrtm(self, a): return scipy.linalg.sqrtm(a) + def kl_div(self, p, q, eps=1e-16): + return np.sum(p * np.log(p / q + eps)) + def isfinite(self, a): return np.isfinite(a) @@ -1167,6 +1241,44 @@ class NumpyBackend(Backend): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return np.tile(a, reps) + + def floor(self, a): + return np.floor(a) + + def prod(self, a, axis=0): + return np.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + np_version = tuple([int(k) for k in np.__version__.split(".")]) + if np_version < (1, 22, 0): + M, N = a.shape[-2], a.shape[-1] + K = min(M, N) + + if len(a.shape) >= 3: + n = a.shape[0] + + qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N)) + + for i in range(a.shape[0]): + qs[i], rs[i] = np.linalg.qr(a[i]) + + else: + return np.linalg.qr(a) + + return qs, rs + return np.linalg.qr(a) + + def atan2(self, a, b): + return np.arctan2(a, b) + + def transpose(self, a, axes=None): + return np.transpose(a, axes) + class JaxBackend(Backend): """ @@ -1333,8 +1445,8 @@ class JaxBackend(Backend): def concatenate(self, arrays, axis=0): return jnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return jnp.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return jnp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return jnp.argmax(a, axis=axis) @@ -1481,6 +1593,9 @@ class JaxBackend(Backend): L, V = jnp.linalg.eigh(a) return (V * jnp.sqrt(L)[None, :]) @ V.T + def kl_div(self, p, q, eps=1e-16): + return jnp.sum(p * jnp.log(p / q + eps)) + def isfinite(self, a): return jnp.isfinite(a) @@ -1490,6 +1605,27 @@ class JaxBackend(Backend): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return jnp.tile(a, reps) + + def floor(self, a): + return jnp.floor(a) + + def prod(self, a, axis=0): + return jnp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return jnp.linalg.qr(a) + + def atan2(self, a, b): + return jnp.arctan2(a, b) + + def transpose(self, a, axes=None): + return jnp.transpose(a, axes) + class TorchBackend(Backend): """ @@ -1507,15 +1643,19 @@ class TorchBackend(Backend): def __init__(self): - self.rng_ = torch.Generator() + self.rng_ = torch.Generator("cpu") self.rng_.seed() self.__type_list__ = [torch.tensor(1, dtype=torch.float32), torch.tensor(1, dtype=torch.float64)] if torch.cuda.is_available(): + self.rng_cuda_ = torch.Generator("cuda") + self.rng_cuda_.seed() self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + else: + self.rng_cuda_ = torch.Generator("cpu") from torch.autograd import Function @@ -1704,13 +1844,13 @@ class TorchBackend(Backend): def concatenate(self, arrays, axis=0): return torch.cat(arrays, dim=axis) - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value=0): from torch.nn.functional import pad # pad_width is an array of ndim tuples indicating how many 0 before and after # we need to add. We first need to make it compliant with torch syntax, that # starts with the last dim, then second last, etc. how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl) - return pad(a, how_pad) + return pad(a, how_pad, value=value) def argmax(self, a, axis=None): return torch.argmax(a, dim=axis) @@ -1761,20 +1901,26 @@ class TorchBackend(Backend): def seed(self, seed=None): if isinstance(seed, int): self.rng_.manual_seed(seed) + self.rng_cuda_.manual_seed(seed) elif isinstance(seed, torch.Generator): - self.rng_ = seed + if self.device_type(seed) == "GPU": + self.rng_cuda_ = seed + else: + self.rng_ = seed else: raise ValueError("Non compatible seed : {}".format(seed)) def rand(self, *size, type_as=None): if type_as is not None: - return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + return torch.rand(size=size, generator=generator, dtype=type_as.dtype, device=type_as.device) else: return torch.rand(size=size, generator=self.rng_) def randn(self, *size, type_as=None): if type_as is not None: - return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_ + return torch.randn(size=size, dtype=type_as.dtype, generator=generator, device=type_as.device) else: return torch.randn(size=size, generator=self.rng_) @@ -1891,6 +2037,9 @@ class TorchBackend(Backend): L, V = torch.linalg.eigh(a) return (V * torch.sqrt(L)[None, :]) @ V.T + def kl_div(self, p, q, eps=1e-16): + return torch.sum(p * torch.log(p / q + eps)) + def isfinite(self, a): return torch.isfinite(a) @@ -1900,6 +2049,29 @@ class TorchBackend(Backend): def is_floating_point(self, a): return a.dtype.is_floating_point + def tile(self, a, reps): + return a.repeat(reps) + + def floor(self, a): + return torch.floor(a) + + def prod(self, a, axis=0): + return torch.prod(a, dim=axis) + + def sort2(self, a, axis=-1): + return torch.sort(a, axis) + + def qr(self, a): + return torch.linalg.qr(a) + + def atan2(self, a, b): + return torch.atan2(a, b) + + def transpose(self, a, axes=None): + if axes is None: + axes = tuple(range(a.ndim)[::-1]) + return a.permute(axes) + class CupyBackend(Backend): # pragma: no cover """ @@ -2062,8 +2234,8 @@ class CupyBackend(Backend): # pragma: no cover def concatenate(self, arrays, axis=0): return cp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return cp.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return cp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return cp.argmax(a, axis=axis) @@ -2238,6 +2410,9 @@ class CupyBackend(Backend): # pragma: no cover L, V = cp.linalg.eigh(a) return (V * self.sqrt(L)[None, :]) @ V.T + def kl_div(self, p, q, eps=1e-16): + return cp.sum(p * cp.log(p / q + eps)) + def isfinite(self, a): return cp.isfinite(a) @@ -2247,6 +2422,27 @@ class CupyBackend(Backend): # pragma: no cover def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return cp.tile(a, reps) + + def floor(self, a): + return cp.floor(a) + + def prod(self, a, axis=0): + return cp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return cp.linalg.qr(a) + + def atan2(self, a, b): + return cp.arctan2(a, b) + + def transpose(self, a, axes=None): + return cp.transpose(a, axes) + class TensorflowBackend(Backend): @@ -2417,8 +2613,8 @@ class TensorflowBackend(Backend): def concatenate(self, arrays, axis=0): return tnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return tnp.pad(a, pad_width, mode="constant") + def zero_pad(self, a, pad_width, value=0): + return tnp.pad(a, pad_width, mode="constant", constant_values=value) def argmax(self, a, axis=None): return tnp.argmax(a, axis=axis) @@ -2598,6 +2794,9 @@ class TensorflowBackend(Backend): def sqrtm(self, a): return tf.linalg.sqrtm(a) + def kl_div(self, p, q, eps=1e-16): + return tnp.sum(p * tnp.log(p / q + eps)) + def isfinite(self, a): return tnp.isfinite(a) @@ -2606,3 +2805,24 @@ class TensorflowBackend(Backend): def is_floating_point(self, a): return a.dtype.is_floating + + def tile(self, a, reps): + return tnp.tile(a, reps) + + def floor(self, a): + return tf.floor(a) + + def prod(self, a, axis=0): + return tnp.prod(a, axis=axis) + + def sort2(self, a, axis=-1): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return tf.linalg.qr(a) + + def atan2(self, a, b): + return tf.math.atan2(a, b) + + def transpose(self, a, axes=None): + return tf.transpose(a, perm=axes) |