From 7a65086dd340265d0223eb8ffb5c9a5152a82dff Mon Sep 17 00:00:00 2001 From: ncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com> Date: Mon, 25 Oct 2021 11:36:21 +0200 Subject: [MRG] Bregman backend (#280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Bregman * Resolve conflicts * Bug solve * Bregman updated for JAX compatibility * Tests coherence between backend improved * No longer enforcing 64 bits operations on Jax except for tests * Now using mixtures, to make backend dependent tests with less code * Better test skipping code * Pep8 + test optimizations * redundancy removed * Docs * Typo corrected * Typo * Typo * Docs * Docs * pep8 * Backend docs * Prettier docs * Mistake corrected * small changes * Better wording Co-authored-by: Rémi Flamary --- ot/backend.py | 581 +++++++++++++++++++++++++++++++-- ot/bregman.py | 970 ++++++++++++++++++++++++++++++------------------------- ot/gromov.py | 6 +- ot/smooth.py | 4 +- ot/unbalanced.py | 14 +- 5 files changed, 1106 insertions(+), 469 deletions(-) (limited to 'ot') diff --git a/ot/backend.py b/ot/backend.py index 2ed40af..a4a4757 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1,6 +1,22 @@ # -*- coding: utf-8 -*- """ Multi-lib backend for POT + +The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch, +or Jax, POT code should work nonetheless. +To achieve that, POT provides backend classes which implements functions in their respective backend +imitating Numpy API. As a convention, we use nx instead of np to refer to the backend. + +Examples +-------- + +>>> from ot.utils import list_to_array +>>> from ot.backend import get_backend +>>> def f(a, b): # the function does not know which backend to use +... a, b = list_to_array(a, b) # if a list in given, make it an array +... nx = get_backend(a, b) # infer the backend from the arguments +... c = nx.dot(a, b) # now use the backend to do any calculation +... return c """ # Author: Remi Flamary @@ -9,6 +25,7 @@ Multi-lib backend for POT # License: MIT License import numpy as np +import scipy.special as scipy try: import torch @@ -20,6 +37,7 @@ except ImportError: try: import jax import jax.numpy as jnp + import jax.scipy.special as jscipy jax_type = jax.numpy.ndarray except ImportError: jax = False @@ -29,7 +47,7 @@ str_type_error = "All array should be from the same type/backend. Current types def get_backend_list(): - """ returns the list of available backends)""" + """Returns the list of available backends""" lst = [NumpyBackend(), ] if torch: @@ -42,7 +60,7 @@ def get_backend_list(): def get_backend(*args): - """returns the proper backend for a list of input arrays + """Returns the proper backend for a list of input arrays Also raises TypeError if all arrays are not from the same backend """ @@ -50,14 +68,12 @@ def get_backend(*args): if not len(args) > 0: raise ValueError(" The function takes at least one parameter") # check all same type + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) if isinstance(args[0], np.ndarray): - if not len(set(type(a) for a in args)) == 1: - raise ValueError(str_type_error.format([type(a) for a in args])) return NumpyBackend() - elif torch and isinstance(args[0], torch_type): - if not len(set(type(a) for a in args)) == 1: - raise ValueError(str_type_error.format([type(a) for a in args])) + elif isinstance(args[0], torch_type): return TorchBackend() elif isinstance(args[0], jax_type): return JaxBackend() @@ -66,7 +82,7 @@ def get_backend(*args): def to_numpy(*args): - """returns numpy arrays from any compatible backend""" + """Returns numpy arrays from any compatible backend""" if len(args) == 1: return get_backend(args[0]).to_numpy(args[0]) @@ -75,6 +91,13 @@ def to_numpy(*args): class Backend(): + """ + Backend abstract class. + Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend` + + - The `__name__` class attribute refers to the name of the backend. + - The `__type__` class attribute refers to the data structure used by the backend. + """ __name__ = None __type__ = None @@ -84,90 +107,426 @@ class Backend(): # convert to numpy def to_numpy(self, a): + """Returns the numpy version of a tensor""" raise NotImplementedError() # convert from numpy def from_numpy(self, a, type_as=None): + """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)""" raise NotImplementedError() def set_gradients(self, val, inputs, grads): - """ define the gradients for the value val wrt the inputs """ + """Define the gradients for the value val wrt the inputs """ raise NotImplementedError() def zeros(self, shape, type_as=None): + r""" + Creates a tensor full of zeros. + + This function follow the api from :any:`numpy.zeros` + + See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html + """ raise NotImplementedError() def ones(self, shape, type_as=None): + r""" + Creates a tensor full of ones. + + This function follow the api from :any:`numpy.ones` + + See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html + """ raise NotImplementedError() def arange(self, stop, start=0, step=1, type_as=None): + r""" + Returns evenly spaced values within a given interval. + + This function follow the api from :any:`numpy.arange` + + See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html + """ raise NotImplementedError() def full(self, shape, fill_value, type_as=None): + r""" + Creates a tensor with given shape, filled with given value. + + This function follow the api from :any:`numpy.full` + + See: https://numpy.org/doc/stable/reference/generated/numpy.full.html + """ raise NotImplementedError() def eye(self, N, M=None, type_as=None): + r""" + Creates the identity matrix of given size. + + This function follow the api from :any:`numpy.eye` + + See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html + """ raise NotImplementedError() def sum(self, a, axis=None, keepdims=False): + r""" + Sums tensor elements over given dimensions. + + This function follow the api from :any:`numpy.sum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html + """ raise NotImplementedError() def cumsum(self, a, axis=None): + r""" + Returns the cumulative sum of tensor elements over given dimensions. + + This function follow the api from :any:`numpy.cumsum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html + """ raise NotImplementedError() def max(self, a, axis=None, keepdims=False): + r""" + Returns the maximum of an array or maximum along given dimensions. + + This function follow the api from :any:`numpy.amax` + + See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html + """ raise NotImplementedError() def min(self, a, axis=None, keepdims=False): + r""" + Returns the maximum of an array or maximum along given dimensions. + + This function follow the api from :any:`numpy.amin` + + See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html + """ raise NotImplementedError() def maximum(self, a, b): + r""" + Returns element-wise maximum of array elements. + + This function follow the api from :any:`numpy.maximum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html + """ raise NotImplementedError() def minimum(self, a, b): + r""" + Returns element-wise minimum of array elements. + + This function follow the api from :any:`numpy.minimum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html + """ raise NotImplementedError() def dot(self, a, b): + r""" + Returns the dot product of two tensors. + + This function follow the api from :any:`numpy.dot` + + See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html + """ raise NotImplementedError() def abs(self, a): + r""" + Computes the absolute value element-wise. + + This function follow the api from :any:`numpy.absolute` + + See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html + """ raise NotImplementedError() def exp(self, a): + r""" + Computes the exponential value element-wise. + + This function follow the api from :any:`numpy.exp` + + See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html + """ raise NotImplementedError() def log(self, a): + r""" + Computes the natural logarithm, element-wise. + + This function follow the api from :any:`numpy.log` + + See: https://numpy.org/doc/stable/reference/generated/numpy.log.html + """ raise NotImplementedError() def sqrt(self, a): + r""" + Returns the non-ngeative square root of a tensor, element-wise. + + This function follow the api from :any:`numpy.sqrt` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html + """ + raise NotImplementedError() + + def power(self, a, exponents): + r""" + First tensor elements raised to powers from second tensor, element-wise. + + This function follow the api from :any:`numpy.power` + + See: https://numpy.org/doc/stable/reference/generated/numpy.power.html + """ raise NotImplementedError() def norm(self, a): + r""" + Computes the matrix frobenius norm. + + This function follow the api from :any:`numpy.linalg.norm` + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html + """ raise NotImplementedError() def any(self, a): + r""" + Tests whether any tensor element along given dimensions evaluates to True. + + This function follow the api from :any:`numpy.any` + + See: https://numpy.org/doc/stable/reference/generated/numpy.any.html + """ raise NotImplementedError() def isnan(self, a): + r""" + Tests element-wise for NaN and returns result as a boolean tensor. + + This function follow the api from :any:`numpy.isnan` + + See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html + """ raise NotImplementedError() def isinf(self, a): + r""" + Tests element-wise for positive or negative infinity and returns result as a boolean tensor. + + This function follow the api from :any:`numpy.isinf` + + See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html + """ raise NotImplementedError() def einsum(self, subscripts, *operands): + r""" + Evaluates the Einstein summation convention on the operands. + + This function follow the api from :any:`numpy.einsum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html + """ raise NotImplementedError() def sort(self, a, axis=-1): + r""" + Returns a sorted copy of a tensor. + + This function follow the api from :any:`numpy.sort` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html + """ raise NotImplementedError() def argsort(self, a, axis=None): + r""" + Returns the indices that would sort a tensor. + + This function follow the api from :any:`numpy.argsort` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html + """ + raise NotImplementedError() + + def searchsorted(self, a, v, side='left'): + r""" + Finds indices where elements should be inserted to maintain order in given tensor. + + This function follow the api from :any:`numpy.searchsorted` + + See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html + """ raise NotImplementedError() def flip(self, a, axis=None): + r""" + Reverses the order of elements in a tensor along given dimensions. + + This function follow the api from :any:`numpy.flip` + + See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html + """ + raise NotImplementedError() + + def clip(self, a, a_min, a_max): + """ + Limits the values in a tensor. + + This function follow the api from :any:`numpy.clip` + + See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html + """ + raise NotImplementedError() + + def repeat(self, a, repeats, axis=None): + r""" + Repeats elements of a tensor. + + This function follow the api from :any:`numpy.repeat` + + See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html + """ + raise NotImplementedError() + + def take_along_axis(self, arr, indices, axis): + r""" + Gathers elements of a tensor along given dimensions. + + This function follow the api from :any:`numpy.take_along_axis` + + See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html + """ + raise NotImplementedError() + + def concatenate(self, arrays, axis=0): + r""" + Joins a sequence of tensors along an existing dimension. + + This function follow the api from :any:`numpy.concatenate` + + See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html + """ + raise NotImplementedError() + + def zero_pad(self, a, pad_width): + r""" + Pads a tensor. + + This function follow the api from :any:`numpy.pad` + + See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + """ + raise NotImplementedError() + + def argmax(self, a, axis=None): + r""" + Returns the indices of the maximum values of a tensor along given dimensions. + + This function follow the api from :any:`numpy.argmax` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html + """ + raise NotImplementedError() + + def mean(self, a, axis=None): + r""" + Computes the arithmetic mean of a tensor along given dimensions. + + This function follow the api from :any:`numpy.mean` + + See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html + """ + raise NotImplementedError() + + def std(self, a, axis=None): + r""" + Computes the standard deviation of a tensor along given dimensions. + + This function follow the api from :any:`numpy.std` + + See: https://numpy.org/doc/stable/reference/generated/numpy.std.html + """ + raise NotImplementedError() + + def linspace(self, start, stop, num): + r""" + Returns a specified number of evenly spaced values over a given interval. + + This function follow the api from :any:`numpy.linspace` + + See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html + """ + raise NotImplementedError() + + def meshgrid(self, a, b): + r""" + Returns coordinate matrices from coordinate vectors (Numpy convention). + + This function follow the api from :any:`numpy.meshgrid` + + See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html + """ + raise NotImplementedError() + + def diag(self, a, k=0): + r""" + Extracts or constructs a diagonal tensor. + + This function follow the api from :any:`numpy.diag` + + See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html + """ + raise NotImplementedError() + + def unique(self, a): + r""" + Finds unique elements of given tensor. + + This function follow the api from :any:`numpy.unique` + + See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html + """ + raise NotImplementedError() + + def logsumexp(self, a, axis=None): + r""" + Computes the log of the sum of exponentials of input elements. + + This function follow the api from :any:`scipy.special.logsumexp` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html + """ + raise NotImplementedError() + + def stack(self, arrays, axis=0): + r""" + Joins a sequence of tensors along a new dimension. + + This function follow the api from :any:`numpy.stack` + + See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html + """ raise NotImplementedError() class NumpyBackend(Backend): + """ + NumPy implementation of the backend + + - `__name__` is "numpy" + - `__type__` is np.ndarray + """ __name__ = 'numpy' __type__ = np.ndarray @@ -184,7 +543,7 @@ class NumpyBackend(Backend): return a.astype(type_as.dtype) def set_gradients(self, val, inputs, grads): - # no gradients for numpy + # No gradients for numpy return val def zeros(self, shape, type_as=None): @@ -247,6 +606,9 @@ class NumpyBackend(Backend): def sqrt(self, a): return np.sqrt(a) + def power(self, a, exponents): + return np.power(a, exponents) + def norm(self, a): return np.sqrt(np.sum(np.square(a))) @@ -268,11 +630,70 @@ class NumpyBackend(Backend): def argsort(self, a, axis=-1): return np.argsort(a, axis) + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return np.searchsorted(a, v, side) + else: + # this is a not very efficient way to make numpy + # searchsorted work on 2d arrays + ret = np.empty(v.shape, dtype=int) + for i in range(a.shape[0]): + ret[i, :] = np.searchsorted(a[i, :], v[i, :], side) + return ret + def flip(self, a, axis=None): return np.flip(a, axis) + def clip(self, a, a_min, a_max): + return np.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return np.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return np.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return np.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return np.pad(a, pad_width) + + def argmax(self, a, axis=None): + return np.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return np.mean(a, axis=axis) + + def std(self, a, axis=None): + return np.std(a, axis=axis) + + def linspace(self, start, stop, num): + return np.linspace(start, stop, num) + + def meshgrid(self, a, b): + return np.meshgrid(a, b) + + def diag(self, a, k=0): + return np.diag(a, k) + + def unique(self, a): + return np.unique(a) + + def logsumexp(self, a, axis=None): + return scipy.logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return np.stack(arrays, axis) + class JaxBackend(Backend): + """ + JAX implementation of the backend + + - `__name__` is "jax" + - `__type__` is jax.numpy.ndarray + """ __name__ = 'jax' __type__ = jax_type @@ -359,6 +780,9 @@ class JaxBackend(Backend): def sqrt(self, a): return jnp.sqrt(a) + def power(self, a, exponents): + return jnp.power(a, exponents) + def norm(self, a): return jnp.sqrt(jnp.sum(jnp.square(a))) @@ -380,11 +804,67 @@ class JaxBackend(Backend): def argsort(self, a, axis=-1): return jnp.argsort(a, axis) + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return jnp.searchsorted(a, v, side) + else: + # this is a not very efficient way to make jax numpy + # searchsorted work on 2d arrays + return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])]) + def flip(self, a, axis=None): return jnp.flip(a, axis) + def clip(self, a, a_min, a_max): + return jnp.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return jnp.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return jnp.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return jnp.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return jnp.pad(a, pad_width) + + def argmax(self, a, axis=None): + return jnp.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return jnp.mean(a, axis=axis) + + def std(self, a, axis=None): + return jnp.std(a, axis=axis) + + def linspace(self, start, stop, num): + return jnp.linspace(start, stop, num) + + def meshgrid(self, a, b): + return jnp.meshgrid(a, b) + + def diag(self, a, k=0): + return jnp.diag(a, k) + + def unique(self, a): + return jnp.unique(a) + + def logsumexp(self, a, axis=None): + return jscipy.logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return jnp.stack(arrays, axis) + class TorchBackend(Backend): + """ + PyTorch implementation of the backend + + - `__name__` is "torch" + - `__type__` is torch.Tensor + """ __name__ = 'torch' __type__ = torch_type @@ -487,22 +967,23 @@ class TorchBackend(Backend): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - return torch.maximum(a, b) + if torch.__version__ >= '1.7.0': + return torch.maximum(a, b) + else: + return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] def minimum(self, a, b): if isinstance(a, int) or isinstance(a, float): a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) if isinstance(b, int) or isinstance(b, float): b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) - return torch.minimum(a, b) + if torch.__version__ >= '1.7.0': + return torch.minimum(a, b) + else: + return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] def dot(self, a, b): - if len(a.shape) == len(b.shape) == 1: - return torch.dot(a, b) - elif len(a.shape) == 2 and len(b.shape) == 1: - return torch.mv(a, b) - else: - return torch.mm(a, b) + return torch.matmul(a, b) def abs(self, a): return torch.abs(a) @@ -516,6 +997,9 @@ class TorchBackend(Backend): def sqrt(self, a): return torch.sqrt(a) + def power(self, a, exponents): + return torch.pow(a, exponents) + def norm(self, a): return torch.sqrt(torch.sum(torch.square(a))) @@ -539,6 +1023,10 @@ class TorchBackend(Backend): sorted, indices = torch.sort(a, dim=axis) return indices + def searchsorted(self, a, v, side='left'): + right = (side != 'left') + return torch.searchsorted(a, v, right=right) + def flip(self, a, axis=None): if axis is None: return torch.flip(a, tuple(i for i in range(len(a.shape)))) @@ -546,3 +1034,60 @@ class TorchBackend(Backend): return torch.flip(a, (axis,)) else: return torch.flip(a, dims=axis) + + def clip(self, a, a_min, a_max): + return torch.clamp(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return torch.repeat_interleave(a, repeats, dim=axis) + + def take_along_axis(self, arr, indices, axis): + return torch.gather(arr, axis, indices) + + def concatenate(self, arrays, axis=0): + return torch.cat(arrays, dim=axis) + + def zero_pad(self, a, pad_width): + from torch.nn.functional import pad + # pad_width is an array of ndim tuples indicating how many 0 before and after + # we need to add. We first need to make it compliant with torch syntax, that + # starts with the last dim, then second last, etc. + how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl) + return pad(a, how_pad) + + def argmax(self, a, axis=None): + return torch.argmax(a, dim=axis) + + def mean(self, a, axis=None): + if axis is not None: + return torch.mean(a, dim=axis) + else: + return torch.mean(a) + + def std(self, a, axis=None): + if axis is not None: + return torch.std(a, dim=axis, unbiased=False) + else: + return torch.std(a, unbiased=False) + + def linspace(self, start, stop, num): + return torch.linspace(start, stop, num, dtype=torch.float64) + + def meshgrid(self, a, b): + X, Y = torch.meshgrid(a, b) + return X.T, Y.T + + def diag(self, a, k=0): + return torch.diag(a, diagonal=k) + + def unique(self, a): + return torch.unique(a) + + def logsumexp(self, a, axis=None): + if axis is not None: + return torch.logsumexp(a, dim=axis) + else: + return torch.logsumexp(a, dim=tuple(range(len(a.shape)))) + + def stack(self, arrays, axis=0): + return torch.stack(arrays, dim=axis) diff --git a/ot/bregman.py b/ot/bregman.py index 317c902..b59ee1b 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -19,7 +19,6 @@ import warnings import numpy as np from scipy.optimize import fmin_l_bfgs_b -from scipy.special import logsumexp from ot.utils import unif, dist, list_to_array from .backend import get_backend @@ -35,36 +34,36 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in [2]_ + scaling algorithm as proposed in :ref:`[2] ` **Choosing a Sinkhorn solver** By default and when using a regularization parameter that is not too small the default sinkhorn solver should be enough. If you need to use a small regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical errors. This last solver can be very slow in practice and might not even converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value of the regularization (and using warm start) sometimes leads to better solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a fast approximation of the Sinkhorn problem. @@ -74,7 +73,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -85,7 +84,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -109,7 +108,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, array([[0.36552929, 0.13447071], [0.13447071, 0.36552929]]) - + .. _references-sinkhorn: References ---------- @@ -125,9 +124,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling :ref:`[9] ` :ref:`[10] ` """ @@ -161,21 +160,21 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, .. math:: W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` **Choosing a Sinkhorn solver** @@ -199,17 +198,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters + method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -234,7 +233,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, array([0.26894142]) - + .. _references-sinkhorn2: References ---------- @@ -244,7 +243,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 @@ -252,9 +251,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.greenkhorn : Greenkhorn [21] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] + ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` + ot.bregman.greenkhorn : Greenkhorn :ref:`[21] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` """ @@ -291,21 +290,21 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -320,7 +319,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -337,6 +336,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, [0.13447071, 0.36552929]]) + .. _references-sinkhorn-knopp: References ---------- @@ -388,7 +388,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, while (err > stopThr and cpt < numItermax): uprev = u vprev = v - KtransposeU = nx.dot(K.T, u) v = b / KtransposeU u = 1. / nx.dot(Kp, v) @@ -444,53 +443,46 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, r""" Solve the entropic regularization optimal transport problem and return the OT matrix - The algorithm used is based on the paper - - Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration - by Jason Altschuler, Jonathan Weed, Philippe Rigollet - appeared at NIPS 2017 - - which is a stochastic version of the Sinkhorn-Knopp algorithm [2]. + The algorithm used is based on the paper :ref:`[22] ` which is a stochastic version of the Sinkhorn-Knopp algorithm :ref:`[2] ` The function solves the following optimization problem: .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) - + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) log : bool, optional record log if True Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -507,11 +499,13 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, [0.13447071, 0.36552929]]) + .. _references-greenkhorn: References ---------- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + + .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 See Also @@ -521,60 +515,58 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + if nx.__name__ == "jax": + raise TypeError("JAX arrays have been received. Greenkhorn is not compatible with JAX") if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] dim_a = a.shape[0] dim_b = b.shape[0] - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty_like(M) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) - u = np.full(dim_a, 1. / dim_a) - v = np.full(dim_b, 1. / dim_b) - G = u[:, np.newaxis] * K * v[np.newaxis, :] + u = nx.full((dim_a,), 1. / dim_a, type_as=K) + v = nx.full((dim_b,), 1. / dim_b, type_as=K) + G = u[:, None] * K * v[None, :] - viol = G.sum(1) - a - viol_2 = G.sum(0) - b + viol = nx.sum(G, axis=1) - a + viol_2 = nx.sum(G, axis=0) - b stopThr_val = 1 - if log: log = dict() log['u'] = u log['v'] = v for i in range(numItermax): - i_1 = np.argmax(np.abs(viol)) - i_2 = np.argmax(np.abs(viol_2)) - m_viol_1 = np.abs(viol[i_1]) - m_viol_2 = np.abs(viol_2[i_2]) - stopThr_val = np.maximum(m_viol_1, m_viol_2) + i_1 = nx.argmax(nx.abs(viol)) + i_2 = nx.argmax(nx.abs(viol_2)) + m_viol_1 = nx.abs(viol[i_1]) + m_viol_2 = nx.abs(viol_2[i_2]) + stopThr_val = nx.maximum(m_viol_1, m_viol_2) if m_viol_1 > m_viol_2: old_u = u[i_1] - u[i_1] = a[i_1] / (K[i_1, :].dot(v)) - G[i_1, :] = u[i_1] * K[i_1, :] * v - - viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1] - viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v) + new_u = a[i_1] / (K[i_1, :].dot(v)) + G[i_1, :] = new_u * K[i_1, :] * v + viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1] + viol_2 += (K[i_1, :].T * (new_u - old_u) * v) + u[i_1] = new_u else: old_v = v[i_2] - v[i_2] = b[i_2] / (K[:, i_2].T.dot(u)) - G[:, i_2] = u * K[:, i_2] * v[i_2] + new_v = b[i_2] / (K[:, i_2].T.dot(u)) + G[:, i_2] = u * K[:, i_2] * new_v # aviol = (G@one_m - a) # aviol_2 = (G.T@one_n - b) - viol += (-old_v + v[i_2]) * K[:, i_2] * u - viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2] - + viol += (-old_v + new_v) * K[:, i_2] * u + viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2] + v[i_2] = new_v # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) if stopThr_val <= stopThr: @@ -603,41 +595,41 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in [2]_ but with the log stabilization - proposed in [10]_ an defined in [9]_ (Algo 3.1) . + scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization + proposed in :ref:`[10] ` an defined in :ref:`[9] ` (Algo 3.1) . Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) + b : array-like, shape (dim_b,) samples in the target domain - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 tau : float - thershold for max value in u or v for log scaling - warmstart : tible of vectors - if given then sarting values for alpha an beta log scalings + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + warmstart : table of vectors + if given then starting values for alpha and beta log scalings numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -645,7 +637,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -662,6 +654,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, [0.13447071, 0.36552929]]) + .. _references-sinkhorn-stabilized: References ---------- @@ -679,19 +672,19 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] # test if multiple target if len(b.shape) > 1: n_hists = b.shape[1] - a = a[:, np.newaxis] + a = a[:, None] else: n_hists = 0 @@ -706,25 +699,25 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, # we assume that no distances are null except those of the diagonal of # distances if warmstart is None: - alpha, beta = np.zeros(dim_a), np.zeros(dim_b) + alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M) else: alpha, beta = warmstart if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b + u, v = nx.ones(dim_a, type_as=M) / dim_a, nx.ones(dim_b, type_as=M) / dim_b def get_K(alpha, beta): """log space computation""" - return np.exp(-(M - alpha.reshape((dim_a, 1)) + return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg) def get_Gamma(alpha, beta, u, v): """log space gamma computation""" - return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) - / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b)))) + return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) + / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b)))) # print(np.min(K)) @@ -739,33 +732,35 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, vprev = v # sinkhorn update - v = b / (np.dot(K.T, u) + 1e-16) - u = a / (np.dot(K, v) + 1e-16) + v = b / (nx.dot(K.T, u) + 1e-16) + u = a / (nx.dot(K, v) + 1e-16) # remove numerical problems and store them in K - if np.abs(u).max() > tau or np.abs(v).max() > tau: + if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: if n_hists: - alpha, beta = alpha + reg * np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(np.log(v)) else: - alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) + alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) if n_hists: - u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b K = get_K(alpha, beta) if cpt % print_period == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: - err_u = abs(u - uprev).max() - err_u /= max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() - err_v /= max(abs(v).max(), abs(vprev).max(), 1.) + err_u = nx.max(nx.abs(u - uprev)) + err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0) + err_v = nx.max(nx.abs(v - vprev)) + err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0) err = 0.5 * (err_u + err_v) else: transp = get_Gamma(alpha, beta, u, v) - err = np.linalg.norm((np.sum(transp, axis=0) - b)) + err = nx.norm(nx.sum(transp, axis=0) - b) if log: log['err'].append(err) @@ -781,7 +776,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, if cpt >= numItermax: loop = False - if np.any(np.isnan(u)) or np.any(np.isnan(v)): + if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) @@ -795,26 +790,28 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, if n_hists: alpha = alpha[:, None] beta = beta[:, None] - logu = alpha / reg + np.log(u) - logv = beta / reg + np.log(v) + logu = alpha / reg + nx.log(u) + logv = beta / reg + nx.log(v) log['logu'] = logu log['logv'] = logv - log['alpha'] = alpha + reg * np.log(u) - log['beta'] = beta + reg * np.log(v) + log['alpha'] = alpha + reg * nx.log(u) + log['beta'] = beta + reg * nx.log(v) log['warmstart'] = (log['alpha'], log['beta']) if n_hists: - res = np.zeros((n_hists)) - for i in range(n_hists): - res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + res = nx.stack([ + nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + for i in range(n_hists) + ]) return res, log else: return get_Gamma(alpha, beta, u, v), log else: if n_hists: - res = np.zeros((n_hists)) - for i in range(n_hists): - res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + res = nx.stack([ + nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + for i in range(n_hists) + ]) return res else: return get_Gamma(alpha, beta, u, v) @@ -833,45 +830,45 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in [2]_ but with the log stabilization - proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2 + scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization + proposed in :ref:`[10] ` and the log scaling proposed in :ref:`[9] ` algorithm 3.2 Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) + b : array-like, shape (dim_b,) samples in the target domain - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 tau : float - thershold for max value in u or v for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` for log scaling warmstart : tuple of vectors - if given then sarting values for alpha an beta log scalings + if given then starting values for alpha and beta log scalings numItermax : int, optional Max number of iterations numInnerItermax : int, optional - Max number of iterationsin the inner slog stabilized sinkhorn + Max number of iterations in the inner slog stabilized sinkhorn epsilon0 : int, optional first epsilon regularization value (then exponential decrease to reg) stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -879,7 +876,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -895,7 +892,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, array([[0.36552929, 0.13447071], [0.13447071, 0.36552929]]) - + .. _references-sinkhorn-epsilon-scaling: References ---------- @@ -903,6 +900,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + See Also -------- ot.lp.emd : Unregularized OT @@ -910,14 +910,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] # init data dim_a = len(a) @@ -934,7 +934,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, # we assume that no distances are null except those of the diagonal of # distances if warmstart is None: - alpha, beta = np.zeros(dim_a), np.zeros(dim_b) + alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M) else: alpha, beta = warmstart @@ -964,15 +964,13 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, # we can speed up the process by checking for the error only all # the 10th iterations transp = G - err = np.linalg.norm( - (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2 + err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2 if log: log['err'].append(err) if verbose: if cpt % (print_period * 10) == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) if err <= stopThr and cpt > numItermin: @@ -991,23 +989,31 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, def geometricBar(weights, alldistribT): """return the weighted geometric mean of distributions""" + weights, alldistribT = list_to_array(weights, alldistribT) + nx = get_backend(weights, alldistribT) assert (len(weights) == alldistribT.shape[1]) - return np.exp(np.dot(np.log(alldistribT), weights.T)) + return nx.exp(nx.dot(nx.log(alldistribT), weights.T)) def geometricMean(alldistribT): """return the geometric mean of distributions""" - return np.exp(np.mean(np.log(alldistribT), axis=1)) + alldistribT = list_to_array(alldistribT) + nx = get_backend(alldistribT) + return nx.exp(nx.mean(nx.log(alldistribT), axis=1)) def projR(gamma, p): """return the KL projection on the row constrints """ - return np.multiply(gamma.T, p / np.maximum(np.sum(gamma, axis=1), 1e-10)).T + gamma, p = list_to_array(gamma, p) + nx = get_backend(gamma, p) + return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T def projC(gamma, q): """return the KL projection on the column constrints """ - return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10)) + gamma, q = list_to_array(gamma, q) + nx = get_backend(gamma, q) + return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10) def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, @@ -1021,28 +1027,28 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` Parameters ---------- - A : ndarray, shape (dim, n_hists) - n_hists training distributions a_i of size dim - M : ndarray, shape (dim, dim) + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`a_i` of size `dim` + M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 method : str (optional) method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' - weights : ndarray, shape (n_hists,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -1051,12 +1057,13 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + .. _references-barycenter: References ---------- @@ -1089,26 +1096,26 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` Parameters ---------- - A : ndarray, shape (dim, n_hists) - n_hists training distributions a_i of size dim - M : ndarray, shape (dim, dim) + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`a_i` of size `dim` + M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 - weights : ndarray, shape (n_hists,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -1117,12 +1124,13 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + .. _references-barycenter-sinkhorn: References ---------- @@ -1130,8 +1138,12 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, """ + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + if weights is None: - weights = np.ones(A.shape[1]) / A.shape[1] + weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] else: assert (len(weights) == A.shape[1]) @@ -1139,21 +1151,22 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, log = {'err': []} # M = M/np.median(M) # suggested by G. Peyre - K = np.exp(-M / reg) + K = nx.exp(-M / reg) cpt = 0 err = 1 - UKv = np.dot(K, np.divide(A.T, np.sum(K, axis=0)).T) + UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) + u = (geometricMean(UKv) / UKv.T).T while (err > stopThr and cpt < numItermax): cpt = cpt + 1 - UKv = u * np.dot(K, np.divide(A, np.dot(K, u))) + UKv = u * nx.dot(K, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv if cpt % 10 == 1: - err = np.sum(np.std(UKv, axis=1)) + err = nx.sum(nx.std(UKv, axis=1)) # log and verbose print if log: @@ -1174,8 +1187,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - with stabilization. + r"""Compute the entropic regularized wasserstein barycenter of distributions A with stabilization. The function solves the following optimization problem: @@ -1184,28 +1196,28 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` Parameters ---------- - A : ndarray, shape (dim, n_hists) - n_hists training distributions a_i of size dim - M : ndarray, shape (dim, dim) + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`a_i` of size `dim` + M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 tau : float - thershold for max value in u or v for log scaling - weights : ndarray, shape (n_hists,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -1214,12 +1226,13 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + .. _references-barycenter-stabilized: References ---------- @@ -1227,49 +1240,48 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, """ + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones((n_hists,), type_as=M) / n_hists else: assert (len(weights) == A.shape[1]) if log: log = {'err': []} - u = np.ones((dim, n_hists)) / dim - v = np.ones((dim, n_hists)) / dim + u = nx.ones((dim, n_hists), type_as=M) / dim + v = nx.ones((dim, n_hists), type_as=M) / dim - # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) cpt = 0 err = 1. - alpha = np.zeros(dim) - beta = np.zeros(dim) - q = np.ones(dim) / dim + alpha = nx.zeros((dim,), type_as=M) + beta = nx.zeros((dim,), type_as=M) + q = nx.ones((dim,), type_as=M) / dim while (err > stopThr and cpt < numItermax): qprev = q - Kv = K.dot(v) + Kv = nx.dot(K, v) u = A / (Kv + 1e-16) - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) q = geometricBar(weights, Ktu) Q = q[:, None] v = Q / (Ktu + 1e-16) absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha += reg * nx.log(nx.max(u, 1)) + beta += reg * nx.log(nx.max(v, 1)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(tuple(v.shape), type_as=v) + Kv = nx.dot(K, v) + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -1278,7 +1290,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, if (cpt % 10 == 0 and not absorbing) or cpt == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(u * Kv - A).max() + err = nx.max(nx.abs(u * Kv - A)) if log: log['err'].append(err) if verbose: @@ -1314,24 +1326,24 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` - - reg is the regularization strength scalar value + - `reg` is the regularization strength scalar value - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[21] ` Parameters ---------- - A : ndarray, shape (n_hists, width, height) - n distributions (2D images) of size width x height + A : array-like, shape (n_hists, width, height) + `n` distributions (2D images) of size `width` x `height` reg : float Regularization term >0 - weights : ndarray, shape (n_hists,) + weights : array-like, shape (n_hists,) Weights of each image on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) stabThr : float, optional Stabilization threshold to avoid numerical precision issue verbose : bool, optional @@ -1341,64 +1353,73 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, Returns ------- - a : ndarray, shape (width, height) + a : array-like, shape (width, height) 2D Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + + .. _references-convolutional-barycenter-2d: References ---------- - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). - Convolutional wasserstein distances: Efficient optimal transportation on geometric domains - ACM Transactions on Graphics (TOG), 34(4), 66 + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 """ + A = list_to_array(A) + + nx = get_backend(A) + if weights is None: - weights = np.ones(A.shape[0]) / A.shape[0] + weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0] else: assert (len(weights) == A.shape[0]) if log: log = {'err': []} - b = np.zeros_like(A[0, :, :]) - U = np.ones_like(A) - KV = np.ones_like(A) + b = nx.zeros(A.shape[1:], type_as=A) + U = nx.ones(A.shape, type_as=A) + KV = nx.ones(A.shape, type_as=A) cpt = 0 err = 1 # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions - t = np.linspace(0, 1, A.shape[1]) - [Y, X] = np.meshgrid(t, t) - xi1 = np.exp(-(X - Y) ** 2 / reg) + t = nx.linspace(0, 1, A.shape[1]) + [Y, X] = nx.meshgrid(t, t) + xi1 = nx.exp(-(X - Y) ** 2 / reg) - t = np.linspace(0, 1, A.shape[2]) - [Y, X] = np.meshgrid(t, t) - xi2 = np.exp(-(X - Y) ** 2 / reg) + t = nx.linspace(0, 1, A.shape[2]) + [Y, X] = nx.meshgrid(t, t) + xi2 = nx.exp(-(X - Y) ** 2 / reg) def K(x): - return np.dot(np.dot(xi1, x), xi2) + return nx.dot(nx.dot(xi1, x), xi2) while (err > stopThr and cpt < numItermax): bold = b cpt = cpt + 1 - b = np.zeros_like(A[0, :, :]) - for r in range(A.shape[0]): - KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :]))) - b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :])) - b = np.exp(b) + b = nx.zeros(A.shape[1:], type_as=A) + KV_cols = [] for r in range(A.shape[0]): - U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :]) - + KV_col_r = K(A[r, :, :] / nx.maximum(stabThr, K(U[r, :, :]))) + b += weights[r] * nx.log(nx.maximum(stabThr, U[r, :, :] * KV_col_r)) + KV_cols.append(KV_col_r) + KV = nx.stack(KV_cols) + b = nx.exp(b) + + U = nx.stack([ + b / nx.maximum(stabThr, KV[r, :, :]) + for r in range(A.shape[0]) + ]) if cpt % 10 == 1: - err = np.sum(np.abs(bold - b)) + err = nx.sum(nx.abs(bold - b)) # log and verbose print if log: log['err'].append(err) @@ -1424,34 +1445,35 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, The function solve the following optimization problem: .. math:: - \mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h}) + + \mathbf{h} = arg\min_\mathbf{h} (1- \alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\alpha W_{M_0,reg_0}(\mathbf{h}_0,\mathbf{h}) where : - - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn) - - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` - - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting - - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization - - :math:`\\alpha`weight data fitting and regularization + - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (`dim_a`, `dim_a`) for OT data fitting + - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization term and the cost matrix (`dim_prior`, `n_atoms`) regularization + - :math:`\\alpha` weight data fitting and regularization - The optimization problem is solved suing the algorithm described in [4] + The optimization problem is solved following the algorithm described in :ref:`[4] ` Parameters ---------- - a : ndarray, shape (dim_a) + a : array-like, shape (dim_a) observed distribution (histogram, sums to 1) - D : ndarray, shape (dim_a, n_atoms) + D : array-like, shape (dim_a, n_atoms) dictionary matrix - M : ndarray, shape (dim_a, dim_a) + M : array-like, shape (dim_a, dim_a) loss matrix - M0 : ndarray, shape (n_atoms, dim_prior) + M0 : array-like, shape (n_atoms, dim_prior) loss matrix - h0 : ndarray, shape (n_atoms,) + h0 : array-like, shape (n_atoms,) prior on the estimated unmixing h reg : float Regularization term >0 (Wasserstein data fitting) @@ -1462,7 +1484,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -1471,11 +1493,13 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, Returns ------- - h : ndarray, shape (n_atoms,) + h : array-like, shape (n_atoms,) Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + + .. _references-unmix: References ---------- @@ -1483,11 +1507,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, """ + a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0) + + nx = get_backend(a, D, M, M0, h0) + # M = M/np.median(M) - K = np.exp(-M / reg) + K = nx.exp(-M / reg) # M0 = M0/np.median(M0) - K0 = np.exp(-M0 / reg0) + K0 = nx.exp(-M0 / reg0) old = h0 err = 1 @@ -1499,16 +1527,16 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, while (err > stopThr and cpt < numItermax): K = projC(K, a) K0 = projC(K0, h0) - new = np.sum(K0, axis=1) + new = nx.sum(K0, axis=1) # we recombine the current selection from dictionnary - inv_new = np.dot(D, new) - other = np.sum(K, axis=1) + inv_new = nx.dot(D, new) + other = nx.sum(K, axis=1) # geometric interpolation - delta = np.exp(alpha * np.log(other) + (1 - alpha) * np.log(inv_new)) + delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new)) K = projR(K, delta) - K0 = np.dot(np.diag(np.dot(D.T, delta / inv_new)), K0) + K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0) - err = np.linalg.norm(np.sum(K0, axis=1) - old) + err = nx.norm(nx.sum(K0, axis=1) - old) old = new if log: log['err'].append(err) @@ -1522,14 +1550,14 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, if log: log['niter'] = cpt - return np.sum(K0, axis=1), log + return nx.sum(K0, axis=1), log else: - return np.sum(K0, axis=1) + return nx.sum(K0, axis=1) def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, stopThr=1e-6, verbose=False, log=False, **kwargs): - r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27] + r'''Joint OT and proportion estimation for multi-source target shift as proposed in :ref:`[27] ` The function solves the following optimization problem: @@ -1542,12 +1570,12 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, where : - - :math:`\lambda_k` is the weight of k-th source domain - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) - - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes - - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C + - :math:`\lambda_k` is the weight of `k`-th source domain + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source domain and `C` is the number of classes + - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C` - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` - - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. @@ -1556,11 +1584,11 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Parameters ---------- - Xs : list of K np.ndarray(nsk,d) + Xs : list of K array-like(nsk,d) features of all source domains' samples - Ys : list of K np.ndarray(nsk,) + Ys : list of K array-like(nsk,) labels of all source domains' samples - Xt : np.ndarray (nt,d) + Xt : array-like (nt,d) samples in the target domain reg : float Regularization term > 0 @@ -1577,12 +1605,13 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Returns ------- - h : (C,) ndarray + h : (C,) array-like proportion estimation in the target domain log : dict log dictionary return only if log==True in parameters + .. _references-jcpot-barycenter: References ---------- @@ -1591,7 +1620,14 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. ''' - nbclasses = len(np.unique(Ys[0])) + + Xs = list_to_array(*Xs) + Ys = list_to_array(*Ys) + Xt = list_to_array(Xt) + + nx = get_backend(*Xs, *Ys, Xt) + + nbclasses = len(nx.unique(Ys[0])) nbdomains = len(Xs) # log dictionary @@ -1608,19 +1644,19 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, dom = {} nsk = Xs[d].shape[0] # get number of elements for this domain dom['nbelem'] = nsk - classes = np.unique(Ys[d]) # get number of classes for this domain + classes = nx.unique(Ys[d]) # get number of classes for this domain # format classes to start from 0 for convenience - if np.min(classes) != 0: - Ys[d] = Ys[d] - np.min(classes) - classes = np.unique(Ys[d]) + if nx.min(classes) != 0: + Ys[d] -= nx.min(classes) + classes = nx.unique(Ys[d]) # build the corresponding D_1 and D_2 matrices - Dtmp1 = np.zeros((nbclasses, nsk)) - Dtmp2 = np.zeros((nbclasses, nsk)) + Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) + Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) for c in classes: - nbelemperclass = np.sum(Ys[d] == c) + nbelemperclass = nx.sum(Ys[d] == c) if nbelemperclass != 0: Dtmp1[int(c), Ys[d] == c] = 1. Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass) @@ -1631,36 +1667,34 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Mtmp = dist(Xs[d], Xt, metric=metric) M.append(Mtmp) - Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype) - np.divide(Mtmp, -reg, out=Ktmp) - np.exp(Ktmp, out=Ktmp) + Ktmp = nx.exp(-Mtmp / reg) K.append(Ktmp) # uniform target distribution - a = unif(np.shape(Xt)[0]) + a = nx.from_numpy(unif(np.shape(Xt)[0])) cpt = 0 # iterations count err = 1 - old_bary = np.ones((nbclasses)) + old_bary = nx.ones((nbclasses,), type_as=Xs[0]) while (err > stopThr and cpt < numItermax): - bary = np.zeros((nbclasses)) + bary = nx.zeros((nbclasses,), type_as=Xs[0]) # update coupling matrices for marginal constraints w.r.t. uniform target distribution for d in range(nbdomains): K[d] = projC(K[d], a) - other = np.sum(K[d], axis=1) - bary = bary + np.log(np.dot(D1[d], other)) / nbdomains + other = nx.sum(K[d], axis=1) + bary += nx.log(nx.dot(D1[d], other)) / nbdomains - bary = np.exp(bary) + bary = nx.exp(bary) # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] for d in range(nbdomains): - new = np.dot(D2[d].T, bary) + new = nx.dot(D2[d].T, bary) K[d] = projR(K[d], new) - err = np.linalg.norm(bary - old_bary) + err = nx.norm(bary - old_bary) cpt = cpt + 1 old_bary = bary @@ -1672,7 +1706,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) - bary = bary / np.sum(bary) + bary = bary / nx.sum(bary) if log: log['niter'] = cpt @@ -1697,39 +1731,38 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', .. math:: \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix + - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`a` and :math:`b` are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (n_samples_a, dim) + X_s : array-like, shape (n_samples_a, dim) samples in the source domain - X_t : ndarray, shape (n_samples_b, dim) + X_t : array-like, shape (n_samples_b, dim) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (n_samples_a,) + a : array-like, shape (n_samples_a,) samples weights in the source domain - b : ndarray, shape (n_samples_b,) + b : array-like, shape (n_samples_b,) samples weights in the target domain numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) - If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional - Size of the batcheses used to compute the sinkhorn update without memory overhead. + Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations @@ -1739,7 +1772,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Returns ------- - gamma : ndarray, shape (n_samples_a, n_samples_b) + gamma : array-like, shape (n_samples_a, n_samples_b) Regularized optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1766,18 +1799,23 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' + + X_s, X_t = list_to_array(X_s, X_t) + + nx = get_backend(X_s, X_t) + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(ns) + a = nx.from_numpy(unif(ns)) if b is None: - b = unif(nt) + b = nx.from_numpy(unif(nt)) if isLazy: if log: dict_log = {"err": []} - log_a, log_b = np.log(a), np.log(b) - f, g = np.zeros(ns), np.zeros(nt) + log_a, log_b = nx.log(a), nx.log(b) + f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) if isinstance(batchSize, int): bs, bt = batchSize, batchSize @@ -1788,27 +1826,44 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', range_s, range_t = range(0, ns, bs), range(0, nt, bt) - lse_f = np.zeros(ns) - lse_g = np.zeros(nt) + lse_f = nx.zeros((ns,), type_as=a) + lse_g = nx.zeros((nt,), type_as=a) + + X_s_np = nx.to_numpy(X_s) + X_t_np = nx.to_numpy(X_t) for i_ot in range(numIterMax): + lse_f_cols = [] for i in range_s: - M = dist(X_s[i:i + bs, :], X_t, metric=metric) - lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1) + M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M = nx.from_numpy(M, type_as=a) + lse_f_cols.append( + nx.logsumexp(g[None, :] - M / reg, axis=1) + ) + lse_f = nx.concatenate(lse_f_cols, axis=0) f = log_a - lse_f + lse_g_cols = [] for j in range_t: - M = dist(X_s, X_t[j:j + bt, :], metric=metric) - lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0) + M = dist(X_s_np, X_t_np[j:j + bt, :], metric=metric) + M = nx.from_numpy(M, type_as=a) + lse_g_cols.append( + nx.logsumexp(f[:, None] - M / reg, axis=0) + ) + lse_g = nx.concatenate(lse_g_cols, axis=0) g = log_b - lse_g if (i_ot + 1) % 10 == 0: - m1 = np.zeros_like(a) + m1_cols = [] for i in range_s: - M = dist(X_s[i:i + bs, :], X_t, metric=metric) - m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1) - err = np.abs(m1 - a).sum() + M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M = nx.from_numpy(M, type_as=a) + m1_cols.append( + nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1) + ) + m1 = nx.concatenate(m1_cols, axis=0) + err = nx.sum(nx.abs(m1 - a)) if log: dict_log["err"].append(err) @@ -1826,8 +1881,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', return (f, g) else: - M = dist(X_s, X_t, metric=metric) - + M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) + M = nx.from_numpy(M, type_as=a) if log: pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) return pi, log @@ -1848,39 +1903,38 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num .. math:: W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b \gamma\geq 0 where : - - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix + - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`a` and :math:`b` are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (n_samples_a, dim) + X_s : array-like, shape (n_samples_a, dim) samples in the source domain - X_t : ndarray, shape (n_samples_b, dim) + X_t : array-like, shape (n_samples_b, dim) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (n_samples_a,) + a : array-like, shape (n_samples_a,) samples weights in the source domain - b : ndarray, shape (n_samples_b,) + b : array-like, shape (n_samples_b,) samples weights in the target domain numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) - If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional - Size of the batcheses used to compute the sinkhorn update without memory overhead. + Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations @@ -1890,7 +1944,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Returns ------- - W : (n_hists) ndarray or float + W : (n_hists) array-like or float Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1918,11 +1972,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' + X_s, X_t = list_to_array(X_s, X_t) + + nx = get_backend(X_s, X_t) + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(ns) + a = nx.from_numpy(unif(ns)) if b is None: - b = unif(nt) + b = nx.from_numpy(unif(nt)) if isLazy: if log: @@ -1936,10 +1994,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num range_s = range(0, ns, bs) loss = 0 + + X_s_np = nx.to_numpy(X_s) + X_t_np = nx.to_numpy(X_t) + for i in range_s: - M_block = dist(X_s[i:i + bs, :], X_t, metric=metric) - pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) - loss += np.sum(M_block * pi_block) + M_block = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M_block = nx.from_numpy(M_block, type_as=a) + pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) + loss += nx.sum(M_block * pi_block) if log: return loss, dict_log @@ -1947,7 +2010,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num return loss else: - M = dist(X_s, X_t, metric=metric) + M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) + M = nx.from_numpy(M, type_as=a) if log: sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, @@ -1975,10 +2039,10 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) - S &= W - 1/2 * (W_a + W_b) + S &= W - \frac{W_a + W_b}{2} .. math:: - s.t. \gamma 1 = a + s.t. \ \gamma 1 = a \gamma^T 1= b @@ -1997,27 +2061,27 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli \gamma_b\geq 0 where : - - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b)) + - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) is the (`n_samples_a`, `n_samples_b`) metric cost matrix (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`a` and :math:`b` are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (n_samples_a, dim) + X_s : array-like, shape (n_samples_a, dim) samples in the source domain - X_t : ndarray, shape (n_samples_b, dim) + X_t : array-like, shape (n_samples_b, dim) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (n_samples_a,) + a : array-like, shape (n_samples_a,) samples weights in the source domain - b : ndarray, shape (n_samples_b,) + b : array-like, shape (n_samples_b,) samples weights in the target domain numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -2025,7 +2089,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Returns ------- - W : (1,) ndarray + W : (1,) array-like Optimal transportation symmetrized loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -2083,47 +2147,54 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): - r"""" + r""" Screening Sinkhorn Algorithm for Regularized Optimal Transport - The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem: + The function solves an approximate dual of Sinkhorn divergence :ref:`[2] ` which is written as the following optimization problem: - ..math:: - (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - + .. math:: - where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and + (u, v) = arg\min_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - - s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns} + where: - e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt} + .. math:: - The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26] + B(u,v) = \mathrm{diag}(e^u) K \mathrm{diag}(e^v) \text{, with } K = e^{-M/reg} \text{ and} + + .. math:: + + s.t. \ e^{u_i} \geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\} + + e^{v_j} \geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} + + The parameters `kappa` and `epsilon` are determined w.r.t the couple number budget of points (`ns_budget`, `nt_budget`), see Equation (5) in :ref:`[26] ` Parameters ---------- - a : `numpy.ndarray`, shape=(ns,) + a : array-like, shape=(ns,) samples weights in the source domain - b : `numpy.ndarray`, shape=(nt,) + b : array-like, shape=(nt,) samples weights in the target domain - M : `numpy.ndarray`, shape=(ns, nt) + M : array-like, shape=(ns, nt) Cost matrix reg : `float` Level of the entropy regularisation - ns_budget : `int`, deafult=None - Number budget of points to be keeped in the source domain - If it is None then 50% of the source sample points will be keeped + ns_budget : `int`, default=None + Number budget of points to be kept in the source domain. + If it is None then 50% of the source sample points will be kept - nt_budget : `int`, deafult=None - Number budget of points to be keeped in the target domain - If it is None then 50% of the target sample points will be keeped + nt_budget : `int`, default=None + Number budget of points to be kept in the target domain. + If it is None then 50% of the target sample points will be kept uniform : `bool`, default=False - If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt + If `True`, the source and target distribution are supposed to be uniform, i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` restricted : `bool`, default=True If `True`, a warm-start initialization for the L-BFGS-B solver @@ -2133,15 +2204,16 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res Maximum number of iterations in LBFGS solver maxfun : `int`, default=10000 - Maximum number of function evaluations in LBFGS solver + Maximum number of function evaluations in LBFGS solver pgtol : `float`, default=1e-09 Final objective function accuracy in LBFGS solver verbose : `bool`, default=False - If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa + If `True`, display informations about the cardinals of the active sets and the parameters kappa and epsilon + Dependency ---------- To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/) @@ -2151,15 +2223,19 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res Returns ------- - gamma : `numpy.ndarray`, shape=(ns, nt) + gamma : array-like, shape=(ns, nt) Screened optimal transportation matrix for the given parameters log : `dict`, default=False Log dictionary return only if log==True in parameters + .. _references-screenkhorn: References ----------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 """ @@ -2171,9 +2247,12 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") bottleneck = np - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + if nx.__name__ == "jax": + raise TypeError("JAX arrays have been received but screenkhorn is not compatible with JAX.") + ns, nt = M.shape # by default, we keep only 50% of the sample data points @@ -2183,9 +2262,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res nt_budget = int(np.floor(0.5 * nt)) # calculate the Gibbs kernel - K = np.empty_like(M) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) def projection(u, epsilon): u[u <= epsilon] = epsilon @@ -2197,8 +2274,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res if ns_budget == ns and nt_budget == nt: # full number of budget points (ns, nt) = (ns_budget, nt_budget) - Isel = np.ones(ns, dtype=bool) - Jsel = np.ones(nt, dtype=bool) + Isel = nx.from_numpy(np.ones(ns, dtype=bool)) + Jsel = nx.from_numpy(np.ones(nt, dtype=bool)) epsilon = 0.0 kappa = 1.0 @@ -2214,57 +2291,61 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res K_IJc = [] K_IcJ = [] - vec_eps_IJc = np.zeros(nt) - vec_eps_IcJ = np.zeros(ns) + vec_eps_IJc = nx.zeros((nt,), type_as=M) + vec_eps_IcJ = nx.zeros((ns,), type_as=M) else: # sum of rows and columns of K - K_sum_cols = K.sum(axis=1) - K_sum_rows = K.sum(axis=0) + K_sum_cols = nx.sum(K, axis=1) + K_sum_rows = nx.sum(K, axis=0) if uniform: if ns / ns_budget < 4: - aK_sort = np.sort(K_sum_cols) + aK_sort = nx.sort(K_sum_cols) epsilon_u_square = a[0] / aK_sort[ns_budget - 1] else: - aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1] + aK_sort = nx.from_numpy( + bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1] + ) epsilon_u_square = a[0] / aK_sort if nt / nt_budget < 4: - bK_sort = np.sort(K_sum_rows) + bK_sort = nx.sort(K_sum_rows) epsilon_v_square = b[0] / bK_sort[nt_budget - 1] else: - bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1] + bK_sort = nx.from_numpy( + bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1] + ) epsilon_v_square = b[0] / bK_sort else: aK = a / K_sum_cols bK = b / K_sum_rows - aK_sort = np.sort(aK)[::-1] + aK_sort = nx.flip(nx.sort(aK), axis=0) epsilon_u_square = aK_sort[ns_budget - 1] - bK_sort = np.sort(bK)[::-1] + bK_sort = nx.flip(nx.sort(bK), axis=0) epsilon_v_square = bK_sort[nt_budget - 1] # active sets I and J (see Lemma 1 in [26]) Isel = a >= epsilon_u_square * K_sum_cols Jsel = b >= epsilon_v_square * K_sum_rows - if sum(Isel) != ns_budget: + if nx.sum(Isel) != ns_budget: if uniform: aK = a / K_sum_cols - aK_sort = np.sort(aK)[::-1] - epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean() + aK_sort = nx.flip(nx.sort(aK), axis=0) + epsilon_u_square = nx.mean(aK_sort[ns_budget - 1:ns_budget + 1]) Isel = a >= epsilon_u_square * K_sum_cols - ns_budget = sum(Isel) + ns_budget = nx.sum(Isel) - if sum(Jsel) != nt_budget: + if nx.sum(Jsel) != nt_budget: if uniform: bK = b / K_sum_rows - bK_sort = np.sort(bK)[::-1] - epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean() + bK_sort = nx.flip(nx.sort(bK), axis=0) + epsilon_v_square = nx.mean(bK_sort[nt_budget - 1:nt_budget + 1]) Jsel = b >= epsilon_v_square * K_sum_rows - nt_budget = sum(Jsel) + nt_budget = nx.sum(Jsel) epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4) kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2) @@ -2282,7 +2363,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res K_IcJ = K[np.ix_(Ic, Jsel)] K_IJc = K[np.ix_(Isel, Jc)] - K_min = K_IJ.min() + #K_min = K_IJ.min() + K_min = nx.min(K_IJ) if K_min == 0: K_min = np.finfo(float).tiny @@ -2290,10 +2372,10 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res a_I = a[Isel] b_J = b[Jsel] if not uniform: - a_I_min = a_I.min() - a_I_max = a_I.max() - b_J_max = b_J.max() - b_J_min = b_J.min() + a_I_min = nx.min(a_I) + a_I_max = nx.max(a_I) + b_J_max = nx.max(b_J) + b_J_min = nx.min(b_J) else: a_I_min = a_I[0] a_I_max = a_I[0] @@ -2309,24 +2391,30 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget # pre-calculated constants for the objective - vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) - vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0) + vec_eps_IJc = epsilon * kappa * nx.sum( + K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :], + axis=1 + ) + vec_eps_IcJ = (epsilon / kappa) * nx.sum( + nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ, + axis=0 + ) # initialisation - u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa) - v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa) + u0 = nx.full((ns_budget,), 1. / ns_budget + epsilon / kappa, type_as=M) + v0 = nx.full((nt_budget,), 1. / nt_budget + epsilon * kappa, type_as=M) # pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26]) if restricted: if ns_budget != ns or nt_budget != nt: - cst_u = kappa * epsilon * K_IJc.sum(axis=1) - cst_v = epsilon * K_IcJ.sum(axis=0) / kappa + cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1) + cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa cpt = 1 while cpt < 5: # 5 iterations - K_IJ_v = np.dot(K_IJ.T, u0) + cst_v + K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v v0 = b_J / (kappa * K_IJ_v) - KIJ_u = np.dot(K_IJ, v0) + cst_u + KIJ_u = nx.dot(K_IJ, v0) + cst_u u0 = (kappa * a_I) / KIJ_u cpt += 1 @@ -2343,9 +2431,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res """ cpt = 1 while cpt < max_iter: - K_IJ_v = np.dot(K_IJ.T, usc) + cst_v + K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v vsc = b_J / (kappa * K_IJ_v) - KIJ_u = np.dot(K_IJ, vsc) + cst_u + KIJ_u = nx.dot(K_IJ, vsc) + cst_u usc = (kappa * a_I) / KIJ_u cpt += 1 @@ -2355,17 +2443,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res return usc, vsc def screened_obj(usc, vsc): - part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, - np.log(vsc)) - part_IJc = np.dot(usc, vec_eps_IJc) - part_IcJ = np.dot(vec_eps_IcJ, vsc) + part_IJ = ( + nx.dot(nx.dot(usc, K_IJ), vsc) + - kappa * nx.dot(a_I, nx.log(usc)) + - (1. / kappa) * nx.dot(b_J, nx.log(vsc)) + ) + part_IJc = nx.dot(usc, vec_eps_IJc) + part_IcJ = nx.dot(vec_eps_IcJ, vsc) psi_epsilon = part_IJ + part_IJc + part_IcJ return psi_epsilon def screened_grad(usc, vsc): # gradients of Psi_(kappa,epsilon) w.r.t u and v - grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc - grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc + grad_u = nx.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc + grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc return grad_u, grad_v def bfgspost(theta): @@ -2375,20 +2466,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res f = screened_obj(u, v) # gradient g_u, g_v = screened_grad(u, v) - g = np.hstack([g_u, g_v]) - return f, g + g = nx.concatenate([g_u, g_v], axis=0) + return nx.to_numpy(f), nx.to_numpy(g) # ----------------------------------------------------------------------------------------------------------------# # Step 2: L-BFGS-B solver # # ----------------------------------------------------------------------------------------------------------------# u0, v0 = restricted_sinkhorn(u0, v0) - theta0 = np.hstack([u0, v0]) + theta0 = nx.concatenate([u0, v0], axis=0) bounds = bounds_u + bounds_v # constraint bounds def obj(theta): - return bfgspost(theta) + return bfgspost(nx.from_numpy(theta, type_as=M)) theta, _, _ = fmin_l_bfgs_b(func=obj, x0=theta0, @@ -2396,12 +2487,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res maxfun=maxfun, pgtol=pgtol, maxiter=maxiter) + theta = nx.from_numpy(theta) usc = theta[:ns_budget] vsc = theta[ns_budget:] - usc_full = np.full(ns, epsilon / kappa) - vsc_full = np.full(nt, epsilon * kappa) + usc_full = nx.full((ns,), epsilon / kappa, type_as=M) + vsc_full = nx.full((nt,), epsilon * kappa, type_as=M) usc_full[Isel] = usc vsc_full[Jsel] = vsc @@ -2413,7 +2505,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res log['Jsel'] = Jsel gamma = usc_full[:, None] * K * vsc_full[None, :] - gamma = gamma / gamma.sum() + gamma = gamma / nx.sum(gamma) if log: return gamma, log diff --git a/ot/gromov.py b/ot/gromov.py index a27217a..85b1549 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -1161,7 +1161,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter : int, optional Max number of iterations tol : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations. log : bool, optional @@ -1267,7 +1267,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, max_iter : int, optional Max number of iterations tol : float, optional - Stop threshol on error (>0). + Stop threshold on error (>0). verbose : bool, optional Print information along iterations. log : bool, optional @@ -1365,7 +1365,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ max_iter : int, optional Max number of iterations tol : float, optional - Stop threshol on error (>0). + Stop threshold on error (>0). verbose : bool, optional Print information along iterations. log : bool, optional diff --git a/ot/smooth.py b/ot/smooth.py index 81f6a3e..ea26bae 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -458,7 +458,7 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -552,7 +552,7 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional diff --git a/ot/unbalanced.py b/ot/unbalanced.py index e37f10c..6a61aa1 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -58,7 +58,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -186,7 +186,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -300,7 +300,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -482,7 +482,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -691,7 +691,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -841,7 +841,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -971,7 +971,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional -- cgit v1.2.3