summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py252
1 files changed, 236 insertions, 16 deletions
diff --git a/ot/backend.py b/ot/backend.py
index 361ffba..0779243 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -534,9 +534,9 @@ class Backend():
"""
raise NotImplementedError()
- def zero_pad(self, a, pad_width):
+ def zero_pad(self, a, pad_width, value=0):
r"""
- Pads a tensor.
+ Pads a tensor with a given value (0 by default).
This function follows the api from :any:`numpy.pad`
@@ -854,6 +854,21 @@ class Backend():
"""
raise NotImplementedError()
+ def kl_div(self, p, q, eps=1e-16):
+ r"""
+ Computes the Kullback-Leibler divergence.
+
+ This function follows the api from :any:`scipy.stats.entropy`.
+
+ Parameter eps is used to avoid numerical errors and is added in the log.
+
+ .. math::
+ KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
+ """
+ raise NotImplementedError()
+
def isfinite(self, a):
r"""
Tests element-wise for finiteness (not infinity and not Not a Number).
@@ -880,6 +895,62 @@ class Backend():
"""
raise NotImplementedError()
+ def tile(self, a, reps):
+ r"""
+ Construct an array by repeating a the number of times given by reps
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html
+ """
+ raise NotImplementedError()
+
+ def floor(self, a):
+ r"""
+ Return the floor of the input element-wise
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html
+ """
+ raise NotImplementedError()
+
+ def prod(self, a, axis=None):
+ r"""
+ Return the product of all elements.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html
+ """
+ raise NotImplementedError()
+
+ def sort2(self, a, axis=None):
+ r"""
+ Return the sorted array and the indices to sort the array
+
+ See: https://pytorch.org/docs/stable/generated/torch.sort.html
+ """
+ raise NotImplementedError()
+
+ def qr(self, a):
+ r"""
+ Return the QR factorization
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html
+ """
+ raise NotImplementedError()
+
+ def atan2(self, a, b):
+ r"""
+ Element wise arctangent
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html
+ """
+ raise NotImplementedError()
+
+ def transpose(self, a, axes=None):
+ r"""
+ Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -1024,8 +1095,8 @@ class NumpyBackend(Backend):
def concatenate(self, arrays, axis=0):
return np.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return np.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return np.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return np.argmax(a, axis=axis)
@@ -1158,6 +1229,9 @@ class NumpyBackend(Backend):
def sqrtm(self, a):
return scipy.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return np.sum(p * np.log(p / q + eps))
+
def isfinite(self, a):
return np.isfinite(a)
@@ -1167,6 +1241,44 @@ class NumpyBackend(Backend):
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return np.tile(a, reps)
+
+ def floor(self, a):
+ return np.floor(a)
+
+ def prod(self, a, axis=0):
+ return np.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ np_version = tuple([int(k) for k in np.__version__.split(".")])
+ if np_version < (1, 22, 0):
+ M, N = a.shape[-2], a.shape[-1]
+ K = min(M, N)
+
+ if len(a.shape) >= 3:
+ n = a.shape[0]
+
+ qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N))
+
+ for i in range(a.shape[0]):
+ qs[i], rs[i] = np.linalg.qr(a[i])
+
+ else:
+ return np.linalg.qr(a)
+
+ return qs, rs
+ return np.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return np.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return np.transpose(a, axes)
+
class JaxBackend(Backend):
"""
@@ -1333,8 +1445,8 @@ class JaxBackend(Backend):
def concatenate(self, arrays, axis=0):
return jnp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return jnp.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return jnp.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return jnp.argmax(a, axis=axis)
@@ -1481,6 +1593,9 @@ class JaxBackend(Backend):
L, V = jnp.linalg.eigh(a)
return (V * jnp.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return jnp.sum(p * jnp.log(p / q + eps))
+
def isfinite(self, a):
return jnp.isfinite(a)
@@ -1490,6 +1605,27 @@ class JaxBackend(Backend):
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return jnp.tile(a, reps)
+
+ def floor(self, a):
+ return jnp.floor(a)
+
+ def prod(self, a, axis=0):
+ return jnp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return jnp.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return jnp.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return jnp.transpose(a, axes)
+
class TorchBackend(Backend):
"""
@@ -1507,15 +1643,19 @@ class TorchBackend(Backend):
def __init__(self):
- self.rng_ = torch.Generator()
+ self.rng_ = torch.Generator("cpu")
self.rng_.seed()
self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
torch.tensor(1, dtype=torch.float64)]
if torch.cuda.is_available():
+ self.rng_cuda_ = torch.Generator("cuda")
+ self.rng_cuda_.seed()
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
+ else:
+ self.rng_cuda_ = torch.Generator("cpu")
from torch.autograd import Function
@@ -1704,13 +1844,13 @@ class TorchBackend(Backend):
def concatenate(self, arrays, axis=0):
return torch.cat(arrays, dim=axis)
- def zero_pad(self, a, pad_width):
+ def zero_pad(self, a, pad_width, value=0):
from torch.nn.functional import pad
# pad_width is an array of ndim tuples indicating how many 0 before and after
# we need to add. We first need to make it compliant with torch syntax, that
# starts with the last dim, then second last, etc.
how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl)
- return pad(a, how_pad)
+ return pad(a, how_pad, value=value)
def argmax(self, a, axis=None):
return torch.argmax(a, dim=axis)
@@ -1761,20 +1901,26 @@ class TorchBackend(Backend):
def seed(self, seed=None):
if isinstance(seed, int):
self.rng_.manual_seed(seed)
+ self.rng_cuda_.manual_seed(seed)
elif isinstance(seed, torch.Generator):
- self.rng_ = seed
+ if self.device_type(seed) == "GPU":
+ self.rng_cuda_ = seed
+ else:
+ self.rng_ = seed
else:
raise ValueError("Non compatible seed : {}".format(seed))
def rand(self, *size, type_as=None):
if type_as is not None:
- return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device)
+ generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_
+ return torch.rand(size=size, generator=generator, dtype=type_as.dtype, device=type_as.device)
else:
return torch.rand(size=size, generator=self.rng_)
def randn(self, *size, type_as=None):
if type_as is not None:
- return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
+ generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_
+ return torch.randn(size=size, dtype=type_as.dtype, generator=generator, device=type_as.device)
else:
return torch.randn(size=size, generator=self.rng_)
@@ -1891,6 +2037,9 @@ class TorchBackend(Backend):
L, V = torch.linalg.eigh(a)
return (V * torch.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return torch.sum(p * torch.log(p / q + eps))
+
def isfinite(self, a):
return torch.isfinite(a)
@@ -1900,6 +2049,29 @@ class TorchBackend(Backend):
def is_floating_point(self, a):
return a.dtype.is_floating_point
+ def tile(self, a, reps):
+ return a.repeat(reps)
+
+ def floor(self, a):
+ return torch.floor(a)
+
+ def prod(self, a, axis=0):
+ return torch.prod(a, dim=axis)
+
+ def sort2(self, a, axis=-1):
+ return torch.sort(a, axis)
+
+ def qr(self, a):
+ return torch.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return torch.atan2(a, b)
+
+ def transpose(self, a, axes=None):
+ if axes is None:
+ axes = tuple(range(a.ndim)[::-1])
+ return a.permute(axes)
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -2062,8 +2234,8 @@ class CupyBackend(Backend): # pragma: no cover
def concatenate(self, arrays, axis=0):
return cp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return cp.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return cp.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return cp.argmax(a, axis=axis)
@@ -2238,6 +2410,9 @@ class CupyBackend(Backend): # pragma: no cover
L, V = cp.linalg.eigh(a)
return (V * self.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return cp.sum(p * cp.log(p / q + eps))
+
def isfinite(self, a):
return cp.isfinite(a)
@@ -2247,6 +2422,27 @@ class CupyBackend(Backend): # pragma: no cover
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return cp.tile(a, reps)
+
+ def floor(self, a):
+ return cp.floor(a)
+
+ def prod(self, a, axis=0):
+ return cp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return cp.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return cp.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return cp.transpose(a, axes)
+
class TensorflowBackend(Backend):
@@ -2417,8 +2613,8 @@ class TensorflowBackend(Backend):
def concatenate(self, arrays, axis=0):
return tnp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return tnp.pad(a, pad_width, mode="constant")
+ def zero_pad(self, a, pad_width, value=0):
+ return tnp.pad(a, pad_width, mode="constant", constant_values=value)
def argmax(self, a, axis=None):
return tnp.argmax(a, axis=axis)
@@ -2598,6 +2794,9 @@ class TensorflowBackend(Backend):
def sqrtm(self, a):
return tf.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return tnp.sum(p * tnp.log(p / q + eps))
+
def isfinite(self, a):
return tnp.isfinite(a)
@@ -2606,3 +2805,24 @@ class TensorflowBackend(Backend):
def is_floating_point(self, a):
return a.dtype.is_floating
+
+ def tile(self, a, reps):
+ return tnp.tile(a, reps)
+
+ def floor(self, a):
+ return tf.floor(a)
+
+ def prod(self, a, axis=0):
+ return tnp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return tf.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return tf.math.atan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return tf.transpose(a, perm=axes)