summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2021-06-01 10:10:54 +0200
committerGitHub <noreply@github.com>2021-06-01 10:10:54 +0200
commit184f8f4f7ac78f1dd7f653496d2753211a4e3426 (patch)
tree483a7274c91030fd644de49b03a5fad04af9deba /ot
parent1f16614954e2522fbdb1598c5b1f5c3630c68472 (diff)
[MRG] POT numpy/torch/jax backends (#249)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty <ncourty@irisa.fr> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py1
-rw-r--r--ot/backend.py536
-rw-r--r--ot/bregman.py141
-rw-r--r--ot/gpu/__init__.py4
-rw-r--r--ot/lp/__init__.py137
-rw-r--r--ot/utils.py128
6 files changed, 807 insertions, 140 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 5a8a415..3b072c6 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -33,6 +33,7 @@ from . import smooth
from . import stochastic
from . import unbalanced
from . import partial
+from . import backend
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
diff --git a/ot/backend.py b/ot/backend.py
new file mode 100644
index 0000000..d68f5cf
--- /dev/null
+++ b/ot/backend.py
@@ -0,0 +1,536 @@
+# -*- coding: utf-8 -*-
+"""
+Multi-lib backend for POT
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+
+try:
+ import torch
+ torch_type = torch.Tensor
+except ImportError:
+ torch = False
+ torch_type = float
+
+try:
+ import jax
+ import jax.numpy as jnp
+ jax_type = jax.numpy.ndarray
+except ImportError:
+ jax = False
+ jax_type = float
+
+str_type_error = "All array should be from the same type/backend. Current types are : {}"
+
+
+def get_backend_list():
+ """ returns the list of available backends)"""
+ lst = [NumpyBackend(), ]
+
+ if torch:
+ lst.append(TorchBackend())
+
+ if jax:
+ lst.append(JaxBackend())
+
+ return lst
+
+
+def get_backend(*args):
+ """returns the proper backend for a list of input arrays
+
+ Also raises TypeError if all arrays are not from the same backend
+ """
+ # check that some arrays given
+ if not len(args) > 0:
+ raise ValueError(" The function takes at least one parameter")
+ # check all same type
+
+ if isinstance(args[0], np.ndarray):
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
+ return NumpyBackend()
+ elif torch and isinstance(args[0], torch_type):
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
+ return TorchBackend()
+ elif isinstance(args[0], jax_type):
+ return JaxBackend()
+ else:
+ raise ValueError("Unknown type of non implemented backend.")
+
+
+def to_numpy(*args):
+ """returns numpy arrays from any compatible backend"""
+
+ if len(args) == 1:
+ return get_backend(args[0]).to_numpy(args[0])
+ else:
+ return [get_backend(a).to_numpy(a) for a in args]
+
+
+class Backend():
+
+ __name__ = None
+ __type__ = None
+
+ def __str__(self):
+ return self.__name__
+
+ # convert to numpy
+ def to_numpy(self, a):
+ raise NotImplementedError()
+
+ # convert from numpy
+ def from_numpy(self, a, type_as=None):
+ raise NotImplementedError()
+
+ def set_gradients(self, val, inputs, grads):
+ """ define the gradients for the value val wrt the inputs """
+ raise NotImplementedError()
+
+ def zeros(self, shape, type_as=None):
+ raise NotImplementedError()
+
+ def ones(self, shape, type_as=None):
+ raise NotImplementedError()
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ raise NotImplementedError()
+
+ def full(self, shape, fill_value, type_as=None):
+ raise NotImplementedError()
+
+ def eye(self, N, M=None, type_as=None):
+ raise NotImplementedError()
+
+ def sum(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def cumsum(self, a, axis=None):
+ raise NotImplementedError()
+
+ def max(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def min(self, a, axis=None, keepdims=False):
+ raise NotImplementedError()
+
+ def maximum(self, a, b):
+ raise NotImplementedError()
+
+ def minimum(self, a, b):
+ raise NotImplementedError()
+
+ def dot(self, a, b):
+ raise NotImplementedError()
+
+ def abs(self, a):
+ raise NotImplementedError()
+
+ def exp(self, a):
+ raise NotImplementedError()
+
+ def log(self, a):
+ raise NotImplementedError()
+
+ def sqrt(self, a):
+ raise NotImplementedError()
+
+ def norm(self, a):
+ raise NotImplementedError()
+
+ def any(self, a):
+ raise NotImplementedError()
+
+ def isnan(self, a):
+ raise NotImplementedError()
+
+ def isinf(self, a):
+ raise NotImplementedError()
+
+ def einsum(self, subscripts, *operands):
+ raise NotImplementedError()
+
+ def sort(self, a, axis=-1):
+ raise NotImplementedError()
+
+ def argsort(self, a, axis=None):
+ raise NotImplementedError()
+
+ def flip(self, a, axis=None):
+ raise NotImplementedError()
+
+
+class NumpyBackend(Backend):
+
+ __name__ = 'numpy'
+ __type__ = np.ndarray
+
+ def to_numpy(self, a):
+ return a
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return a
+ elif isinstance(a, float):
+ return a
+ else:
+ return a.astype(type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # no gradients for numpy
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return np.zeros(shape)
+ else:
+ return np.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return np.ones(shape)
+ else:
+ return np.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return np.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return np.full(shape, fill_value)
+ else:
+ return np.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return np.eye(N, M)
+ else:
+ return np.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return np.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return np.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return np.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return np.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return np.maximum(a, b)
+
+ def minimum(self, a, b):
+ return np.minimum(a, b)
+
+ def dot(self, a, b):
+ return np.dot(a, b)
+
+ def abs(self, a):
+ return np.abs(a)
+
+ def exp(self, a):
+ return np.exp(a)
+
+ def log(self, a):
+ return np.log(a)
+
+ def sqrt(self, a):
+ return np.sqrt(a)
+
+ def norm(self, a):
+ return np.sqrt(np.sum(np.square(a)))
+
+ def any(self, a):
+ return np.any(a)
+
+ def isnan(self, a):
+ return np.isnan(a)
+
+ def isinf(self, a):
+ return np.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return np.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return np.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return np.argsort(a, axis)
+
+ def flip(self, a, axis=None):
+ return np.flip(a, axis)
+
+
+class JaxBackend(Backend):
+
+ __name__ = 'jax'
+ __type__ = jax_type
+
+ def to_numpy(self, a):
+ return np.array(a)
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return jnp.array(a)
+ else:
+ return jnp.array(a).astype(type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # no gradients for jax because it is functional
+
+ # does not work
+ # from jax import custom_jvp
+ # @custom_jvp
+ # def f(*inputs):
+ # return val
+ # f.defjvps(*grads)
+ # return f(*inputs)
+
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.zeros(shape)
+ else:
+ return jnp.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.ones(shape)
+ else:
+ return jnp.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return jnp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return jnp.full(shape, fill_value)
+ else:
+ return jnp.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return jnp.eye(N, M)
+ else:
+ return jnp.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return jnp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return jnp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return jnp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return jnp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return jnp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return jnp.minimum(a, b)
+
+ def dot(self, a, b):
+ return jnp.dot(a, b)
+
+ def abs(self, a):
+ return jnp.abs(a)
+
+ def exp(self, a):
+ return jnp.exp(a)
+
+ def log(self, a):
+ return jnp.log(a)
+
+ def sqrt(self, a):
+ return jnp.sqrt(a)
+
+ def norm(self, a):
+ return jnp.sqrt(jnp.sum(jnp.square(a)))
+
+ def any(self, a):
+ return jnp.any(a)
+
+ def isnan(self, a):
+ return jnp.isnan(a)
+
+ def isinf(self, a):
+ return jnp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return jnp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return jnp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return jnp.argsort(a, axis)
+
+ def flip(self, a, axis=None):
+ return jnp.flip(a, axis)
+
+
+class TorchBackend(Backend):
+
+ __name__ = 'torch'
+ __type__ = torch_type
+
+ def to_numpy(self, a):
+ return a.cpu().detach().numpy()
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return torch.from_numpy(a)
+ else:
+ return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)
+
+ def set_gradients(self, val, inputs, grads):
+ from torch.autograd import Function
+
+ # define a function that takes inputs and return val
+ class ValFunction(Function):
+ @staticmethod
+ def forward(ctx, *inputs):
+ return val
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # the gradients are grad
+ return grads
+
+ return ValFunction.apply(*inputs)
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return torch.zeros(shape)
+ else:
+ return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return torch.ones(shape)
+ else:
+ return torch.ones(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ if type_as is None:
+ return torch.arange(start, stop, step)
+ else:
+ return torch.arange(start, stop, step, device=type_as.device)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return torch.full(shape, fill_value)
+ else:
+ return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device)
+
+ def eye(self, N, M=None, type_as=None):
+ if M is None:
+ M = N
+ if type_as is None:
+ return torch.eye(N, m=M)
+ else:
+ return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device)
+
+ def sum(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.sum(a)
+ else:
+ return torch.sum(a, axis, keepdim=keepdims)
+
+ def cumsum(self, a, axis=None):
+ if axis is None:
+ return torch.cumsum(a.flatten(), 0)
+ else:
+ return torch.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.max(a)
+ else:
+ return torch.max(a, axis, keepdim=keepdims)[0]
+
+ def min(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.min(a)
+ else:
+ return torch.min(a, axis, keepdim=keepdims)[0]
+
+ def maximum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ return torch.maximum(a, b)
+
+ def minimum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ return torch.minimum(a, b)
+
+ def dot(self, a, b):
+ if len(a.shape) == len(b.shape) == 1:
+ return torch.dot(a, b)
+ elif len(a.shape) == 2 and len(b.shape) == 1:
+ return torch.mv(a, b)
+ else:
+ return torch.mm(a, b)
+
+ def abs(self, a):
+ return torch.abs(a)
+
+ def exp(self, a):
+ return torch.exp(a)
+
+ def log(self, a):
+ return torch.log(a)
+
+ def sqrt(self, a):
+ return torch.sqrt(a)
+
+ def norm(self, a):
+ return torch.sqrt(torch.sum(torch.square(a)))
+
+ def any(self, a):
+ return torch.any(a)
+
+ def isnan(self, a):
+ return torch.isnan(a)
+
+ def isinf(self, a):
+ return torch.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return torch.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ sorted0, indices = torch.sort(a, dim=axis)
+ return sorted0
+
+ def argsort(self, a, axis=-1):
+ sorted, indices = torch.sort(a, dim=axis)
+ return indices
+
+ def flip(self, a, axis=None):
+ if axis is None:
+ return torch.flip(a, tuple(i for i in range(len(a.shape))))
+ if isinstance(axis, int):
+ return torch.flip(a, (axis,))
+ else:
+ return torch.flip(a, dims=axis)
diff --git a/ot/bregman.py b/ot/bregman.py
index 559db14..b10effd 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -19,7 +19,8 @@ import warnings
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
-from ot.utils import unif, dist
+from ot.utils import unif, dist, list_to_array
+from .backend import get_backend
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -43,17 +44,36 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in [2]_
+
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
+ fast approximation of the Sinkhorn problem.
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -69,25 +89,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True
- **Choosing a Sinkhorn solver**
-
- By default and when using a regularization parameter that is not too small
- the default sinkhorn solver should be enough. If you need to use a small
- regularization to get sharper OT matrices, you should use the
- :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
- errors. This last solver can be very slow in practice and might not even
- converge to a reasonable OT matrix in a finite time. This is why
- :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value
- of the regularization (and using warm start) sometimes leads to better
- solutions. Note that the greedy version of the sinkhorn
- :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
- version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
- fast approximation of the Sinkhorn problem.
-
-
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -166,17 +170,35 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (histograms, both sum to 1)
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
+ fast approximation of the Sinkhorn problem.
+
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -191,28 +213,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True
- **Choosing a Sinkhorn solver**
-
- By default and when using a regularization parameter that is not too small
- the default sinkhorn solver should be enough. If you need to use a small
- regularization to get sharper OT matrices, you should use the
- :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
- errors. This last solver can be very slow in practice and might not even
- converge to a reasonable OT matrix in a finite time. This is why
- :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value
- of the regularization (and using warm start) sometimes leads to better
- solutions. Note that the greedy version of the sinkhorn
- :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
- version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
- fast approximation of the Sinkhorn problem.
-
Returns
-------
- W : (n_hists) ndarray
+ W : (n_hists) float/array-like
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
+
Examples
--------
@@ -247,7 +255,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
"""
- b = np.asarray(b, dtype=np.float64)
+
+ b = list_to_array(b)
if len(b.shape) < 2:
b = b[:, None]
@@ -339,14 +348,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
# init data
dim_a = len(a)
@@ -363,21 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
- # print(reg)
-
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
-
- # print(np.min(K))
- tmp2 = np.empty(b.shape, dtype=M.dtype)
+ K = nx.exp(M / (-reg))
Kp = (1 / a).reshape(-1, 1) * K
cpt = 0
@@ -386,13 +387,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
uprev = u
vprev = v
- KtransposeU = np.dot(K.T, u)
- v = np.divide(b, KtransposeU)
- u = 1. / np.dot(Kp, v)
+ KtransposeU = nx.dot(K.T, u)
+ v = b / KtransposeU
+ u = 1. / nx.dot(Kp, v)
- if (np.any(KtransposeU == 0)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(KtransposeU == 0)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)
@@ -403,11 +404,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
- np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
+ tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
- np.einsum('i,ij,j->j', u, K, v, out=tmp2)
- err = np.linalg.norm(tmp2 - b) # violation of marginal
+ tmp2 = nx.einsum('i,ij,j->j', u, K, v)
+ err = nx.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
@@ -422,7 +423,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['v'] = v
if n_hists: # return only loss
- res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index 7478fb9..e939610 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -25,6 +25,8 @@ result of the function with parameter ``to_numpy=False``.
#
# License: MIT License
+import warnings
+
from . import bregman
from . import da
from .bregman import sinkhorn
@@ -34,7 +36,7 @@ from . import utils
from .utils import dist, to_gpu, to_np
-
+warnings.warn('This module will be deprecated in the next minor release of POT', category=DeprecationWarning)
__all__ = ["utils", "dist", "sinkhorn",
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index d5c3a5e..c8c9da6 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -18,8 +18,9 @@ from . import cvx
from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from ..utils import dist
+from ..utils import dist, list_to_array
from ..utils import parmap
+from ..backend import get_backend
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -176,8 +177,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
r"""Solves the Earth Movers distance problem and returns the OT matrix
- .. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ .. math:: \gamma = arg\min_\gamma <\gamma,M>_F
s.t. \gamma 1 = a
@@ -189,37 +189,41 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
- M is the metric cost matrix
- a and b are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. warning:: Note that the M matrix in numpy needs to be a C-order
+ numpy.array in float64 format. It will be converted if not in this
+ format
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Uses the algorithm proposed in [1]_
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
- Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
- numItermax : int, optional (default=100000)
+ b : (nt,) array-like, float
+ Target histogram (uniform weight if empty list)
+ M : (ns,nt) array-like, float
+ Loss matrix (c-order array in numpy with type float64)
+ numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
- algorithm if it has not converged.
- log: bool, optional (default=False)
- If True, returns a dictionary containing the cost and dual
- variables. Otherwise returns only the optimal transportation matrix.
+ algorithm if it has not converged.
+ log: bool, optional (default=False)
+ If True, returns a dictionary containing the cost and dual variables.
+ Otherwise returns only the optimal transportation matrix.
center_dual: boolean, optional (default=True)
- If True, centers the dual potential using function
+ If True, centers the dual potential using function
:ref:`center_ot_dual`.
Returns
-------
- gamma: (ns x nt) numpy.ndarray
- Optimal transportation matrix for the given parameters
- log: dict
- If input log is true, a dictionary containing the cost and dual
- variables and exit status
+ gamma: array-like, shape (ns, nt)
+ Optimal transportation matrix for the given
+ parameters
+ log: dict, optional
+ If input log is true, a dictionary containing the
+ cost and dual variables and exit status
Examples
@@ -232,26 +236,37 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
- >>> ot.emd(a,b,M)
+ >>> ot.emd(a, b, M)
array([[0.5, 0. ],
[0. , 0.5]])
References
----------
- .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
- (2011, December). Displacement interpolation using Lagrangian mass
- transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
- 158). ACM.
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
+ December). Displacement interpolation using Lagrangian mass transport.
+ In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
See Also
--------
- ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT"""
-
+ ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General
+ regularized OT"""
+
+ # convert to numpy if list
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
+
+ # ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64, order='C')
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -262,6 +277,11 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0), err_msg='a and b vector must have the same sum')
+ b=b*a.sum()/b.sum()
+
asel = a != 0
bsel = b != 0
@@ -277,12 +297,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
if log:
log = {}
log['cost'] = cost
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
- return G, log
- return G
+ return nx.from_numpy(G, type_as=M0), log
+ return nx.from_numpy(G, type_as=M0)
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
@@ -303,20 +323,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- M is the metric cost matrix
- a and b are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Uses the algorithm proposed in [1]_
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float64
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
+ b : (nt,) array-like, float64
Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
+ M : (ns,nt) array-like, float64
+ Loss matrix (for numpy c-order array with type float64)
processes : int, optional (default=nb cpu)
Nb of processes used for multiple emd computation (not used on windows)
numItermax : int, optional (default=100000)
@@ -333,9 +352,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
Returns
-------
- W: float
+ W: float, array-like
Optimal transportation loss for the given parameters
- log: dictnp
+ log: dict
If input log is true, a dictionary containing dual
variables and exit status
@@ -367,12 +386,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
ot.bregman.sinkhorn : Entropic regularized OT
ot.optim.cg : General regularized OT"""
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
+
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64, order= 'C')
# problem with pikling Forks
- if sys.platform.endswith('win32'):
+ if sys.platform.endswith('win32') or not nx.__name__ == 'numpy':
processes = 1
# if empty array given then use uniform distributions
@@ -400,12 +429,15 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
result_code_string = check_result(result_code)
log = {}
+ G = nx.from_numpy(G, type_as=M0)
if return_matrix:
log['G'] = G
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0,b0, M0), (log['u'], log['v'], G))
return [cost, log]
else:
def f(b):
@@ -418,6 +450,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
+ G = nx.from_numpy(G, type_as=M0)
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0,b0, M0), (nx.from_numpy(u, type_as=a0),
+ nx.from_numpy(v, type_as=b0),G))
+
check_result(result_code)
return cost
@@ -637,6 +674,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
if b.ndim == 0 or len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum')
+ b=b*a.sum()/b.sum()
+
x_a_1d = x_a.reshape((-1,))
x_b_1d = x_b.reshape((-1,))
perm_a = np.argsort(x_a_1d)
diff --git a/ot/utils.py b/ot/utils.py
index 544c569..4dac0c5 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -16,6 +16,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
+from .backend import get_backend
__time_tic_toc = time.time()
@@ -41,8 +42,11 @@ def toq():
def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
"""Compute kernel matrix"""
+
+ nx = get_backend(x1, x2)
+
if method.lower() in ['gaussian', 'gauss', 'rbf']:
- K = np.exp(-dist(x1, x2) / (2 * sigma**2))
+ K = nx.exp(-dist(x1, x2) / (2 * sigma**2))
return K
@@ -52,6 +56,66 @@ def laplacian(x):
return L
+def list_to_array(*lst):
+ """ Convert a list if in numpy format """
+ if len(lst) > 1:
+ return [np.array(a) if isinstance(a, list) else a for a in lst]
+ else:
+ return np.array(lst[0]) if isinstance(lst[0], list) else lst[0]
+
+
+def proj_simplex(v, z=1):
+ r""" compute the closest point (orthogonal projection) on the
+ generalized (n-1)-simplex of a vector v wrt. to the Euclidean
+ distance, thus solving:
+ .. math::
+ \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2
+
+ s.t. \gamma^T 1= z
+
+ \gamma\geq 0
+
+ If v is a 2d array, compute all the projections wrt. axis 0
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
+ Parameters
+ ----------
+ v : {array-like}, shape (n, d)
+ z : int, optional
+ 'size' of the simplex (each vectors sum to z, 1 by default)
+
+ Returns
+ -------
+ h : ndarray, shape (n,d)
+ Array of projections on the simplex
+ """
+ nx = get_backend(v)
+ n = v.shape[0]
+ if v.ndim == 1:
+ d1 = 1
+ v = v[:, None]
+ else:
+ d1 = 0
+ d = v.shape[1]
+
+ # sort u in ascending order
+ u = nx.sort(v, axis=0)
+ # take the descending order
+ u = nx.flip(u, 0)
+ cssv = nx.cumsum(u, axis=0) - z
+ ind = nx.arange(n, type_as=v)[:, None] + 1
+ cond = u - cssv / ind > 0
+ rho = nx.sum(cond, 0)
+ theta = cssv[rho - 1, nx.arange(d)] / rho
+ w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v))
+ if d1:
+ return w[:, 0]
+ else:
+ return w
+
+
def unif(n):
""" return a uniform histogram of length n (simplex)
@@ -84,52 +148,68 @@ def euclidean_distances(X, Y, squared=False):
"""
Considering the rows of X (and Y=X) as vectors, compute the
distance matrix between each pair of vectors.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
Parameters
----------
X : {array-like}, shape (n_samples_1, n_features)
Y : {array-like}, shape (n_samples_2, n_features)
squared : boolean, optional
Return squared Euclidean distances.
+
Returns
-------
distances : {array}, shape (n_samples_1, n_samples_2)
"""
- XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
- YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
- distances = np.dot(X, Y.T)
- distances *= -2
- distances += XX
- distances += YY
- np.maximum(distances, 0, out=distances)
+
+ nx = get_backend(X, Y)
+
+ a2 = nx.einsum('ij,ij->i', X, X)
+ b2 = nx.einsum('ij,ij->i', Y, Y)
+
+ c = -2 * nx.dot(X, Y.T)
+ c += a2[:, None]
+ c += b2[None, :]
+
+ c = nx.maximum(c, 0)
+
+ if not squared:
+ c = nx.sqrt(c)
+
if X is Y:
- # Ensure that distances between vectors and themselves are set to 0.0.
- # This may not be the case due to floating point rounding errors.
- distances.flat[::distances.shape[0] + 1] = 0.0
- return distances if squared else np.sqrt(distances, out=distances)
+ c = c * (1 - nx.eye(X.shape[0], type_as=c))
+
+ return c
def dist(x1, x2=None, metric='sqeuclidean'):
- """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
+ """Compute distance between samples in x1 and x2
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Parameters
----------
- x1 : ndarray, shape (n1,d)
+ x1 : array-like, shape (n1,d)
matrix with n1 samples of size d
- x2 : array, shape (n2,d), optional
+ x2 : array-like, shape (n2,d), optional
matrix with n2 samples of size d (if None then x2=x1)
metric : str | callable, optional
- Name of the metric to be computed (full list in the doc of scipy), If a string,
- the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
- 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
- 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
+ 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also
+ accepts from the scipy.spatial.distance.cdist function : 'braycurtis',
+ 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
+ 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
+ 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
Returns
-------
- M : np.array (n1,n2)
+ M : array-like, shape (n1, n2)
distance matrix computed with given metric
"""
@@ -137,7 +217,13 @@ def dist(x1, x2=None, metric='sqeuclidean'):
x2 = x1
if metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True)
- return cdist(x1, x2, metric=metric)
+ elif metric == "euclidean":
+ return euclidean_distances(x1, x2, squared=False)
+ else:
+ if not get_backend(x1, x2).__name__ == 'numpy':
+ raise NotImplementedError()
+ else:
+ return cdist(x1, x2, metric=metric)
def dist0(n, method='lin_square'):