diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2021-06-01 10:10:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-01 10:10:54 +0200 |
commit | 184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch) | |
tree | 483a7274c91030fd644de49b03a5fad04af9deba /ot | |
parent | 1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff) |
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends
* stat sets on functions
* proper import
* install recent torch on windows
* install recent torch on windows
* now testing all functions in backedn
* add jax backedn
* clenaup windowds
* proper convert for jax backedn
* pep8
* try again windows tests
* test jax conversion
* try proper widows tests
* emd fuction ses backedn
* better test partial OT
* proper tests to_numpy and teplate Backend
* pep8
* pep8 x2
* feaking sinkhorn works with torch
* sinkhorn2 compatible
* working ot.emd2
* important detach
* it should work
* jax autodiff emd
* pep8
* no tast same for jax
* new independat tests per backedn
* freaking pep8
* add tests for gradients
* deprecate ot.gpu
* worging dist function
* working dist
* dist done in backedn
* not in
* remove indexing
* change accuacy for jax
* first pull backend
* projection simplex
* projection simplex
* projection simplex
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8
* add backedn discusion to quickstart guide
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8 + better doc
* proper links
* corect doctest
* big debug documentation
* doctest again
* doctest again bis
* doctest again ter (last one or i kill myself)
* backend test + doc proj simplex
* correction test_utils
* correction test_utils
* correction cumsum
* correction flip
* correction flip v2
* more debug
* more debug
* more debug + pep8
* pep8
* argh
* proj_simplex
* backedn works for sort
* proj simplex
* jax sucks
* update doc
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/lp/__init__.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* begin comment alex
* comment alex part 2
* optimize test gromov
* proj_simplex on vectors
* add awesome gradient decsnt example on the weights
* pep98 of course
* proof read example by alex
* pep8 again
* encoding oos in translation
* correct legend
Co-authored-by: Nicolas Courty <ncourty@irisa.fr>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot')
-rw-r--r-- | ot/__init__.py | 1 | ||||
-rw-r--r-- | ot/backend.py | 536 | ||||
-rw-r--r-- | ot/bregman.py | 141 | ||||
-rw-r--r-- | ot/gpu/__init__.py | 4 | ||||
-rw-r--r-- | ot/lp/__init__.py | 137 | ||||
-rw-r--r-- | ot/utils.py | 128 |
6 files changed, 807 insertions, 140 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index 5a8a415..3b072c6 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -33,6 +33,7 @@ from . import smooth from . import stochastic from . import unbalanced from . import partial +from . import backend # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d diff --git a/ot/backend.py b/ot/backend.py new file mode 100644 index 0000000..d68f5cf --- /dev/null +++ b/ot/backend.py @@ -0,0 +1,536 @@ +# -*- coding: utf-8 -*- +""" +Multi-lib backend for POT +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + +import numpy as np + +try: + import torch + torch_type = torch.Tensor +except ImportError: + torch = False + torch_type = float + +try: + import jax + import jax.numpy as jnp + jax_type = jax.numpy.ndarray +except ImportError: + jax = False + jax_type = float + +str_type_error = "All array should be from the same type/backend. Current types are : {}" + + +def get_backend_list(): + """ returns the list of available backends)""" + lst = [NumpyBackend(), ] + + if torch: + lst.append(TorchBackend()) + + if jax: + lst.append(JaxBackend()) + + return lst + + +def get_backend(*args): + """returns the proper backend for a list of input arrays + + Also raises TypeError if all arrays are not from the same backend + """ + # check that some arrays given + if not len(args) > 0: + raise ValueError(" The function takes at least one parameter") + # check all same type + + if isinstance(args[0], np.ndarray): + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) + return NumpyBackend() + elif torch and isinstance(args[0], torch_type): + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) + return TorchBackend() + elif isinstance(args[0], jax_type): + return JaxBackend() + else: + raise ValueError("Unknown type of non implemented backend.") + + +def to_numpy(*args): + """returns numpy arrays from any compatible backend""" + + if len(args) == 1: + return get_backend(args[0]).to_numpy(args[0]) + else: + return [get_backend(a).to_numpy(a) for a in args] + + +class Backend(): + + __name__ = None + __type__ = None + + def __str__(self): + return self.__name__ + + # convert to numpy + def to_numpy(self, a): + raise NotImplementedError() + + # convert from numpy + def from_numpy(self, a, type_as=None): + raise NotImplementedError() + + def set_gradients(self, val, inputs, grads): + """ define the gradients for the value val wrt the inputs """ + raise NotImplementedError() + + def zeros(self, shape, type_as=None): + raise NotImplementedError() + + def ones(self, shape, type_as=None): + raise NotImplementedError() + + def arange(self, stop, start=0, step=1, type_as=None): + raise NotImplementedError() + + def full(self, shape, fill_value, type_as=None): + raise NotImplementedError() + + def eye(self, N, M=None, type_as=None): + raise NotImplementedError() + + def sum(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def cumsum(self, a, axis=None): + raise NotImplementedError() + + def max(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def min(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def maximum(self, a, b): + raise NotImplementedError() + + def minimum(self, a, b): + raise NotImplementedError() + + def dot(self, a, b): + raise NotImplementedError() + + def abs(self, a): + raise NotImplementedError() + + def exp(self, a): + raise NotImplementedError() + + def log(self, a): + raise NotImplementedError() + + def sqrt(self, a): + raise NotImplementedError() + + def norm(self, a): + raise NotImplementedError() + + def any(self, a): + raise NotImplementedError() + + def isnan(self, a): + raise NotImplementedError() + + def isinf(self, a): + raise NotImplementedError() + + def einsum(self, subscripts, *operands): + raise NotImplementedError() + + def sort(self, a, axis=-1): + raise NotImplementedError() + + def argsort(self, a, axis=None): + raise NotImplementedError() + + def flip(self, a, axis=None): + raise NotImplementedError() + + +class NumpyBackend(Backend): + + __name__ = 'numpy' + __type__ = np.ndarray + + def to_numpy(self, a): + return a + + def from_numpy(self, a, type_as=None): + if type_as is None: + return a + elif isinstance(a, float): + return a + else: + return a.astype(type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # no gradients for numpy + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return np.zeros(shape) + else: + return np.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return np.ones(shape) + else: + return np.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return np.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return np.full(shape, fill_value) + else: + return np.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return np.eye(N, M) + else: + return np.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return np.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return np.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return np.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return np.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return np.maximum(a, b) + + def minimum(self, a, b): + return np.minimum(a, b) + + def dot(self, a, b): + return np.dot(a, b) + + def abs(self, a): + return np.abs(a) + + def exp(self, a): + return np.exp(a) + + def log(self, a): + return np.log(a) + + def sqrt(self, a): + return np.sqrt(a) + + def norm(self, a): + return np.sqrt(np.sum(np.square(a))) + + def any(self, a): + return np.any(a) + + def isnan(self, a): + return np.isnan(a) + + def isinf(self, a): + return np.isinf(a) + + def einsum(self, subscripts, *operands): + return np.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return np.sort(a, axis) + + def argsort(self, a, axis=-1): + return np.argsort(a, axis) + + def flip(self, a, axis=None): + return np.flip(a, axis) + + +class JaxBackend(Backend): + + __name__ = 'jax' + __type__ = jax_type + + def to_numpy(self, a): + return np.array(a) + + def from_numpy(self, a, type_as=None): + if type_as is None: + return jnp.array(a) + else: + return jnp.array(a).astype(type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # no gradients for jax because it is functional + + # does not work + # from jax import custom_jvp + # @custom_jvp + # def f(*inputs): + # return val + # f.defjvps(*grads) + # return f(*inputs) + + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return jnp.zeros(shape) + else: + return jnp.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return jnp.ones(shape) + else: + return jnp.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return jnp.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return jnp.full(shape, fill_value) + else: + return jnp.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return jnp.eye(N, M) + else: + return jnp.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return jnp.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return jnp.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return jnp.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return jnp.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return jnp.maximum(a, b) + + def minimum(self, a, b): + return jnp.minimum(a, b) + + def dot(self, a, b): + return jnp.dot(a, b) + + def abs(self, a): + return jnp.abs(a) + + def exp(self, a): + return jnp.exp(a) + + def log(self, a): + return jnp.log(a) + + def sqrt(self, a): + return jnp.sqrt(a) + + def norm(self, a): + return jnp.sqrt(jnp.sum(jnp.square(a))) + + def any(self, a): + return jnp.any(a) + + def isnan(self, a): + return jnp.isnan(a) + + def isinf(self, a): + return jnp.isinf(a) + + def einsum(self, subscripts, *operands): + return jnp.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return jnp.sort(a, axis) + + def argsort(self, a, axis=-1): + return jnp.argsort(a, axis) + + def flip(self, a, axis=None): + return jnp.flip(a, axis) + + +class TorchBackend(Backend): + + __name__ = 'torch' + __type__ = torch_type + + def to_numpy(self, a): + return a.cpu().detach().numpy() + + def from_numpy(self, a, type_as=None): + if type_as is None: + return torch.from_numpy(a) + else: + return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device) + + def set_gradients(self, val, inputs, grads): + from torch.autograd import Function + + # define a function that takes inputs and return val + class ValFunction(Function): + @staticmethod + def forward(ctx, *inputs): + return val + + @staticmethod + def backward(ctx, grad_output): + # the gradients are grad + return grads + + return ValFunction.apply(*inputs) + + def zeros(self, shape, type_as=None): + if type_as is None: + return torch.zeros(shape) + else: + return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device) + + def ones(self, shape, type_as=None): + if type_as is None: + return torch.ones(shape) + else: + return torch.ones(shape, dtype=type_as.dtype, device=type_as.device) + + def arange(self, stop, start=0, step=1, type_as=None): + if type_as is None: + return torch.arange(start, stop, step) + else: + return torch.arange(start, stop, step, device=type_as.device) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return torch.full(shape, fill_value) + else: + return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device) + + def eye(self, N, M=None, type_as=None): + if M is None: + M = N + if type_as is None: + return torch.eye(N, m=M) + else: + return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device) + + def sum(self, a, axis=None, keepdims=False): + if axis is None: + return torch.sum(a) + else: + return torch.sum(a, axis, keepdim=keepdims) + + def cumsum(self, a, axis=None): + if axis is None: + return torch.cumsum(a.flatten(), 0) + else: + return torch.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + if axis is None: + return torch.max(a) + else: + return torch.max(a, axis, keepdim=keepdims)[0] + + def min(self, a, axis=None, keepdims=False): + if axis is None: + return torch.min(a) + else: + return torch.min(a, axis, keepdim=keepdims)[0] + + def maximum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + return torch.maximum(a, b) + + def minimum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + return torch.minimum(a, b) + + def dot(self, a, b): + if len(a.shape) == len(b.shape) == 1: + return torch.dot(a, b) + elif len(a.shape) == 2 and len(b.shape) == 1: + return torch.mv(a, b) + else: + return torch.mm(a, b) + + def abs(self, a): + return torch.abs(a) + + def exp(self, a): + return torch.exp(a) + + def log(self, a): + return torch.log(a) + + def sqrt(self, a): + return torch.sqrt(a) + + def norm(self, a): + return torch.sqrt(torch.sum(torch.square(a))) + + def any(self, a): + return torch.any(a) + + def isnan(self, a): + return torch.isnan(a) + + def isinf(self, a): + return torch.isinf(a) + + def einsum(self, subscripts, *operands): + return torch.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + sorted0, indices = torch.sort(a, dim=axis) + return sorted0 + + def argsort(self, a, axis=-1): + sorted, indices = torch.sort(a, dim=axis) + return indices + + def flip(self, a, axis=None): + if axis is None: + return torch.flip(a, tuple(i for i in range(len(a.shape)))) + if isinstance(axis, int): + return torch.flip(a, (axis,)) + else: + return torch.flip(a, dims=axis) diff --git a/ot/bregman.py b/ot/bregman.py index 559db14..b10effd 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -19,7 +19,8 @@ import warnings import numpy as np from scipy.optimize import fmin_l_bfgs_b -from ot.utils import unif, dist +from ot.utils import unif, dist, list_to_array +from .backend import get_backend def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -43,17 +44,36 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in [2]_ + + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -69,25 +89,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. - - Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -166,17 +170,35 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (histograms, both sum to 1) + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. + Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -191,28 +213,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. - Returns ------- - W : (n_hists) ndarray + W : (n_hists) float/array-like Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters + Examples -------- @@ -247,7 +255,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] """ - b = np.asarray(b, dtype=np.float64) + + b = list_to_array(b) if len(b.shape) < 2: b = b[:, None] @@ -339,14 +348,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) # init data dim_a = len(a) @@ -363,21 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b - # print(reg) - - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - - # print(np.min(K)) - tmp2 = np.empty(b.shape, dtype=M.dtype) + K = nx.exp(M / (-reg)) Kp = (1 / a).reshape(-1, 1) * K cpt = 0 @@ -386,13 +387,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, uprev = u vprev = v - KtransposeU = np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.dot(Kp, v) + KtransposeU = nx.dot(K.T, u) + v = b / KtransposeU + u = 1. / nx.dot(Kp, v) - if (np.any(KtransposeU == 0) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(KtransposeU == 0) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) @@ -403,11 +404,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: - np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2) + tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - np.einsum('i,ij,j->j', u, K, v, out=tmp2) - err = np.linalg.norm(tmp2 - b) # violation of marginal + tmp2 = nx.einsum('i,ij,j->j', u, K, v) + err = nx.norm(tmp2 - b) # violation of marginal if log: log['err'].append(err) @@ -422,7 +423,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, log['v'] = v if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py index 7478fb9..e939610 100644 --- a/ot/gpu/__init__.py +++ b/ot/gpu/__init__.py @@ -25,6 +25,8 @@ result of the function with parameter ``to_numpy=False``. # # License: MIT License +import warnings + from . import bregman from . import da from .bregman import sinkhorn @@ -34,7 +36,7 @@ from . import utils from .utils import dist, to_gpu, to_np - +warnings.warn('This module will be deprecated in the next minor release of POT', category=DeprecationWarning) __all__ = ["utils", "dist", "sinkhorn", diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index d5c3a5e..c8c9da6 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -18,8 +18,9 @@ from . import cvx from .cvx import barycenter # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from ..utils import dist +from ..utils import dist, list_to_array from ..utils import parmap +from ..backend import get_backend __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -176,8 +177,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): r"""Solves the Earth Movers distance problem and returns the OT matrix - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + .. math:: \gamma = arg\min_\gamma <\gamma,M>_F s.t. \gamma 1 = a @@ -189,37 +189,41 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): - M is the metric cost matrix - a and b are the sample weights - .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 - format. + .. warning:: Note that the M matrix in numpy needs to be a C-order + numpy.array in float64 format. It will be converted if not in this + format + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) numpy.ndarray, float64 + a : (ns,) array-like, float Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 - Target histogram (uniform weight if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) - numItermax : int, optional (default=100000) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) + numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual - variables. Otherwise returns only the optimal transportation matrix. + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. center_dual: boolean, optional (default=True) - If True, centers the dual potential using function + If True, centers the dual potential using function :ref:`center_ot_dual`. Returns ------- - gamma: (ns x nt) numpy.ndarray - Optimal transportation matrix for the given parameters - log: dict - If input log is true, a dictionary containing the cost and dual - variables and exit status + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status Examples @@ -232,26 +236,37 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] - >>> ot.emd(a,b,M) + >>> ot.emd(a, b, M) array([[0.5, 0. ], [0. , 0.5]]) References ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. - (2011, December). Displacement interpolation using Lagrangian mass - transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. - 158). ACM. + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, + December). Displacement interpolation using Lagrangian mass transport. + In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. See Also -------- - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT""" - + ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General + regularized OT""" + + # convert to numpy if list + a, b, M = list_to_array(a, b, M) + + a0, b0, M0 = a, b, M + nx = get_backend(M0, a0, b0) + + # convert to numpy + M = nx.to_numpy(M) + a = nx.to_numpy(a) + b = nx.to_numpy(b) + + # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order='C') # if empty array given then use uniform distributions if len(a) == 0: @@ -262,6 +277,11 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0), + b.sum(0), err_msg='a and b vector must have the same sum') + b=b*a.sum()/b.sum() + asel = a != 0 bsel = b != 0 @@ -277,12 +297,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): if log: log = {} log['cost'] = cost - log['u'] = u - log['v'] = v + log['u'] = nx.from_numpy(u, type_as=a0) + log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code - return G, log - return G + return nx.from_numpy(G, type_as=M0), log + return nx.from_numpy(G, type_as=M0) def emd2(a, b, M, processes=multiprocessing.cpu_count(), @@ -303,20 +323,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), - M is the metric cost matrix - a and b are the sample weights - .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 - format. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) numpy.ndarray, float64 + a : (ns,) array-like, float64 Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 + b : (nt,) array-like, float64 Target histogram (uniform weight if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) + M : (ns,nt) array-like, float64 + Loss matrix (for numpy c-order array with type float64) processes : int, optional (default=nb cpu) Nb of processes used for multiple emd computation (not used on windows) numItermax : int, optional (default=100000) @@ -333,9 +352,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), Returns ------- - W: float + W: float, array-like Optimal transportation loss for the given parameters - log: dictnp + log: dict If input log is true, a dictionary containing dual variables and exit status @@ -367,12 +386,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General regularized OT""" + a, b, M = list_to_array(a, b, M) + + a0, b0, M0 = a, b, M + nx = get_backend(M0, a0, b0) + + # convert to numpy + M = nx.to_numpy(M) + a = nx.to_numpy(a) + b = nx.to_numpy(b) + a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order= 'C') # problem with pikling Forks - if sys.platform.endswith('win32'): + if sys.platform.endswith('win32') or not nx.__name__ == 'numpy': processes = 1 # if empty array given then use uniform distributions @@ -400,12 +429,15 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), result_code_string = check_result(result_code) log = {} + G = nx.from_numpy(G, type_as=M0) if return_matrix: log['G'] = G - log['u'] = u - log['v'] = v + log['u'] = nx.from_numpy(u, type_as=a0) + log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0,b0, M0), (log['u'], log['v'], G)) return [cost, log] else: def f(b): @@ -418,6 +450,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) + G = nx.from_numpy(G, type_as=M0) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0,b0, M0), (nx.from_numpy(u, type_as=a0), + nx.from_numpy(v, type_as=b0),G)) + check_result(result_code) return cost @@ -637,6 +674,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, if b.ndim == 0 or len(b) == 0: b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum') + b=b*a.sum()/b.sum() + x_a_1d = x_a.reshape((-1,)) x_b_1d = x_b.reshape((-1,)) perm_a = np.argsort(x_a_1d) diff --git a/ot/utils.py b/ot/utils.py index 544c569..4dac0c5 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -16,6 +16,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature +from .backend import get_backend __time_tic_toc = time.time() @@ -41,8 +42,11 @@ def toq(): def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): """Compute kernel matrix""" + + nx = get_backend(x1, x2) + if method.lower() in ['gaussian', 'gauss', 'rbf']: - K = np.exp(-dist(x1, x2) / (2 * sigma**2)) + K = nx.exp(-dist(x1, x2) / (2 * sigma**2)) return K @@ -52,6 +56,66 @@ def laplacian(x): return L +def list_to_array(*lst): + """ Convert a list if in numpy format """ + if len(lst) > 1: + return [np.array(a) if isinstance(a, list) else a for a in lst] + else: + return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + + +def proj_simplex(v, z=1): + r""" compute the closest point (orthogonal projection) on the + generalized (n-1)-simplex of a vector v wrt. to the Euclidean + distance, thus solving: + .. math:: + \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2 + + s.t. \gamma^T 1= z + + \gamma\geq 0 + + If v is a 2d array, compute all the projections wrt. axis 0 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + Parameters + ---------- + v : {array-like}, shape (n, d) + z : int, optional + 'size' of the simplex (each vectors sum to z, 1 by default) + + Returns + ------- + h : ndarray, shape (n,d) + Array of projections on the simplex + """ + nx = get_backend(v) + n = v.shape[0] + if v.ndim == 1: + d1 = 1 + v = v[:, None] + else: + d1 = 0 + d = v.shape[1] + + # sort u in ascending order + u = nx.sort(v, axis=0) + # take the descending order + u = nx.flip(u, 0) + cssv = nx.cumsum(u, axis=0) - z + ind = nx.arange(n, type_as=v)[:, None] + 1 + cond = u - cssv / ind > 0 + rho = nx.sum(cond, 0) + theta = cssv[rho - 1, nx.arange(d)] / rho + w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v)) + if d1: + return w[:, 0] + else: + return w + + def unif(n): """ return a uniform histogram of length n (simplex) @@ -84,52 +148,68 @@ def euclidean_distances(X, Y, squared=False): """ Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair of vectors. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + Parameters ---------- X : {array-like}, shape (n_samples_1, n_features) Y : {array-like}, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. + Returns ------- distances : {array}, shape (n_samples_1, n_samples_2) """ - XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] - YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] - distances = np.dot(X, Y.T) - distances *= -2 - distances += XX - distances += YY - np.maximum(distances, 0, out=distances) + + nx = get_backend(X, Y) + + a2 = nx.einsum('ij,ij->i', X, X) + b2 = nx.einsum('ij,ij->i', Y, Y) + + c = -2 * nx.dot(X, Y.T) + c += a2[:, None] + c += b2[None, :] + + c = nx.maximum(c, 0) + + if not squared: + c = nx.sqrt(c) + if X is Y: - # Ensure that distances between vectors and themselves are set to 0.0. - # This may not be the case due to floating point rounding errors. - distances.flat[::distances.shape[0] + 1] = 0.0 - return distances if squared else np.sqrt(distances, out=distances) + c = c * (1 - nx.eye(X.shape[0], type_as=c)) + + return c def dist(x1, x2=None, metric='sqeuclidean'): - """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + """Compute distance between samples in x1 and x2 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Parameters ---------- - x1 : ndarray, shape (n1,d) + x1 : array-like, shape (n1,d) matrix with n1 samples of size d - x2 : array, shape (n2,d), optional + x2 : array-like, shape (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) metric : str | callable, optional - Name of the metric to be computed (full list in the doc of scipy), If a string, - the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', - 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', - 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also + accepts from the scipy.spatial.distance.cdist function : 'braycurtis', + 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', + 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. Returns ------- - M : np.array (n1,n2) + M : array-like, shape (n1, n2) distance matrix computed with given metric """ @@ -137,7 +217,13 @@ def dist(x1, x2=None, metric='sqeuclidean'): x2 = x1 if metric == "sqeuclidean": return euclidean_distances(x1, x2, squared=True) - return cdist(x1, x2, metric=metric) + elif metric == "euclidean": + return euclidean_distances(x1, x2, squared=False) + else: + if not get_backend(x1, x2).__name__ == 'numpy': + raise NotImplementedError() + else: + return cdist(x1, x2, metric=metric) def dist0(n, method='lin_square'): |