summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-06-01 10:10:54 +0200
committerGitHub <noreply@github.com>2021-06-01 10:10:54 +0200
commit184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch)
tree483a7274c91030fd644de49b03a5fad04af9deba /ot/backend.py
parent1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff)
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty <ncourty@irisa.fr> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py536
1 files changed, 536 insertions, 0 deletions
diff --git a/ot/backend.py b/ot/backend.py
new file mode 100644
index 0000000..d68f5cf
--- /dev/null
+++ b/ot/backend.py
@@ -0,0 +1,536 @@
+# -*- coding: utf-8 -*-
+"""
+Multi-lib backend for POT
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+
+try:
+ import torch
+ torch_type = torch.Tensor
+except ImportError:
+ torch = False
+ torch_type = float
+
+try:
+ import jax
+ import jax.numpy as jnp
+ jax_type = jax.numpy.ndarray
+except ImportError:
+ jax = False
+ jax_type = float
+
+str_type_error = "All array should be from the same type/backend. Current types are : {}"
+
+
+def get_backend_list():
+ """ returns the list of available backends)"""
+ lst = [NumpyBackend(), ]
+
+ if torch:
+ lst.append(TorchBackend())
+
+ if jax:
+ lst.append(JaxBackend())
+
+ return lst
+
+
+def get_backend(*args):
+ """returns the proper backend for a list of input arrays
+
+ Also raises TypeError if all arrays are not from the same backend
+ """
+ # check that some arrays given
+ if not len(args) > 0:
+ raise ValueError(" The function takes at least one parameter")
+ # check all same type
+
+ if isinstance(args[0], np.ndarray):
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
+ return NumpyBackend()
+ elif torch and isinstance(args[0], torch_type):
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
+ return TorchBackend()
+ elif isinstance(args[0], jax_type):
+ return JaxBackend()
+ else:
+ raise ValueError("Unknown type of non implemented backend.")
+
+
+def to_numpy(*args):
+ """returns numpy arrays from any compatible backend"""
+
+ if len(args) == 1:
+ return get_backend(args[0]).to_numpy(args[0])
+ else:
+ return [get_backend(a).to_numpy(a) for a in args]
+
+
+class Backend():
+
+ __name__ = None
+ __type__ = None
+
+ def __str__(self):
+ return self.__name__
+
+ # convert to numpy
+ def to_numpy(self, a):
+ raise NotImplementedError()
+
+ # convert from numpy
+ def from_numpy(self, a, type_as=None):
+ raise NotImplementedError()
+
+ def set_gradients(self, val, inputs, grads):
+ """ define the gradients for the value val wrt the inputs """
+ raise NotImplementedError()
+
+ def zeros(self, shape, type_as=None):
+ raise NotImplementedError()
+
+ def ones(self, shape, type_as=None):
+ raise NotImplementedError()
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ raise NotImplementedError()
+
+ def full(self, shape, fill_value, type_as=None):
+ raise NotImplementedError()
+
+ def eye(self, N, M=None, type_as=None):
+ raise NotImplementedError()
+
+ def sum(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def cumsum(self, a, axis=None):
+ raise NotImplementedError()
+
+ def max(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def min(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def maximum(self, a, b):
+ raise NotImplementedError()
+
+ def minimum(self, a, b):
+ raise NotImplementedError()
+
+ def dot(self, a, b):
+ raise NotImplementedError()
+
+ def abs(self, a):
+ raise NotImplementedError()
+
+ def exp(self, a):
+ raise NotImplementedError()
+
+ def log(self, a):
+ raise NotImplementedError()
+
+ def sqrt(self, a):
+ raise NotImplementedError()
+
+ def norm(self, a):
+ raise NotImplementedError()
+
+ def any(self, a):
+ raise NotImplementedError()
+
+ def isnan(self, a):
+ raise NotImplementedError()
+
+ def isinf(self, a):
+ raise NotImplementedError()
+
+ def einsum(self, subscripts, *operands):
+ raise NotImplementedError()
+
+ def sort(self, a, axis=-1):
+ raise NotImplementedError()
+
+ def argsort(self, a, axis=None):
+ raise NotImplementedError()
+
+ def flip(self, a, axis=None):
+ raise NotImplementedError()
+
+
+class NumpyBackend(Backend):
+
+ __name__ = 'numpy'
+ __type__ = np.ndarray
+
+ def to_numpy(self, a):
+ return a
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return a
+ elif isinstance(a, float):
+ return a
+ else:
+ return a.astype(type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # no gradients for numpy
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return np.zeros(shape)
+ else:
+ return np.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return np.ones(shape)
+ else:
+ return np.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return np.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return np.full(shape, fill_value)
+ else:
+ return np.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return np.eye(N, M)
+ else:
+ return np.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return np.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return np.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return np.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return np.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return np.maximum(a, b)
+
+ def minimum(self, a, b):
+ return np.minimum(a, b)
+
+ def dot(self, a, b):
+ return np.dot(a, b)
+
+ def abs(self, a):
+ return np.abs(a)
+
+ def exp(self, a):
+ return np.exp(a)
+
+ def log(self, a):
+ return np.log(a)
+
+ def sqrt(self, a):
+ return np.sqrt(a)
+
+ def norm(self, a):
+ return np.sqrt(np.sum(np.square(a)))
+
+ def any(self, a):
+ return np.any(a)
+
+ def isnan(self, a):
+ return np.isnan(a)
+
+ def isinf(self, a):
+ return np.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return np.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return np.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return np.argsort(a, axis)
+
+ def flip(self, a, axis=None):
+ return np.flip(a, axis)
+
+
+class JaxBackend(Backend):
+
+ __name__ = 'jax'
+ __type__ = jax_type
+
+ def to_numpy(self, a):
+ return np.array(a)
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return jnp.array(a)
+ else:
+ return jnp.array(a).astype(type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # no gradients for jax because it is functional
+
+ # does not work
+ # from jax import custom_jvp
+ # @custom_jvp
+ # def f(*inputs):
+ # return val
+ # f.defjvps(*grads)
+ # return f(*inputs)
+
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.zeros(shape)
+ else:
+ return jnp.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.ones(shape)
+ else:
+ return jnp.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return jnp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return jnp.full(shape, fill_value)
+ else:
+ return jnp.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return jnp.eye(N, M)
+ else:
+ return jnp.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return jnp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return jnp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return jnp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return jnp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return jnp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return jnp.minimum(a, b)
+
+ def dot(self, a, b):
+ return jnp.dot(a, b)
+
+ def abs(self, a):
+ return jnp.abs(a)
+
+ def exp(self, a):
+ return jnp.exp(a)
+
+ def log(self, a):
+ return jnp.log(a)
+
+ def sqrt(self, a):
+ return jnp.sqrt(a)
+
+ def norm(self, a):
+ return jnp.sqrt(jnp.sum(jnp.square(a)))
+
+ def any(self, a):
+ return jnp.any(a)
+
+ def isnan(self, a):
+ return jnp.isnan(a)
+
+ def isinf(self, a):
+ return jnp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return jnp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return jnp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return jnp.argsort(a, axis)
+
+ def flip(self, a, axis=None):
+ return jnp.flip(a, axis)
+
+
+class TorchBackend(Backend):
+
+ __name__ = 'torch'
+ __type__ = torch_type
+
+ def to_numpy(self, a):
+ return a.cpu().detach().numpy()
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return torch.from_numpy(a)
+ else:
+ return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)
+
+ def set_gradients(self, val, inputs, grads):
+ from torch.autograd import Function
+
+ # define a function that takes inputs and return val
+ class ValFunction(Function):
+ @staticmethod
+ def forward(ctx, *inputs):
+ return val
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # the gradients are grad
+ return grads
+
+ return ValFunction.apply(*inputs)
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return torch.zeros(shape)
+ else:
+ return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return torch.ones(shape)
+ else:
+ return torch.ones(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ if type_as is None:
+ return torch.arange(start, stop, step)
+ else:
+ return torch.arange(start, stop, step, device=type_as.device)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return torch.full(shape, fill_value)
+ else:
+ return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device)
+
+ def eye(self, N, M=None, type_as=None):
+ if M is None:
+ M = N
+ if type_as is None:
+ return torch.eye(N, m=M)
+ else:
+ return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device)
+
+ def sum(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.sum(a)
+ else:
+ return torch.sum(a, axis, keepdim=keepdims)
+
+ def cumsum(self, a, axis=None):
+ if axis is None:
+ return torch.cumsum(a.flatten(), 0)
+ else:
+ return torch.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.max(a)
+ else:
+ return torch.max(a, axis, keepdim=keepdims)[0]
+
+ def min(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.min(a)
+ else:
+ return torch.min(a, axis, keepdim=keepdims)[0]
+
+ def maximum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ return torch.maximum(a, b)
+
+ def minimum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ return torch.minimum(a, b)
+
+ def dot(self, a, b):
+ if len(a.shape) == len(b.shape) == 1:
+ return torch.dot(a, b)
+ elif len(a.shape) == 2 and len(b.shape) == 1:
+ return torch.mv(a, b)
+ else:
+ return torch.mm(a, b)
+
+ def abs(self, a):
+ return torch.abs(a)
+
+ def exp(self, a):
+ return torch.exp(a)
+
+ def log(self, a):
+ return torch.log(a)
+
+ def sqrt(self, a):
+ return torch.sqrt(a)
+
+ def norm(self, a):
+ return torch.sqrt(torch.sum(torch.square(a)))
+
+ def any(self, a):
+ return torch.any(a)
+
+ def isnan(self, a):
+ return torch.isnan(a)
+
+ def isinf(self, a):
+ return torch.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return torch.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ sorted0, indices = torch.sort(a, dim=axis)
+ return sorted0
+
+ def argsort(self, a, axis=-1):
+ sorted, indices = torch.sort(a, dim=axis)
+ return indices
+
+ def flip(self, a, axis=None):
+ if axis is None:
+ return torch.flip(a, tuple(i for i in range(len(a.shape))))
+ if isinstance(axis, int):
+ return torch.flip(a, (axis,))
+ else:
+ return torch.flip(a, dims=axis)