diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/__init__.py | 24 | ||||
-rw-r--r-- | ot/backend.py | 1502 | ||||
-rw-r--r-- | ot/bregman.py | 2811 | ||||
-rw-r--r-- | ot/da.py | 507 | ||||
-rw-r--r-- | ot/datasets.py | 12 | ||||
-rw-r--r-- | ot/dr.py | 156 | ||||
-rw-r--r-- | ot/gpu/__init__.py | 12 | ||||
-rw-r--r-- | ot/gpu/bregman.py | 12 | ||||
-rw-r--r-- | ot/gpu/da.py | 2 | ||||
-rw-r--r-- | ot/gromov.py | 1312 | ||||
-rw-r--r-- | ot/helpers/__init__.py | 3 | ||||
-rw-r--r-- | ot/helpers/openmp_helpers.py | 85 | ||||
-rw-r--r-- | ot/helpers/pre_build_helpers.py | 87 | ||||
-rw-r--r-- | ot/lp/EMD.h | 5 | ||||
-rw-r--r-- | ot/lp/EMD_wrapper.cpp | 124 | ||||
-rw-r--r-- | ot/lp/__init__.py | 597 | ||||
-rw-r--r-- | ot/lp/cvx.py | 3 | ||||
-rw-r--r-- | ot/lp/emd_wrap.pyx | 32 | ||||
-rw-r--r-- | ot/lp/full_bipartitegraph.h | 27 | ||||
-rw-r--r-- | ot/lp/full_bipartitegraph_omp.h | 234 | ||||
-rw-r--r-- | ot/lp/network_simplex_simple.h | 212 | ||||
-rw-r--r-- | ot/lp/network_simplex_simple_omp.h | 1699 | ||||
-rw-r--r-- | ot/lp/solver_1d.py | 367 | ||||
-rw-r--r-- | ot/optim.py | 189 | ||||
-rwxr-xr-x | ot/partial.py | 352 | ||||
-rw-r--r-- | ot/plot.py | 10 | ||||
-rw-r--r-- | ot/regpath.py | 827 | ||||
-rw-r--r-- | ot/sliced.py | 258 | ||||
-rw-r--r-- | ot/smooth.py | 183 | ||||
-rw-r--r-- | ot/stochastic.py | 192 | ||||
-rw-r--r-- | ot/unbalanced.py | 220 | ||||
-rw-r--r-- | ot/utils.py | 269 |
32 files changed, 9780 insertions, 2545 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index 0e6e2e2..b6dc2b4 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -5,7 +5,8 @@ :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim` :py:mod:`ot.utils`, :py:mod:`ot.datasets`, :py:mod:`ot.gromov`, :py:mod:`ot.smooth` - :py:mod:`ot.stochastic` + :py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath` + , :py:mod:`ot.unbalanced`. The following sub-modules are not imported due to additional dependencies: @@ -33,21 +34,30 @@ from . import smooth from . import stochastic from . import unbalanced from . import partial +from . import backend +from . import regpath # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d from .bregman import sinkhorn, sinkhorn2, barycenter -from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 +from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, + sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm +from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance +from .gromov import (gromov_wasserstein, gromov_wasserstein2, + gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.7.0" +__version__ = "0.8.0" -__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', - 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', +__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', + 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', + 'emd2_1d', 'wasserstein_1d', 'backend', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2'] + 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', + 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', + 'max_sliced_wasserstein_distance', + 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/backend.py b/ot/backend.py new file mode 100644 index 0000000..a044f84 --- /dev/null +++ b/ot/backend.py @@ -0,0 +1,1502 @@ +# -*- coding: utf-8 -*- +""" +Multi-lib backend for POT + +The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch, +or Jax, POT code should work nonetheless. +To achieve that, POT provides backend classes which implements functions in their respective backend +imitating Numpy API. As a convention, we use nx instead of np to refer to the backend. + +Examples +-------- + +>>> from ot.utils import list_to_array +>>> from ot.backend import get_backend +>>> def f(a, b): # the function does not know which backend to use +... a, b = list_to_array(a, b) # if a list in given, make it an array +... nx = get_backend(a, b) # infer the backend from the arguments +... c = nx.dot(a, b) # now use the backend to do any calculation +... return c +""" + +# Author: Remi Flamary <remi.flamary@polytechnique.edu> +# Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + +import numpy as np +import scipy.special as scipy +from scipy.sparse import issparse, coo_matrix, csr_matrix + +try: + import torch + torch_type = torch.Tensor +except ImportError: + torch = False + torch_type = float + +try: + import jax + import jax.numpy as jnp + import jax.scipy.special as jscipy + jax_type = jax.numpy.ndarray +except ImportError: + jax = False + jax_type = float + +str_type_error = "All array should be from the same type/backend. Current types are : {}" + + +def get_backend_list(): + """Returns the list of available backends""" + lst = [NumpyBackend(), ] + + if torch: + lst.append(TorchBackend()) + + if jax: + lst.append(JaxBackend()) + + return lst + + +def get_backend(*args): + """Returns the proper backend for a list of input arrays + + Also raises TypeError if all arrays are not from the same backend + """ + # check that some arrays given + if not len(args) > 0: + raise ValueError(" The function takes at least one parameter") + # check all same type + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) + + if isinstance(args[0], np.ndarray): + return NumpyBackend() + elif isinstance(args[0], torch_type): + return TorchBackend() + elif isinstance(args[0], jax_type): + return JaxBackend() + else: + raise ValueError("Unknown type of non implemented backend.") + + +def to_numpy(*args): + """Returns numpy arrays from any compatible backend""" + + if len(args) == 1: + return get_backend(args[0]).to_numpy(args[0]) + else: + return [get_backend(a).to_numpy(a) for a in args] + + +class Backend(): + """ + Backend abstract class. + Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend` + + - The `__name__` class attribute refers to the name of the backend. + - The `__type__` class attribute refers to the data structure used by the backend. + """ + + __name__ = None + __type__ = None + __type_list__ = None + + rng_ = None + + def __str__(self): + return self.__name__ + + # convert to numpy + def to_numpy(self, a): + """Returns the numpy version of a tensor""" + raise NotImplementedError() + + # convert from numpy + def from_numpy(self, a, type_as=None): + """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)""" + raise NotImplementedError() + + def set_gradients(self, val, inputs, grads): + """Define the gradients for the value val wrt the inputs """ + raise NotImplementedError() + + def zeros(self, shape, type_as=None): + r""" + Creates a tensor full of zeros. + + This function follows the api from :any:`numpy.zeros` + + See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html + """ + raise NotImplementedError() + + def ones(self, shape, type_as=None): + r""" + Creates a tensor full of ones. + + This function follows the api from :any:`numpy.ones` + + See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html + """ + raise NotImplementedError() + + def arange(self, stop, start=0, step=1, type_as=None): + r""" + Returns evenly spaced values within a given interval. + + This function follows the api from :any:`numpy.arange` + + See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html + """ + raise NotImplementedError() + + def full(self, shape, fill_value, type_as=None): + r""" + Creates a tensor with given shape, filled with given value. + + This function follows the api from :any:`numpy.full` + + See: https://numpy.org/doc/stable/reference/generated/numpy.full.html + """ + raise NotImplementedError() + + def eye(self, N, M=None, type_as=None): + r""" + Creates the identity matrix of given size. + + This function follows the api from :any:`numpy.eye` + + See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html + """ + raise NotImplementedError() + + def sum(self, a, axis=None, keepdims=False): + r""" + Sums tensor elements over given dimensions. + + This function follows the api from :any:`numpy.sum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html + """ + raise NotImplementedError() + + def cumsum(self, a, axis=None): + r""" + Returns the cumulative sum of tensor elements over given dimensions. + + This function follows the api from :any:`numpy.cumsum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html + """ + raise NotImplementedError() + + def max(self, a, axis=None, keepdims=False): + r""" + Returns the maximum of an array or maximum along given dimensions. + + This function follows the api from :any:`numpy.amax` + + See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html + """ + raise NotImplementedError() + + def min(self, a, axis=None, keepdims=False): + r""" + Returns the maximum of an array or maximum along given dimensions. + + This function follows the api from :any:`numpy.amin` + + See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html + """ + raise NotImplementedError() + + def maximum(self, a, b): + r""" + Returns element-wise maximum of array elements. + + This function follows the api from :any:`numpy.maximum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html + """ + raise NotImplementedError() + + def minimum(self, a, b): + r""" + Returns element-wise minimum of array elements. + + This function follows the api from :any:`numpy.minimum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html + """ + raise NotImplementedError() + + def dot(self, a, b): + r""" + Returns the dot product of two tensors. + + This function follows the api from :any:`numpy.dot` + + See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html + """ + raise NotImplementedError() + + def abs(self, a): + r""" + Computes the absolute value element-wise. + + This function follows the api from :any:`numpy.absolute` + + See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html + """ + raise NotImplementedError() + + def exp(self, a): + r""" + Computes the exponential value element-wise. + + This function follows the api from :any:`numpy.exp` + + See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html + """ + raise NotImplementedError() + + def log(self, a): + r""" + Computes the natural logarithm, element-wise. + + This function follows the api from :any:`numpy.log` + + See: https://numpy.org/doc/stable/reference/generated/numpy.log.html + """ + raise NotImplementedError() + + def sqrt(self, a): + r""" + Returns the non-ngeative square root of a tensor, element-wise. + + This function follows the api from :any:`numpy.sqrt` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html + """ + raise NotImplementedError() + + def power(self, a, exponents): + r""" + First tensor elements raised to powers from second tensor, element-wise. + + This function follows the api from :any:`numpy.power` + + See: https://numpy.org/doc/stable/reference/generated/numpy.power.html + """ + raise NotImplementedError() + + def norm(self, a): + r""" + Computes the matrix frobenius norm. + + This function follows the api from :any:`numpy.linalg.norm` + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html + """ + raise NotImplementedError() + + def any(self, a): + r""" + Tests whether any tensor element along given dimensions evaluates to True. + + This function follows the api from :any:`numpy.any` + + See: https://numpy.org/doc/stable/reference/generated/numpy.any.html + """ + raise NotImplementedError() + + def isnan(self, a): + r""" + Tests element-wise for NaN and returns result as a boolean tensor. + + This function follows the api from :any:`numpy.isnan` + + See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html + """ + raise NotImplementedError() + + def isinf(self, a): + r""" + Tests element-wise for positive or negative infinity and returns result as a boolean tensor. + + This function follows the api from :any:`numpy.isinf` + + See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html + """ + raise NotImplementedError() + + def einsum(self, subscripts, *operands): + r""" + Evaluates the Einstein summation convention on the operands. + + This function follows the api from :any:`numpy.einsum` + + See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html + """ + raise NotImplementedError() + + def sort(self, a, axis=-1): + r""" + Returns a sorted copy of a tensor. + + This function follows the api from :any:`numpy.sort` + + See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html + """ + raise NotImplementedError() + + def argsort(self, a, axis=None): + r""" + Returns the indices that would sort a tensor. + + This function follows the api from :any:`numpy.argsort` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html + """ + raise NotImplementedError() + + def searchsorted(self, a, v, side='left'): + r""" + Finds indices where elements should be inserted to maintain order in given tensor. + + This function follows the api from :any:`numpy.searchsorted` + + See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html + """ + raise NotImplementedError() + + def flip(self, a, axis=None): + r""" + Reverses the order of elements in a tensor along given dimensions. + + This function follows the api from :any:`numpy.flip` + + See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html + """ + raise NotImplementedError() + + def clip(self, a, a_min, a_max): + """ + Limits the values in a tensor. + + This function follows the api from :any:`numpy.clip` + + See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html + """ + raise NotImplementedError() + + def repeat(self, a, repeats, axis=None): + r""" + Repeats elements of a tensor. + + This function follows the api from :any:`numpy.repeat` + + See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html + """ + raise NotImplementedError() + + def take_along_axis(self, arr, indices, axis): + r""" + Gathers elements of a tensor along given dimensions. + + This function follows the api from :any:`numpy.take_along_axis` + + See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html + """ + raise NotImplementedError() + + def concatenate(self, arrays, axis=0): + r""" + Joins a sequence of tensors along an existing dimension. + + This function follows the api from :any:`numpy.concatenate` + + See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html + """ + raise NotImplementedError() + + def zero_pad(self, a, pad_width): + r""" + Pads a tensor. + + This function follows the api from :any:`numpy.pad` + + See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + """ + raise NotImplementedError() + + def argmax(self, a, axis=None): + r""" + Returns the indices of the maximum values of a tensor along given dimensions. + + This function follows the api from :any:`numpy.argmax` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html + """ + raise NotImplementedError() + + def mean(self, a, axis=None): + r""" + Computes the arithmetic mean of a tensor along given dimensions. + + This function follows the api from :any:`numpy.mean` + + See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html + """ + raise NotImplementedError() + + def std(self, a, axis=None): + r""" + Computes the standard deviation of a tensor along given dimensions. + + This function follows the api from :any:`numpy.std` + + See: https://numpy.org/doc/stable/reference/generated/numpy.std.html + """ + raise NotImplementedError() + + def linspace(self, start, stop, num): + r""" + Returns a specified number of evenly spaced values over a given interval. + + This function follows the api from :any:`numpy.linspace` + + See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html + """ + raise NotImplementedError() + + def meshgrid(self, a, b): + r""" + Returns coordinate matrices from coordinate vectors (Numpy convention). + + This function follows the api from :any:`numpy.meshgrid` + + See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html + """ + raise NotImplementedError() + + def diag(self, a, k=0): + r""" + Extracts or constructs a diagonal tensor. + + This function follows the api from :any:`numpy.diag` + + See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html + """ + raise NotImplementedError() + + def unique(self, a): + r""" + Finds unique elements of given tensor. + + This function follows the api from :any:`numpy.unique` + + See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html + """ + raise NotImplementedError() + + def logsumexp(self, a, axis=None): + r""" + Computes the log of the sum of exponentials of input elements. + + This function follows the api from :any:`scipy.special.logsumexp` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html + """ + raise NotImplementedError() + + def stack(self, arrays, axis=0): + r""" + Joins a sequence of tensors along a new dimension. + + This function follows the api from :any:`numpy.stack` + + See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html + """ + raise NotImplementedError() + + def outer(self, a, b): + r""" + Computes the outer product between two vectors. + + This function follows the api from :any:`numpy.outer` + + See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html + """ + raise NotImplementedError() + + def reshape(self, a, shape): + r""" + Gives a new shape to a tensor without changing its data. + + This function follows the api from :any:`numpy.reshape` + + See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html + """ + raise NotImplementedError() + + def seed(self, seed=None): + r""" + Sets the seed for the random generator. + + This function follows the api from :any:`numpy.random.seed` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html + """ + raise NotImplementedError() + + def rand(self, *size, type_as=None): + r""" + Generate uniform random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + + def randn(self, *size, type_as=None): + r""" + Generate normal Gaussian random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + r""" + Creates a sparse tensor in COOrdinate format. + + This function follows the api from :any:`scipy.sparse.coo_matrix` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html + """ + raise NotImplementedError() + + def issparse(self, a): + r""" + Checks whether or not the input tensor is a sparse tensor. + + This function follows the api from :any:`scipy.sparse.issparse` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html + """ + raise NotImplementedError() + + def tocsr(self, a): + r""" + Converts this matrix to Compressed Sparse Row format. + + This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html + """ + raise NotImplementedError() + + def eliminate_zeros(self, a, threshold=0.): + r""" + Removes entries smaller than the given threshold from the sparse tensor. + + This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros` + + See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html + """ + raise NotImplementedError() + + def todense(self, a): + r""" + Converts a sparse tensor to a dense tensor. + + This function follows the api from :any:`scipy.sparse.csr_matrix.toarray` + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html + """ + raise NotImplementedError() + + def where(self, condition, x, y): + r""" + Returns elements chosen from x or y depending on condition. + + This function follows the api from :any:`numpy.where` + + See: https://numpy.org/doc/stable/reference/generated/numpy.where.html + """ + raise NotImplementedError() + + def copy(self, a): + r""" + Returns a copy of the given tensor. + + This function follows the api from :any:`numpy.copy` + + See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html + """ + raise NotImplementedError() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + r""" + Returns True if two arrays are element-wise equal within a tolerance. + + This function follows the api from :any:`numpy.allclose` + + See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html + """ + raise NotImplementedError() + + def dtype_device(self, a): + r""" + Returns the dtype and the device of the given tensor. + """ + raise NotImplementedError() + + def assert_same_dtype_device(self, a, b): + r""" + Checks whether or not the two given inputs have the same dtype as well as the same device + """ + raise NotImplementedError() + + +class NumpyBackend(Backend): + """ + NumPy implementation of the backend + + - `__name__` is "numpy" + - `__type__` is np.ndarray + """ + + __name__ = 'numpy' + __type__ = np.ndarray + __type_list__ = [np.array(1, dtype=np.float32), + np.array(1, dtype=np.float64)] + + rng_ = np.random.RandomState() + + def to_numpy(self, a): + return a + + def from_numpy(self, a, type_as=None): + if type_as is None: + return a + elif isinstance(a, float): + return a + else: + return a.astype(type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # No gradients for numpy + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return np.zeros(shape) + else: + return np.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return np.ones(shape) + else: + return np.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return np.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return np.full(shape, fill_value) + else: + return np.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return np.eye(N, M) + else: + return np.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return np.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return np.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return np.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return np.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return np.maximum(a, b) + + def minimum(self, a, b): + return np.minimum(a, b) + + def dot(self, a, b): + return np.dot(a, b) + + def abs(self, a): + return np.abs(a) + + def exp(self, a): + return np.exp(a) + + def log(self, a): + return np.log(a) + + def sqrt(self, a): + return np.sqrt(a) + + def power(self, a, exponents): + return np.power(a, exponents) + + def norm(self, a): + return np.sqrt(np.sum(np.square(a))) + + def any(self, a): + return np.any(a) + + def isnan(self, a): + return np.isnan(a) + + def isinf(self, a): + return np.isinf(a) + + def einsum(self, subscripts, *operands): + return np.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return np.sort(a, axis) + + def argsort(self, a, axis=-1): + return np.argsort(a, axis) + + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return np.searchsorted(a, v, side) + else: + # this is a not very efficient way to make numpy + # searchsorted work on 2d arrays + ret = np.empty(v.shape, dtype=int) + for i in range(a.shape[0]): + ret[i, :] = np.searchsorted(a[i, :], v[i, :], side) + return ret + + def flip(self, a, axis=None): + return np.flip(a, axis) + + def outer(self, a, b): + return np.outer(a, b) + + def clip(self, a, a_min, a_max): + return np.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return np.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return np.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return np.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return np.pad(a, pad_width) + + def argmax(self, a, axis=None): + return np.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return np.mean(a, axis=axis) + + def std(self, a, axis=None): + return np.std(a, axis=axis) + + def linspace(self, start, stop, num): + return np.linspace(start, stop, num) + + def meshgrid(self, a, b): + return np.meshgrid(a, b) + + def diag(self, a, k=0): + return np.diag(a, k) + + def unique(self, a): + return np.unique(a) + + def logsumexp(self, a, axis=None): + return scipy.logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return np.stack(arrays, axis) + + def reshape(self, a, shape): + return np.reshape(a, shape) + + def seed(self, seed=None): + if seed is not None: + self.rng_.seed(seed) + + def rand(self, *size, type_as=None): + return self.rng_.rand(*size) + + def randn(self, *size, type_as=None): + return self.rng_.randn(*size) + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + if type_as is None: + return coo_matrix((data, (rows, cols)), shape=shape) + else: + return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype) + + def issparse(self, a): + return issparse(a) + + def tocsr(self, a): + if self.issparse(a): + return a.tocsr() + else: + return csr_matrix(a) + + def eliminate_zeros(self, a, threshold=0.): + if threshold > 0: + if self.issparse(a): + a.data[self.abs(a.data) <= threshold] = 0 + else: + a[self.abs(a) <= threshold] = 0 + if self.issparse(a): + a.eliminate_zeros() + return a + + def todense(self, a): + if self.issparse(a): + return a.toarray() + else: + return a + + def where(self, condition, x, y): + return np.where(condition, x, y) + + def copy(self, a): + return a.copy() + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + if hasattr(a, "dtype"): + return a.dtype, "cpu" + else: + return type(a), "cpu" + + def assert_same_dtype_device(self, a, b): + # numpy has implicit type conversion so we automatically validate the test + pass + + +class JaxBackend(Backend): + """ + JAX implementation of the backend + + - `__name__` is "jax" + - `__type__` is jax.numpy.ndarray + """ + + __name__ = 'jax' + __type__ = jax_type + __type_list__ = None + + rng_ = None + + def __init__(self): + self.rng_ = jax.random.PRNGKey(42) + + for d in jax.devices(): + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d), + jax.device_put(jnp.array(1, dtype=jnp.float64), d)] + + def to_numpy(self, a): + return np.array(a) + + def _change_device(self, a, type_as): + return jax.device_put(a, type_as.device_buffer.device()) + + def from_numpy(self, a, type_as=None): + if type_as is None: + return jnp.array(a) + else: + return self._change_device(jnp.array(a).astype(type_as.dtype), type_as) + + def set_gradients(self, val, inputs, grads): + from jax.flatten_util import ravel_pytree + val, = jax.lax.stop_gradient((val,)) + + ravelled_inputs, _ = ravel_pytree(inputs) + ravelled_grads, _ = ravel_pytree(grads) + + aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2 + aux = aux - jax.lax.stop_gradient(aux) + + val, = jax.tree_map(lambda z: z + aux, (val,)) + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return jnp.zeros(shape) + else: + return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as) + + def ones(self, shape, type_as=None): + if type_as is None: + return jnp.ones(shape) + else: + return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as) + + def arange(self, stop, start=0, step=1, type_as=None): + return jnp.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return jnp.full(shape, fill_value) + else: + return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return jnp.eye(N, M) + else: + return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as) + + def sum(self, a, axis=None, keepdims=False): + return jnp.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return jnp.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return jnp.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return jnp.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return jnp.maximum(a, b) + + def minimum(self, a, b): + return jnp.minimum(a, b) + + def dot(self, a, b): + return jnp.dot(a, b) + + def abs(self, a): + return jnp.abs(a) + + def exp(self, a): + return jnp.exp(a) + + def log(self, a): + return jnp.log(a) + + def sqrt(self, a): + return jnp.sqrt(a) + + def power(self, a, exponents): + return jnp.power(a, exponents) + + def norm(self, a): + return jnp.sqrt(jnp.sum(jnp.square(a))) + + def any(self, a): + return jnp.any(a) + + def isnan(self, a): + return jnp.isnan(a) + + def isinf(self, a): + return jnp.isinf(a) + + def einsum(self, subscripts, *operands): + return jnp.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return jnp.sort(a, axis) + + def argsort(self, a, axis=-1): + return jnp.argsort(a, axis) + + def searchsorted(self, a, v, side='left'): + if a.ndim == 1: + return jnp.searchsorted(a, v, side) + else: + # this is a not very efficient way to make jax numpy + # searchsorted work on 2d arrays + return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])]) + + def flip(self, a, axis=None): + return jnp.flip(a, axis) + + def outer(self, a, b): + return jnp.outer(a, b) + + def clip(self, a, a_min, a_max): + return jnp.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return jnp.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return jnp.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return jnp.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return jnp.pad(a, pad_width) + + def argmax(self, a, axis=None): + return jnp.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return jnp.mean(a, axis=axis) + + def std(self, a, axis=None): + return jnp.std(a, axis=axis) + + def linspace(self, start, stop, num): + return jnp.linspace(start, stop, num) + + def meshgrid(self, a, b): + return jnp.meshgrid(a, b) + + def diag(self, a, k=0): + return jnp.diag(a, k) + + def unique(self, a): + return jnp.unique(a) + + def logsumexp(self, a, axis=None): + return jscipy.logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return jnp.stack(arrays, axis) + + def reshape(self, a, shape): + return jnp.reshape(a, shape) + + def seed(self, seed=None): + if seed is not None: + self.rng_ = jax.random.PRNGKey(seed) + + def rand(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.uniform(subkey, shape=size) + + def randn(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.normal(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.normal(subkey, shape=size) + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + # Currently, JAX does not support sparse matrices + data = self.to_numpy(data) + rows = self.to_numpy(rows) + cols = self.to_numpy(cols) + nx = NumpyBackend() + coo_matrix = nx.coo_matrix(data, rows, cols, shape=shape, type_as=type_as) + matrix = nx.todense(coo_matrix) + return self.from_numpy(matrix) + + def issparse(self, a): + # Currently, JAX does not support sparse matrices + return False + + def tocsr(self, a): + # Currently, JAX does not support sparse matrices + return a + + def eliminate_zeros(self, a, threshold=0.): + # Currently, JAX does not support sparse matrices + if threshold > 0: + return self.where( + self.abs(a) <= threshold, + self.zeros((1,), type_as=a), + a + ) + return a + + def todense(self, a): + # Currently, JAX does not support sparse matrices + return a + + def where(self, condition, x, y): + return jnp.where(condition, x, y) + + def copy(self, a): + # No need to copy, JAX arrays are immutable + return a + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + return a.dtype, a.device_buffer.device() + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + + +class TorchBackend(Backend): + """ + PyTorch implementation of the backend + + - `__name__` is "torch" + - `__type__` is torch.Tensor + """ + + __name__ = 'torch' + __type__ = torch_type + __type_list__ = None + + rng_ = None + + def __init__(self): + + self.rng_ = torch.Generator() + self.rng_.seed() + + self.__type_list__ = [torch.tensor(1, dtype=torch.float32), + torch.tensor(1, dtype=torch.float64)] + + if torch.cuda.is_available(): + self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) + self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + + from torch.autograd import Function + + # define a function that takes inputs val and grads + # ad returns a val tensor with proper gradients + class ValFunction(Function): + + @staticmethod + def forward(ctx, val, grads, *inputs): + ctx.grads = grads + return val + + @staticmethod + def backward(ctx, grad_output): + # the gradients are grad + return (None, None) + ctx.grads + + self.ValFunction = ValFunction + + def to_numpy(self, a): + return a.cpu().detach().numpy() + + def from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) + if type_as is None: + return torch.from_numpy(a) + else: + return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device) + + def set_gradients(self, val, inputs, grads): + + Func = self.ValFunction() + + res = Func.apply(val, grads, *inputs) + + return res + + def zeros(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) + if type_as is None: + return torch.zeros(shape) + else: + return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device) + + def ones(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) + if type_as is None: + return torch.ones(shape) + else: + return torch.ones(shape, dtype=type_as.dtype, device=type_as.device) + + def arange(self, stop, start=0, step=1, type_as=None): + if type_as is None: + return torch.arange(start, stop, step) + else: + return torch.arange(start, stop, step, device=type_as.device) + + def full(self, shape, fill_value, type_as=None): + if isinstance(shape, int): + shape = (shape,) + if type_as is None: + return torch.full(shape, fill_value) + else: + return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device) + + def eye(self, N, M=None, type_as=None): + if M is None: + M = N + if type_as is None: + return torch.eye(N, m=M) + else: + return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device) + + def sum(self, a, axis=None, keepdims=False): + if axis is None: + return torch.sum(a) + else: + return torch.sum(a, axis, keepdim=keepdims) + + def cumsum(self, a, axis=None): + if axis is None: + return torch.cumsum(a.flatten(), 0) + else: + return torch.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + if axis is None: + return torch.max(a) + else: + return torch.max(a, axis, keepdim=keepdims)[0] + + def min(self, a, axis=None, keepdims=False): + if axis is None: + return torch.min(a) + else: + return torch.min(a, axis, keepdim=keepdims)[0] + + def maximum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + if hasattr(torch, "maximum"): + return torch.maximum(a, b) + else: + return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] + + def minimum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + if hasattr(torch, "minimum"): + return torch.minimum(a, b) + else: + return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0] + + def dot(self, a, b): + return torch.matmul(a, b) + + def abs(self, a): + return torch.abs(a) + + def exp(self, a): + return torch.exp(a) + + def log(self, a): + return torch.log(a) + + def sqrt(self, a): + return torch.sqrt(a) + + def power(self, a, exponents): + return torch.pow(a, exponents) + + def norm(self, a): + return torch.sqrt(torch.sum(torch.square(a))) + + def any(self, a): + return torch.any(a) + + def isnan(self, a): + return torch.isnan(a) + + def isinf(self, a): + return torch.isinf(a) + + def einsum(self, subscripts, *operands): + return torch.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + sorted0, indices = torch.sort(a, dim=axis) + return sorted0 + + def argsort(self, a, axis=-1): + sorted, indices = torch.sort(a, dim=axis) + return indices + + def searchsorted(self, a, v, side='left'): + right = (side != 'left') + return torch.searchsorted(a, v, right=right) + + def flip(self, a, axis=None): + if axis is None: + return torch.flip(a, tuple(i for i in range(len(a.shape)))) + if isinstance(axis, int): + return torch.flip(a, (axis,)) + else: + return torch.flip(a, dims=axis) + + def outer(self, a, b): + return torch.outer(a, b) + + def clip(self, a, a_min, a_max): + return torch.clamp(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return torch.repeat_interleave(a, repeats, dim=axis) + + def take_along_axis(self, arr, indices, axis): + return torch.gather(arr, axis, indices) + + def concatenate(self, arrays, axis=0): + return torch.cat(arrays, dim=axis) + + def zero_pad(self, a, pad_width): + from torch.nn.functional import pad + # pad_width is an array of ndim tuples indicating how many 0 before and after + # we need to add. We first need to make it compliant with torch syntax, that + # starts with the last dim, then second last, etc. + how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl) + return pad(a, how_pad) + + def argmax(self, a, axis=None): + return torch.argmax(a, dim=axis) + + def mean(self, a, axis=None): + if axis is not None: + return torch.mean(a, dim=axis) + else: + return torch.mean(a) + + def std(self, a, axis=None): + if axis is not None: + return torch.std(a, dim=axis, unbiased=False) + else: + return torch.std(a, unbiased=False) + + def linspace(self, start, stop, num): + return torch.linspace(start, stop, num, dtype=torch.float64) + + def meshgrid(self, a, b): + X, Y = torch.meshgrid(a, b) + return X.T, Y.T + + def diag(self, a, k=0): + return torch.diag(a, diagonal=k) + + def unique(self, a): + return torch.unique(a) + + def logsumexp(self, a, axis=None): + if axis is not None: + return torch.logsumexp(a, dim=axis) + else: + return torch.logsumexp(a, dim=tuple(range(len(a.shape)))) + + def stack(self, arrays, axis=0): + return torch.stack(arrays, dim=axis) + + def reshape(self, a, shape): + return torch.reshape(a, shape) + + def seed(self, seed=None): + if isinstance(seed, int): + self.rng_.manual_seed(seed) + elif isinstance(seed, torch.Generator): + self.rng_ = seed + else: + raise ValueError("Non compatible seed : {}".format(seed)) + + def rand(self, *size, type_as=None): + if type_as is not None: + return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + else: + return torch.rand(size=size, generator=self.rng_) + + def randn(self, *size, type_as=None): + if type_as is not None: + return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + else: + return torch.randn(size=size, generator=self.rng_) + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + if type_as is None: + return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape) + else: + return torch.sparse_coo_tensor( + torch.stack([rows, cols]), data, size=shape, + dtype=type_as.dtype, device=type_as.device + ) + + def issparse(self, a): + return getattr(a, "is_sparse", False) or getattr(a, "is_sparse_csr", False) + + def tocsr(self, a): + # Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support + return self.todense(a) + + def eliminate_zeros(self, a, threshold=0.): + if self.issparse(a): + if threshold > 0: + mask = self.abs(a) <= threshold + mask = ~mask + mask = mask.nonzero() + else: + mask = a._values().nonzero() + nv = a._values().index_select(0, mask.view(-1)) + ni = a._indices().index_select(1, mask.view(-1)) + return self.coo_matrix(nv, ni[0], ni[1], shape=a.shape, type_as=a) + else: + if threshold > 0: + a[self.abs(a) <= threshold] = 0 + return a + + def todense(self, a): + if self.issparse(a): + return a.to_dense() + else: + return a + + def where(self, condition, x, y): + return torch.where(condition, x, y) + + def copy(self, a): + return torch.clone(a) + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + def dtype_device(self, a): + return a.dtype, a.device + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" diff --git a/ot/bregman.py b/ot/bregman.py index f1f8437..cce52e2 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -7,70 +7,104 @@ Bregman projections solvers for entropic regularized OT # Nicolas Courty <ncourty@irisa.fr> # Kilian Fatras <kilian.fatras@irisa.fr> # Titouan Vayer <titouan.vayer@irisa.fr> -# Hicham Janati <hicham.janati@inria.fr> +# Hicham Janati <hicham.janati100@gmail.com> # Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> # Alexander Tong <alexander.tong@yale.edu> # Ievgen Redko <ievgen.redko@univ-st-etienne.fr> +# Quang Huy Tran <quang-huy.tran@univ-ubs.fr> # # License: MIT License -import numpy as np import warnings -from .utils import unif, dist + +import numpy as np from scipy.optimize import fmin_l_bfgs_b +from ot.utils import unif, dist, list_to_array +from .backend import get_backend + def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=True, + **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - s.t. \gamma 1 = a + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma^T 1= b + \gamma &\geq 0 - \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[2] <references-sinkhorn>` + + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + method used for the solver either 'sinkhorn','sinkhorn_log', + 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see + those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True - + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -86,102 +120,152 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, array([[0.36552929, 0.13447071], [0.13447071, 0.36552929]]) - + .. _references-sinkhorn: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, + A., & Peyré, G. (2019, April). Interpolating between optimal transport + and MMD using Sinkhorn divergences. In The 22nd International Conference + on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] + ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn>` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn + :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>` + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling + :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>` """ if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, + **kwargs) elif method.lower() == 'greenkhorn': return greenkhorn(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log) + stopThr=stopThr, verbose=verbose, log=log, + warn=warn) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) else: raise ValueError("Unknown method '%s'." % method) def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - s.t. \gamma 1 = a + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma^T 1= b + \gamma &\geq 0 - \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[2] <references-sinkhorn2>` + + + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 method : str - method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or - 'sinkhorn_epsilon_scaling', see those function for specific parameters + method used for the solver either 'sinkhorn','sinkhorn_log', + 'sinkhorn_stabilized', see those function for specific parameters numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - W : (n_hists) ndarray or float + W : (n_hists) float/array-like Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters + Examples -------- @@ -190,99 +274,142 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn2(a, b, M, 1) - array([0.26894142]) - + 0.26894142136999516 + .. _references-sinkhorn2: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. - [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation + algorithms for optimal transport via Sinkhorn iteration, + Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., + Trouvé, A., & Peyré, G. (2019, April). + Interpolating between optimal transport and MMD using Sinkhorn + divergences. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2681-2690). PMLR. See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] - ot.bregman.greenkhorn : Greenkhorn [21] - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] - + ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn2>` + ot.bregman.greenkhorn : Greenkhorn :ref:`[21] <references-sinkhorn2>` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn + :ref:`[9] <references-sinkhorn2>` :ref:`[10] <references-sinkhorn2>` """ - b = np.asarray(b, dtype=np.float64) + + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + if len(b.shape) < 2: - b = b[:, None] - if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - elif method.lower() == 'sinkhorn_epsilon_scaling': - return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + if method.lower() == 'sinkhorn': + res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + if log: + return nx.sum(M * res[0]), res[1] + else: + return nx.sum(M * res) + else: - raise ValueError("Unknown method '%s'." % method) + + if method.lower() == 'sinkhorn': + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, + verbose=False, log=False, warn=True, + **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp + matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-knopp>` Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -299,10 +426,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, [0.13447071, 0.36552929]]) + .. _references-sinkhorn-knopp: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 See Also @@ -312,18 +442,18 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) # init data dim_a = len(a) - dim_b = len(b) + dim_b = b.shape[0] if len(b.shape) > 1: n_hists = b.shape[1] @@ -336,66 +466,64 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b - # print(reg) - - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - - # print(np.min(K)) - tmp2 = np.empty(b.shape, dtype=M.dtype) + K = nx.exp(M / (-reg)) Kp = (1 / a).reshape(-1, 1) * K - cpt = 0 + err = 1 - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): uprev = u vprev = v + KtransposeU = nx.dot(K.T, u) + v = b / KtransposeU + u = 1. / nx.dot(Kp, v) - KtransposeU = np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.dot(Kp, v) - - if (np.any(KtransposeU == 0) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(KtransposeU == 0) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Warning: numerical errors at iteration %d' % ii) u = uprev v = vprev break - if cpt % 10 == 0: + if ii % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: - np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2) + tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - np.einsum('i,ij,j->j', u, K, v, out=tmp2) - err = np.linalg.norm(tmp2 - b) # violation of marginal + tmp2 = nx.einsum('i,ij,j->j', u, K, v) + err = nx.norm(tmp2 - b) # violation of marginal if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log['niter'] = ii log['u'] = u log['v'] = v if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: @@ -409,58 +537,259 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, - log=False): +def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, + log=False, warn=True, **kwargs): r""" - Solve the entropic regularization optimal transport problem and return the OT matrix + Solve the entropic regularization optimal transport problem in log space + and return the OT matrix + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm :ref:`[2] <references-sinkhorn-log>` with the + implementation from :ref:`[34] <references-sinkhorn-log>` + + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + Returns + ------- + gamma : array-like, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + .. _references-sinkhorn-log: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., + Trouvé, A., & Peyré, G. (2019, April). Interpolating between + optimal transport and MMD using Sinkhorn divergences. In The + 22nd International Conference on Artificial Intelligence and + Statistics (pp. 2681-2690). PMLR. + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + + if len(a) == 0: + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) + if len(b) == 0: + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) + + # init data + dim_a = len(a) + dim_b = b.shape[0] + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 - The algorithm used is based on the paper + if n_hists: # we do not want to use tensors sor we do a loop - Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration - by Jason Altschuler, Jonathan Weed, Philippe Rigollet - appeared at NIPS 2017 + lst_loss = [] + lst_u = [] + lst_v = [] - which is a stochastic version of the Sinkhorn-Knopp algorithm [2]. + for k in range(n_hists): + res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, **kwargs) + + if log: + lst_loss.append(nx.sum(M * res[0])) + lst_u.append(res[1]['log_u']) + lst_v.append(res[1]['log_v']) + else: + lst_loss.append(nx.sum(M * res)) + res = nx.stack(lst_loss) + if log: + log = {'log_u': nx.stack(lst_u, 1), + 'log_v': nx.stack(lst_v, 1), } + log['u'] = nx.exp(log['log_u']) + log['v'] = nx.exp(log['log_v']) + return res, log + else: + return res + + else: + + if log: + log = {'err': []} + + Mr = - M / reg + + # we assume that no distances are null except those of the diagonal of + # distances + + u = nx.zeros(dim_a, type_as=M) + v = nx.zeros(dim_b, type_as=M) + + def get_logT(u, v): + if n_hists: + return Mr[:, :, None] + u + v + else: + return Mr + u[:, None] + v[None, :] + + loga = nx.log(a) + logb = nx.log(b) + + err = 1 + for ii in range(numItermax): + + v = logb - nx.logsumexp(Mr + u[:, None], 0) + u = loga - nx.logsumexp(Mr + v[None, :], 1) + + if ii % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0) + err = nx.norm(tmp2 - b) # violation of marginal + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + + if log: + log['niter'] = ii + log['log_u'] = u + log['log_v'] = v + log['u'] = nx.exp(u) + log['v'] = nx.exp(v) + + return nx.exp(get_logT(u, v)), log + + else: + return nx.exp(get_logT(u, v)) + + +def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, + log=False, warn=True): + r""" + Solve the entropic regularization optimal transport problem and return the OT matrix + + The algorithm used is based on the paper :ref:`[22] <references-greenkhorn>` + which is a stochastic version of the Sinkhorn-Knopp + algorithm :ref:`[2] <references-greenkhorn>` The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) - + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -477,11 +806,18 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, [0.13447071, 0.36552929]]) + .. _references-greenkhorn: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 + + .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time + approximation algorithms for optimal transport via Sinkhorn + iteration, Advances in Neural Information Processing + Systems (NIPS) 31, 2017 See Also @@ -491,68 +827,70 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + if nx.__name__ == "jax": + raise TypeError("JAX arrays have been received. Greenkhorn is not " + "compatible with JAX") if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] dim_a = a.shape[0] dim_b = b.shape[0] - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty_like(M) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) - u = np.full(dim_a, 1. / dim_a) - v = np.full(dim_b, 1. / dim_b) - G = u[:, np.newaxis] * K * v[np.newaxis, :] + u = nx.full((dim_a,), 1. / dim_a, type_as=K) + v = nx.full((dim_b,), 1. / dim_b, type_as=K) + G = u[:, None] * K * v[None, :] - viol = G.sum(1) - a - viol_2 = G.sum(0) - b + viol = nx.sum(G, axis=1) - a + viol_2 = nx.sum(G, axis=0) - b stopThr_val = 1 - if log: log = dict() log['u'] = u log['v'] = v - for i in range(numItermax): - i_1 = np.argmax(np.abs(viol)) - i_2 = np.argmax(np.abs(viol_2)) - m_viol_1 = np.abs(viol[i_1]) - m_viol_2 = np.abs(viol_2[i_2]) - stopThr_val = np.maximum(m_viol_1, m_viol_2) + for ii in range(numItermax): + i_1 = nx.argmax(nx.abs(viol)) + i_2 = nx.argmax(nx.abs(viol_2)) + m_viol_1 = nx.abs(viol[i_1]) + m_viol_2 = nx.abs(viol_2[i_2]) + stopThr_val = nx.maximum(m_viol_1, m_viol_2) if m_viol_1 > m_viol_2: old_u = u[i_1] - u[i_1] = a[i_1] / (K[i_1, :].dot(v)) - G[i_1, :] = u[i_1] * K[i_1, :] * v - - viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1] - viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v) + new_u = a[i_1] / (K[i_1, :].dot(v)) + G[i_1, :] = new_u * K[i_1, :] * v + viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1] + viol_2 += (K[i_1, :].T * (new_u - old_u) * v) + u[i_1] = new_u else: old_v = v[i_2] - v[i_2] = b[i_2] / (K[:, i_2].T.dot(u)) - G[:, i_2] = u * K[:, i_2] * v[i_2] + new_v = b[i_2] / (K[:, i_2].T.dot(u)) + G[:, i_2] = u * K[:, i_2] * new_v # aviol = (G@one_m - a) # aviol_2 = (G.T@one_n - b) - viol += (-old_v + v[i_2]) * K[:, i_2] * u - viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2] - - # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) + viol += (-old_v + new_v) * K[:, i_2] * u + viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2] + v[i_2] = new_v if stopThr_val <= stopThr: break else: - print('Warning: Algorithm did not converge') + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log["n_iter"] = ii log['u'] = u log['v'] = v @@ -564,58 +902,66 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, - log=False, **kwargs): + log=False, warn=True, **kwargs): r""" Solve the entropic regularization OT problem with log stabilization The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in [2]_ but with the log stabilization - proposed in [10]_ an defined in [9]_ (Algo 3.1) . + scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-stabilized>` + but with the log stabilization + proposed in :ref:`[10] <references-sinkhorn-stabilized>` an defined in + :ref:`[9] <references-sinkhorn-stabilized>` (Algo 3.1) . Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) + b : array-like, shape (dim_b,) samples in the target domain - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 tau : float - thershold for max value in u or v for log scaling - warmstart : tible of vectors - if given then sarting values for alpha an beta log scalings + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling + warmstart : table of vectors + if given then starting values for alpha and beta log scalings numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -632,14 +978,21 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, [0.13447071, 0.36552929]]) + .. _references-sinkhorn-stabilized: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. See Also @@ -649,19 +1002,19 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] # test if multiple target if len(b.shape) > 1: n_hists = b.shape[1] - a = a[:, np.newaxis] + a = a[:, None] else: n_hists = 0 @@ -669,123 +1022,123 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, dim_a = len(a) dim_b = len(b) - cpt = 0 if log: log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances if warmstart is None: - alpha, beta = np.zeros(dim_a), np.zeros(dim_b) + alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M) else: alpha, beta = warmstart if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b + u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M) + u /= dim_a + v /= dim_b def get_K(alpha, beta): """log space computation""" - return np.exp(-(M - alpha.reshape((dim_a, 1)) + return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg) def get_Gamma(alpha, beta, u, v): """log space gamma computation""" - return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) - / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b)))) - - # print(np.min(K)) + return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) + / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b)))) K = get_K(alpha, beta) transp = K - loop = 1 - cpt = 0 err = 1 - while loop: + for ii in range(numItermax): uprev = u vprev = v # sinkhorn update - v = b / (np.dot(K.T, u) + 1e-16) - u = a / (np.dot(K, v) + 1e-16) + v = b / (nx.dot(K.T, u)) + u = a / (nx.dot(K, v)) # remove numerical problems and store them in K - if np.abs(u).max() > tau or np.abs(v).max() > tau: + if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: if n_hists: - alpha, beta = alpha + reg * \ - np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) else: - alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) + alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) if n_hists: - u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b K = get_K(alpha, beta) - if cpt % print_period == 0: + if ii % print_period == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: - err_u = abs(u - uprev).max() - err_u /= max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() - err_v /= max(abs(v).max(), abs(vprev).max(), 1.) + err_u = nx.max(nx.abs(u - uprev)) + err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0) + err_v = nx.max(nx.abs(v - vprev)) + err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0) err = 0.5 * (err_u + err_v) else: transp = get_Gamma(alpha, beta, u, v) - err = np.linalg.norm((np.sum(transp, axis=0) - b)) + err = nx.norm(nx.sum(transp, axis=0) - b) if log: log['err'].append(err) if verbose: - if cpt % (print_period * 20) == 0: + if ii % (print_period * 20) == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) if err <= stopThr: - loop = False - - if cpt >= numItermax: - loop = False + break - if np.any(np.isnan(u)) or np.any(np.isnan(v)): + if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration %d' % ii) u = uprev v = vprev break - - cpt = cpt + 1 - + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: if n_hists: alpha = alpha[:, None] beta = beta[:, None] - logu = alpha / reg + np.log(u) - logv = beta / reg + np.log(v) + logu = alpha / reg + nx.log(u) + logv = beta / reg + nx.log(v) + log["n_iter"] = ii log['logu'] = logu log['logv'] = logv - log['alpha'] = alpha + reg * np.log(u) - log['beta'] = beta + reg * np.log(v) + log['alpha'] = alpha + reg * nx.log(u) + log['beta'] = beta + reg * nx.log(v) log['warmstart'] = (log['alpha'], log['beta']) if n_hists: - res = np.zeros((n_hists)) - for i in range(n_hists): - res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + res = nx.stack([ + nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + for i in range(n_hists) + ]) return res, log else: return get_Gamma(alpha, beta, u, v), log else: if n_hists: - res = np.zeros((n_hists)) - for i in range(n_hists): - res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + res = nx.stack([ + nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + for i in range(n_hists) + ]) return res else: return get_Gamma(alpha, beta, u, v) @@ -794,70 +1147,73 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, - log=False, **kwargs): + log=False, warn=True, **kwargs): r""" Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. - The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 - where : + \gamma &\geq 0 - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (histograms, both sum to 1) + where : + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in [2]_ but with the log stabilization - proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2 - + scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-epsilon-scaling>` + but with the log stabilization + proposed in :ref:`[10] <references-sinkhorn-epsilon-scaling>` and the log scaling + proposed in :ref:`[9] <references-sinkhorn-epsilon-scaling>` algorithm 3.2 Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) + b : array-like, shape (dim_b,) samples in the target domain - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 tau : float - thershold for max value in u or v for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` + for log scaling warmstart : tuple of vectors - if given then sarting values for alpha an beta log scalings + if given then starting values for alpha and beta log scalings numItermax : int, optional Max number of iterations numInnerItermax : int, optional - Max number of iterationsin the inner slog stabilized sinkhorn + Max number of iterations in the inner slog stabilized sinkhorn epsilon0 : int, optional first epsilon regularization value (then exponential decrease to reg) stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters - Examples -------- - >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] @@ -866,29 +1222,32 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, array([[0.36552929, 0.13447071], [0.13447071, 0.36552929]]) - + .. _references-sinkhorn-epsilon-scaling: References ---------- + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] # init data dim_a = len(a) @@ -898,14 +1257,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numItermin = 35 numItermax = max(numItermin, numItermax) # ensure that last velue is exact - cpt = 0 + ii = 0 if log: log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances if warmstart is None: - alpha, beta = np.zeros(dim_a), np.zeros(dim_b) + alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M) else: alpha, beta = warmstart @@ -913,12 +1272,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, def get_reg(n): # exponential decreasing return (epsilon0 - reg) * np.exp(-n) + reg - loop = 1 - cpt = 0 err = 1 - while loop: + for ii in range(numItermax): - regi = get_reg(cpt) + regi = get_reg(ii) G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, @@ -928,33 +1285,31 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, alpha = logi['alpha'] beta = logi['beta'] - if cpt >= numItermax: - loop = False - - if cpt % (print_period) == 0: # spsion nearly converged + if ii % (print_period) == 0: # spsion nearly converged # we can speed up the process by checking for the error only all # the 10th iterations transp = G - err = np.linalg.norm( - (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2 + err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2 if log: log['err'].append(err) if verbose: - if cpt % (print_period * 10) == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - if err <= stopThr and cpt > numItermin: - loop = False + if ii % (print_period * 10) == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) - cpt = cpt + 1 - # print('err=',err,' cpt=',cpt) + if err <= stopThr and ii > numItermin: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['alpha'] = alpha log['beta'] = beta log['warmstart'] = (log['alpha'], log['beta']) + log['niter'] = ii return G, log else: return G @@ -962,76 +1317,94 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, def geometricBar(weights, alldistribT): """return the weighted geometric mean of distributions""" + weights, alldistribT = list_to_array(weights, alldistribT) + nx = get_backend(weights, alldistribT) assert (len(weights) == alldistribT.shape[1]) - return np.exp(np.dot(np.log(alldistribT), weights.T)) + return nx.exp(nx.dot(nx.log(alldistribT), weights.T)) def geometricMean(alldistribT): """return the geometric mean of distributions""" - return np.exp(np.mean(np.log(alldistribT), axis=1)) + alldistribT = list_to_array(alldistribT) + nx = get_backend(alldistribT) + return nx.exp(nx.mean(nx.log(alldistribT), axis=1)) def projR(gamma, p): """return the KL projection on the row constrints """ - return np.multiply(gamma.T, p / np.maximum(np.sum(gamma, axis=1), 1e-10)).T + gamma, p = list_to_array(gamma, p) + nx = get_backend(gamma, p) + return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T def projC(gamma, q): """return the KL projection on the column constrints """ - return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10)) + gamma, q = list_to_array(gamma, q) + nx = get_backend(gamma, q) + return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10) def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, **kwargs): - r"""Compute the entropic regularized wasserstein barycenter of distributions A + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`. + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] <references-barycenter>` Parameters ---------- - A : ndarray, shape (dim, n_hists) - n_hists training distributions a_i of size dim - M : ndarray, shape (dim, dim) + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` + M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 method : str (optional) - method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' - weights : ndarray, shape (n_hists,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log' + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + .. _references-barycenter: References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1039,232 +1412,327 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, return barycenter_sinkhorn(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return barycenter_stabilized(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_sinkhorn_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[3]<references-barycenter-sinkhorn>`. Parameters ---------- - A : ndarray, shape (dim, n_hists) - n_hists training distributions a_i of size dim - M : ndarray, shape (dim, dim) + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` + M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 - weights : ndarray, shape (n_hists,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + .. _references-barycenter-sinkhorn: References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + if weights is None: - weights = np.ones(A.shape[1]) / A.shape[1] + weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] else: assert (len(weights) == A.shape[1]) if log: log = {'err': []} - # M = M/np.median(M) # suggested by G. Peyre - K = np.exp(-M / reg) + K = nx.exp(-M / reg) - cpt = 0 err = 1 - UKv = np.dot(K, np.divide(A.T, np.sum(K, axis=0)).T) + UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) + u = (geometricMean(UKv) / UKv.T).T - while (err > stopThr and cpt < numItermax): - cpt = cpt + 1 - UKv = u * np.dot(K, np.divide(A, np.dot(K, u))) + for ii in range(numItermax): + + UKv = u * nx.dot(K, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv - if cpt % 10 == 1: - err = np.sum(np.std(UKv, axis=1)) + if ii % 10 == 1: + err = nx.sum(nx.std(UKv, axis=1)) # log and verbose print if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt + log['niter'] = ii return geometricBar(weights, UKv), log else: return geometricBar(weights, UKv) +def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the entropic wasserstein barycenter in log-domain + """ + + A, M = list_to_array(A, M) + dim, n_hists = A.shape + + nx = get_backend(A, M) + + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + if weights is None: + weights = nx.ones(n_hists, type_as=A) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar = log_bar + weights[k] * log_KU[:, k] + + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - with stabilization. + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_ + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] <references-barycenter-stabilized>` Parameters ---------- - A : ndarray, shape (dim, n_hists) - n_hists training distributions a_i of size dim - M : ndarray, shape (dim, dim) + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` + M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 tau : float - thershold for max value in u or v for log scaling - weights : ndarray, shape (n_hists,) - Weights of each histogram a_i on the simplex (barycentric coodinates) + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + .. _references-barycenter-stabilized: References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones((n_hists,), type_as=M) / n_hists else: assert (len(weights) == A.shape[1]) if log: log = {'err': []} - u = np.ones((dim, n_hists)) / dim - v = np.ones((dim, n_hists)) / dim + u = nx.ones((dim, n_hists), type_as=M) / dim + v = nx.ones((dim, n_hists), type_as=M) / dim - # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) - cpt = 0 err = 1. - alpha = np.zeros(dim) - beta = np.zeros(dim) - q = np.ones(dim) / dim - while (err > stopThr and cpt < numItermax): + alpha = nx.zeros((dim,), type_as=M) + beta = nx.zeros((dim,), type_as=M) + q = nx.ones((dim,), type_as=M) / dim + for ii in range(numItermax): qprev = q - Kv = K.dot(v) - u = A / (Kv + 1e-16) - Ktu = K.T.dot(u) + Kv = nx.dot(K, v) + u = A / Kv + Ktu = nx.dot(K.T, u) q = geometricBar(weights, Ktu) Q = q[:, None] - v = Q / (Ktu + 1e-16) + v = Q / Ktu absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha += reg * nx.log(nx.max(u, 1)) + beta += reg * nx.log(nx.max(v, 1)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(tuple(v.shape), type_as=v) + Kv = nx.dot(K, v) + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn('Numerical errors at iteration %s' % ii) q = qprev break - if (cpt % 10 == 0 and not absorbing) or cpt == 0: + if (ii % 10 == 0 and not absorbing) or ii == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(u * Kv - A).max() + err = nx.max(nx.abs(u * Kv - A)) if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 50 == 0: + if ii % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) - cpt += 1 - if err > stopThr: - warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + - "Try a larger entropy `reg`" + - "Or a larger absorption threshold `tau`.") + else: + if warn: + warnings.warn("Stabilized Sinkhorn did not converge." + + "Try a larger entropy `reg`" + + "Or a larger absorption threshold `tau`.") if log: - log['niter'] = cpt + log['niter'] = ii log['logu'] = np.log(u + 1e-16) log['logv'] = np.log(v + 1e-16) return q, log @@ -1272,157 +1740,717 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, return q -def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, - stopThr=1e-9, stabThr=1e-30, verbose=False, - log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - where A is a collection of 2D images. +def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): + r"""Compute the debiased Sinkhorn barycenter of distributions A The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) - - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` - - reg is the regularization strength scalar value + - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence + (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_ + The algorithm used for solving the problem is the debiased Sinkhorn + algorithm as proposed in :ref:`[37] <references-barycenter-debiased>` Parameters ---------- - A : ndarray, shape (n_hists, width, height) - n distributions (2D images) of size width x height + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` + M : array-like, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + method : str (optional) + method used for the solver either 'sinkhorn' or 'sinkhorn_log' + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + Returns + ------- + a : (dim,) array-like + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-barycenter-debiased: + References + ---------- + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _barycenter_debiased(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_debiased_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the debiased sinkhorn barycenter of distributions A. + """ + + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + + if weights is None: + weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + K = nx.exp(-M / reg) + + err = 1 + + UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) + + u = (geometricMean(UKv) / UKv.T).T + c = nx.ones(A.shape[0], type_as=A) + bar = nx.ones(A.shape[0], type_as=A) + + for ii in range(numItermax): + bold = bar + UKv = nx.dot(K, A / nx.dot(K, u)) + bar = c * geometricBar(weights, UKv) + u = bar[:, None] / UKv + c = (c * bar / nx.dot(K, c)) ** 0.5 + + if ii % 10 == 9: + err = abs(bar - bold).max() / max(bar.max(), 1.) + + # log and verbose print + if log: + log['err'].append(err) + + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return bar, log + else: + return bar + + +def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, + warn=True): + r"""Compute the debiased sinkhorn barycenter in log domain. + """ + + A, M = list_to_array(A, M) + dim, n_hists = A.shape + + nx = get_backend(A, M) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + if weights is None: + weights = nx.ones(n_hists, type_as=A) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) + c = nx.zeros(dim, type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar += weights[k] * log_KU[:, k] + log_bar += c + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + for _ in range(10): + c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, + warn=True, **kwargs): + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` + where :math:`\mathbf{A}` is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions + of matrix :math:`\mathbf{A}` + - `reg` is the regularization strength scalar value + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm + as proposed in :ref:`[21] <references-convolutional-barycenter-2d>` + + Parameters + ---------- + A : array-like, shape (n_hists, width, height) + `n` distributions (2D images) of size `width` x `height` reg : float Regularization term >0 - weights : ndarray, shape (n_hists,) + weights : array-like, shape (n_hists,) Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) stabThr : float, optional Stabilization threshold to avoid numerical precision issue verbose : bool, optional Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - a : ndarray, shape (width, height) + a : array-like, shape (width, height) 2D Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + + .. _references-convolutional-barycenter-2d: References ---------- - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). - Convolutional wasserstein distances: Efficient optimal transportation on geometric domains - ACM Transactions on Graphics (TOG), 34(4), 66 + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, + A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: + Efficient optimal transportation on geometric domains. ACM Transactions + on Graphics (TOG), 34(4), 66 + + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) +def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, + stopThr=1e-9, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. """ + A = list_to_array(A) + + nx = get_backend(A) + if weights is None: - weights = np.ones(A.shape[0]) / A.shape[0] + weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0] else: assert (len(weights) == A.shape[0]) if log: log = {'err': []} - b = np.zeros_like(A[0, :, :]) - U = np.ones_like(A) - KV = np.ones_like(A) - - cpt = 0 + bar = nx.ones(A.shape[1:], type_as=A) + bar /= bar.sum() + U = nx.ones(A.shape, type_as=A) + V = nx.ones(A.shape, type_as=A) err = 1 # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions - t = np.linspace(0, 1, A.shape[1]) - [Y, X] = np.meshgrid(t, t) - xi1 = np.exp(-(X - Y) ** 2 / reg) + t = nx.linspace(0, 1, A.shape[1]) + [Y, X] = nx.meshgrid(t, t) + K1 = nx.exp(-(X - Y) ** 2 / reg) + + t = nx.linspace(0, 1, A.shape[2]) + [Y, X] = nx.meshgrid(t, t) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KU = convol_imgs(U) + for ii in range(numItermax): + V = bar[None] / KU + KV = convol_imgs(V) + U = A / KV + KU = convol_imgs(U) + bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + if ii % 10 == 9: + err = (V * KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + log['U'] = U + return bar, log + else: + return bar + + +def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-4, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images in log-domain. + """ + + A = list_to_array(A) + + nx = get_backend(A) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + n_hists, width, height = A.shape + + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == n_hists) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar = log_bar + weights[k] * log_KU[k] + + if ii % 10 == 9: + err = nx.exp(G + log_KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + G = log_bar[None, :, :] - log_KU + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", + numItermax=10000, stopThr=1e-3, + verbose=False, log=False, warn=True, + **kwargs): + r"""Compute the debiased sinkhorn barycenter of distributions :math:`\mathbf{A}` + where :math:`\mathbf{A}` is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.barycenter_debiased`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two + dimensions of matrix :math:`\mathbf{A}` + - `reg` is the regularization strength scalar value + + The algorithm used for solving the problem is the debiased Sinkhorn scaling + algorithm as proposed in :ref:`[37] <references-convolutional-barycenter2d-debiased>` + + Parameters + ---------- + A : array-like, shape (n_hists, width, height) + `n` distributions (2D images) of size `width` x `height` + reg : float + Regularization term >0 + weights : array-like, shape (n_hists,) + Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + Returns + ------- + a : array-like, shape (width, height) + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-convolutional-barycenter2d-debiased: + References + ---------- + + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ - t = np.linspace(0, 1, A.shape[2]) - [Y, X] = np.meshgrid(t, t) - xi2 = np.exp(-(X - Y) ** 2 / reg) + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d_debiased(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) - def K(x): - return np.dot(np.dot(xi1, x), xi2) - while (err > stopThr and cpt < numItermax): +def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, + stopThr=1e-3, stabThr=1e-15, verbose=False, + log=False, warn=True): + r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions. + """ - bold = b - cpt = cpt + 1 + A = list_to_array(A) + n_hists, width, height = A.shape - b = np.zeros_like(A[0, :, :]) - for r in range(A.shape[0]): - KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :]))) - b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :])) - b = np.exp(b) - for r in range(A.shape[0]): - U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :]) + nx = get_backend(A) - if cpt % 10 == 1: - err = np.sum(np.abs(bold - b)) + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == n_hists) + + if log: + log = {'err': []} + + bar = nx.ones((width, height), type_as=A) + bar /= width * height + U = nx.ones(A.shape, type_as=A) + V = nx.ones(A.shape, type_as=A) + c = nx.ones(A.shape[1:], type_as=A) + err = 1 + + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + K1 = nx.exp(-(X - Y) ** 2 / reg) + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KU = convol_imgs(U) + for ii in range(numItermax): + V = bar[None] / KU + KV = convol_imgs(V) + U = A / KV + KU = convol_imgs(U) + bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + + for _ in range(10): + c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5 + + if ii % 10 == 9: + err = (V * KU).std(axis=0).sum() # log and verbose print if log: log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt + log['niter'] = ii log['U'] = U - return b, log + return bar, log + else: + return bar + + +def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-3, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the debiased barycenter of 2D images in log-domain. + """ + + A = list_to_array(A) + n_hists, width, height = A.shape + nx = get_backend(A) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_bar, c = nx.zeros((2, width, height), type_as=A) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar = log_bar + weights[k] * log_KU[k] + log_bar += c + for _ in range(10): + c = 0.5 * (c + log_bar - convol_img(c)) + + if ii % 10 == 9: + err = nx.exp(G + log_KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr and ii > 20: + break + G = log_bar[None, :, :] - log_KU + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log else: - return b + return nx.exp(log_bar) def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, - stopThr=1e-3, verbose=False, log=False): + stopThr=1e-3, verbose=False, log=False, warn=True): r""" Compute the unmixing of an observation with a given dictionary using Wasserstein distance The function solve the following optimization problem: .. math:: - \mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h}) + + \mathbf{h} = \mathop{\arg \min}_\mathbf{h} \quad + (1 - \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a}, \mathbf{Dh}) + + \alpha W_{\mathbf{M_0}, \mathrm{reg}_0}(\mathbf{h}_0, \mathbf{h}) where : - - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn) - - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, + its expected shape is `(dim_a, n_atoms)` - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` - - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior` - - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting - - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization - - :math:`\\alpha`weight data fitting and regularization + - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the + cost matrix (`dim_a`, `dim_a`) for OT data fitting + - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization + term and the cost matrix (`dim_prior`, `n_atoms`) regularization + - :math:`\alpha` weight data fitting and regularization - The optimization problem is solved suing the algorithm described in [4] + The optimization problem is solved following the algorithm described + in :ref:`[4] <references-unmix>` Parameters ---------- - a : ndarray, shape (dim_a) + a : array-like, shape (dim_a) observed distribution (histogram, sums to 1) - D : ndarray, shape (dim_a, n_atoms) + D : array-like, shape (dim_a, n_atoms) dictionary matrix - M : ndarray, shape (dim_a, dim_a) + M : array-like, shape (dim_a, dim_a) loss matrix - M0 : ndarray, shape (n_atoms, dim_prior) + M0 : array-like, shape (n_atoms, dim_prior) loss matrix - h0 : ndarray, shape (n_atoms,) + h0 : array-like, shape (n_atoms,) prior on the estimated unmixing h reg : float Regularization term >0 (Wasserstein data fitting) @@ -1433,105 +2461,125 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True - + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - h : ndarray, shape (n_atoms,) + h : array-like, shape (n_atoms,) Wasserstein barycenter log : dict log dictionary return only if log==True in parameters + + .. _references-unmix: References ---------- - .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. - + .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, + Supervised planetary unmixing with optimal transport, Whorkshop + on Hyperspectral Image and Signal Processing : + Evolution in Remote Sensing (WHISPERS), 2016. """ + a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0) + + nx = get_backend(a, D, M, M0, h0) + # M = M/np.median(M) - K = np.exp(-M / reg) + K = nx.exp(-M / reg) # M0 = M0/np.median(M0) - K0 = np.exp(-M0 / reg0) + K0 = nx.exp(-M0 / reg0) old = h0 err = 1 - cpt = 0 # log = {'niter':0, 'all_err':[]} if log: log = {'err': []} - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): K = projC(K, a) K0 = projC(K0, h0) - new = np.sum(K0, axis=1) + new = nx.sum(K0, axis=1) # we recombine the current selection from dictionnary - inv_new = np.dot(D, new) - other = np.sum(K, axis=1) + inv_new = nx.dot(D, new) + other = nx.sum(K, axis=1) # geometric interpolation - delta = np.exp(alpha * np.log(other) + (1 - alpha) * np.log(inv_new)) + delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new)) K = projR(K, delta) - K0 = np.dot(np.diag(np.dot(D.T, delta / inv_new)), K0) + K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0) - err = np.linalg.norm(np.sum(K0, axis=1) - old) + err = nx.norm(nx.sum(K0, axis=1) - old) old = new if log: log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt = cpt + 1 - + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + else: + if warn: + warnings.warn("Unmixing algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt - return np.sum(K0, axis=1), log + log['niter'] = ii + return nx.sum(K0, axis=1), log else: - return np.sum(K0, axis=1) + return nx.sum(K0, axis=1) def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, - stopThr=1e-6, verbose=False, log=False, **kwargs): - r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27] + stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs): + r'''Joint OT and proportion estimation for multi-source target shift as + proposed in :ref:`[27] <references-jcpot-barycenter>` The function solves the following optimization problem: .. math:: - \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k + \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \quad \sum_{k=1}^{K} \lambda_k W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} where : - - :math:`\lambda_k` is the weight of k-th source domain - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) - - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes - - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C + - :math:`\lambda_k` is the weight of `k`-th source domain + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain + defined as in [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape + is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source + domain and `C` is the number of classes + - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C` - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` - - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in + [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape is :math:`(n_k, C)` - The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. + The problem consist in solving a Wasserstein barycenter problem to estimate + the proportions :math:`\mathbf{h}` in the target domain. The algorithm used for solving the problem is the Iterative Bregman projections algorithm - with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution. + with two sets of marginal constraints related to the unknown vector + :math:`\mathbf{h}` and uniform target distribution. Parameters ---------- - Xs : list of K np.ndarray(nsk,d) + Xs : list of K array-like(nsk,d) features of all source domains' samples - Ys : list of K np.ndarray(nsk,) + Ys : list of K array-like(nsk,) labels of all source domains' samples - Xt : np.ndarray (nt,d) + Xt : array-like (nt,d) samples in the target domain reg : float Regularization term > 0 @@ -1541,28 +2589,37 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Max number of iterations stopThr : float, optional Stop threshold on relative change in the barycenter (>0) - log : bool, optional - record log if True verbose : bool, optional (default=False) Controls the verbosity of the optimization algorithm + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - h : (C,) ndarray + h : (C,) array-like proportion estimation in the target domain log : dict log dictionary return only if log==True in parameters + .. _references-jcpot-barycenter: References ---------- .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia - "Optimal transport for multi-source domain adaptation under target shift", - International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. - + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. ''' - nbclasses = len(np.unique(Ys[0])) + + Xs = list_to_array(*Xs) + Ys = list_to_array(*Ys) + Xt = list_to_array(Xt) + + nx = get_backend(*Xs, *Ys, Xt) + + nbclasses = len(nx.unique(Ys[0])) nbdomains = len(Xs) # log dictionary @@ -1579,19 +2636,19 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, dom = {} nsk = Xs[d].shape[0] # get number of elements for this domain dom['nbelem'] = nsk - classes = np.unique(Ys[d]) # get number of classes for this domain + classes = nx.unique(Ys[d]) # get number of classes for this domain # format classes to start from 0 for convenience - if np.min(classes) != 0: - Ys[d] = Ys[d] - np.min(classes) - classes = np.unique(Ys[d]) + if nx.min(classes) != 0: + Ys[d] -= nx.min(classes) + classes = nx.unique(Ys[d]) # build the corresponding D_1 and D_2 matrices - Dtmp1 = np.zeros((nbclasses, nsk)) - Dtmp2 = np.zeros((nbclasses, nsk)) + Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) + Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) for c in classes: - nbelemperclass = np.sum(Ys[d] == c) + nbelemperclass = nx.sum(Ys[d] == c) if nbelemperclass != 0: Dtmp1[int(c), Ys[d] == c] = 1. Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass) @@ -1602,51 +2659,54 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Mtmp = dist(Xs[d], Xt, metric=metric) M.append(Mtmp) - Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype) - np.divide(Mtmp, -reg, out=Ktmp) - np.exp(Ktmp, out=Ktmp) + Ktmp = nx.exp(-Mtmp / reg) K.append(Ktmp) # uniform target distribution - a = unif(np.shape(Xt)[0]) + a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0]) - cpt = 0 # iterations count err = 1 - old_bary = np.ones((nbclasses)) + old_bary = nx.ones((nbclasses,), type_as=Xs[0]) - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): - bary = np.zeros((nbclasses)) + bary = nx.zeros((nbclasses,), type_as=Xs[0]) # update coupling matrices for marginal constraints w.r.t. uniform target distribution for d in range(nbdomains): K[d] = projC(K[d], a) - other = np.sum(K[d], axis=1) - bary = bary + np.log(np.dot(D1[d], other)) / nbdomains + other = nx.sum(K[d], axis=1) + bary += nx.log(nx.dot(D1[d], other)) / nbdomains - bary = np.exp(bary) + bary = nx.exp(bary) # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] for d in range(nbdomains): - new = np.dot(D2[d].T, bary) + new = nx.dot(D2[d].T, bary) K[d] = projR(K[d], new) - err = np.linalg.norm(bary - old_bary) - cpt = cpt + 1 + err = nx.norm(bary - old_bary) + old_bary = bary if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - bary = bary / np.sum(bary) + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + bary = bary / nx.sum(bary) if log: - log['niter'] = cpt + log['niter'] = ii log['M'] = M log['D1'] = D1 log['D2'] = D2 @@ -1657,8 +2717,8 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, verbose=False, - log=False, **kwargs): + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, + log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the OT matrix from empirical data @@ -1666,45 +2726,56 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`a` and :math:`b` are source and target weights (sum to 1) + - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (n_samples_a, dim) + X_s : array-like, shape (n_samples_a, dim) samples in the source domain - X_t : ndarray, shape (n_samples_b, dim) + X_t : array-like, shape (n_samples_b, dim) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (n_samples_a,) + a : array-like, shape (n_samples_a,) samples weights in the source domain - b : ndarray, shape (n_samples_b,) + b : array-like, shape (n_samples_b,) samples weights in the target domain numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) + isLazy: boolean, optional + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate full + cost matrix and return outputs of sinkhorn function. + batchSize: int or tuple of 2 int, optional + Size of the batches used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - gamma : ndarray, shape (n_samples_a, n_samples_b) + gamma : array-like, shape (n_samples_a, n_samples_b) Regularized optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1715,9 +2786,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', >>> n_samples_a = 2 >>> n_samples_b = 2 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) - >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) + >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE array([[4.99977301e-01, 2.26989344e-05], [2.26989344e-05, 4.99977301e-01]]) @@ -1725,30 +2796,115 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' + X_s, X_t = list_to_array(X_s, X_t) + + nx = get_backend(X_s, X_t) + + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(np.shape(X_s)[0]) + a = nx.from_numpy(unif(ns), type_as=X_s) if b is None: - b = unif(np.shape(X_t)[0]) + b = nx.from_numpy(unif(nt), type_as=X_s) + + if isLazy: + if log: + dict_log = {"err": []} - M = dist(X_s, X_t, metric=metric) + log_a, log_b = nx.log(a), nx.log(b) + f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) + + if isinstance(batchSize, int): + bs, bt = batchSize, batchSize + elif isinstance(batchSize, tuple) and len(batchSize) == 2: + bs, bt = batchSize[0], batchSize[1] + else: + raise ValueError("Batch size must be in integer or a tuple of two integers") + + range_s, range_t = range(0, ns, bs), range(0, nt, bt) + + lse_f = nx.zeros((ns,), type_as=a) + lse_g = nx.zeros((nt,), type_as=a) + + X_s_np = nx.to_numpy(X_s) + X_t_np = nx.to_numpy(X_t) + + for i_ot in range(numIterMax): + + lse_f_cols = [] + for i in range_s: + M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M = nx.from_numpy(M, type_as=a) + lse_f_cols.append( + nx.logsumexp(g[None, :] - M / reg, axis=1) + ) + lse_f = nx.concatenate(lse_f_cols, axis=0) + f = log_a - lse_f + + lse_g_cols = [] + for j in range_t: + M = dist(X_s_np, X_t_np[j:j + bt, :], metric=metric) + M = nx.from_numpy(M, type_as=a) + lse_g_cols.append( + nx.logsumexp(f[:, None] - M / reg, axis=0) + ) + lse_g = nx.concatenate(lse_g_cols, axis=0) + g = log_b - lse_g + + if (i_ot + 1) % 10 == 0: + m1_cols = [] + for i in range_s: + M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M = nx.from_numpy(M, type_as=a) + m1_cols.append( + nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1) + ) + m1 = nx.concatenate(m1_cols, axis=0) + err = nx.sum(nx.abs(m1 - a)) + if log: + dict_log["err"].append(err) + + if verbose and (i_ot + 1) % 100 == 0: + print("Error in marginal at iteration {} = {}".format(i_ot + 1, err)) + + if err <= stopThr: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + dict_log["u"] = f + dict_log["v"] = g + return (f, g, dict_log) + else: + return (f, g) - if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) - return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) - return pi + M = dist(X_s, X_t, metric=metric) + if log: + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=True, **kwargs) + return pi, log + else: + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=False, **kwargs) + return pi -def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, **kwargs): +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, isLazy=False, + batchSize=100, verbose=False, log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -1757,46 +2913,57 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`a` and :math:`b` are source and target weights (sum to 1) + - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (n_samples_a, dim) + X_s : array-like, shape (n_samples_a, dim) samples in the source domain - X_t : ndarray, shape (n_samples_b, dim) + X_t : array-like, shape (n_samples_b, dim) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (n_samples_a,) + a : array-like, shape (n_samples_a,) samples weights in the source domain - b : ndarray, shape (n_samples_b,) + b : array-like, shape (n_samples_b,) samples weights in the target domain numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) + isLazy: boolean, optional + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate + full cost matrix and return outputs of sinkhorn function. + batchSize: int or tuple of 2 int, optional + Size of the batches used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - gamma : ndarray, shape (n_samples_a, n_samples_b) - Regularized optimal transportation matrix for the given parameters + W : (n_hists) array-like or float + Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1806,41 +2973,94 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num >>> n_samples_a = 2 >>> n_samples_b = 2 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) - >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False) - array([4.53978687e-05]) + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) + >>> b = np.full((n_samples_b, 3), 1/n_samples_b) + >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False) + array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05]) References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling + Algorithms for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. ''' + X_s, X_t = list_to_array(X_s, X_t) + + nx = get_backend(X_s, X_t) + + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(np.shape(X_s)[0]) + a = nx.from_numpy(unif(ns), type_as=X_s) if b is None: - b = unif(np.shape(X_t)[0]) + b = nx.from_numpy(unif(nt), type_as=X_s) - M = dist(X_s, X_t, metric=metric) + if isLazy: + if log: + f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, + stopThr=stopThr, + isLazy=isLazy, + batchSize=batchSize, + verbose=verbose, log=log, + warn=warn) + else: + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, + verbose=verbose, log=log, + warn=warn) + + bs = batchSize if isinstance(batchSize, int) else batchSize[0] + range_s = range(0, ns, bs) + + loss = 0 + + X_s_np = nx.to_numpy(X_s) + X_t_np = nx.to_numpy(X_t) + + for i in range_s: + M_block = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M_block = nx.from_numpy(M_block, type_as=a) + pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) + loss += nx.sum(M_block * pi_block) + + if log: + return loss, dict_log + else: + return loss - if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss + M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric) + M = nx.from_numpy(M, type_as=a) + + if log: + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) + return sinkhorn_loss, log + else: + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) + return sinkhorn_loss -def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, **kwargs): +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, warn=True, + **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -1849,64 +3069,72 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli .. math:: - W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + W &= \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a) + W_a &= \min_{\gamma_a} \quad \langle \gamma_a, \mathbf{M_a} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma_a) - W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b) + W_b &= \min_{\gamma_b} \quad \langle \gamma_b, \mathbf{M_b} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma_b) - S &= W - 1/2 * (W_a + W_b) + S &= W - \frac{W_a + W_b}{2} .. math:: - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 - \gamma_a 1 = a + \gamma_a \mathbf{1} &= \mathbf{a} - \gamma_a^T 1= a + \gamma_a^T \mathbf{1} &= \mathbf{a} - \gamma_a\geq 0 + \gamma_a &\geq 0 - \gamma_b 1 = b + \gamma_b \mathbf{1} &= \mathbf{b} - \gamma_b^T 1= b + \gamma_b^T \mathbf{1} &= \mathbf{b} - \gamma_b\geq 0 + \gamma_b &\geq 0 where : - - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b)) - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`a` and :math:`b` are source and target weights (sum to 1) + - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) + is the (`n_samples_a`, `n_samples_b`) metric cost matrix + (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- - X_s : ndarray, shape (n_samples_a, dim) + X_s : array-like, shape (n_samples_a, dim) samples in the source domain - X_t : ndarray, shape (n_samples_b, dim) + X_t : array-like, shape (n_samples_b, dim) samples in the target domain reg : float Regularization term >0 - a : ndarray, shape (n_samples_a,) + a : array-like, shape (n_samples_a,) samples weights in the source domain - b : ndarray, shape (n_samples_b,) + b : array-like, shape (n_samples_b,) samples weights in the target domain numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- - gamma : ndarray, shape (n_samples_a, n_samples_b) - Regularized optimal transportation matrix for the given parameters + W : (1,) array-like + Optimal transportation symmetrized loss for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1915,27 +3143,36 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli >>> n_samples_a = 2 >>> n_samples_b = 4 >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1)) + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS - array([1.499...]) + 1.499887176049052 References ---------- - .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 + .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative + Models with Sinkhorn Divergences, Proceedings of the Twenty-First + International Conference on Artficial Intelligence and Statistics, + (AISTATS) 21, 2018 ''' if log: - sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) - sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) + sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) log = {} log['sinkhorn_loss_ab'] = sinkhorn_loss_ab @@ -1948,99 +3185,119 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli return max(0, sinkhorn_div), log else: - sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) + + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) + + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) + + sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) + return max(0, sinkhorn_div) - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) +def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, + restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, + verbose=False, log=False): + r""" + Screening Sinkhorn Algorithm for Regularized Optimal Transport - sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) - return max(0, sinkhorn_div) + The function solves an approximate dual of Sinkhorn divergence :ref:`[2] + <references-screenkhorn>` which is written as the following optimization problem: + + .. math:: + (\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \quad + \mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} - + \langle \kappa \mathbf{u}, \mathbf{a} \rangle - + \langle \frac{1}{\kappa} \mathbf{v}, \mathbf{b} \rangle -def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, - maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): - r"""" - Screening Sinkhorn Algorithm for Regularized Optimal Transport + where: - The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem: + .. math:: - ..math:: - (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b> + \mathbf{B}(\mathbf{u}, \mathbf{v}) = \mathrm{diag}(e^\mathbf{u}) \mathbf{K} \mathrm{diag}(e^\mathbf{v}) \text{, with } \mathbf{K} = e^{-\mathbf{M} / \mathrm{reg}} \text{ and} - where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and + .. math:: - s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns} + s.t. \ e^{u_i} &\geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\} - e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt} + e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} - The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26] + The parameters `kappa` and `epsilon` are determined w.r.t the couple number + budget of points (`ns_budget`, `nt_budget`), see Equation (5) + in :ref:`[26] <references-screenkhorn>` Parameters ---------- - a : `numpy.ndarray`, shape=(ns,) + a: array-like, shape=(ns,) samples weights in the source domain - - b : `numpy.ndarray`, shape=(nt,) + b: array-like, shape=(nt,) samples weights in the target domain - - M : `numpy.ndarray`, shape=(ns, nt) + M: array-like, shape=(ns, nt) Cost matrix - - reg : `float` + reg: `float` Level of the entropy regularisation - - ns_budget : `int`, deafult=None - Number budget of points to be keeped in the source domain - If it is None then 50% of the source sample points will be keeped - - nt_budget : `int`, deafult=None - Number budget of points to be keeped in the target domain - If it is None then 50% of the target sample points will be keeped - - uniform : `bool`, default=False - If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt - + ns_budget: `int`, default=None + Number budget of points to be kept in the source domain. + If it is None then 50% of the source sample points will be kept + nt_budget: `int`, default=None + Number budget of points to be kept in the target domain. + If it is None then 50% of the target sample points will be kept + uniform: `bool`, default=False + If `True`, the source and target distribution are supposed to be uniform, + i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` restricted : `bool`, default=True If `True`, a warm-start initialization for the L-BFGS-B solver using a restricted Sinkhorn algorithm with at most 5 iterations - - maxiter : `int`, default=10000 + maxiter: `int`, default=10000 Maximum number of iterations in LBFGS solver + maxfun: `int`, default=10000 + Maximum number of function evaluations in LBFGS solver + pgtol: `float`, default=1e-09 + Final objective function accuracy in LBFGS solver + verbose: `bool`, default=False + If `True`, display informations about the cardinals of the active sets + and the parameters kappa and epsilon - maxfun : `int`, default=10000 - Maximum number of function evaluations in LBFGS solver - pgtol : `float`, default=1e-09 - Final objective function accuracy in LBFGS solver + .. admonition:: Dependency - verbose : `bool`, default=False - If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa - and epsilon + To gain more efficiency, :py:func:`ot.bregman.screenkhorn` needs to call the "Bottleneck" + package (https://pypi.org/project/Bottleneck/) in the screening pre-processing step. - Dependency - ---------- - To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/) - in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears: - "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" + If Bottleneck isn't installed, the following error message appears: + + "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" Returns ------- - gamma : `numpy.ndarray`, shape=(ns, nt) + gamma : array-like, shape=(ns, nt) Screened optimal transportation matrix for the given parameters log : `dict`, default=False Log dictionary return only if log==True in parameters + .. _references-screenkhorn: References ----------- - .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). + Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 """ # check if bottleneck module exists @@ -2048,12 +3305,17 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res import bottleneck except ImportError: warnings.warn( - "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") + "Bottleneck module is not installed. Install it from" + " https://pypi.org/project/Bottleneck/ for better performance.") bottleneck = np - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + if nx.__name__ == "jax": + raise TypeError("JAX arrays have been received but screenkhorn is not " + "compatible with JAX.") + ns, nt = M.shape # by default, we keep only 50% of the sample data points @@ -2063,9 +3325,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res nt_budget = int(np.floor(0.5 * nt)) # calculate the Gibbs kernel - K = np.empty_like(M) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) def projection(u, epsilon): u[u <= epsilon] = epsilon @@ -2077,8 +3337,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res if ns_budget == ns and nt_budget == nt: # full number of budget points (ns, nt) = (ns_budget, nt_budget) - Isel = np.ones(ns, dtype=bool) - Jsel = np.ones(nt, dtype=bool) + Isel = nx.from_numpy(np.ones(ns, dtype=bool)) + Jsel = nx.from_numpy(np.ones(nt, dtype=bool)) epsilon = 0.0 kappa = 1.0 @@ -2094,57 +3354,63 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res K_IJc = [] K_IcJ = [] - vec_eps_IJc = np.zeros(nt) - vec_eps_IcJ = np.zeros(ns) + vec_eps_IJc = nx.zeros((nt,), type_as=M) + vec_eps_IcJ = nx.zeros((ns,), type_as=M) else: # sum of rows and columns of K - K_sum_cols = K.sum(axis=1) - K_sum_rows = K.sum(axis=0) + K_sum_cols = nx.sum(K, axis=1) + K_sum_rows = nx.sum(K, axis=0) if uniform: if ns / ns_budget < 4: - aK_sort = np.sort(K_sum_cols) + aK_sort = nx.sort(K_sum_cols) epsilon_u_square = a[0] / aK_sort[ns_budget - 1] else: - aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1] + aK_sort = nx.from_numpy( + bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1], + type_as=M + ) epsilon_u_square = a[0] / aK_sort if nt / nt_budget < 4: - bK_sort = np.sort(K_sum_rows) + bK_sort = nx.sort(K_sum_rows) epsilon_v_square = b[0] / bK_sort[nt_budget - 1] else: - bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1] + bK_sort = nx.from_numpy( + bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1], + type_as=M + ) epsilon_v_square = b[0] / bK_sort else: aK = a / K_sum_cols bK = b / K_sum_rows - aK_sort = np.sort(aK)[::-1] + aK_sort = nx.flip(nx.sort(aK), axis=0) epsilon_u_square = aK_sort[ns_budget - 1] - bK_sort = np.sort(bK)[::-1] + bK_sort = nx.flip(nx.sort(bK), axis=0) epsilon_v_square = bK_sort[nt_budget - 1] # active sets I and J (see Lemma 1 in [26]) Isel = a >= epsilon_u_square * K_sum_cols Jsel = b >= epsilon_v_square * K_sum_rows - if sum(Isel) != ns_budget: + if nx.sum(Isel) != ns_budget: if uniform: aK = a / K_sum_cols - aK_sort = np.sort(aK)[::-1] - epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean() + aK_sort = nx.flip(nx.sort(aK), axis=0) + epsilon_u_square = nx.mean(aK_sort[ns_budget - 1:ns_budget + 1]) Isel = a >= epsilon_u_square * K_sum_cols - ns_budget = sum(Isel) + ns_budget = nx.sum(Isel) - if sum(Jsel) != nt_budget: + if nx.sum(Jsel) != nt_budget: if uniform: bK = b / K_sum_rows - bK_sort = np.sort(bK)[::-1] - epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean() + bK_sort = nx.flip(nx.sort(bK), axis=0) + epsilon_v_square = nx.mean(bK_sort[nt_budget - 1:nt_budget + 1]) Jsel = b >= epsilon_v_square * K_sum_rows - nt_budget = sum(Jsel) + nt_budget = nx.sum(Jsel) epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4) kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2) @@ -2152,7 +3418,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res if verbose: print("epsilon = %s\n" % epsilon) print("kappa = %s\n" % kappa) - print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel))) + print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' + % (sum(Isel), sum(Jsel))) # Ic, Jc: complementary of the active sets I and J Ic = ~Isel @@ -2162,18 +3429,18 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res K_IcJ = K[np.ix_(Ic, Jsel)] K_IJc = K[np.ix_(Isel, Jc)] - K_min = K_IJ.min() + K_min = nx.min(K_IJ) if K_min == 0: - K_min = np.finfo(float).tiny + K_min = float(np.finfo(float).tiny) # a_I, b_J, a_Ic, b_Jc a_I = a[Isel] b_J = b[Jsel] if not uniform: - a_I_min = a_I.min() - a_I_max = a_I.max() - b_J_max = b_J.max() - b_J_min = b_J.min() + a_I_min = nx.min(a_I) + a_I_max = nx.max(a_I) + b_J_max = nx.max(b_J) + b_J_min = nx.min(b_J) else: a_I_min = a_I[0] a_I_max = a_I[0] @@ -2182,33 +3449,37 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res # box constraints in L-BFGS-B (see Proposition 1 in [26]) bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( - ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget bounds_v = [( - max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), - epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget # pre-calculated constants for the objective - vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) - vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0) + vec_eps_IJc = epsilon * kappa * nx.sum( + K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :], + axis=1 + ) + vec_eps_IcJ = (epsilon / kappa) * nx.sum( + nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ, + axis=0 + ) # initialisation - u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa) - v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa) + u0 = nx.full((ns_budget,), 1. / ns_budget + epsilon / kappa, type_as=M) + v0 = nx.full((nt_budget,), 1. / nt_budget + epsilon * kappa, type_as=M) # pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26]) if restricted: if ns_budget != ns or nt_budget != nt: - cst_u = kappa * epsilon * K_IJc.sum(axis=1) - cst_v = epsilon * K_IcJ.sum(axis=0) / kappa + cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1) + cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa - cpt = 1 - while cpt < 5: # 5 iterations - K_IJ_v = np.dot(K_IJ.T, u0) + cst_v + for _ in range(5): # 5 iterations + K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v v0 = b_J / (kappa * K_IJ_v) - KIJ_u = np.dot(K_IJ, v0) + cst_u + KIJ_u = nx.dot(K_IJ, v0) + cst_u u0 = (kappa * a_I) / KIJ_u - cpt += 1 u0 = projection(u0, epsilon / kappa) v0 = projection(v0, epsilon * kappa) @@ -2219,15 +3490,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res def restricted_sinkhorn(usc, vsc, max_iter=5): """ - Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26]) + Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B) """ - cpt = 1 - while cpt < max_iter: - K_IJ_v = np.dot(K_IJ.T, usc) + cst_v + for _ in range(max_iter): + K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v vsc = b_J / (kappa * K_IJ_v) - KIJ_u = np.dot(K_IJ, vsc) + cst_u + KIJ_u = nx.dot(K_IJ, vsc) + cst_u usc = (kappa * a_I) / KIJ_u - cpt += 1 usc = projection(usc, epsilon / kappa) vsc = projection(vsc, epsilon * kappa) @@ -2235,17 +3504,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res return usc, vsc def screened_obj(usc, vsc): - part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, - np.log(vsc)) - part_IJc = np.dot(usc, vec_eps_IJc) - part_IcJ = np.dot(vec_eps_IcJ, vsc) + part_IJ = ( + nx.dot(nx.dot(usc, K_IJ), vsc) + - kappa * nx.dot(a_I, nx.log(usc)) + - (1. / kappa) * nx.dot(b_J, nx.log(vsc)) + ) + part_IJc = nx.dot(usc, vec_eps_IJc) + part_IcJ = nx.dot(vec_eps_IcJ, vsc) psi_epsilon = part_IJ + part_IJc + part_IcJ return psi_epsilon def screened_grad(usc, vsc): # gradients of Psi_(kappa,epsilon) w.r.t u and v - grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc - grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc + grad_u = nx.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc + grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc return grad_u, grad_v def bfgspost(theta): @@ -2255,20 +3527,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res f = screened_obj(u, v) # gradient g_u, g_v = screened_grad(u, v) - g = np.hstack([g_u, g_v]) - return f, g + g = nx.concatenate([g_u, g_v], axis=0) + return nx.to_numpy(f), nx.to_numpy(g) # ----------------------------------------------------------------------------------------------------------------# # Step 2: L-BFGS-B solver # # ----------------------------------------------------------------------------------------------------------------# u0, v0 = restricted_sinkhorn(u0, v0) - theta0 = np.hstack([u0, v0]) + theta0 = nx.concatenate([u0, v0], axis=0) bounds = bounds_u + bounds_v # constraint bounds def obj(theta): - return bfgspost(theta) + return bfgspost(nx.from_numpy(theta, type_as=M)) theta, _, _ = fmin_l_bfgs_b(func=obj, x0=theta0, @@ -2276,12 +3548,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res maxfun=maxfun, pgtol=pgtol, maxiter=maxiter) + theta = nx.from_numpy(theta, type_as=M) usc = theta[:ns_budget] vsc = theta[ns_budget:] - usc_full = np.full(ns, epsilon / kappa) - vsc_full = np.full(nt, epsilon * kappa) + usc_full = nx.full((ns,), epsilon / kappa, type_as=M) + vsc_full = nx.full((nt,), epsilon * kappa, type_as=M) usc_full[Isel] = usc vsc_full[Jsel] = vsc @@ -2293,7 +3566,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res log['Jsel'] = Jsel gamma = usc_full[:, None] * K * vsc_full[None, :] - gamma = gamma / gamma.sum() + gamma = gamma / nx.sum(gamma) if log: return gamma, log @@ -26,34 +26,36 @@ from .optim import gcg def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False): - """ + r""" Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma) - + \eta \Omega_g(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot \Omega_e(\gamma) + \eta \ \Omega_g(\gamma) + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} - s.t. \gamma 1 = a + \gamma \geq 0 - \gamma^T 1= b - \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e (\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\Omega_g` is the group lasso regularization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` - where :math:`\mathcal{I}_c` are the index of samples from class c + where :math:`\mathcal{I}_c` are the index of samples from class `c` in the source domain. - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is the generalized conditional - gradient as proposed in [5]_ [7]_ + gradient as proposed in :ref:`[5, 7] <references-sinkhorn-lpl1-mm>`. Parameters @@ -84,19 +86,20 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters + .. _references-sinkhorn-lpl1-mm: References ---------- - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. @@ -137,34 +140,36 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False): - """ + r""" Solve the entropic regularization optimal transport problem with group lasso regularization The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ - \eta \Omega_g(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot \Omega_e(\gamma) + \eta \ \Omega_g(\gamma) + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} - s.t. \gamma 1 = a + \gamma \geq 0 - \gamma^T 1= b - \gamma\geq 0 where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2` where :math:`\mathcal{I}_c` are the index of samples from class - c in the source domain. - - a and b are source and target weights (sum to 1) + `c` in the source domain. + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is the generalised conditional - gradient as proposed in [5]_ [7]_ + gradient as proposed in :ref:`[5, 7] <references-sinkhorn-l1l2-gl>`. Parameters @@ -195,18 +200,19 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters + .. _references-sinkhorn-l1l2-gl: References ---------- - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. @@ -245,38 +251,40 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, verbose2=False, numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs): - """Joint OT and linear mapping estimation as proposed in [8] + r"""Joint OT and linear mapping estimation as proposed in + :ref:`[8] <references-joint-OT-mapping-linear>`. The function solves the following optimization problem: .. math:: - \min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F + - \mu<\gamma,M>_F + \eta \|L -I\|^2_F + \min_{\gamma,L}\quad \|L(\mathbf{X_s}) - n_s\gamma \mathbf{X_t} \|^2_F + + \mu \langle \gamma, \mathbf{M} \rangle_F + \eta \|L - \mathbf{I}\|^2_F - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 - \gamma\geq 0 where : - - M is the (ns,nt) squared euclidean cost matrix between samples in - Xs and Xt (scaled by ns) - - :math:`L` is a dxd linear operator that approximates the barycentric + - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in + :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`) + - :math:`L` is a :math:`d\times d` linear operator that approximates the barycentric mapping - - :math:`I` is the identity matrix (neutral linear mapping) - - a and b are uniform source and target weights + - :math:`\mathbf{I}` is the identity matrix (neutral linear mapping) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights The problem consist in solving jointly an optimal transport matrix :math:`\gamma` and a linear mapping that fits the barycentric mapping - :math:`n_s\gamma X_t`. + :math:`n_s\gamma \mathbf{X_t}`. One can also estimate a mapping with constant bias (see supplementary - material of [8]) using the bias optional argument. + material of :ref:`[8] <references-joint-OT-mapping-linear>`) using the bias optional argument. The algorithm used for solving the problem is the block coordinate - descent that alternates between updates of G (using conditionnal gradient) - and the update of L using a classical least square solver. + descent that alternates between updates of :math:`\mathbf{G}` (using conditionnal gradient) + and the update of :math:`\mathbf{L}` using a classical least square solver. Parameters @@ -307,17 +315,17 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters - L : (d x d) ndarray - Linear mapping matrix (d+1 x d if bias) + L : (d, d) ndarray + Linear mapping matrix ((:math:`d+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters + .. _references-joint-OT-mapping-linear: References ---------- - .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. @@ -434,37 +442,41 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs): - """Joint OT and nonlinear mapping estimation with kernels as proposed in [8] + r"""Joint OT and nonlinear mapping estimation with kernels as proposed in + :ref:`[8] <references-joint-OT-mapping-kernel>`. The function solves the following optimization problem: .. math:: - \min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) - - n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H} + \min_{\gamma, L\in\mathcal{H}}\quad \|L(\mathbf{X_s}) - + n_s\gamma \mathbf{X_t}\|^2_F + \mu \langle \gamma, \mathbf{M} \rangle_F + + \eta \|L\|^2_\mathcal{H} + + s.t. \ \gamma \mathbf{1} = \mathbf{a} - s.t. \gamma 1 = a + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 - \gamma^T 1= b - \gamma\geq 0 where : - - M is the (ns,nt) squared euclidean cost matrix between samples in - Xs and Xt (scaled by ns) - - :math:`L` is a ns x d linear operator on a kernel matrix that + - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in + :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`) + - :math:`L` is a :math:`n_s \times d` linear operator on a kernel matrix that approximates the barycentric mapping - - a and b are uniform source and target weights + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights The problem consist in solving jointly an optimal transport matrix :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping - :math:`n_s\gamma X_t`. + :math:`n_s\gamma \mathbf{X_t}`. One can also estimate a mapping with constant bias (see supplementary - material of [8]) using the bias optional argument. + material of :ref:`[8] <references-joint-OT-mapping-kernel>`) using the bias optional argument. The algorithm used for solving the problem is the block coordinate - descent that alternates between updates of G (using conditionnal gradient) - and the update of L using a classical kernel least square solver. + descent that alternates between updates of :math:`\mathbf{G}` (using conditionnal gradient) + and the update of :math:`\mathbf{L}` using a classical kernel least square solver. Parameters @@ -478,7 +490,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', eta : float, optional Regularization term for the linear mapping L (>0) kerneltype : str,optional - kernel used by calling function ot.utils.kernel (gaussian by default) + kernel used by calling function :py:func:`ot.utils.kernel` (gaussian by default) sigma : float, optional Gaussian kernel bandwidth. bias : bool,optional @@ -501,17 +513,17 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters - L : (ns x d) ndarray - Nonlinear mapping matrix (ns+1 x d if bias) + L : (ns, d) ndarray + Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters + .. _references-joint-OT-mapping-kernel: References ---------- - .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. @@ -645,26 +657,27 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False): - """ return OT linear operator between samples + r"""Return OT linear operator between samples. The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)` - and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark - 2.29 in [15]. + form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` + and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in + :ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in + :ref:`[15] <references-OT-mapping-linear>`. The linear operator from source to target :math:`M` .. math:: - M(x)=Ax+b + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} where : .. math:: - A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} \Sigma_s^{-1/2} - .. math:: - b=\mu_t-A\mu_s + + \mathbf{b} &= \mu_t - \mathbf{A} \mu_s Parameters ---------- @@ -673,35 +686,35 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, xt : np.ndarray (nt,d) samples in the target domain reg : float,optional - regularization added to the diagonals of convariances (>0) + regularization added to the diagonals of covariances (>0) ws : np.ndarray (ns,1), optional weights for the source samples wt : np.ndarray (ns,1), optional weights for the target samples bias: boolean, optional - estimate bias b else b=0 (default:True) + estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) log : bool, optional record log if True Returns ------- - A : (d x d) ndarray + A : (d, d) ndarray Linear operator - b : (1 x d) ndarray + b : (1, d) ndarray bias log : dict log dictionary return only if log==True in parameters + .. _references-OT-mapping-linear: References ---------- - .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of distributions", Journal of Optimization Theory and Applications Vol 43, 1984 - .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. @@ -754,24 +767,34 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al r"""Solve the optimal transport problem (OT) with Laplacian regularization .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + eta\Omega_\alpha(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \eta \cdot \Omega_\alpha(\gamma) - s.t.\ \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} - \gamma\geq 0 + \gamma \geq 0 where: - - a and b are source and target weights (sum to 1) - - xs and xt are source and target samples - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + - :math:`\mathbf{x_s}` and :math:`\mathbf{x_t}` are source and target samples + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega_\alpha` is the Laplacian regularization term - :math:`\Omega_\alpha = (1-\alpha)/n_s^2\sum_{i,j}S^s_{i,j}\|T(\mathbf{x}^s_i)-T(\mathbf{x}^s_j)\|^2+\alpha/n_t^2\sum_{i,j}S^t_{i,j}^'\|T(\mathbf{x}^t_i)-T(\mathbf{x}^t_j)\|^2` - with :math:`S^s_{i,j}, S^t_{i,j}` denoting source and target similarity matrices and :math:`T(\cdot)` being a barycentric mapping - The algorithm used for solving the problem is the conditional gradient algorithm as proposed in [5]. + .. math:: + \Omega_\alpha = \frac{1 - \alpha}{n_s^2} \sum_{i,j} + \mathbf{S^s}_{i,j} \|T(\mathbf{x}^s_i) - T(\mathbf{x}^s_j) \|^2 + + \frac{\alpha}{n_t^2} \sum_{i,j} + \mathbf{S^t}_{i,j} \|T(\mathbf{x}^t_i) - T(\mathbf{x}^t_j) \|^2 + + + with :math:`\mathbf{S^s}_{i,j}, \mathbf{S^t}_{i,j}` denoting source and target similarity + matrices and :math:`T(\cdot)` being a barycentric mapping. + + The algorithm used for solving the problem is the conditional gradient algorithm as proposed in + :ref:`[5] <references-emd-laplace>`. Parameters ---------- @@ -811,22 +834,23 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters + .. _references-emd-laplace: References ---------- - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE - Transactions on Pattern Analysis and Machine Intelligence , + Transactions on Pattern Analysis and Machine Intelligence, vol.PP, no.99, pp.1-1 + .. [30] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy, "Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching," - in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. See Also -------- @@ -882,7 +906,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al def distribution_estimation_uniform(X): - """estimates a uniform distribution from an array of samples X + """estimates a uniform distribution from an array of samples :math:`\mathbf{X}` Parameters ---------- @@ -892,7 +916,7 @@ def distribution_estimation_uniform(X): Returns ------- mu : array-like, shape (n_samples,) - The uniform distribution estimated from X + The uniform distribution estimated from :math:`\mathbf{X}` """ return unif(X.shape[0]) @@ -902,32 +926,32 @@ class BaseTransport(BaseEstimator): """Base class for OTDA objects - Notes - ----- - All estimators should specify all the parameters that can be set - at the class level in their ``__init__`` as explicit keyword - arguments (no ``*args`` or ``**kwargs``). + .. note:: + All estimators should specify all the parameters that can be set + at the class level in their ``__init__`` as explicit keyword + arguments (no ``*args`` or ``**kwargs``). - the fit method should: + The fit method should: - estimate a cost matrix and store it in a `cost_` attribute - - estimate a coupling matrix and store it in a `coupling_` - attribute + - estimate a coupling matrix and store it in a `coupling_` attribute - estimate distributions from source and target data and store them in - mu_s and mu_t attributes - - store Xs and Xt in attributes to be used later on in transform and - inverse_transform methods + `mu_s` and `mu_t` attributes + - store `Xs` and `Xt` in attributes to be used later on in `transform` and + `inverse_transform` methods + + `transform` method should always get as input a `Xs` parameter + + `inverse_transform` method should always get as input a `Xt` parameter - transform method should always get as input a Xs parameter - inverse_transform method should always get as input a Xt parameter + `transform_labels` method should always get as input a `ys` parameter - transform_labels method should always get as input a ys parameter - inverse_transform_labels method should always get as input a yt parameter + `inverse_transform_labels` method should always get as input a `yt` parameter """ def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -938,8 +962,8 @@ class BaseTransport(BaseEstimator): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -987,8 +1011,8 @@ class BaseTransport(BaseEstimator): def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) and transports source samples Xs onto target - ones Xt + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` + and transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -999,8 +1023,8 @@ class BaseTransport(BaseEstimator): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1014,7 +1038,7 @@ class BaseTransport(BaseEstimator): return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt) def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples Xs onto target ones Xt + """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -1025,8 +1049,8 @@ class BaseTransport(BaseEstimator): Xt : array-like, shape (n_target_samples, n_features) The target input samples. yt : array-like, shape (n_target_samples,) - The class labels for target. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels for target. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1081,7 +1105,8 @@ class BaseTransport(BaseEstimator): return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels ys to obtain estimated target labels as in [27] + """Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in + :ref:`[27] <references-basetransport-transform-labels>`. Parameters ---------- @@ -1093,9 +1118,10 @@ class BaseTransport(BaseEstimator): transp_ys : array-like, shape (n_target_samples, nb_classes) Estimated soft target labels. + + .. _references-basetransport-transform-labels: References ---------- - .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia "Optimal transport for multi-source domain adaptation under target shift", International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. @@ -1111,7 +1137,7 @@ class BaseTransport(BaseEstimator): D1 = np.zeros((n, len(ysTemp))) # perform label propagation - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True) # set nans to 0 transp[~ np.isfinite(transp)] = 0 @@ -1126,7 +1152,7 @@ class BaseTransport(BaseEstimator): def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports target samples Xt onto source samples Xs + """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` Parameters ---------- @@ -1137,8 +1163,8 @@ class BaseTransport(BaseEstimator): Xt : array-like, shape (n_target_samples, n_features) The target input samples. yt : array-like, shape (n_target_samples,) - The target class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The target class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1192,7 +1218,8 @@ class BaseTransport(BaseEstimator): return transp_Xt def inverse_transform_labels(self, yt=None): - """Propagate target labels yt to obtain estimated source labels ys + """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels + :math:`\mathbf{y_s}` Parameters ---------- @@ -1228,39 +1255,41 @@ class BaseTransport(BaseEstimator): class LinearTransport(BaseTransport): - """ OT linear operator between empirical distributions + r""" OT linear operator between empirical distributions The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)` - and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in - remark 2.29 in [15]. + form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` + and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in + :ref:`[14] <references-lineartransport>` and discussed in remark 2.29 in + :ref:`[15] <references-lineartransport>`. The linear operator from source to target :math:`M` .. math:: - M(x)=Ax+b + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} where : .. math:: - A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} \Sigma_s^{-1/2} - .. math:: - b=\mu_t-A\mu_s + + \mathbf{b} &= \mu_t - \mathbf{A} \mu_s Parameters ---------- reg : float,optional - regularization added to the daigonals of convariances (>0) + regularization added to the daigonals of covariances (>0) bias: boolean, optional - estimate bias b else b=0 (default:True) + estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) log : bool, optional record log if True + + .. _references-lineartransport: References ---------- - .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of distributions", Journal of Optimization Theory and Applications Vol 43, 1984 @@ -1279,7 +1308,7 @@ class LinearTransport(BaseTransport): def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1290,8 +1319,8 @@ class LinearTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1325,7 +1354,7 @@ class LinearTransport(BaseTransport): return self def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples Xs onto target ones Xt + """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -1336,8 +1365,8 @@ class LinearTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1358,7 +1387,7 @@ class LinearTransport(BaseTransport): def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports target samples Xt onto target samples Xs + """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` Parameters ---------- @@ -1369,8 +1398,8 @@ class LinearTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1392,7 +1421,7 @@ class LinearTransport(BaseTransport): class SinkhornTransport(BaseTransport): - """Domain Adapatation OT method based on Sinkhorn Algorithm + """Domain Adaptation OT method based on Sinkhorn Algorithm Parameters ---------- @@ -1400,7 +1429,7 @@ class SinkhornTransport(BaseTransport): Entropic regularization parameter max_iter : int, float, optional (default=1000) The minimum number of iteration before stopping the optimization - algorithm if no it has not converged + algorithm if it has not converged tol : float, optional (default=10e-9) The precision required to stop the optimization algorithm. verbose : bool, optional (default=False) @@ -1417,8 +1446,8 @@ class SinkhornTransport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. - limit_max: float, optional (defaul=np.infty) + "ferradans" which uses the method proposed in :ref:`[6] <references-sinkhorntransport>`. + limit_max: float, optional (default=np.infty) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an cost defined by this variable @@ -1428,16 +1457,20 @@ class SinkhornTransport(BaseTransport): coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling log_ : dictionary - The dictionary of log, empty dic if parameter log is not True + The dictionary of log, empty dict if parameter log is not True + + .. _references-sinkhorntransport: References ---------- .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. @@ -1461,7 +1494,7 @@ class SinkhornTransport(BaseTransport): def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1472,8 +1505,8 @@ class SinkhornTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1504,7 +1537,7 @@ class SinkhornTransport(BaseTransport): class EMDTransport(BaseTransport): - """Domain Adapatation OT method based on Earth Mover's Distance + """Domain Adaptation OT method based on Earth Mover's Distance Parameters ---------- @@ -1520,7 +1553,7 @@ class EMDTransport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. + "ferradans" which uses the method proposed in :ref:`[6] <references-emdtransport>`. limit_max: float, optional (default=10) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an infinite cost @@ -1534,14 +1567,16 @@ class EMDTransport(BaseTransport): coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling + + .. _references-emdtransport: References ---------- .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, - "Optimal Transport for Domain Adaptation," in IEEE Transactions - on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + "Optimal Transport for Domain Adaptation," in IEEE Transactions + on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). - Regularized discrete optimal transport. SIAM Journal on Imaging - Sciences, 7(3), 1853-1882. + Regularized discrete optimal transport. SIAM Journal on Imaging + Sciences, 7(3), 1853-1882. """ def __init__(self, metric="sqeuclidean", norm=None, log=False, @@ -1558,7 +1593,7 @@ class EMDTransport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1569,8 +1604,8 @@ class EMDTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1597,8 +1632,7 @@ class EMDTransport(BaseTransport): class SinkhornLpl1Transport(BaseTransport): - - """Domain Adapatation OT method based on sinkhorn algorithm + + r"""Domain Adaptation OT method based on sinkhorn algorithm + LpL1 class regularization. Parameters @@ -1609,7 +1643,7 @@ class SinkhornLpl1Transport(BaseTransport): Class regularization parameter max_iter : int, float, optional (default=10) The minimum number of iteration before stopping the optimization - algorithm if no it has not converged + algorithm if it has not converged max_inner_iter : int, float, optional (default=200) The number of iteration in the inner loop log : bool, optional (default=False) @@ -1628,8 +1662,8 @@ class SinkhornLpl1Transport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. - limit_max: float, optional (defaul=np.infty) + "ferradans" which uses the method proposed in :ref:`[6] <references-sinkhornlpl1transport>`. + limit_max: float, optional (default=np.infty) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit a cost defined by limit_max. @@ -1639,16 +1673,19 @@ class SinkhornLpl1Transport(BaseTransport): coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling + + .. _references-sinkhornlpl1transport: References ---------- - .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. @@ -1675,7 +1712,7 @@ class SinkhornLpl1Transport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1686,8 +1723,8 @@ class SinkhornLpl1Transport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1719,13 +1756,14 @@ class SinkhornLpl1Transport(BaseTransport): class EMDLaplaceTransport(BaseTransport): - """Domain Adapatation OT method based on Earth Mover's Distance with Laplacian regularization + """Domain Adaptation OT method based on Earth Mover's Distance with Laplacian regularization Parameters ---------- reg_type : string optional (default='pos') Type of the regularization term: 'pos' and 'disp' for - regularization term defined in [2] and [6], respectively. + regularization term defined in :ref:`[2] <references-emdlaplacetransport>` and + :ref:`[6] <references-emdlaplacetransport>`, respectively. reg_lap : float, optional (default=1) Laplacian regularization parameter reg_src : float, optional (default=0.5) @@ -1756,24 +1794,27 @@ class EMDLaplaceTransport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. + "ferradans" which uses the method proposed in :ref:`[6] <references-emdlaplacetransport>`. Attributes ---------- coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling + + .. _references-emdlaplacetransport: References ---------- .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [2] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy, "Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching," - in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). - Regularized discrete optimal transport. SIAM Journal on Imaging - Sciences, 7(3), 1853-1882. + Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. """ def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean", @@ -1799,7 +1840,7 @@ class EMDLaplaceTransport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1810,8 +1851,8 @@ class EMDLaplaceTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1840,8 +1881,8 @@ class EMDLaplaceTransport(BaseTransport): class SinkhornL1l2Transport(BaseTransport): - """Domain Adapatation OT method based on sinkhorn algorithm + - l1l2 class regularization. + """Domain Adaptation OT method based on sinkhorn algorithm + + L1L2 class regularization. Parameters ---------- @@ -1851,7 +1892,7 @@ class SinkhornL1l2Transport(BaseTransport): Class regularization parameter max_iter : int, float, optional (default=10) The minimum number of iteration before stopping the optimization - algorithm if no it has not converged + algorithm if it has not converged max_inner_iter : int, float, optional (default=200) The number of iteration in the inner loop tol : float, optional (default=10e-9) @@ -1870,7 +1911,7 @@ class SinkhornL1l2Transport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. + "ferradans" which uses the method proposed in :ref:`[6] <references-sinkhornl1l2transport>`. limit_max: float, optional (default=10) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an infinite cost @@ -1881,18 +1922,21 @@ class SinkhornL1l2Transport(BaseTransport): coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling log_ : dictionary - The dictionary of log, empty dic if parameter log is not True + The dictionary of log, empty dict if parameter log is not True + + .. _references-sinkhornl1l2transport: References ---------- - .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. + .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. @@ -1919,7 +1963,7 @@ class SinkhornL1l2Transport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -1930,8 +1974,8 @@ class SinkhornL1l2Transport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -1973,7 +2017,7 @@ class MappingTransport(BaseEstimator): mu : float, optional (default=1) Weight for the linear OT loss (>0) eta : float, optional (default=0.001) - Regularization term for the linear mapping L (>0) + Regularization term for the linear mapping `L` (>0) bias : bool, optional (default=False) Estimate linear mapping with constant bias metric : string, optional (default="sqeuclidean") @@ -2004,17 +2048,20 @@ class MappingTransport(BaseEstimator): ---------- coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling - mapping_ : array-like, shape (n_features (+ 1), n_features) - (if bias) for kernel == linear + mapping_ : The associated mapping - array-like, shape (n_source_samples (+ 1), n_features) - (if bias) for kernel == gaussian + + - array-like, shape (`n_features` (+ 1), `n_features`), + (if bias) for kernel == linear + + - array-like, shape (`n_source_samples` (+ 1), `n_features`), + (if bias) for kernel == gaussian log_ : dictionary - The dictionary of log, empty dic if parameter log is not True + The dictionary of log, empty dict if parameter log is not True + References ---------- - .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. @@ -2042,7 +2089,8 @@ class MappingTransport(BaseEstimator): def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Builds an optimal coupling and estimates the associated mapping - from source and target sets of samples (Xs, ys) and (Xt, yt) + from source and target sets of samples + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -2053,8 +2101,8 @@ class MappingTransport(BaseEstimator): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -2098,7 +2146,7 @@ class MappingTransport(BaseEstimator): return self def transform(self, Xs): - """Transports source samples Xs onto target ones Xt + """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -2138,7 +2186,7 @@ class MappingTransport(BaseEstimator): class UnbalancedSinkhornTransport(BaseTransport): - """Domain Adapatation unbalanced OT method based on sinkhorn algorithm + """Domain Adaptation unbalanced OT method based on sinkhorn algorithm Parameters ---------- @@ -2151,7 +2199,7 @@ class UnbalancedSinkhornTransport(BaseTransport): 'sinkhorn_epsilon_scaling', see those function for specific parameters max_iter : int, float, optional (default=10) The minimum number of iteration before stopping the optimization - algorithm if no it has not converged + algorithm if it has not converged tol : float, optional (default=10e-9) Stop threshold on error (inner sinkhorn solver) (>0) verbose : bool, optional (default=False) @@ -2168,7 +2216,7 @@ class UnbalancedSinkhornTransport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. + "ferradans" which uses the method proposed in :ref:`[6] <references-unbalancedsinkhorntransport>`. limit_max: float, optional (default=10) Controls the semi supervised mode. Transport between labeled source and target samples of different classes will exhibit an infinite cost @@ -2179,14 +2227,16 @@ class UnbalancedSinkhornTransport(BaseTransport): coupling_ : array-like, shape (n_source_samples, n_target_samples) The optimal coupling log_ : dictionary - The dictionary of log, empty dic if parameter log is not True + The dictionary of log, empty dict if parameter log is not True + + .. _references-unbalancedsinkhorntransport: References ---------- - .. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). - Scaling algorithms for unbalanced transport problems. arXiv preprint - arXiv:1607.05816. + Scaling algorithms for unbalanced transport problems. arXiv preprint + arXiv:1607.05816. + .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. @@ -2212,7 +2262,7 @@ class UnbalancedSinkhornTransport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -2223,8 +2273,8 @@ class UnbalancedSinkhornTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -2258,7 +2308,7 @@ class UnbalancedSinkhornTransport(BaseTransport): class JCPOTTransport(BaseTransport): - """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm. + """Domain Adaptation OT method for multi-source target shift based on Wasserstein barycenter algorithm. Parameters ---------- @@ -2266,7 +2316,7 @@ class JCPOTTransport(BaseTransport): Entropic regularization parameter max_iter : int, float, optional (default=10) The minimum number of iteration before stopping the optimization - algorithm if no it has not converged + algorithm if it has not converged tol : float, optional (default=10e-9) Stop threshold on error (inner sinkhorn solver) (>0) verbose : bool, optional (default=False) @@ -2283,7 +2333,7 @@ class JCPOTTransport(BaseTransport): out_of_sample_map : string, optional (default="ferradans") The kind of out of sample mapping to apply to transport samples from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. + "ferradans" which uses the method proposed in :ref:`[6] <references-jcpottransport>`. Attributes ---------- @@ -2292,11 +2342,12 @@ class JCPOTTransport(BaseTransport): proportions_ : array-like, shape (n_classes,) Estimated class proportions in the target domain log_ : dictionary - The dictionary of log, empty dic if parameter log is not True + The dictionary of log, empty dict if parameter log is not True + + .. _references-jcpottransport: References ---------- - .. [1] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia "Optimal transport for multi-source domain adaptation under target shift", International Conference on Artificial Intelligence and Statistics (AISTATS), @@ -2323,7 +2374,7 @@ class JCPOTTransport(BaseTransport): def fit(self, Xs, ys=None, Xt=None, yt=None): """Building coupling matrices from a list of source and target sets of samples - (Xs, ys) and (Xt, yt) + :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters ---------- @@ -2334,8 +2385,8 @@ class JCPOTTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -2368,7 +2419,7 @@ class JCPOTTransport(BaseTransport): return self def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples Xs onto target ones Xt + """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -2379,8 +2430,8 @@ class JCPOTTransport(BaseTransport): Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. + The class labels. If some target samples are unlabelled, fill the + :math:`\mathbf{y_t}`'s elements with -1. Warning: Note that, due to this convention -1 cannot be used as a class label @@ -2440,7 +2491,8 @@ class JCPOTTransport(BaseTransport): return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels ys to obtain target labels as in [27] + """Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in + :ref:`[27] <references-jcpottransport-transform-labels>` Parameters ---------- @@ -2451,6 +2503,14 @@ class JCPOTTransport(BaseTransport): ------- yt : array-like, shape (n_target_samples, nb_classes) Estimated soft target labels. + + + .. _references-jcpottransport-transform-labels: + References + ---------- + .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. """ # check the necessary inputs parameters are here @@ -2482,11 +2542,12 @@ class JCPOTTransport(BaseTransport): return yt.T def inverse_transform_labels(self, yt=None): - """Propagate source labels ys to obtain target labels + """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels + :math:`\mathbf{y_s}` Parameters ---------- - yt : array-like, shape (n_source_samples,) + yt : array-like, shape (n_target_samples,) The target class labels Returns diff --git a/ot/datasets.py b/ot/datasets.py index b86ef3b..ad6390c 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -13,7 +13,7 @@ from .utils import check_random_state, deprecated def make_1D_gauss(n, m, s): - """return a 1D histogram for a gaussian distribution (n bins, mean m and std s) + """return a 1D histogram for a gaussian distribution (`n` bins, mean `m` and std `s`) Parameters ---------- @@ -26,7 +26,7 @@ def make_1D_gauss(n, m, s): Returns ------- - h : ndarray (n,) + h : ndarray (`n`,) 1D histogram for a gaussian distribution """ x = np.arange(n, dtype=np.float64) @@ -41,7 +41,7 @@ def get_1D_gauss(n, m, sigma): def make_2D_samples_gauss(n, m, sigma, random_state=None): - """Return n samples drawn from 2D gaussian N(m,sigma) + """Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)` Parameters ---------- @@ -59,8 +59,8 @@ def make_2D_samples_gauss(n, m, sigma, random_state=None): Returns ------- - X : ndarray, shape (n, 2) - n samples drawn from N(m, sigma). + X : ndarray, shape (`n`, 2) + n samples drawn from :math:`\mathcal{N}(m, \sigma)`. """ generator = check_random_state(random_state) @@ -102,7 +102,7 @@ def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwa Returns ------- X : ndarray, shape (n, d) - n observation of size d + `n` observation of size `d` y : ndarray, shape (n,) labels of the samples. """ @@ -10,6 +10,7 @@ Dimension reduction with OT """ # Author: Remi Flamary <remi.flamary@unice.fr> +# Minhui Huang <mhhuang@ucdavis.edu> # # License: MIT License @@ -21,7 +22,7 @@ from pymanopt.solvers import SteepestDescent, TrustRegions def dist(x1, x2): - """ Compute squared euclidean distance between samples (autograd) + r""" Compute squared euclidean distance between samples (autograd) """ x1p2 = np.sum(np.square(x1), 1) x2p2 = np.sum(np.square(x2), 1) @@ -29,7 +30,7 @@ def dist(x1, x2): def sinkhorn(w1, w2, M, reg, k): - """Sinkhorn algorithm with fixed number of iteration (autograd) + r"""Sinkhorn algorithm with fixed number of iteration (autograd) """ K = np.exp(-M / reg) ui = np.ones((M.shape[0],)) @@ -42,14 +43,14 @@ def sinkhorn(w1, w2, M, reg, k): def split_classes(X, y): - """split samples in X by classes in y + r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}` """ lstsclass = np.unique(y) return [X[y == i, :].astype(np.float32) for i in lstsclass] def fda(X, y, p=2, reg=1e-16): - """Fisher Discriminant Analysis + r"""Fisher Discriminant Analysis Parameters ---------- @@ -108,20 +109,21 @@ def fda(X, y, p=2, reg=1e-16): return Popt, proj -def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): - """ - Wasserstein Discriminant Analysis [11]_ +def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False): + r""" + Wasserstein Discriminant Analysis :ref:`[11] <references-wda>` The function solves the following optimization problem: .. math:: - P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)} + \mathbf{P} = \mathop{\arg \min}_\mathbf{P} \quad + \frac{\sum\limits_i W(P \mathbf{X}^i, P \mathbf{X}^i)}{\sum\limits_{i, j \neq i} W(P \mathbf{X}^i, P \mathbf{X}^j)} where : - - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold + - :math:`P` is a linear projection operator in the Stiefel(`p`, `d`) manifold - :math:`W` is entropic regularized Wasserstein distances - - :math:`X^i` are samples in the dataset corresponding to class i + - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i Parameters ---------- @@ -138,6 +140,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): else should be a pymanopt.solvers P0 : ndarray, shape (d, p) Initial starting point for projection. + normalize : bool, optional + Normalise the Wasserstaiun distance by the average distance on P0 (default : False) verbose : int, optional Print information along iterations. @@ -148,6 +152,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): proj : callable Projection function including mean centering. + + .. _references-wda: References ---------- .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). @@ -163,6 +169,18 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): # compute uniform weighs wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc] + # pre-compute reg_c,c' + if P0 is not None and normalize: + regmean = np.zeros((len(xc), len(xc))) + for i, xi in enumerate(xc): + xi = np.dot(xi, P0) + for j, xj in enumerate(xc[i:]): + xj = np.dot(xj, P0) + M = dist(xi, xj) + regmean[i, j] = np.sum(M) / (len(xi) * len(xj)) + else: + regmean = np.ones((len(xc), len(xc))) + def cost(P): # wda loss loss_b = 0 @@ -173,7 +191,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P) M = dist(xi, xj) - G = sinkhorn(wc[i], wc[j + i], M, reg, k) + G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k) if j == 0: loss_w += np.sum(G * M) else: @@ -198,3 +216,119 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None): return (X - mx.reshape((1, -1))).dot(Popt) return Popt, proj + + +def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): + r""" + Projection Robust Wasserstein Distance :ref:`[32] <references-projection-robust-wasserstein>` + + The function solves the following optimization problem: + + .. math:: + \max_{U \in St(d, k)} \ \min_{\pi \in \Pi(\mu,\nu)} \quad \sum_{i,j} \pi_{i,j} + \|U^T(\mathbf{x}_i - \mathbf{y}_j)\|^2 - \mathrm{reg} \cdot H(\pi) + + - :math:`U` is a linear projection operator in the Stiefel(`d`, `k`) manifold + - :math:`H(\pi)` is entropy regularizer + - :math:`\mathbf{x}_i`, :math:`\mathbf{y}_j` are samples of measures :math:`\mu` and :math:`\nu` respectively + + Parameters + ---------- + X : ndarray, shape (n, d) + Samples from measure :math:`\mu` + Y : ndarray, shape (n, d) + Samples from measure :math:`\nu` + a : ndarray, shape (n, ) + weights for measure :math:`\mu` + b : ndarray, shape (n, ) + weights for measure :math:`\nu` + tau : float + stepsize for Riemannian Gradient Descent + U0 : ndarray, shape (d, p) + Initial starting point for projection. + reg : float, optional + Regularization term >0 (entropic regularization) + k : int + Subspace dimension + stopThr : float, optional + Stop threshold on error (>0) + verbose : int, optional + Print information along iterations. + + Returns + ------- + pi : ndarray, shape (n, n) + Optimal transportation matrix for the given parameters + U : ndarray, shape (d, k) + Projection operator. + + + .. _references-projection-robust-wasserstein: + References + ---------- + .. [32] Huang, M. , Ma S. & Lai L. (2021). + A Riemannian Block Coordinate Descent Method for Computing + the Projection Robust Wasserstein Distance, ICML. + """ # noqa + + # initialization + n, d = X.shape + m, d = Y.shape + a = np.asarray(a, dtype=np.float64) + b = np.asarray(b, dtype=np.float64) + u = np.ones(n) / n + v = np.ones(m) / m + ones = np.ones((n, m)) + + assert d > k + + if U0 is None: + U = np.random.randn(d, k) + U, _ = np.linalg.qr(U) + else: + U = U0 + + def Vpi(X, Y, a, b, pi): + # Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }. + A = X.T.dot(pi).dot(Y) + return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T + + err = 1 + iter = 0 + + while err > stopThr and iter < maxiter: + + # Projected cost matrix + UUT = U.dot(U.T) + M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot( + np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T)) + + A = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=A) + np.exp(A, out=A) + + # Sinkhorn update + Ap = (1 / a).reshape(-1, 1) * A + AtransposeU = np.dot(A.T, u) + v = np.divide(b, AtransposeU) + u = 1. / np.dot(Ap, v) + pi = u.reshape((-1, 1)) * A * v.reshape((1, -1)) + + V = Vpi(X, Y, a, b, pi) + + # Riemannian gradient descent + G = 2 / reg * V.dot(U) + GTU = G.T.dot(U) + xi = G - U.dot(GTU + GTU.T) / 2 # Riemannian gradient + U, _ = np.linalg.qr(U + tau * xi) # Retraction by QR decomposition + + grad_norm = np.linalg.norm(xi) + err = max(reg * grad_norm, np.linalg.norm(np.sum(pi, 0) - b, 1)) + + f_val = np.trace(U.T.dot(V.dot(U))) + if verbose: + print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val) + + iter = iter + 1 + + return pi, U diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py index 7478fb9..12db605 100644 --- a/ot/gpu/__init__.py +++ b/ot/gpu/__init__.py @@ -7,7 +7,13 @@ The GPU backend in handled by `cupy <https://cupy.chainer.org/>`_. .. warning:: - Note that by default the module is not import in :mod:`ot`. In order to + This module is now deprecated and will be removed in future releases. POT + now privides a backend mechanism that allows for solving prolem on GPU wth + the pytorch backend. + + +.. warning:: + Note that by default the module is not imported in :mod:`ot`. In order to use it you need to explicitely import :mod:`ot.gpu` . By default, the functions in this module accept and return numpy arrays @@ -25,6 +31,8 @@ result of the function with parameter ``to_numpy=False``. # # License: MIT License +import warnings + from . import bregman from . import da from .bregman import sinkhorn @@ -34,7 +42,7 @@ from . import utils from .utils import dist, to_gpu, to_np - +warnings.warn('This module is deprecated and will be removed in the next minor release of POT', category=DeprecationWarning) __all__ = ["utils", "dist", "sinkhorn", diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py index 2e2df83..76af00e 100644 --- a/ot/gpu/bregman.py +++ b/ot/gpu/bregman.py @@ -15,7 +15,7 @@ from . import utils def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, to_numpy=True, **kwargs): - """ + r""" Solve the entropic regularization optimal transport on GPU If the input matrix are in numpy format, they will be uploaded to the @@ -54,7 +54,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -148,13 +148,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, # we can speed up the process by checking for the error only all # the 10th iterations if nbb: - err = np.sum((u - uprev)**2) / np.sum((u)**2) + \ - np.sum((v - vprev)**2) / np.sum((v)**2) + err = np.sqrt( + np.sum((u - uprev)**2) / np.sum((u)**2) + + np.sum((v - vprev)**2) / np.sum((v)**2) + ) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 tmp2 = np.sum(u[:, None] * K * v[None, :], 0) #tmp2=np.einsum('i,ij,j->j', u, K, v) - err = np.linalg.norm(tmp2 - b)**2 # violation of marginal + err = np.linalg.norm(tmp2 - b) # violation of marginal if log: log['err'].append(err) diff --git a/ot/gpu/da.py b/ot/gpu/da.py index 4a98038..7adb830 100644 --- a/ot/gpu/da.py +++ b/ot/gpu/da.py @@ -120,7 +120,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, labels_a2 = cp.asnumpy(labels_a) classes = npp.unique(labels_a2) for c in classes: - idxc, = utils.to_gpu(npp.where(labels_a2 == c)) + idxc = utils.to_gpu(*npp.where(labels_a2 == c)) indices_labels.append(idxc) W = np.zeros(M.shape) diff --git a/ot/gromov.py b/ot/gromov.py index 4427a96..ea667e4 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -14,63 +14,85 @@ import numpy as np from .bregman import sinkhorn
-from .utils import dist, UndefinedParameter
+from .utils import dist, UndefinedParameter, list_to_array
from .optim import cg
+from .lp import emd_1d, emd
+from .utils import check_random_state
+from .backend import get_backend
def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
- """Return loss matrices and tensors for Gromov-Wasserstein fast computation
+ r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation
- Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
- function as the loss function of Gromow-Wasserstein discrepancy.
+ Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the
+ selected loss function as the loss function of Gromow-Wasserstein discrepancy.
- The matrices are computed as described in Proposition 1 in [12]
+ The matrices are computed as described in Proposition 1 in :ref:`[12] <references-init-matrix>`
Where :
- * C1 : Metric cost matrix in the source space
- * C2 : Metric cost matrix in the target space
- * T : A coupling between those two spaces
-
- The square-loss function L(a,b)=|a-b|^2 is read as :
- L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
- * f1(a)=(a^2)
- * f2(b)=(b^2)
- * h1(a)=a
- * h2(b)=2*b
-
- The kl-loss function L(a,b)=a*log(a/b)-a+b is read as :
- L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
- * f1(a)=a*log(a)-a
- * f2(b)=b
- * h1(a)=a
- * h2(b)=log(b)
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{T}`: A coupling between those two spaces
+
+ The square-loss function :math:`L(a, b) = |a - b|^2` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a^2
+
+ f_2(b) &= b^2
+
+ h_1(a) &= a
+
+ h_2(b) &= 2b
+
+ The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a \log(a) - a
+
+ f_2(b) &= b
+
+ h_1(a) &= a
+
+ h_2(b) &= \log(b)
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- T : ndarray, shape (ns, nt)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ T : array-like, shape (ns, nt)
Coupling between source and target spaces
- p : ndarray, shape (ns,)
+ p : array-like, shape (ns,)
Returns
-------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+
+ .. _references-init-matrix:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
if loss_fun == 'square_loss':
def f1(a):
@@ -86,7 +108,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
- return a * np.log(a + 1e-15) - a
+ return a * nx.log(a + 1e-15) - a
def f2(b):
return b
@@ -95,12 +117,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): return a
def h2(b):
- return np.log(b + 1e-15)
-
- constC1 = np.dot(np.dot(f1(C1), p.reshape(-1, 1)),
- np.ones(len(q)).reshape(1, -1))
- constC2 = np.dot(np.ones(len(p)).reshape(-1, 1),
- np.dot(q.reshape(1, -1), f2(C2).T))
+ return nx.log(b + 1e-15)
+
+ constC1 = nx.dot(
+ nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
+ nx.ones((1, len(q)), type_as=q)
+ )
+ constC2 = nx.dot(
+ nx.ones((len(p), 1), type_as=p),
+ nx.dot(nx.reshape(q, (1, -1)), f2(C2).T)
+ )
constC = constC1 + constC2
hC1 = h1(C1)
hC2 = h2(C2)
@@ -109,61 +135,70 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): def tensor_product(constC, hC1, hC2, T):
- """Return the tensor for Gromov-Wasserstein fast computation
+ r"""Return the tensor for Gromov-Wasserstein fast computation
- The tensor is computed as described in Proposition 1 Eq. (6) in [12].
+ The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-tensor-product>`
Parameters
----------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
Returns
-------
- tens : ndarray, shape (ns, nt)
- \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
+ tens : array-like, shape (`ns`, `nt`)
+ :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result
+
+ .. _references-tensor-product:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
- A = -np.dot(hC1, T).dot(hC2.T)
+ constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T)
+ nx = get_backend(constC, hC1, hC2, T)
+
+ A = - nx.dot(
+ nx.dot(hC1, T), hC2.T
+ )
tens = constC + A
# tens -= tens.min()
return tens
def gwloss(constC, hC1, hC2, T):
- """Return the Loss for Gromov-Wasserstein
+ r"""Return the Loss for Gromov-Wasserstein
- The loss is computed as described in Proposition 1 Eq. (6) in [12].
+ The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-gwloss>`
Parameters
----------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
- T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ T : array-like, shape (ns, nt)
+ Current value of transport matrix :math:`\mathbf{T}`
Returns
-------
loss : float
Gromov Wasserstein loss
+
+ .. _references-gwloss:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -171,33 +206,38 @@ def gwloss(constC, hC1, hC2, T): tens = tensor_product(constC, hC1, hC2, T)
- return np.sum(tens * T)
+ tens, T = list_to_array(tens, T)
+ nx = get_backend(tens, T)
+
+ return nx.sum(tens * T)
def gwggrad(constC, hC1, hC2, T):
- """Return the gradient for Gromov-Wasserstein
+ r"""Return the gradient for Gromov-Wasserstein
- The gradient is computed as described in Proposition 2 in [12].
+ The gradient is computed as described in Proposition 2 in :ref:`[12] <references-gwggrad>`
Parameters
----------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
- T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ T : array-like, shape (ns, nt)
+ Current value of transport matrix :math:`\mathbf{T}`
Returns
-------
- grad : ndarray, shape (ns, nt)
+ grad : array-like, shape (`ns`, `nt`)
Gromov Wasserstein gradient
+
+ .. _references-gwggrad:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -207,89 +247,109 @@ def gwggrad(constC, hC1, hC2, T): def update_square_loss(p, lambdas, T, Cs):
- """
- Updates C according to the L2 Loss kernel with the S Ts couplings
- calculated at each iteration
+ r"""
+ Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
+ couplings calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
Masses in the targeted barycenter.
lambdas : list of float
- List of the S spaces' weights.
- T : list of S np.ndarray of shape (ns,N)
- The S Ts couplings calculated at each iteration.
- Cs : list of S ndarray, shape(ns,ns)
+ List of the `S` spaces' weights.
+ T : list of S array-like of shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape(ns,ns)
Metric cost matrices.
Returns
----------
- C : ndarray, shape (nt, nt)
- Updated C matrix.
+ C : array-like, shape (`nt`, `nt`)
+ Updated :math:`\mathbf{C}` matrix.
"""
- tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
- for s in range(len(T))])
- ppt = np.outer(p, p)
+ T = list_to_array(*T)
+ Cs = list_to_array(*Cs)
+ p = list_to_array(p)
+ nx = get_backend(p, *T, *Cs)
- return np.divide(tmpsum, ppt)
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+
+ return tmpsum / ppt
def update_kl_loss(p, lambdas, T, Cs):
- """
- Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration
+ r"""
+ Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
Weights in the targeted barycenter.
- lambdas : list of the S spaces' weights
- T : list of S np.ndarray of shape (ns,N)
- The S Ts couplings calculated at each iteration.
- Cs : list of S ndarray, shape(ns,ns)
+ lambdas : list of float
+ List of the `S` spaces' weights
+ T : list of S array-like of shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape(ns,ns)
Metric cost matrices.
Returns
----------
- C : ndarray, shape (ns,ns)
- updated C matrix
+ C : array-like, shape (`ns`, `ns`)
+ updated :math:`\mathbf{C}` matrix
"""
- tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
- for s in range(len(T))])
- ppt = np.outer(p, p)
+ Cs = list_to_array(*Cs)
+ T = list_to_array(*T)
+ p = list_to_array(p)
+ nx = get_backend(p, *T, *Cs)
- return np.exp(np.divide(tmpsum, ppt))
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+ return nx.exp(tmpsum / ppt)
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
- """
- Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
+
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- p : ndarray, shape (ns,)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
-
max_iter : int, optional
Max number of iterations
tol : float, optional
@@ -299,22 +359,23 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs log : bool, optional
record log if True
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver
Returns
-------
- T : ndarray, shape (ns, nt)
- Doupling between the two spaces that minimizes:
- \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ T : array-like, shape (`ns`, `nt`)
+ Coupling between the two spaces that minimizes:
+
+ :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
log : dict
Convergence information and loss.
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -323,6 +384,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs mathematics 11.4 (2011): 417-487.
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20 = p, q, C1, C2
+ nx = get_backend(p0, q0, C10, C20)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -336,37 +406,45 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs if log:
res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log['gw_dist'] = gwloss(constC, hC1, hC2, res)
- return res, log
+ log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10)
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
else:
- return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
- """
- Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+ r"""
+ Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ GW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ marices and weights for quadratic loss using the gradients from [38]_.
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
+ C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : ndarray, shape (ns,)
+ p : array-like, shape (ns,)
Distribution in the source space.
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
@@ -379,8 +457,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg log : bool, optional
record log if True
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
Returns
-------
@@ -391,7 +469,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -399,7 +477,20 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg metric approach to object matching. Foundations of computational
mathematics 11.4 (2011): 417-487.
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20 = p, q, C1, C2
+ nx = get_backend(p0, q0, C10, C20)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -410,53 +501,71 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res)
- log_gw['T'] = res
+
+ T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+ T0 = nx.from_numpy(T, type_as=C10)
+
+ log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10)
+ log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10)
+ log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10)
+ log_gw['T'] = T0
+
+ gw = log_gw['gw_dist']
+
+ if loss_fun == 'square_loss':
+ gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
+ gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gw = nx.set_gradients(gw, (p0, q0, C10, C20),
+ (log_gw['u'], log_gw['v'], gC1, gC2))
+
if log:
- return log_gw['gw_dist'], log_gw
+ return gw, log_gw
else:
- return log_gw['gw_dist']
+ return gw
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
- """
- Computes the FGW transport between two graphs see [24]
+ r"""
+ Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
.. math::
- \gamma = arg\min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l}
- L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. \gamma 1 = p
- \gamma^T 1= q
- \gamma\geq 0
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
- - p and q are source and target weights (sum to 1)
- - L is a loss function to account for the misfit between the similarity matrices
- The algorithm used for solving the problem is conditional gradient as discussed in [24]_
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
Parameters
----------
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
Metric cost matrix between features across domains
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix representative of the structure in the source space
- C2 : ndarray, shape (nt, nt)
+ C2 : array-like, shape (nt, nt)
Metric cost matrix representative of the structure in the target space
- p : ndarray, shape (ns,)
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : str, optional
Loss function used for the solver
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
log : bool, optional
record log if True
**kwargs : dict
@@ -464,18 +573,30 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : array-like, shape (`ns`, `nt`)
Optimal transportation matrix for the given parameters.
log : dict
Log dictionary return only if log==True in parameters.
+
+ .. _references-fused-gromov-wasserstein:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas "Optimal Transport for structured data with
application on graphs", International Conference on Machine Learning
(ICML). 2019.
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ nx = get_backend(p0, q0, C10, C20, M0)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -489,69 +610,98 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
- log['fgw_dist'] = log['loss'][::-1][0]
- return res, log
+
+ fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
+
+ log['fgw_dist'] = fgw_dist
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
+
else:
- return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10)
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
- """
- Computes the FGW distance between two graphs see [24]
+ r"""
+ Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
.. math::
- \min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l}
- L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
- s.t. \gamma 1 = p
- \gamma^T 1= q
- \gamma\geq 0
+ \mathbf{\gamma} &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
- - p and q are source and target weights (sum to 1)
- - L is a loss function to account for the misfit between the similarity matrices
- The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as
+ discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+
+ Note that when using backends, this loss function is differentiable wrt the
+ marices and weights for quadratic loss using the gradients from [38]_.
Parameters
----------
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
Metric cost matrix between features across domains
- C1 : ndarray, shape (ns, ns)
- Metric cost matrix respresentative of the structure in the source space.
- C2 : ndarray, shape (nt, nt)
- Metric cost matrix espresentative of the structure in the target space.
- p : ndarray, shape (ns,)
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space.
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space.
+ p : array-like, shape (ns,)
Distribution in the source space.
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space.
loss_fun : str, optional
Loss function used for the solver.
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research.
- Else closed form is used. If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research.
+ Else closed form is used. If there are convergence issues use False.
log : bool, optional
Record log if True.
**kwargs : dict
- Parameters can be directly pased to the ot.optim.cg solver.
+ Parameters can be directly passed to the ot.optim.cg solver.
Returns
-------
- gamma : ndarray, shape (ns, nt)
- Optimal transportation matrix for the given parameters.
+ fgw-distance : float
+ Fused gromov wasserstein distance for the given parameters.
log : dict
Log dictionary return only if log==True in parameters.
+
+ .. _references-fused-gromov-wasserstein2:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
+
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ nx = get_backend(p0, q0, C10, C20, M0)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -563,50 +713,462 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+
+ fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10)
+
+ T0 = nx.from_numpy(T, type_as=C10)
+
+ log_fgw['fgw_dist'] = fgw_dist
+ log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10)
+ log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10)
+ log_fgw['T'] = T0
+
+ if loss_fun == 'square_loss':
+ gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
+ gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
+ (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
+
if log:
- log['fgw_dist'] = log['loss'][::-1][0]
- log['T'] = res
- return log['fgw_dist'], log
+ return fgw_dist, log_fgw
else:
- return log['fgw_dist']
+ return fgw_dist
-def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False):
+def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
+ nb_samples_p=None, nb_samples_q=None, std=True, random_state=None):
+ r"""
+ Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ with a fixed transport plan :math:`\mathbf{T}`.
+
+ The function gives an unbiased approximation of the following equation:
+
+ .. math::
+
+ GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - `L` : Loss function to account for the misfit between the similarity matrices
+ - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}`
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ T : csr or array-like, shape (ns, nt)
+ Transport plan matrix, either a sparse csr or a dense matrix
+ nb_samples_p : int, optional
+ `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}`
+ nb_samples_q : int, optional
+ `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first
+ std : bool, optional
+ Standard deviation associated with the prediction of the gromov-wasserstein cost
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ : float
+ Gromov-wasserstein cost
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ generator = check_random_state(random_state)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ # It is always better to sample from the biggest distribution first.
+ if len_p < len_q:
+ p, q = q, p
+ len_p, len_q = len_q, len_p
+ C1, C2 = C2, C1
+ T = T.T
+
+ if nb_samples_p is None:
+ if nx.issparse(T):
+ # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced
+ nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p)
+ else:
+ nb_samples_p = len_p
+ else:
+ # The number of sample along the first dimension is without replacement.
+ nb_samples_p = min(nb_samples_p, len_p)
+ if nb_samples_q is None:
+ nb_samples_q = 1
+ if std:
+ nb_samples_q = max(2, nb_samples_q)
+
+ index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+ index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+
+ index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+ index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+
+ for i in range(nb_samples_p):
+ if nx.issparse(T):
+ T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,))
+ T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,))
+ else:
+ T_indexi = T[index_i[i], :]
+ T_indexj = T[index_j[i], :]
+ # For each of the row sampled, the column is sampled.
+ index_k[i] = generator.choice(
+ len_q,
+ size=nb_samples_q,
+ p=T_indexi / nx.sum(T_indexi),
+ replace=True
+ )
+ index_l[i] = generator.choice(
+ len_q,
+ size=nb_samples_q,
+ p=T_indexj / nx.sum(T_indexj),
+ replace=True
+ )
+
+ list_value_sample = nx.stack([
+ loss_fun(
+ C1[np.ix_(index_i, index_j)],
+ C2[np.ix_(index_k[:, n], index_l[:, n])]
+ ) for n in range(nb_samples_q)
+ ], axis=2)
+
+ if std:
+ std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5
+ return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p)
+ else:
+ return nx.mean(list_value_sample)
+
+
+def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe.
+ This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ alpha : float
+ Step of the Frank-Wolfe algorithm, should be between 0 and 1
+ max_iter : int, optional
+ Max number of iterations
+ threshold_plan : float, optional
+ Deleting very small values in the transport plan. If above zero, it violates the marginal constraints.
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
"""
- Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ generator = check_random_state(random_state)
+
+ index = np.zeros(2, dtype=int)
- (C1,p) and (C2,q)
+ # Initialize with default marginal
+ index[0] = generator.choice(len_p, size=1, p=p)
+ index[1] = generator.choice(len_q, size=1, p=q)
+ T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
+
+ best_gw_dist_estimated = np.inf
+ for cpt in range(max_iter):
+ index[0] = generator.choice(len_p, size=1, p=p)
+ T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
+ index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum())
+
+ if alpha == 1:
+ T = nx.tocsr(
+ emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
+ )
+ else:
+ new_T = nx.tocsr(
+ emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
+ )
+ T = (1 - alpha) * T + alpha * new_T
+ # To limit the number of non 0, the values below the threshold are set to 0.
+ T = nx.eliminate_zeros(T, threshold=threshold_plan)
+
+ if cpt % 10 == 0 or cpt == (max_iter - 1):
+ gw_dist_estimated = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, std=False, random_state=generator
+ )
+
+ if gw_dist_estimated < best_gw_dist_estimated:
+ best_gw_dist_estimated = gw_dist_estimated
+ best_T = nx.copy(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated))
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=best_T, random_state=generator
+ )
+ return best_T, log
+ return best_T
+
+
+def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False,
+ random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe.
+ This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver.
The function solves the following optimization problem:
.. math::
- GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. T 1 = p
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- T^T 1= q
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
- T\geq 0
+ \mathbf{T} &\geq 0
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
- - H : entropy
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- p : ndarray, shape (ns,)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ nb_samples_grad : int
+ Number of samples to approximate the gradient
+ epsilon : float
+ Weight of the Kullback-Leibler regularization
+ max_iter : int, optional
+ Max number of iterations
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ generator = check_random_state(random_state)
+
+ # The most natural way to define nb_sample is with a simple integer.
+ if isinstance(nb_samples_grad, int):
+ if nb_samples_grad > len_p:
+ # As the sampling along the first dimension is done without replacement, the rest is reported to the second
+ # dimension.
+ nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad
+ T = nx.outer(p, q)
+ # continue_loop allows to stop the loop if there is several successive small modification of T.
+ continue_loop = 0
+
+ # The gradient of GW is more complex if the two matrices are not symmetric.
+ C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
+
+ for cpt in range(max_iter):
+ index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False)
+ Lik = 0
+ for i, index0_i in enumerate(index0):
+ index1 = generator.choice(len_q,
+ size=nb_samples_grad_q,
+ p=T[index0_i, :] / nx.sum(T[index0_i, :]),
+ replace=False)
+ # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
+ if (not C_are_symmetric) and generator.rand(1) > 0.5:
+ Lik += nx.mean(loss_fun(
+ C1[:, [index0[i]] * nb_samples_grad_q][:, None, :],
+ C2[:, index1][None, :, :]
+ ), axis=2)
+ else:
+ Lik += nx.mean(loss_fun(
+ C1[[index0[i]] * nb_samples_grad_q, :][:, :, None],
+ C2[index1, :][:, None, :]
+ ), axis=0)
+
+ max_Lik = nx.max(Lik)
+ if max_Lik == 0:
+ continue
+ # This division by the max is here to facilitate the choice of epsilon.
+ Lik /= max_Lik
+
+ if epsilon > 0:
+ # Set to infinity all the numbers below exp(-200) to avoid log of 0.
+ log_T = nx.log(nx.clip(T, np.exp(-200), 1))
+ log_T = nx.where(log_T == -200, -np.inf, log_T)
+ Lik = Lik - epsilon * log_T
+
+ try:
+ new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon)
+ except (RuntimeWarning, UserWarning):
+ print("Warning catched in Sinkhorn: Return last stable T")
+ break
+ else:
+ new_T = emd(a=p, b=q, M=Lik)
+
+ change_T = nx.mean((T - new_T) ** 2)
+ if change_T <= 10e-20:
+ continue_loop += 1
+ if continue_loop > 100: # Number max of low modifications of T
+ T = nx.copy(new_T)
+ break
+ else:
+ continue_loop = 0
+
+ if verbose and cpt % 10 == 0:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, change_T))
+ T = nx.copy(new_T)
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, random_state=generator
+ )
+ return T, log
+ return T
+
+
+def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+ - `H`: entropy
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : string
Loss function used for the solver either 'square_loss' or 'kl_loss'
@@ -623,21 +1185,20 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, Returns
-------
- T : ndarray, shape (ns, nt)
+ T : array-like, shape (`ns`, `nt`)
Optimal coupling between the two spaces
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
- C1 = np.asarray(C1, dtype=np.float64)
- C2 = np.asarray(C2, dtype=np.float64)
-
- T = np.outer(p, q) # Initialization
+ T = nx.outer(p, q)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -654,12 +1215,12 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, # compute the gradient
tens = gwggrad(constC, hC1, hC2, T)
- T = sinkhorn(p, q, tens, epsilon)
+ T = sinkhorn(p, q, tens, epsilon, method='sinkhorn')
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.linalg.norm(T - Tprev)
+ err = nx.norm(T - Tprev)
if log:
log['err'].append(err)
@@ -681,33 +1242,33 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False):
- """
- Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices
-
- (C1,p) and (C2,q)
+ r"""
+ Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+ GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})
+ \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
- - H : entropy
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+ - `H`: entropy
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- p : ndarray, shape (ns,)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : str
Loss function used for the solver either 'square_loss' or 'kl_loss'
@@ -729,7 +1290,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -746,76 +1307,79 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
- """
- Returns the gromov-wasserstein barycenters of S measured similarity matrices
-
- (Cs)_{s=1}^{s=S}
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
+ r"""
+ Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
The function solves the following optimization problem:
.. math::
- C = argmin_{C\in R^{NxN}} \sum_s \lambda_s GW(C,C_s,p,p_s)
+ \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
Where :
- - :math:`C_s` : metric cost matrix
- - :math:`p_s` : distribution
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
Parameters
----------
N : int
Size of the targeted barycenter
- Cs : list of S np.ndarray of shape (ns,ns)
+ Cs : list of S array-like of shape (ns,ns)
Metric cost matrices
- ps : list of S np.ndarray of shape (ns,)
- Sample weights in the S spaces
- p : ndarray, shape(N,)
+ ps : list of S array-like of shape (ns,)
+ Sample weights in the `S` spaces
+ p : array-like, shape(N,)
Weights in the targeted barycenter
lambdas : list of float
- List of the S spaces' weights.
+ List of the `S` spaces' weights.
loss_fun : callable
Tensor-matrix multiplication function based on specific loss function.
update : callable
- function(p,lambdas,T,Cs) that updates C according to a specific Kernel
- with the S Ts couplings calculated at each iteration
+ function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
+ :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
+ calculated at each iteration
epsilon : float
Regularization term >0
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations.
log : bool, optional
Record log if True.
- init_C : bool | ndarray, shape (N, N)
- Random initial value for the C matrix provided by user.
+ init_C : bool | array-like, shape (N, N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
Returns
-------
- C : ndarray, shape (N, N)
+ C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *ps, p)
S = len(Cs)
- Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
- lambdas = np.asarray(lambdas, dtype=np.float64)
-
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
- # XXX use random state
- xalea = np.random.randn(N, 2)
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
else:
C = init_C
@@ -828,7 +1392,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-5, verbose, log) for s in range(S)]
+ max_iter, 1e-4, verbose, log) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -838,7 +1402,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.linalg.norm(C - Cprev)
+ err = nx.norm(C - Cprev)
error.append(err)
if log:
@@ -856,72 +1420,78 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
- """
- Returns the gromov-wasserstein barycenters of S measured similarity matrices
-
- (Cs)_{s=1}^{s=S}
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
+ r"""
+ Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
- The function solves the following optimization problem with block
- coordinate descent:
+ The function solves the following optimization problem with block coordinate descent:
.. math::
- C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
+
+ \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
Where :
- - Cs : metric cost matrix
- - ps : distribution
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
Parameters
----------
N : int
Size of the targeted barycenter
- Cs : list of S np.ndarray of shape (ns, ns)
+ Cs : list of S array-like of shape (ns, ns)
Metric cost matrices
- ps : list of S np.ndarray of shape (ns,)
- Sample weights in the S spaces
- p : ndarray, shape (N,)
+ ps : list of S array-like of shape (ns,)
+ Sample weights in the `S` spaces
+ p : array-like, shape (N,)
Weights in the targeted barycenter
lambdas : list of float
- List of the S spaces' weights
- loss_fun : tensor-matrix multiplication function based on specific loss function
- update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
- with the S Ts couplings calculated at each iteration
+ List of the `S` spaces' weights
+ loss_fun : callable
+ tensor-matrix multiplication function based on specific loss function
+ update : callable
+ function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
+ :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
+ calculated at each iteration
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0).
+ Stop threshold on error (>0).
verbose : bool, optional
Print information along iterations.
log : bool, optional
Record log if True.
- init_C : bool | ndarray, shape(N,N)
- Random initial value for the C matrix provided by user.
+ init_C : bool | array-like, shape(N,N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
Returns
-------
- C : ndarray, shape (N, N)
+ C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
- S = len(Cs)
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *ps, p)
- Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
- lambdas = np.asarray(lambdas, dtype=np.float64)
+ S = len(Cs)
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
- # XXX : should use a random state and not use the global seed
- xalea = np.random.randn(N, 2)
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
else:
C = init_C
@@ -944,7 +1514,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.linalg.norm(C - Cprev)
+ err = nx.norm(C - Cprev)
error.append(err)
if log:
@@ -963,21 +1533,21 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
- verbose=False, log=False, init_C=None, init_X=None):
- """Compute the fgw barycenter as presented eq (5) in [24].
+ verbose=False, log=False, init_C=None, init_X=None, random_state=None):
+ r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
Parameters
----------
- N : integer
+ N : int
Desired number of samples of the target barycenter
- Ys: list of ndarray, each element has shape (ns,d)
+ Ys: list of array-like, each element has shape (ns,d)
Features of all samples
- Cs : list of ndarray, each element has shape (ns,ns)
+ Cs : list of array-like, each element has shape (ns,ns)
Structure matrices of all samples
- ps : list of ndarray, each element has shape (ns,)
+ ps : list of array-like, each element has shape (ns,)
Masses of all samples.
lambdas : list of float
- List of the S spaces' weights
+ List of the `S` spaces' weights
alpha : float
Alpha parameter for the fgw distance
fixed_structure : bool
@@ -989,46 +1559,51 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0).
+ Stop threshold on error (>0).
verbose : bool, optional
Print information along iterations.
log : bool, optional
Record log if True.
- init_C : ndarray, shape (N,N), optional
+ init_C : array-like, shape (N,N), optional
Initialization for the barycenters' structure matrix. If not set
a random init is used.
- init_X : ndarray, shape (N,d), optional
+ init_X : array-like, shape (N,d), optional
Initialization for the barycenters' features. If not set a
random init is used.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
Returns
-------
- X : ndarray, shape (N, d)
+ X : array-like, shape (`N`, `d`)
Barycenters' features
- C : ndarray, shape (N, N)
+ C : array-like, shape (`N`, `N`)
Barycenters' structure matrix
- log_: dict
+ log : dict
Only returned when log=True. It contains the keys:
- T : list of (N,ns) transport matrices
- Ms : all distance matrices between the feature of the barycenter and the
- other features dist(X,Ys) shape (N,ns)
+ - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
+ - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
+
+
+ .. _references-fgw-barycenters:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ Ys = list_to_array(*Ys)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *Ys, *ps)
+
S = len(Cs)
d = Ys[0].shape[1] # dimension on the node features
if p is None:
- p = np.ones(N) / N
-
- Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
- Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
-
- lambdas = np.asarray(lambdas, dtype=np.float64)
+ p = nx.ones(N, type_as=Cs[0]) / N
if fixed_structure:
if init_C is None:
@@ -1037,8 +1612,10 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ C = init_C
else:
if init_C is None:
- xalea = np.random.randn(N, 2)
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
+ C = nx.from_numpy(C, type_as=ps[0])
else:
C = init_C
@@ -1049,13 +1626,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ X = init_X
else:
if init_X is None:
- X = np.zeros((N, d))
+ X = nx.zeros((N, d), type_as=ps[0])
else:
X = init_X
- T = [np.outer(p, q) for q in ps]
+ T = [nx.outer(p, q) for q in ps]
- Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
cpt = 0
err_feature = 1
@@ -1075,20 +1652,19 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
- Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
if not fixed_structure:
if loss_fun == 'square_loss':
T_temp = [t.T for t in T]
- C = update_sructure_matrix(p, lambdas, T_temp, Cs)
+ C = update_structure_matrix(p, lambdas, T_temp, Cs)
T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
- err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
- err_structure = np.linalg.norm(C - Cprev)
-
+ err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
+ err_structure = nx.norm(C - Cprev)
if log:
log_['err_feature'].append(err_feature)
log_['err_structure'].append(err_structure)
@@ -1114,64 +1690,80 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ return X, C
-def update_sructure_matrix(p, lambdas, T, Cs):
- """Updates C according to the L2 Loss kernel with the S Ts couplings.
+def update_structure_matrix(p, lambdas, T, Cs):
+ r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
It is calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
Masses in the targeted barycenter.
lambdas : list of float
- List of the S spaces' weights.
- T : list of S ndarray of shape (ns, N)
- The S Ts couplings calculated at each iteration.
- Cs : list of S ndarray, shape (ns, ns)
- Metric cost matrices.
+ List of the `S` spaces' weights.
+ T : list of S array-like of shape (ns, N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape (ns, ns)
+ Metric cost matrices.
Returns
-------
- C : ndarray, shape (nt, nt)
- Updated C matrix.
+ C : array-like, shape (`nt`, `nt`)
+ Updated :math:`\mathbf{C}` matrix.
"""
- tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
- ppt = np.outer(p, p)
+ p = list_to_array(p)
+ T = list_to_array(*T)
+ Cs = list_to_array(*Cs)
+ nx = get_backend(*Cs, *T, p)
- return np.divide(tmpsum, ppt)
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+ return tmpsum / ppt
def update_feature_matrix(lambdas, Ys, Ts, p):
- """Updates the feature with respect to the S Ts couplings.
+ r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
- in [24] calculated at each iteration
+ in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
masses in the targeted barycenter
lambdas : list of float
- List of the S spaces' weights
- Ts : list of S np.ndarray(ns,N)
- the S Ts couplings calculated at each iteration
- Ys : list of S ndarray, shape(d,ns)
+ List of the `S` spaces' weights
+ Ts : list of S array-like, shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
+ Ys : list of S array-like, shape (d,ns)
The features.
Returns
-------
- X : ndarray, shape (d, N)
+ X : array-like, shape (`d`, `N`)
+
+ .. _references-update-feature-matrix:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
- and Courty Nicolas
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
- p = np.array(1. / p).reshape(-1,)
-
- tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))])
-
+ p = list_to_array(p)
+ Ts = list_to_array(*Ts)
+ Ys = list_to_array(*Ys)
+ nx = get_backend(*Ys, *Ts, p)
+
+ p = 1. / p
+ tmpsum = sum([
+ lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
+ for s in range(len(Ts))
+ ])
return tmpsum
diff --git a/ot/helpers/__init__.py b/ot/helpers/__init__.py new file mode 100644 index 0000000..b948671 --- /dev/null +++ b/ot/helpers/__init__.py @@ -0,0 +1,3 @@ +# Author: Remi Flamary <remi.flamary@unice.fr> +# +# License: MIT License diff --git a/ot/helpers/openmp_helpers.py b/ot/helpers/openmp_helpers.py new file mode 100644 index 0000000..a6ad38b --- /dev/null +++ b/ot/helpers/openmp_helpers.py @@ -0,0 +1,85 @@ +"""Helpers for OpenMP support during the build.""" + +# This code is adapted for a large part from the astropy openmp helpers, which +# can be found at: https://github.com/astropy/extension-helpers/blob/master/extension_helpers/_openmp_helpers.py # noqa + + +import os +import sys +import textwrap +import subprocess + +from distutils.errors import CompileError, LinkError + +from pre_build_helpers import compile_test_program + + +def get_openmp_flag(compiler): + """Get openmp flags for a given compiler""" + + if hasattr(compiler, 'compiler'): + compiler = compiler.compiler[0] + else: + compiler = compiler.__class__.__name__ + + if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler): + omp_flag = ['/Qopenmp'] + elif sys.platform == "win32": + omp_flag = ['/openmp'] + elif sys.platform in ("darwin", "linux") and "icc" in compiler: + omp_flag = ['-qopenmp'] + elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''): + omp_flag = [] + else: + # Default flag for GCC and clang: + omp_flag = ['-fopenmp'] + if sys.platform.startswith("darwin"): + omp_flag += ["-Xpreprocessor", "-lomp"] + return omp_flag + + +def check_openmp_support(): + """Check whether OpenMP test code can be compiled and run""" + + code = textwrap.dedent( + """\ + #include <omp.h> + #include <stdio.h> + int main(void) { + #pragma omp parallel + printf("nthreads=%d\\n", omp_get_num_threads()); + return 0; + } + """) + + extra_preargs = os.getenv('LDFLAGS', None) + if extra_preargs is not None: + extra_preargs = extra_preargs.strip().split(" ") + extra_preargs = [ + flag for flag in extra_preargs + if flag.startswith(('-L', '-Wl,-rpath', '-l'))] + + extra_postargs = get_openmp_flag + + try: + output, compile_flags = compile_test_program( + code, + extra_preargs=extra_preargs, + extra_postargs=extra_postargs + ) + + if output and 'nthreads=' in output[0]: + nthreads = int(output[0].strip().split('=')[1]) + openmp_supported = len(output) == nthreads + elif "PYTHON_CROSSENV" in os.environ: + # Since we can't run the test program when cross-compiling + # assume that openmp is supported if the program can be + # compiled. + openmp_supported = True + else: + openmp_supported = False + + except (CompileError, LinkError, subprocess.CalledProcessError): + openmp_supported = False + compile_flags = [] + return openmp_supported, compile_flags diff --git a/ot/helpers/pre_build_helpers.py b/ot/helpers/pre_build_helpers.py new file mode 100644 index 0000000..93ecd6a --- /dev/null +++ b/ot/helpers/pre_build_helpers.py @@ -0,0 +1,87 @@ +"""Helpers to check build environment before actual build of POT""" + +import os +import sys +import glob +import tempfile +import setuptools # noqa +import subprocess + +from distutils.dist import Distribution +from distutils.sysconfig import customize_compiler +from numpy.distutils.ccompiler import new_compiler +from numpy.distutils.command.config_compiler import config_cc + + +def _get_compiler(): + """Get a compiler equivalent to the one that will be used to build POT + Handles compiler specified as follows: + - python setup.py build_ext --compiler=<compiler> + - CC=<compiler> python setup.py build_ext + """ + dist = Distribution({'script_name': os.path.basename(sys.argv[0]), + 'script_args': sys.argv[1:], + 'cmdclass': {'config_cc': config_cc}}) + + cmd_opts = dist.command_options.get('build_ext') + if cmd_opts is not None and 'compiler' in cmd_opts: + compiler = cmd_opts['compiler'][1] + else: + compiler = None + + ccompiler = new_compiler(compiler=compiler) + customize_compiler(ccompiler) + + return ccompiler + + +def compile_test_program(code, extra_preargs=[], extra_postargs=[]): + """Check that some C code can be compiled and run""" + ccompiler = _get_compiler() + + # extra_(pre/post)args can be a callable to make it possible to get its + # value from the compiler + if callable(extra_preargs): + extra_preargs = extra_preargs(ccompiler) + if callable(extra_postargs): + extra_postargs = extra_postargs(ccompiler) + + start_dir = os.path.abspath('.') + + with tempfile.TemporaryDirectory() as tmp_dir: + try: + os.chdir(tmp_dir) + + # Write test program + with open('test_program.c', 'w') as f: + f.write(code) + + os.mkdir('objects') + + # Compile, test program + ccompiler.compile(['test_program.c'], output_dir='objects', + extra_postargs=extra_postargs) + + # Link test program + objects = glob.glob( + os.path.join('objects', '*' + ccompiler.obj_extension)) + ccompiler.link_executable(objects, 'test_program', + extra_preargs=extra_preargs, + extra_postargs=extra_postargs) + + if "PYTHON_CROSSENV" not in os.environ: + # Run test program if not cross compiling + # will raise a CalledProcessError if return code was non-zero + output = subprocess.check_output('./test_program') + output = output.decode( + sys.stdout.encoding or 'utf-8').splitlines() + else: + # Return an empty output if we are cross compiling + # as we cannot run the test_program + output = [] + except Exception: + raise + finally: + os.chdir(start_dir) + + return output, extra_postargs diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index c0fe7a3..8a1f9ac 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -18,19 +18,18 @@ #include <iostream> #include <vector> -#include "network_simplex_simple.h" -using namespace lemon; typedef unsigned int node_id_type; enum ProblemType { INFEASIBLE, OPTIMAL, UNBOUNDED, - MAX_ITER_REACHED + MAX_ITER_REACHED }; int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter); +int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads); diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index bc873ed..2bdc172 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -12,16 +12,22 @@ * */ + +#include "network_simplex_simple.h" +#include "network_simplex_simple_omp.h" #include "EMD.h" +#include <cstdint> int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) { - // beware M and C anre strored in row major C style!!! - int n, m, i, cur; + // beware M and C are stored in row major C style!!! + + using namespace lemon; + int n, m, cur; typedef FullBipartiteDigraph Digraph; - DIGRAPH_TYPEDEFS(FullBipartiteDigraph); + DIGRAPH_TYPEDEFS(Digraph); // Get the number of non zero coordinates for r and c n=0; @@ -48,7 +54,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, std::vector<int> indI(n), indJ(m); std::vector<double> weights1(n), weights2(m); Digraph di(n, m); - NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); + NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter); // Set supply and demand, don't account for 0 values (faster) @@ -76,10 +82,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, net.supplyMap(&weights1[0], n, &weights2[0], m); // Set the cost of each edge + int64_t idarc = 0; for (int i=0; i<n; i++) { for (int j=0; j<m; j++) { double val=*(D+indI[i]*n2+indJ[j]); - net.setCost(di.arcFromId(i*m+j), val); + net.setCost(di.arcFromId(idarc), val); + ++idarc; } } @@ -87,12 +95,13 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, // Solve the problem with the network simplex algorithm int ret=net.run(); + int i, j; if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { *cost = 0; Arc a; di.first(a); for (; a != INVALID; di.next(a)) { - int i = di.source(a); - int j = di.target(a); + i = di.source(a); + j = di.target(a); double flow = net.flow(a); *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); *(G+indI[i]*n2+indJ[j-n]) = flow; @@ -106,3 +115,104 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, return ret; } + + + + + + +int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G, + double* alpha, double* beta, double *cost, int maxIter, int numThreads) { + // beware M and C are stored in row major C style!!! + + using namespace lemon_omp; + int n, m, cur; + + typedef FullBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(Digraph); + + // Get the number of non zero coordinates for r and c + n=0; + for (int i=0; i<n1; i++) { + double val=*(X+i); + if (val>0) { + n++; + }else if(val<0){ + return INFEASIBLE; + } + } + m=0; + for (int i=0; i<n2; i++) { + double val=*(Y+i); + if (val>0) { + m++; + }else if(val<0){ + return INFEASIBLE; + } + } + + // Define the graph + + std::vector<int> indI(n), indJ(m); + std::vector<double> weights1(n), weights2(m); + Digraph di(n, m); + NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads); + + // Set supply and demand, don't account for 0 values (faster) + + cur=0; + for (int i=0; i<n1; i++) { + double val=*(X+i); + if (val>0) { + weights1[ cur ] = val; + indI[cur++]=i; + } + } + + // Demand is actually negative supply... + + cur=0; + for (int i=0; i<n2; i++) { + double val=*(Y+i); + if (val>0) { + weights2[ cur ] = -val; + indJ[cur++]=i; + } + } + + + net.supplyMap(&weights1[0], n, &weights2[0], m); + + // Set the cost of each edge + int64_t idarc = 0; + for (int i=0; i<n; i++) { + for (int j=0; j<m; j++) { + double val=*(D+indI[i]*n2+indJ[j]); + net.setCost(di.arcFromId(idarc), val); + ++idarc; + } + } + + + // Solve the problem with the network simplex algorithm + + int ret=net.run(); + int i, j; + if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { + *cost = 0; + Arc a; di.first(a); + for (; a != INVALID; di.next(a)) { + i = di.source(a); + j = di.target(a); + double flow = net.flow(a); + *cost += flow * (*(D+indI[i]*n2+indJ[j-n])); + *(G+indI[i]*n2+indJ[j-n]) = flow; + *(alpha + indI[i]) = -net.potential(i); + *(beta + indJ[j-n]) = net.potential(j); + } + + } + + + return ret; +} diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 514a607..5da897d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -8,25 +8,50 @@ Solvers for the original linear program OT problem # # License: MIT License +import os import multiprocessing import sys import numpy as np -from scipy.sparse import coo_matrix +import warnings from . import cvx from .cvx import barycenter + # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from ..utils import dist +from .solver_1d import emd_1d, emd2_1d, wasserstein_1d + +from ..utils import dist, list_to_array from ..utils import parmap +from ..backend import get_backend -__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', +__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] +def check_number_threads(numThreads): + """Checks whether or not the requested number of threads has a valid value. + + Parameters + ---------- + numThreads : int or str + The requested number of threads, should either be a strictly positive integer or "max" or None + + Returns + ------- + numThreads : int + Corrected number of threads + """ + if (numThreads is None) or (isinstance(numThreads, str) and numThreads.lower() == 'max'): + return -1 + if (not isinstance(numThreads, int)) or numThreads < 1: + raise ValueError('numThreads should either be "max" or a strictly positive integer') + return numThreads + + def center_ot_dual(alpha0, beta0, a=None, b=None): - r"""Center dual OT potentials w.r.t. theirs weights + r"""Center dual OT potentials w.r.t. their weights The main idea of this function is to find unique dual potentials that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having @@ -37,7 +62,7 @@ def center_ot_dual(alpha0, beta0, a=None, b=None): is the following: .. math:: - \alpha^T a= \beta^T b + \alpha^T \mathbf{a} = \beta^T \mathbf{b} in addition to the OT problem constraints. @@ -45,11 +70,11 @@ def center_ot_dual(alpha0, beta0, a=None, b=None): a constant from both :math:`\alpha_0` and :math:`\beta_0`. .. math:: - c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta} + c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}} - \alpha=\alpha_0+c + \alpha &= \alpha_0 + c - \beta=\beta0+c + \beta &= \beta_0 + c Parameters ---------- @@ -92,35 +117,35 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): The feasible values are computed efficiently but rather coarsely. .. warning:: - This function is necessary because the C++ solver in emd_c - discards all samples in the distributions with - zeros weights. This means that while the primal variable (transport + This function is necessary because the C++ solver in `emd_c` + discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport matrix) is exact, the solver only returns feasible dual potentials - on the samples with weights different from zero. + on the samples with weights different from zero. First we compute the constraints violations: .. math:: - V=\alpha+\beta^T-M + \mathbf{V} = \alpha + \beta^T - \mathbf{M} - Next we compute the max amount of violation per row (alpha) and - columns (beta) + Next we compute the max amount of violation per row (:math:`\alpha`) and + columns (:math:`beta`) .. math:: - v^a_i=\max_j V_{i,j} + \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j} - v^b_j=\max_i V_{i,j} + \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j} Finally we update the dual potential with 0 weights if a constraint is violated .. math:: - \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0 + \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0 - \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0 + \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0 In the end the dual potentials are centered using function - :ref:`center_ot_dual`. + :py:func:`ot.lp.center_ot_dual`. Note that all those updates do not change the objective value of the solution but provide dual potentials that do not violate the constraints. @@ -172,54 +197,62 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): return center_ot_dual(alpha, beta, a, b) -def emd(a, b, M, numItermax=100000, log=False, center_dual=True): +def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): r"""Solves the Earth Movers distance problem and returns the OT matrix .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 - \gamma\geq 0 where : - - M is the metric cost matrix - - a and b are the sample weights + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 - format. + .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order + numpy.array in float64 format. It will be converted if not in this + format + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. - Uses the algorithm proposed in [1]_ + Uses the algorithm proposed in :ref:`[1] <references-emd>`. Parameters ---------- - a : (ns,) numpy.ndarray, float64 + a : (ns,) array-like, float Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 + b : (nt,) array-like, float Target histogram (uniform weight if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual - variables. Otherwise returns only the optimal transportation matrix. + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. center_dual: boolean, optional (default=True) If True, centers the dual potential using function :ref:`center_ot_dual`. + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. Returns ------- - gamma: (ns x nt) numpy.ndarray - Optimal transportation matrix for the given parameters - log: dict - If input log is true, a dictionary containing the cost and dual - variables and exit status + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status Examples @@ -232,26 +265,39 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] - >>> ot.emd(a,b,M) + >>> ot.emd(a, b, M) array([[0.5, 0. ], [0. , 0.5]]) + + .. _references-emd: References ---------- - - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. - (2011, December). Displacement interpolation using Lagrangian mass - transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. - 158). ACM. + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, + December). Displacement interpolation using Lagrangian mass transport. + In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. See Also -------- ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT""" + ot.optim.cg : General regularized OT + """ + + # convert to numpy if list + a, b, M = list_to_array(a, b, M) + + a0, b0, M0 = a, b, M + nx = get_backend(M0, a0, b0) + # convert to numpy + M = nx.to_numpy(M) + a = nx.to_numpy(a) + b = nx.to_numpy(b) + + # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order='C') # if empty array given then use uniform distributions if len(a) == 0: @@ -262,81 +308,91 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0), + b.sum(0), err_msg='a and b vector must have the same sum') + b = b * a.sum() / b.sum() + asel = a != 0 bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + numThreads = check_number_threads(numThreads) + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) if center_dual: u, v = center_ot_dual(u, v, a, b) if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) - + result_code_string = check_result(result_code) if log: log = {} log['cost'] = cost - log['u'] = u - log['v'] = v + log['u'] = nx.from_numpy(u, type_as=a0) + log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code - return G, log - return G + return nx.from_numpy(G, type_as=M0), log + return nx.from_numpy(G, type_as=M0) -def emd2(a, b, M, processes=multiprocessing.cpu_count(), +def emd2(a, b, M, processes=1, numItermax=100000, log=False, return_matrix=False, - center_dual=True): + center_dual=True, numThreads=1): r"""Solves the Earth Movers distance problem and returns the loss .. math:: - \min_\gamma <\gamma,M>_F + \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + s.t. \ \gamma \mathbf{1} = \mathbf{a} - s.t. \gamma 1 = a + \gamma^T \mathbf{1} = \mathbf{b} - \gamma^T 1= b + \gamma \geq 0 - \gamma\geq 0 where : - - M is the metric cost matrix - - a and b are the sample weights + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 - format. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. - Uses the algorithm proposed in [1]_ + Uses the algorithm proposed in :ref:`[1] <references-emd2>`. Parameters ---------- - a : (ns,) numpy.ndarray, float64 + a : (ns,) array-like, float64 Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 + b : (nt,) array-like, float64 Target histogram (uniform weight if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) - processes : int, optional (default=nb cpu) - Nb of processes used for multiple emd computation (not used on windows) + M : (ns,nt) array-like, float64 + Loss matrix (for numpy c-order array with type float64) + processes : int, optional (default=1) + Nb of processes used for multiple emd computation (deprecated) numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. log: boolean, optional (default=False) - If True, returns a dictionary containing the cost and dual + If True, returns a dictionary containing dual variables. Otherwise returns only the optimal transportation cost. return_matrix: boolean, optional (default=False) If True, returns the optimal transportation matrix in the log. center_dual: boolean, optional (default=True) If True, centers the dual potential using function :ref:`center_ot_dual`. + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. Returns ------- - gamma: (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log: dictnp - If input log is true, a dictionary containing the cost and dual + W: float, array-like + Optimal transportation loss for the given parameters + log: dict + If input log is true, a dictionary containing dual variables and exit status @@ -354,9 +410,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), >>> ot.emd2(a,b,M) 0.0 + + .. _references-emd2: References ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. @@ -365,15 +422,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), See Also -------- ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT""" + ot.optim.cg : General regularized OT + """ + + a, b, M = list_to_array(a, b, M) + + a0, b0, M0 = a, b, M + nx = get_backend(M0, a0, b0) + + # convert to numpy + M = nx.to_numpy(M) + a = nx.to_numpy(a) + b = nx.to_numpy(b) a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) - - # problem with pikling Forks - if sys.platform.endswith('win32'): - processes = 1 + M = np.asarray(M, dtype=np.float64, order='C') # if empty array given then use uniform distributions if len(a) == 0: @@ -386,11 +450,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), asel = a != 0 + numThreads = check_number_threads(numThreads) + if log or return_matrix: def f(b): bsel = b != 0 - - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) if center_dual: u, v = center_ot_dual(u, v, a, b) @@ -400,17 +466,20 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), result_code_string = check_result(result_code) log = {} + G = nx.from_numpy(G, type_as=M0) if return_matrix: log['G'] = G - log['u'] = u - log['v'] = v + log['u'] = nx.from_numpy(u, type_as=a0) + log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0, b0, M0), (log['u'], log['v'], G)) return [cost, log] else: def f(b): bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) if center_dual: u, v = center_ot_dual(u, v, a, b) @@ -418,6 +487,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) + G = nx.from_numpy(G, type_as=M0) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0, b0, M0), (nx.from_numpy(u, type_as=a0), + nx.from_numpy(v, type_as=b0), G)) + check_result(result_code) return cost @@ -426,35 +500,53 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), nb = b.shape[1] if processes > 1: - res = parmap(f, [b[:, i] for i in range(nb)], processes) - else: - res = list(map(f, [b[:, i].copy() for i in range(nb)])) + warnings.warn( + "The 'processes' parameter has been deprecated. " + "Multiprocessing should be done outside of POT." + ) + res = list(map(f, [b[:, i].copy() for i in range(nb)])) return res def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, - stopThr=1e-7, verbose=False, log=None): - """ - Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) + stopThr=1e-7, verbose=False, log=None, numThreads=1): + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: + + .. math:: + \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` + - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations + - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + + This problem is considered in :ref:`[1] <references-free-support-barycenter>` (Algorithm 2). + There are two differences with the following codes: - The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. - This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: - we do not optimize over the weights - - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. + - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in + :ref:`[1] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete + implementation of the fixed-point algorithm of + :ref:`[2] <references-free-support-barycenter>` proposed in the continuous setting. Parameters ---------- - measures_locations : list of (k_i,d) numpy.ndarray - The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) - measures_weights : list of (k_i,) numpy.ndarray - Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure + measures_locations : list of N (k_i,d) numpy.ndarray + The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space + (:math:`k_i` can be different for each element of the list) + measures_weights : list of N (k_i,) numpy.ndarray + Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one + representing the weights of each discrete input measure X_init : (k,d) np.ndarray - Initialization of the support locations (on k atoms) of the barycenter + Initialization of the support locations (on `k` atoms) of the barycenter b : (k,) np.ndarray Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (k,) np.ndarray + weights : (N,) np.ndarray Initialization of the coefficients of the barycenter (non-negatives, sum to 1) numItermax : int, optional @@ -465,15 +557,20 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None Print information along iterations log : bool, optional record log if True + numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) + If compiled with OpenMP, chooses the number of threads to parallelize. + "max" selects the highest number possible. + Returns ------- X : (k,d) np.ndarray Support locations (on k atoms) of the barycenter + + .. _references-free-support-barycenter: References ---------- - .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. @@ -504,7 +601,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()): M_i = dist(X, measure_locations_i) - T_i = emd(b, measure_weights_i, M_i) + T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) displacement_square_norm = np.sum(np.square(T_sum - X)) @@ -523,287 +620,3 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X, log_dict else: return X - - -def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False): - r"""Solves the Earth Movers distance problem between 1d measures and returns - the OT matrix - - - .. math:: - \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - where : - - - d is the metric - - x_a and x_b are the samples - - a and b are the sample weights - - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. - p: float, optional (default=1.0) - The p-norm to apply for if metric='minkowski' - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. - log: boolean, optional (default=False) - If True, returns a dictionary containing the cost. - Otherwise returns only the optimal transportation matrix. - - Returns - ------- - gamma: (ns, nt) ndarray - Optimal transportation matrix for the given parameters - log: dict - If input log is True, a dictionary containing the cost - - - Examples - -------- - - Simple example with obvious solution. The function emd_1d accepts lists and - performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.emd_1d(x_a, x_b, a, b) - array([[0. , 0.5], - [0.5, 0. ]]) - >>> ot.emd_1d(x_a, x_b) - array([[0. , 0.5], - [0.5, 0. ]]) - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd : EMD for multidimensional distributions - ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the - transportation matrix) - """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - x_a = np.asarray(x_a, dtype=np.float64) - x_b = np.asarray(x_b, dtype=np.float64) - - assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - - # if empty array given then use uniform distributions - if a.ndim == 0 or len(a) == 0: - a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0] - if b.ndim == 0 or len(b) == 0: - b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] - - x_a_1d = x_a.reshape((-1,)) - x_b_1d = x_b.reshape((-1,)) - perm_a = np.argsort(x_a_1d) - perm_b = np.argsort(x_b_1d) - - G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b], - x_a_1d[perm_a], x_b_1d[perm_b], - metric=metric, p=p) - G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])), - shape=(a.shape[0], b.shape[0])) - if dense: - G = G.toarray() - if log: - log = {'cost': cost} - return G, log - return G - - -def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False): - r"""Solves the Earth Movers distance problem between 1d measures and returns - the loss - - - .. math:: - \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - where : - - - d is the metric - - x_a and x_b are the samples - - a and b are the sample weights - - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. - p: float, optional (default=1.0) - The p-norm to apply for if metric='minkowski' - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Only used if log is set to True. Due to implementation details, - this function runs faster when dense is set to False. - log: boolean, optional (default=False) - If True, returns a dictionary containing the transportation matrix. - Otherwise returns only the loss. - - Returns - ------- - loss: float - Cost associated to the optimal transportation - log: dict - If input log is True, a dictionary containing the Optimal transportation - matrix for the given parameters - - - Examples - -------- - - Simple example with obvious solution. The function emd2_1d accepts lists and - performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.emd2_1d(x_a, x_b, a, b) - 0.5 - >>> ot.emd2_1d(x_a, x_b) - 0.5 - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd2 : EMD for multidimensional distributions - ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix - instead of the cost) - """ - # If we do not return G (log==False), then we should not to cast it to dense - # (useless overhead) - G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, - dense=dense and log, log=True) - cost = log_emd['cost'] - if log: - log_emd = {'G': G} - return cost, log_emd - return cost - - -def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.): - r"""Solves the p-Wasserstein distance problem between 1d measures and returns - the distance - - .. math:: - \min_\gamma \left( \sum_i \sum_j \gamma_{ij} \|x_a[i] - x_b[j]\|^p \right)^{1/p} - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - - where : - - - x_a and x_b are the samples - - a and b are the sample weights - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - p: float, optional (default=1.0) - The order of the p-Wasserstein distance to be computed - - Returns - ------- - dist: float - p-Wasserstein distance - - - Examples - -------- - - Simple example with obvious solution. The function wasserstein_1d accepts - lists and performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.wasserstein_1d(x_a, x_b, a, b) - 0.5 - >>> ot.wasserstein_1d(x_a, x_b) - 0.5 - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd_1d : EMD for 1d distributions - """ - cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p, - dense=False, log=False) - return np.power(cost_emd, 1. / p) diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 8e763be..869d450 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -27,7 +27,7 @@ def scipy_sparse_to_spmatrix(A): def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): - """Compute the Wasserstein barycenter of distributions A + r"""Compute the Wasserstein barycenter of distributions A The function solves the following optimization problem [16]: @@ -76,7 +76,6 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924. - """ if weights is None: diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index c167964..42e08f4 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -20,6 +20,7 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil + int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -38,7 +39,7 @@ def check_result(result_code): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads): """ Solves the Earth Movers distance problem and returns the optimal transport matrix @@ -97,8 +98,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0]) cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0) - cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int) - cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int) if not len(a): a=np.ones((n1,))/n1 @@ -111,8 +110,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod # calling the function with nogil: - result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) - + if numThreads == 1: + result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter) + else: + result_code = EMD_wrap_omp(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, numThreads) return G, cost, alpha, beta, result_code @@ -157,22 +158,22 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cost associated to the optimal transportation """ cdef double cost = 0. - cdef int n = u_weights.shape[0] - cdef int m = v_weights.shape[0] + cdef Py_ssize_t n = u_weights.shape[0] + cdef Py_ssize_t m = v_weights.shape[0] - cdef int i = 0 + cdef Py_ssize_t i = 0 cdef double w_i = u_weights[0] - cdef int j = 0 + cdef Py_ssize_t j = 0 cdef double w_j = v_weights[0] cdef double m_ij = 0. cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ), dtype=np.float64) - cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), - dtype=np.int) - cdef int cur_idx = 0 - while i < n and j < m: + cdef np.ndarray[long long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), + dtype=np.int64) + cdef Py_ssize_t cur_idx = 0 + while True: if metric == 'sqeuclidean': m_ij = (u[i] - v[j]) * (u[i] - v[j]) elif metric == 'cityblock' or metric == 'euclidean': @@ -188,6 +189,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, indices[cur_idx, 0] = i indices[cur_idx, 1] = j i += 1 + if i == n: + break w_j -= w_i w_i = u_weights[i] else: @@ -196,7 +199,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, indices[cur_idx, 0] = i indices[cur_idx, 1] = j j += 1 + if j == m: + break w_i -= w_j w_j = v_weights[j] cur_idx += 1 + cur_idx += 1 return G[:cur_idx], indices[:cur_idx], cost diff --git a/ot/lp/full_bipartitegraph.h b/ot/lp/full_bipartitegraph.h index 87a1bec..713ccb5 100644 --- a/ot/lp/full_bipartitegraph.h +++ b/ot/lp/full_bipartitegraph.h @@ -23,10 +23,10 @@ * */ -#ifndef LEMON_FULL_BIPARTITE_GRAPH_H -#define LEMON_FULL_BIPARTITE_GRAPH_H +#pragma once #include "core.h" +#include <cstdint> ///\ingroup graphs ///\file @@ -44,16 +44,16 @@ namespace lemon { //class Node; typedef int Node; //class Arc; - typedef long long Arc; + typedef int64_t Arc; protected: int _node_num; - long long _arc_num; + int64_t _arc_num; FullBipartiteDigraphBase() {} - void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = n1 * n2; _n1=n1; _n2=n2;} + void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;} public: @@ -65,25 +65,25 @@ namespace lemon { Arc arc(const Node& s, const Node& t) const { if (s<_n1 && t>=_n1) - return Arc(s * _n2 + (t-_n1) ); + return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) ); else return Arc(-1); } int nodeNum() const { return _node_num; } - long long arcNum() const { return _arc_num; } + int64_t arcNum() const { return _arc_num; } int maxNodeId() const { return _node_num - 1; } - long long maxArcId() const { return _arc_num - 1; } + int64_t maxArcId() const { return _arc_num - 1; } Node source(Arc arc) const { return arc / _n2; } Node target(Arc arc) const { return (arc % _n2) + _n1; } static int id(Node node) { return node; } - static long long id(Arc arc) { return arc; } + static int64_t id(Arc arc) { return arc; } static Node nodeFromId(int id) { return Node(id);} - static Arc arcFromId(int id) { return Arc(id);} + static Arc arcFromId(int64_t id) { return Arc(id);} Arc findArc(Node s, Node t, Arc prev = -1) const { @@ -136,7 +136,7 @@ namespace lemon { /// /// \brief A directed full graph class. /// - /// FullBipartiteDigraph is a simple and fast implmenetation of directed full + /// FullBipartiteDigraph is a simple and fast implementation of directed full /// (complete) graphs. It contains an arc from each node to each node /// (including a loop for each node), therefore the number of arcs /// is the square of the number of nodes. @@ -203,13 +203,10 @@ namespace lemon { /// \brief Number of nodes. int nodeNum() const { return Parent::nodeNum(); } /// \brief Number of arcs. - long long arcNum() const { return Parent::arcNum(); } + int64_t arcNum() const { return Parent::arcNum(); } }; } //namespace lemon - - -#endif //LEMON_FULL_GRAPH_H diff --git a/ot/lp/full_bipartitegraph_omp.h b/ot/lp/full_bipartitegraph_omp.h new file mode 100644 index 0000000..8cbed0b --- /dev/null +++ b/ot/lp/full_bipartitegraph_omp.h @@ -0,0 +1,234 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- + * + * This file has been adapted by Nicolas Bonneel (2013), + * from full_graph.h from LEMON, a generic C++ optimization library, + * to implement a lightweight fully connected bipartite graph. A previous + * version of this file is used as part of the Displacement Interpolation + * project, + * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/ + * + * + **** Original file Copyright Notice : + * Copyright (C) 2003-2010 + * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport + * (Egervary Research Group on Combinatorial Optimization, EGRES). + * + * Permission to use, modify and distribute this software is granted + * provided that this copyright notice appears in all copies. For + * precise terms see the accompanying LICENSE file. + * + * This software is provided "AS IS" with no warranty of any kind, + * express or implied, and with no claim as to its suitability for any + * purpose. + * + */ + +#pragma once + +#include <cstdint> + +///\ingroup graphs +///\file +///\brief FullBipartiteDigraph and FullBipartiteGraph classes. + + +namespace lemon_omp { + + ///This \c \#define creates convenient type definitions for the following + ///types of \c Digraph: \c Node, \c NodeIt, \c Arc, \c ArcIt, \c InArcIt, + ///\c OutArcIt, \c BoolNodeMap, \c IntNodeMap, \c DoubleNodeMap, + ///\c BoolArcMap, \c IntArcMap, \c DoubleArcMap. + /// + ///\note If the graph type is a dependent type, ie. the graph type depend + ///on a template parameter, then use \c TEMPLATE_DIGRAPH_TYPEDEFS() + ///macro. +#define DIGRAPH_TYPEDEFS(Digraph) \ + typedef Digraph::Node Node; \ + typedef Digraph::Arc Arc; \ + + + ///Create convenience typedefs for the digraph types and iterators + + ///\see DIGRAPH_TYPEDEFS + /// + ///\note Use this macro, if the graph type is a dependent type, + ///ie. the graph type depend on a template parameter. +#define TEMPLATE_DIGRAPH_TYPEDEFS(Digraph) \ + typedef typename Digraph::Node Node; \ + typedef typename Digraph::Arc Arc; \ + + + class FullBipartiteDigraphBase { + public: + + typedef FullBipartiteDigraphBase Digraph; + + //class Node; + typedef int Node; + //class Arc; + typedef int64_t Arc; + + protected: + + int _node_num; + int64_t _arc_num; + + FullBipartiteDigraphBase() {} + + void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;} + + public: + + int _n1, _n2; + + + Node operator()(int ix) const { return Node(ix); } + static int index(const Node& node) { return node; } + + Arc arc(const Node& s, const Node& t) const { + if (s<_n1 && t>=_n1) + return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) ); + else + return Arc(-1); + } + + int nodeNum() const { return _node_num; } + int64_t arcNum() const { return _arc_num; } + + int maxNodeId() const { return _node_num - 1; } + int64_t maxArcId() const { return _arc_num - 1; } + + Node source(Arc arc) const { return arc / _n2; } + Node target(Arc arc) const { return (arc % _n2) + _n1; } + + static int id(Node node) { return node; } + static int64_t id(Arc arc) { return arc; } + + static Node nodeFromId(int id) { return Node(id);} + static Arc arcFromId(int64_t id) { return Arc(id);} + + + Arc findArc(Node s, Node t, Arc prev = -1) const { + return prev == -1 ? arc(s, t) : -1; + } + + void first(Node& node) const { + node = _node_num - 1; + } + + static void next(Node& node) { + --node; + } + + void first(Arc& arc) const { + arc = _arc_num - 1; + } + + static void next(Arc& arc) { + --arc; + } + + void firstOut(Arc& arc, const Node& node) const { + if (node>=_n1) + arc = -1; + else + arc = (node + 1) * _n2 - 1; + } + + void nextOut(Arc& arc) const { + if (arc % _n2 == 0) arc = 0; + --arc; + } + + void firstIn(Arc& arc, const Node& node) const { + if (node<_n1) + arc = -1; + else + arc = _arc_num + node - _node_num; + } + + void nextIn(Arc& arc) const { + arc -= _n2; + if (arc < 0) arc = -1; + } + + }; + + /// \ingroup graphs + /// + /// \brief A directed full graph class. + /// + /// FullBipartiteDigraph is a simple and fast implmenetation of directed full + /// (complete) graphs. It contains an arc from each node to each node + /// (including a loop for each node), therefore the number of arcs + /// is the square of the number of nodes. + /// This class is completely static and it needs constant memory space. + /// Thus you can neither add nor delete nodes or arcs, however + /// the structure can be resized using resize(). + /// + /// This type fully conforms to the \ref concepts::Digraph "Digraph concept". + /// Most of its member functions and nested classes are documented + /// only in the concept class. + /// + /// This class provides constant time counting for nodes and arcs. + /// + /// \note FullBipartiteDigraph and FullBipartiteGraph classes are very similar, + /// but there are two differences. While this class conforms only + /// to the \ref concepts::Digraph "Digraph" concept, FullBipartiteGraph + /// conforms to the \ref concepts::Graph "Graph" concept, + /// moreover FullBipartiteGraph does not contain a loop for each + /// node as this class does. + /// + /// \sa FullBipartiteGraph + class FullBipartiteDigraph : public FullBipartiteDigraphBase { + typedef FullBipartiteDigraphBase Parent; + + public: + + /// \brief Default constructor. + /// + /// Default constructor. The number of nodes and arcs will be zero. + FullBipartiteDigraph() { construct(0,0); } + + /// \brief Constructor + /// + /// Constructor. + /// \param n The number of the nodes. + FullBipartiteDigraph(int n1, int n2) { construct(n1, n2); } + + + /// \brief Returns the node with the given index. + /// + /// Returns the node with the given index. Since this structure is + /// completely static, the nodes can be indexed with integers from + /// the range <tt>[0..nodeNum()-1]</tt>. + /// The index of a node is the same as its ID. + /// \sa index() + Node operator()(int ix) const { return Parent::operator()(ix); } + + /// \brief Returns the index of the given node. + /// + /// Returns the index of the given node. Since this structure is + /// completely static, the nodes can be indexed with integers from + /// the range <tt>[0..nodeNum()-1]</tt>. + /// The index of a node is the same as its ID. + /// \sa operator()() + static int index(const Node& node) { return Parent::index(node); } + + /// \brief Returns the arc connecting the given nodes. + /// + /// Returns the arc connecting the given nodes. + /*Arc arc(Node u, Node v) const { + return Parent::arc(u, v); + }*/ + + /// \brief Number of nodes. + int nodeNum() const { return Parent::nodeNum(); } + /// \brief Number of arcs. + int64_t arcNum() const { return Parent::arcNum(); } + }; + + + + +} //namespace lemon_omp diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 5d93040..3b46b9b 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -25,15 +25,17 @@ * */ -#ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H -#define LEMON_NETWORK_SIMPLEX_SIMPLE_H +#pragma once +#undef DEBUG_LVL #define DEBUG_LVL 0 #if DEBUG_LVL>0 #include <iomanip> #endif - +#undef EPSILON +#undef _EPSILON +#undef MAX_DEBUG_ITER #define EPSILON 2.2204460492503131e-15 #define _EPSILON 1e-8 #define MAX_DEBUG_ITER 100000 @@ -50,6 +52,7 @@ #include <vector> #include <limits> #include <algorithm> +#include <iostream> #include <cstdio> #ifdef HASHMAP #include <hash_map> @@ -63,6 +66,8 @@ //#include "sparse_array_n.h" #include "full_bipartitegraph.h" +#undef INVALIDNODE +#undef INVALID #define INVALIDNODE -1 #define INVALID (-1) @@ -76,16 +81,16 @@ namespace lemon { class SparseValueVector { public: - SparseValueVector(int n=0) + SparseValueVector(size_t n=0) { } - void resize(int n=0){}; - T operator[](const int id) const + void resize(size_t n=0){}; + T operator[](const size_t id) const { #ifdef HASHMAP - typename stdext::hash_map<int,T>::const_iterator it = data.find(id); + typename stdext::hash_map<size_t,T>::const_iterator it = data.find(id); #else - typename std::map<int,T>::const_iterator it = data.find(id); + typename std::map<size_t,T>::const_iterator it = data.find(id); #endif if (it==data.end()) return 0; @@ -93,16 +98,16 @@ namespace lemon { return it->second; } - ProxyObject<T> operator[](const int id) + ProxyObject<T> operator[](const size_t id) { return ProxyObject<T>( this, id ); } //private: #ifdef HASHMAP - stdext::hash_map<int,T> data; + stdext::hash_map<size_t,T> data; #else - std::map<int,T> data; + std::map<size_t,T> data; #endif }; @@ -110,7 +115,7 @@ namespace lemon { template <typename T> class ProxyObject { public: - ProxyObject( SparseValueVector<T> *v, int idx ){_v=v; _idx=idx;}; + ProxyObject( SparseValueVector<T> *v, size_t idx ){_v=v; _idx=idx;}; ProxyObject<T> & operator=( const T &v ) { // If we get here, we know that operator[] was called to perform a write access, // so we can insert an item in the vector if needed @@ -123,9 +128,9 @@ namespace lemon { // If we get here, we know that operator[] was called to perform a read access, // so we can simply return the existing object #ifdef HASHMAP - typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx); + typename stdext::hash_map<size_t,T>::iterator it = _v->data.find(_idx); #else - typename std::map<int,T>::iterator it = _v->data.find(_idx); + typename std::map<size_t,T>::iterator it = _v->data.find(_idx); #endif if (it==_v->data.end()) return 0; @@ -137,9 +142,9 @@ namespace lemon { { if (val==0) return; #ifdef HASHMAP - typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx); + typename stdext::hash_map<size_t,T>::iterator it = _v->data.find(_idx); #else - typename std::map<int,T>::iterator it = _v->data.find(_idx); + typename std::map<size_t,T>::iterator it = _v->data.find(_idx); #endif if (it==_v->data.end()) _v->data[_idx] = val; @@ -156,9 +161,9 @@ namespace lemon { { if (val==0) return; #ifdef HASHMAP - typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx); + typename stdext::hash_map<size_t,T>::iterator it = _v->data.find(_idx); #else - typename std::map<int,T>::iterator it = _v->data.find(_idx); + typename std::map<size_t,T>::iterator it = _v->data.find(_idx); #endif if (it==_v->data.end()) _v->data[_idx] = -val; @@ -173,7 +178,7 @@ namespace lemon { } SparseValueVector<T> *_v; - int _idx; + size_t _idx; }; @@ -204,7 +209,7 @@ namespace lemon { /// /// \tparam GR The digraph type the algorithm runs on. /// \tparam V The number type used for flow amounts, capacity bounds - /// and supply values in the algorithm. By default, it is \c int. + /// and supply values in the algorithm. By default, it is \c int64_t. /// \tparam C The number type used for costs and potentials in the /// algorithm. By default, it is the same as \c V. /// @@ -214,7 +219,7 @@ namespace lemon { /// \note %NetworkSimplexSimple provides five different pivot rule /// implementations, from which the most efficient one is used /// by default. For more information, see \ref PivotRule. - template <typename GR, typename V = int, typename C = V, typename NodesType = unsigned short int> + template <typename GR, typename V = int, typename C = V, typename NodesType = unsigned short int, typename ArcsType = int64_t> class NetworkSimplexSimple { public: @@ -228,7 +233,7 @@ namespace lemon { /// mixed order in the internal data structure. /// In special cases, it could lead to better overall performance, /// but it is usually slower. Therefore it is disabled by default. - NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) : + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) : _graph(graph), //_arc_id(graph), _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), MAX(std::numeric_limits<Value>::max()), @@ -288,11 +293,11 @@ namespace lemon { private: - int max_iter; + size_t max_iter; TEMPLATE_DIGRAPH_TYPEDEFS(GR); typedef std::vector<int> IntVector; - typedef std::vector<NodesType> UHalfIntVector; + typedef std::vector<ArcsType> ArcVector; typedef std::vector<Value> ValueVector; typedef std::vector<Cost> CostVector; // typedef SparseValueVector<Cost> CostVector; @@ -315,9 +320,9 @@ namespace lemon { // Data related to the underlying digraph const GR &_graph; int _node_num; - int _arc_num; - int _all_arc_num; - int _search_arc_num; + ArcsType _arc_num; + ArcsType _all_arc_num; + ArcsType _search_arc_num; // Parameters of the problem SupplyType _stype; @@ -325,9 +330,9 @@ namespace lemon { inline int _node_id(int n) const {return _node_num-n-1;} ; - //IntArcMap _arc_id; - UHalfIntVector _source; - UHalfIntVector _target; +// IntArcMap _arc_id; + IntVector _source; // keep nodes as integers + IntVector _target; bool _arc_mixing; public: // Node and arc data @@ -341,7 +346,7 @@ namespace lemon { private: // Data for storing the spanning tree structure IntVector _parent; - IntVector _pred; + ArcVector _pred; IntVector _thread; IntVector _rev_thread; IntVector _succ_num; @@ -349,17 +354,17 @@ namespace lemon { IntVector _dirty_revs; BoolVector _forward; StateVector _state; - int _root; + ArcsType _root; // Temporary data used in the current pivot iteration - int in_arc, join, u_in, v_in, u_out, v_out; - int first, second, right, last; - int stem, par_stem, new_stem; + ArcsType in_arc, join, u_in, v_in, u_out, v_out; + ArcsType first, second, right, last; + ArcsType stem, par_stem, new_stem; Value delta; const Value MAX; - int mixingCoeff; + ArcsType mixingCoeff; public: @@ -373,27 +378,27 @@ namespace lemon { private: // thank you to DVK and MizardX from StackOverflow for this function! - inline int sequence(int k) const { - int smallv = (k > num_total_big_subsequence_numbers) & 1; + inline ArcsType sequence(ArcsType k) const { + ArcsType smallv = (k > num_total_big_subsequence_numbers) & 1; k -= num_total_big_subsequence_numbers * smallv; - int subsequence_length2 = subsequence_length- smallv; - int subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv; - int subsequence_offset = (k % subsequence_length2) * mixingCoeff; + ArcsType subsequence_length2 = subsequence_length- smallv; + ArcsType subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv; + ArcsType subsequence_offset = (k % subsequence_length2) * mixingCoeff; return subsequence_offset + subsequence_num; } - int subsequence_length; - int num_big_subseqiences; - int num_total_big_subsequence_numbers; + ArcsType subsequence_length; + ArcsType num_big_subseqiences; + ArcsType num_total_big_subsequence_numbers; - inline int getArcID(const Arc &arc) const + inline ArcsType getArcID(const Arc &arc) const { //int n = _arc_num-arc._id-1; - int n = _arc_num-GR::id(arc)-1; + ArcsType n = _arc_num-GR::id(arc)-1; - //int a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; - //int b = _arc_id[arc]; + //ArcsType a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; + //ArcsType b = _arc_id[arc]; if (_arc_mixing) return sequence(n); else @@ -401,16 +406,16 @@ namespace lemon { } // finally unused because too slow - inline int getSource(const int arc) const + inline ArcsType getSource(const ArcsType arc) const { - //int a = _source[arc]; + //ArcsType a = _source[arc]; //return a; - int n = _arc_num-arc-1; + ArcsType n = _arc_num-arc-1; if (_arc_mixing) n = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; - int b; + ArcsType b; if (n>=0) b = _node_id(_graph.source(GR::arcFromId( n ) )); else @@ -436,17 +441,17 @@ namespace lemon { private: // References to the NetworkSimplexSimple class - const UHalfIntVector &_source; - const UHalfIntVector &_target; + const IntVector &_source; + const IntVector &_target; const CostVector &_cost; const StateVector &_state; const CostVector &_pi; - int &_in_arc; - int _search_arc_num; + ArcsType &_in_arc; + ArcsType _search_arc_num; // Pivot rule data - int _block_size; - int _next_arc; + ArcsType _block_size; + ArcsType _next_arc; NetworkSimplexSimple &_ns; public: @@ -460,17 +465,16 @@ namespace lemon { { // The main parameters of the pivot rule const double BLOCK_SIZE_FACTOR = 1.0; - const int MIN_BLOCK_SIZE = 10; + const ArcsType MIN_BLOCK_SIZE = 10; - _block_size = std::max( int(BLOCK_SIZE_FACTOR * - std::sqrt(double(_search_arc_num))), - MIN_BLOCK_SIZE ); + _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE); } + // Find next entering arc bool findEnteringArc() { Cost c, min = 0; - int e; - int cnt = _block_size; + ArcsType e; + ArcsType cnt = _block_size; double a; for (e = _next_arc; e != _search_arc_num; ++e) { c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); @@ -516,7 +520,7 @@ namespace lemon { int _init_nb_nodes; - long long _init_nb_arcs; + ArcsType _init_nb_arcs; /// \name Parameters /// The parameters of the algorithm can be specified using these @@ -736,7 +740,7 @@ namespace lemon { for (int i = 0; i != _node_num; ++i) { _supply[i] = 0; } - for (int i = 0; i != _arc_num; ++i) { + for (ArcsType i = 0; i != _arc_num; ++i) { _cost[i] = 1; } _stype = GEQ; @@ -745,7 +749,7 @@ namespace lemon { - int divid (int x, int y) + int64_t divid (int64_t x, int64_t y) { return (x-x%y)/y; } @@ -775,7 +779,7 @@ namespace lemon { _node_num = _init_nb_nodes; _arc_num = _init_nb_arcs; int all_node_num = _node_num + 1; - int max_arc_num = _arc_num + 2 * _node_num; + ArcsType max_arc_num = _arc_num + 2 * _node_num; _source.resize(max_arc_num); _target.resize(max_arc_num); @@ -798,13 +802,13 @@ namespace lemon { //_arc_mixing=false; if (_arc_mixing) { // Store the arcs in a mixed order - int k = std::max(int(std::sqrt(double(_arc_num))), 10); + const ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10)); mixingCoeff = k; subsequence_length = _arc_num / mixingCoeff + 1; num_big_subseqiences = _arc_num % mixingCoeff; num_total_big_subsequence_numbers = subsequence_length * num_big_subseqiences; - int i = 0, j = 0; + ArcsType i = 0, j = 0; Arc a; _graph.first(a); for (; a != INVALID; _graph.next(a)) { _source[i] = _node_id(_graph.source(a)); @@ -814,7 +818,7 @@ namespace lemon { } } else { // Store the arcs in the original order - int i = 0; + ArcsType i = 0; Arc a; _graph.first(a); for (; a != INVALID; _graph.next(a), ++i) { _source[i] = _node_id(_graph.source(a)); @@ -856,7 +860,7 @@ namespace lemon { Number totalCost() const { Number c = 0; for (ArcIt a(_graph); a != INVALID; ++a) { - int i = getArcID(a); + int64_t i = getArcID(a); c += Number(_flow[i]) * Number(_cost[i]); } return c; @@ -867,15 +871,15 @@ namespace lemon { Number c = 0; /*#ifdef HASHMAP - typename stdext::hash_map<int, Value>::const_iterator it; + typename stdext::hash_map<int64_t, Value>::const_iterator it; #else - typename std::map<int, Value>::const_iterator it; + typename std::map<int64_t, Value>::const_iterator it; #endif for (it = _flow.data.begin(); it!=_flow.data.end(); ++it) c += Number(it->second) * Number(_cost[it->first]); return c;*/ - for (unsigned long i=0; i<_flow.size(); i++) + for (ArcsType i=0; i<_flow.size(); i++) c += _flow[i] * Number(_cost[i]); return c; @@ -944,14 +948,14 @@ namespace lemon { // Initialize internal data structures bool init() { if (_node_num == 0) return false; - + // Check the sum of supply values _sum_supply = 0; for (int i = 0; i != _node_num; ++i) { _sum_supply += _supply[i]; } if ( fabs(_sum_supply) > _EPSILON ) return false; - + _sum_supply = 0; // Initialize artifical cost @@ -960,14 +964,14 @@ namespace lemon { ART_COST = std::numeric_limits<Cost>::max() / 2 + 1; } else { ART_COST = 0; - for (int i = 0; i != _arc_num; ++i) { + for (ArcsType i = 0; i != _arc_num; ++i) { if (_cost[i] > ART_COST) ART_COST = _cost[i]; } ART_COST = (ART_COST + 1) * _node_num; } // Initialize arc maps - for (int i = 0; i != _arc_num; ++i) { + for (ArcsType i = 0; i != _arc_num; ++i) { //_flow[i] = 0; //by default, the sparse matrix is empty _state[i] = STATE_LOWER; } @@ -988,7 +992,7 @@ namespace lemon { // EQ supply constraints _search_arc_num = _arc_num; _all_arc_num = _arc_num + _node_num; - for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { _parent[u] = _root; _pred[u] = e; _thread[u] = u + 1; @@ -1016,8 +1020,8 @@ namespace lemon { else if (_sum_supply > 0) { // LEQ supply constraints _search_arc_num = _arc_num + _node_num; - int f = _arc_num + _node_num; - for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { _parent[u] = _root; _thread[u] = u + 1; _rev_thread[u + 1] = u; @@ -1054,8 +1058,8 @@ namespace lemon { else { // GEQ supply constraints _search_arc_num = _arc_num + _node_num; - int f = _arc_num + _node_num; - for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { _parent[u] = _root; _thread[u] = u + 1; _rev_thread[u + 1] = u; @@ -1120,9 +1124,9 @@ namespace lemon { second = _source[in_arc]; } delta = INF; - int result = 0; + char result = 0; Value d; - int e; + ArcsType e; // Search the cycle along the path form the first node to the root for (int u = first; u != join; u = _parent[u]) { @@ -1239,7 +1243,7 @@ namespace lemon { // Update _rev_thread using the new _thread values for (int i = 0; i != int(_dirty_revs.size()); ++i) { - u = _dirty_revs[i]; + int u = _dirty_revs[i]; _rev_thread[_thread[u]] = u; } @@ -1257,7 +1261,7 @@ namespace lemon { u = w; } _pred[u_in] = in_arc; - _forward[u_in] = ((unsigned int)u_in == _source[in_arc]); + _forward[u_in] = (u_in == _source[in_arc]); _succ_num[u_in] = old_succ_num; // Set limits for updating _last_succ form v_in and v_out @@ -1328,7 +1332,7 @@ namespace lemon { if (_sum_supply > 0) total -= _sum_supply; if (total <= 0) return true; - IntVector arc_vector; + ArcVector arc_vector; if (_sum_supply >= 0) { if (supply_nodes.size() == 1 && demand_nodes.size() == 1) { // Perform a reverse graph search from the sink to the source @@ -1345,7 +1349,7 @@ namespace lemon { Arc a; _graph.firstIn(a, v); for (; a != INVALID; _graph.nextIn(a)) { if (reached[u = _graph.source(a)]) continue; - int j = getArcID(a); + ArcsType j = getArcID(a); if (INF >= total) { arc_vector.push_back(j); reached[u] = true; @@ -1355,7 +1359,7 @@ namespace lemon { } } else { // Find the min. cost incomming arc for each demand node - for (int i = 0; i != int(demand_nodes.size()); ++i) { + for (int i = 0; i != demand_nodes.size(); ++i) { Node v = demand_nodes[i]; Cost c, min_cost = std::numeric_limits<Cost>::max(); Arc min_arc = INVALID; @@ -1393,7 +1397,7 @@ namespace lemon { } // Perform heuristic initial pivots - for (int i = 0; i != int(arc_vector.size()); ++i) { + for (ArcsType i = 0; i != arc_vector.size(); ++i) { in_arc = arc_vector[i]; // l'erreur est probablement ici... if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] - @@ -1423,7 +1427,7 @@ namespace lemon { // Perform heuristic initial pivots if (!initialPivots()) return UNBOUNDED; - int iter_number=0; + size_t iter_number=0; //pivot.setDantzig(true); // Execute the Network Simplex algorithm while (pivot.findEnteringArc()) { @@ -1443,7 +1447,7 @@ namespace lemon { double a; a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int i=0; i<_flow.size(); i++) { + for (int64_t i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; @@ -1482,12 +1486,12 @@ namespace lemon { double a; a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); - for (int i=0; i<_flow.size(); i++) { + for (int64_t i=0; i<_flow.size(); i++) { sumFlow+=_state[i]*_flow[i]; } - + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; - + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; @@ -1505,9 +1509,9 @@ namespace lemon { #endif // Check feasibility if( retVal == OPTIMAL){ - for (int e = _search_arc_num; e != _all_arc_num; ++e) { + for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) { if (_flow[e] != 0){ - if (abs(_flow[e]) > EPSILON) + if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126 return INFEASIBLE; else _flow[e]=0; @@ -1521,20 +1525,20 @@ namespace lemon { if (_sum_supply == 0) { if (_stype == GEQ) { Cost max_pot = -std::numeric_limits<Cost>::max(); - for (int i = 0; i != _node_num; ++i) { + for (ArcsType i = 0; i != _node_num; ++i) { if (_pi[i] > max_pot) max_pot = _pi[i]; } if (max_pot > 0) { - for (int i = 0; i != _node_num; ++i) + for (ArcsType i = 0; i != _node_num; ++i) _pi[i] -= max_pot; } } else { Cost min_pot = std::numeric_limits<Cost>::max(); - for (int i = 0; i != _node_num; ++i) { + for (ArcsType i = 0; i != _node_num; ++i) { if (_pi[i] < min_pot) min_pot = _pi[i]; } if (min_pot < 0) { - for (int i = 0; i != _node_num; ++i) + for (ArcsType i = 0; i != _node_num; ++i) _pi[i] -= min_pot; } } @@ -1548,5 +1552,3 @@ namespace lemon { ///@} } //namespace lemon - -#endif //LEMON_NETWORK_SIMPLEX_H diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h new file mode 100644 index 0000000..87e4c05 --- /dev/null +++ b/ot/lp/network_simplex_simple_omp.h @@ -0,0 +1,1699 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- +* +* +* This file has been adapted by Nicolas Bonneel (2013), +* from network_simplex.h from LEMON, a generic C++ optimization library, +* to implement a lightweight network simplex for mass transport, more +* memory efficient than the original file. A previous version of this file +* is used as part of the Displacement Interpolation project, +* Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/ +* +* Revisions: +* March 2015: added OpenMP parallelization +* March 2017: included Antoine Rolet's trick to make it more robust +* April 2018: IMPORTANT bug fix + uses 64bit integers (slightly slower but less risks of overflows), updated to a newer version of the algo by LEMON, sparse flow by default + minor edits. +* +* +**** Original file Copyright Notice : +* +* Copyright (C) 2003-2010 +* Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport +* (Egervary Research Group on Combinatorial Optimization, EGRES). +* +* Permission to use, modify and distribute this software is granted +* provided that this copyright notice appears in all copies. For +* precise terms see the accompanying LICENSE file. +* +* This software is provided "AS IS" with no warranty of any kind, +* express or implied, and with no claim as to its suitability for any +* purpose. +* +*/ + +#pragma once +#undef DEBUG_LVL +#define DEBUG_LVL 0 + +#if DEBUG_LVL>0 +#include <iomanip> +#endif + +#undef EPSILON +#undef _EPSILON +#undef MAX_DEBUG_ITER +#define EPSILON std::numeric_limits<Cost>::epsilon()*10 +#define _EPSILON 1e-8 +#define MAX_DEBUG_ITER 100000 + +/// \ingroup min_cost_flow_algs +/// +/// \file +/// \brief Network Simplex algorithm for finding a minimum cost flow. + +// if your compiler has troubles with unorderedmaps, just comment the following line to use a slower std::map instead +#define HASHMAP // now handled with unorderedmaps instead of stdext::hash_map. Should be better supported. + +#define SPARSE_FLOW // a sparse flow vector will be 10-15% slower for small problems but uses less memory and becomes faster for large problems (40k total nodes) + +#include <vector> +#include <limits> +#include <algorithm> +#include <iostream> +#ifdef HASHMAP +#include <unordered_map> +#else +#include <map> +#endif +//#include "core.h" +//#include "lmath.h" + +#ifdef OMP +#include <omp.h> +#endif +#include <cmath> + + +//#include "sparse_array_n.h" +#include "full_bipartitegraph_omp.h" + +#undef INVALIDNODE +#undef INVALID +#define INVALIDNODE -1 +#define INVALID (-1) + +namespace lemon_omp { + + int64_t max_threads = -1; + + template <typename T> + class ProxyObject; + + template<typename T> + class SparseValueVector + { + public: + SparseValueVector(size_t n = 0) // parameter n for compatibility with standard vectors + { + } + void resize(size_t n = 0) {}; + T operator[](const size_t id) const + { +#ifdef HASHMAP + typename std::unordered_map<size_t, T>::const_iterator it = data.find(id); +#else + typename std::map<size_t, T>::const_iterator it = data.find(id); +#endif + if (it == data.end()) + return 0; + else + return it->second; + } + + ProxyObject<T> operator[](const size_t id) + { + return ProxyObject<T>(this, id); + } + + //private: +#ifdef HASHMAP + std::unordered_map<size_t, T> data; +#else + std::map<size_t, T> data; +#endif + + }; + + template <typename T> + class ProxyObject { + public: + ProxyObject(SparseValueVector<T> *v, size_t idx) { _v = v; _idx = idx; }; + ProxyObject<T> & operator=(const T &v) { + // If we get here, we know that operator[] was called to perform a write access, + // so we can insert an item in the vector if needed + if (v != 0) + _v->data[_idx] = v; + return *this; + } + + operator T() { + // If we get here, we know that operator[] was called to perform a read access, + // so we can simply return the existing object +#ifdef HASHMAP + typename std::unordered_map<size_t, T>::iterator it = _v->data.find(_idx); +#else + typename std::map<size_t, T>::iterator it = _v->data.find(_idx); +#endif + if (it == _v->data.end()) + return 0; + else + return it->second; + } + + void operator+=(T val) + { + if (val == 0) return; +#ifdef HASHMAP + typename std::unordered_map<size_t, T>::iterator it = _v->data.find(_idx); +#else + typename std::map<size_t, T>::iterator it = _v->data.find(_idx); +#endif + if (it == _v->data.end()) + _v->data[_idx] = val; + else + { + T sum = it->second + val; + if (sum == 0) + _v->data.erase(it); + else + it->second = sum; + } + } + void operator-=(T val) + { + if (val == 0) return; +#ifdef HASHMAP + typename std::unordered_map<size_t, T>::iterator it = _v->data.find(_idx); +#else + typename std::map<size_t, T>::iterator it = _v->data.find(_idx); +#endif + if (it == _v->data.end()) + _v->data[_idx] = -val; + else + { + T sum = it->second - val; + if (sum == 0) + _v->data.erase(it); + else + it->second = sum; + } + } + + SparseValueVector<T> *_v; + size_t _idx; + }; + + + + /// \addtogroup min_cost_flow_algs + /// @{ + + /// \brief Implementation of the primal Network Simplex algorithm + /// for finding a \ref min_cost_flow "minimum cost flow". + /// + /// \ref NetworkSimplexSimple implements the primal Network Simplex algorithm + /// for finding a \ref min_cost_flow "minimum cost flow" + /// \ref amo93networkflows, \ref dantzig63linearprog, + /// \ref kellyoneill91netsimplex. + /// This algorithm is a highly efficient specialized version of the + /// linear programming simplex method directly for the minimum cost + /// flow problem. + /// + /// In general, %NetworkSimplexSimple is the fastest implementation available + /// in LEMON for this problem. + /// Moreover, it supports both directions of the supply/demand inequality + /// constraints. For more information, see \ref SupplyType. + /// + /// Most of the parameters of the problem (except for the digraph) + /// can be given using separate functions, and the algorithm can be + /// executed using the \ref run() function. If some parameters are not + /// specified, then default values will be used. + /// + /// \tparam GR The digraph type the algorithm runs on. + /// \tparam V The number type used for flow amounts, capacity bounds + /// and supply values in the algorithm. By default, it is \c int. + /// \tparam C The number type used for costs and potentials in the + /// algorithm. By default, it is the same as \c V. + /// + /// \warning Both number types must be signed and all input data must + /// be integer. + /// + /// \note %NetworkSimplexSimple provides five different pivot rule + /// implementations, from which the most efficient one is used + /// by default. For more information, see \ref PivotRule. + template <typename GR, typename V = int, typename C = V, typename ArcsType = int64_t> + class NetworkSimplexSimple + { + public: + + /// \brief Constructor. + /// + /// The constructor of the class. + /// + /// \param graph The digraph the algorithm runs on. + /// \param arc_mixing Indicate if the arcs have to be stored in a + /// mixed order in the internal data structure. + /// In special cases, it could lead to better overall performance, + /// but it is usually slower. Therefore it is disabled by default. + NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) : + _graph(graph), //_arc_id(graph), + _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs), + MAX(std::numeric_limits<Value>::max()), + INF(std::numeric_limits<Value>::has_infinity ? + std::numeric_limits<Value>::infinity() : MAX) + { + // Reset data structures + reset(); + max_iter = maxiters; +#ifdef OMP + if (max_threads < 0) { + max_threads = omp_get_max_threads(); + } + if (numThreads > 0 && numThreads<=max_threads){ + num_threads = numThreads; + } else if (numThreads == -1 || numThreads>max_threads) { + num_threads = max_threads; + } else { + num_threads = 1; + } + omp_set_num_threads(num_threads); +#else + num_threads = 1; +#endif + } + + /// The type of the flow amounts, capacity bounds and supply values + typedef V Value; + /// The type of the arc costs + typedef C Cost; + + public: + /// \brief Problem type constants for the \c run() function. + /// + /// Enum type containing the problem type constants that can be + /// returned by the \ref run() function of the algorithm. + enum ProblemType { + /// The problem has no feasible solution (flow). + INFEASIBLE, + /// The problem has optimal solution (i.e. it is feasible and + /// bounded), and the algorithm has found optimal flow and node + /// potentials (primal and dual solutions). + OPTIMAL, + /// The objective function of the problem is unbounded, i.e. + /// there is a directed cycle having negative total cost and + /// infinite upper bound. + UNBOUNDED, + // The maximum number of iteration has been reached + MAX_ITER_REACHED + }; + + /// \brief Constants for selecting the type of the supply constraints. + /// + /// Enum type containing constants for selecting the supply type, + /// i.e. the direction of the inequalities in the supply/demand + /// constraints of the \ref min_cost_flow "minimum cost flow problem". + /// + /// The default supply type is \c GEQ, the \c LEQ type can be + /// selected using \ref supplyType(). + /// The equality form is a special case of both supply types. + enum SupplyType { + /// This option means that there are <em>"greater or equal"</em> + /// supply/demand constraints in the definition of the problem. + GEQ, + /// This option means that there are <em>"less or equal"</em> + /// supply/demand constraints in the definition of the problem. + LEQ + }; + + + + private: + size_t max_iter; + int num_threads; + TEMPLATE_DIGRAPH_TYPEDEFS(GR); + + typedef std::vector<int> IntVector; + typedef std::vector<ArcsType> ArcVector; + typedef std::vector<Value> ValueVector; + typedef std::vector<Cost> CostVector; + // typedef SparseValueVector<Cost> CostVector; + typedef std::vector<char> BoolVector; + // Note: vector<char> is used instead of vector<bool> for efficiency reasons + + // State constants for arcs + enum ArcState { + STATE_UPPER = -1, + STATE_TREE = 0, + STATE_LOWER = 1 + }; + + typedef std::vector<signed char> StateVector; + // Note: vector<signed char> is used instead of vector<ArcState> for + // efficiency reasons + + private: + + // Data related to the underlying digraph + const GR &_graph; + int _node_num; + ArcsType _arc_num; + ArcsType _all_arc_num; + ArcsType _search_arc_num; + + // Parameters of the problem + SupplyType _stype; + Value _sum_supply; + + inline int _node_id(int n) const { return _node_num - n - 1; }; + + //IntArcMap _arc_id; + IntVector _source; // keep nodes as integers + IntVector _target; + bool _arc_mixing; + + // Node and arc data + CostVector _cost; + ValueVector _supply; +#ifdef SPARSE_FLOW + SparseValueVector<Value> _flow; +#else + ValueVector _flow; +#endif + + CostVector _pi; + + // Data for storing the spanning tree structure + IntVector _parent; + ArcVector _pred; + IntVector _thread; + IntVector _rev_thread; + IntVector _succ_num; + IntVector _last_succ; + IntVector _dirty_revs; + BoolVector _forward; + StateVector _state; + ArcsType _root; + + // Temporary data used in the current pivot iteration + ArcsType in_arc, join, u_in, v_in, u_out, v_out; + ArcsType first, second, right, last; + ArcsType stem, par_stem, new_stem; + Value delta; + + const Value MAX; + + ArcsType mixingCoeff; + + public: + + /// \brief Constant for infinite upper bounds (capacities). + /// + /// Constant for infinite upper bounds (capacities). + /// It is \c std::numeric_limits<Value>::infinity() if available, + /// \c std::numeric_limits<Value>::max() otherwise. + const Value INF; + + private: + + // thank you to DVK and MizardX from StackOverflow for this function! + inline ArcsType sequence(ArcsType k) const { + ArcsType smallv = (k > num_total_big_subsequence_numbers) & 1; + + k -= num_total_big_subsequence_numbers * smallv; + ArcsType subsequence_length2 = subsequence_length - smallv; + ArcsType subsequence_num = (k / subsequence_length2) + num_big_subsequences * smallv; + ArcsType subsequence_offset = (k % subsequence_length2) * mixingCoeff; + + return subsequence_offset + subsequence_num; + } + ArcsType subsequence_length; + ArcsType num_big_subsequences; + ArcsType num_total_big_subsequence_numbers; + + inline ArcsType getArcID(const Arc &arc) const + { + //int n = _arc_num-arc._id-1; + ArcsType n = _arc_num - GR::id(arc) - 1; + + //ArcsType a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff; + //ArcsType b = _arc_id[arc]; + if (_arc_mixing) + return sequence(n); + else + return n; + } + + // finally unused because too slow + inline ArcsType getSource(const ArcsType arc) const + { + //ArcsType a = _source[arc]; + //return a; + + ArcsType n = _arc_num - arc - 1; + if (_arc_mixing) + n = mixingCoeff*(n%mixingCoeff) + n / mixingCoeff; + + ArcsType b; + if (n >= 0) + b = _node_id(_graph.source(GR::arcFromId(n))); + else + { + n = arc + 1 - _arc_num; + if (n <= _node_num) + b = _node_num; + else + if (n >= _graph._n1) + b = _graph._n1; + else + b = _graph._n1 - n; + } + + return b; + } + + + + // Implementation of the Block Search pivot rule + class BlockSearchPivotRule + { + private: + + // References to the NetworkSimplexSimple class + const IntVector &_source; + const IntVector &_target; + const CostVector &_cost; + const StateVector &_state; + const CostVector &_pi; + ArcsType &_in_arc; + ArcsType _search_arc_num; + + // Pivot rule data + ArcsType _block_size; + ArcsType _next_arc; + NetworkSimplexSimple &_ns; + + public: + + // Constructor + BlockSearchPivotRule(NetworkSimplexSimple &ns) : + _source(ns._source), _target(ns._target), + _cost(ns._cost), _state(ns._state), _pi(ns._pi), + _in_arc(ns.in_arc), _search_arc_num(ns._search_arc_num), + _next_arc(0), _ns(ns) + { + // The main parameters of the pivot rule + const double BLOCK_SIZE_FACTOR = 1; + const ArcsType MIN_BLOCK_SIZE = 10; + + _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE); + } + + // Find next entering arc + bool findEnteringArc() { + Cost min_val = 0; + + ArcsType N = _ns.num_threads; + + std::vector<Cost> minArray(N, 0); + std::vector<ArcsType> arcId(N); + ArcsType bs = (ArcsType)ceil(_block_size / (double)N); + + for (ArcsType i = 0; i < _search_arc_num; i += _block_size) { + + ArcsType e; + int j; +#pragma omp parallel + { +#ifdef OMP + int t = omp_get_thread_num(); +#else + int t = 0; +#endif + +#pragma omp for schedule(static, bs) lastprivate(e) + for (j = 0; j < std::min(i + _block_size, _search_arc_num) - i; j++) { + e = (_next_arc + i + j); if (e >= _search_arc_num) e -= _search_arc_num; + Cost c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < minArray[t]) { + minArray[t] = c; + arcId[t] = e; + } + } + } + for (int j = 0; j < N; j++) { + if (minArray[j] < min_val) { + min_val = minArray[j]; + _in_arc = arcId[j]; + } + } + Cost a = std::abs(_pi[_source[_in_arc]]) > std::abs(_pi[_target[_in_arc]]) ? std::abs(_pi[_source[_in_arc]]) : std::abs(_pi[_target[_in_arc]]); + a = a > std::abs(_cost[_in_arc]) ? a : std::abs(_cost[_in_arc]); + if (min_val < -EPSILON*a) { + _next_arc = e; + return true; + } + } + + Cost a = fabs(_pi[_source[_in_arc]]) > fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]) : fabs(_pi[_target[_in_arc]]); + a = a > fabs(_cost[_in_arc]) ? a : fabs(_cost[_in_arc]); + if (min_val >= -EPSILON*a) return false; + + return true; + } + + + // Find next entering arc + /*bool findEnteringArc() { + Cost min_val = 0; + int N = omp_get_max_threads(); + std::vector<Cost> minArray(N); + std::vector<ArcsType> arcId(N); + + ArcsType bs = (ArcsType)ceil(_block_size / (double)N); + for (ArcsType i = 0; i < _search_arc_num; i += _block_size) { + + ArcsType maxJ = std::min(i + _block_size, _search_arc_num) - i; + ArcsType j; +#pragma omp parallel + { + int t = omp_get_thread_num(); + Cost minV = 0; + ArcsType arcStart = _next_arc + i; + ArcsType arc = -1; +#pragma omp for schedule(static, bs) + for (j = 0; j < maxJ; j++) { + ArcsType e = arcStart + j; if (e >= _search_arc_num) e -= _search_arc_num; + Cost c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < minV) { + minV = c; + arc = e; + } + } + + minArray[t] = minV; + arcId[t] = arc; + } + for (int j = 0; j < N; j++) { + if (minArray[j] < min_val) { + min_val = minArray[j]; + _in_arc = arcId[j]; + } + } + + //FIX by Antoine Rolet to avoid precision issues + Cost a = std::max(std::abs(_cost[_in_arc]), std::max(std::abs(_pi[_source[_in_arc]]), std::abs(_pi[_target[_in_arc]]))); + if (min_val <-std::numeric_limits<Cost>::epsilon()*a) { + _next_arc = _next_arc + i + maxJ - 1; + if (_next_arc >= _search_arc_num) _next_arc -= _search_arc_num; + return true; + } + } + + if (min_val >= 0) { + return false; + } + + return true; + }*/ + + + /*bool findEnteringArc() { + Cost c, min = 0; + int cnt = _block_size; + int e, min_arc = _next_arc; + for (e = _next_arc; e < _search_arc_num; ++e) { + c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < min) { + min = c; + min_arc = e; + + } + if (--cnt == 0) { + if (min < 0) break; + cnt = _block_size; + + } + + } + if (min == 0 || cnt > 0) { + for (e = 0; e < _next_arc; ++e) { + c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]); + if (c < min) { + min = c; + min_arc = e; + + } + if (--cnt == 0) { + if (min < 0) break; + cnt = _block_size; + + } + + } + + } + if (min >= 0) return false; + _in_arc = min_arc; + _next_arc = e; + return true; + }*/ + + + + }; //class BlockSearchPivotRule + + + + public: + + + + int _init_nb_nodes; + ArcsType _init_nb_arcs; + + /// \name Parameters + /// The parameters of the algorithm can be specified using these + /// functions. + + /// @{ + + + /// \brief Set the costs of the arcs. + /// + /// This function sets the costs of the arcs. + /// If it is not used before calling \ref run(), the costs + /// will be set to \c 1 on all arcs. + /// + /// \param map An arc map storing the costs. + /// Its \c Value type must be convertible to the \c Cost type + /// of the algorithm. + /// + /// \return <tt>(*this)</tt> + template<typename CostMap> + NetworkSimplexSimple& costMap(const CostMap& map) { + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a)) { + _cost[getArcID(a)] = map[a]; + } + return *this; + } + + + /// \brief Set the costs of one arc. + /// + /// This function sets the costs of one arcs. + /// Done for memory reasons + /// + /// \param arc An arc. + /// \param arc A cost + /// + /// \return <tt>(*this)</tt> + template<typename Value> + NetworkSimplexSimple& setCost(const Arc& arc, const Value cost) { + _cost[getArcID(arc)] = cost; + return *this; + } + + + /// \brief Set the supply values of the nodes. + /// + /// This function sets the supply values of the nodes. + /// If neither this function nor \ref stSupply() is used before + /// calling \ref run(), the supply of each node will be set to zero. + /// + /// \param map A node map storing the supply values. + /// Its \c Value type must be convertible to the \c Value type + /// of the algorithm. + /// + /// \return <tt>(*this)</tt> + template<typename SupplyMap> + NetworkSimplexSimple& supplyMap(const SupplyMap& map) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + _supply[_node_id(n)] = map[n]; + } + return *this; + } + template<typename SupplyMap> + NetworkSimplexSimple& supplyMap(const SupplyMap* map1, int n1, const SupplyMap* map2, int n2) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + if (n<n1) + _supply[_node_id(n)] = map1[n]; + else + _supply[_node_id(n)] = map2[n - n1]; + } + return *this; + } + template<typename SupplyMap> + NetworkSimplexSimple& supplyMapAll(SupplyMap val1, int n1, SupplyMap val2, int n2) { + Node n; _graph.first(n); + for (; n != INVALIDNODE; _graph.next(n)) { + if (n<n1) + _supply[_node_id(n)] = val1; + else + _supply[_node_id(n)] = val2; + } + return *this; + } + + /// \brief Set single source and target nodes and a supply value. + /// + /// This function sets a single source node and a single target node + /// and the required flow value. + /// If neither this function nor \ref supplyMap() is used before + /// calling \ref run(), the supply of each node will be set to zero. + /// + /// Using this function has the same effect as using \ref supplyMap() + /// with such a map in which \c k is assigned to \c s, \c -k is + /// assigned to \c t and all other nodes have zero supply value. + /// + /// \param s The source node. + /// \param t The target node. + /// \param k The required amount of flow from node \c s to node \c t + /// (i.e. the supply of \c s and the demand of \c t). + /// + /// \return <tt>(*this)</tt> + NetworkSimplexSimple& stSupply(const Node& s, const Node& t, Value k) { + for (int i = 0; i != _node_num; ++i) { + _supply[i] = 0; + } + _supply[_node_id(s)] = k; + _supply[_node_id(t)] = -k; + return *this; + } + + /// \brief Set the type of the supply constraints. + /// + /// This function sets the type of the supply/demand constraints. + /// If it is not used before calling \ref run(), the \ref GEQ supply + /// type will be used. + /// + /// For more information, see \ref SupplyType. + /// + /// \return <tt>(*this)</tt> + NetworkSimplexSimple& supplyType(SupplyType supply_type) { + _stype = supply_type; + return *this; + } + + /// @} + + /// \name Execution Control + /// The algorithm can be executed using \ref run(). + + /// @{ + + /// \brief Run the algorithm. + /// + /// This function runs the algorithm. + /// The paramters can be specified using functions \ref lowerMap(), + /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(), + /// \ref supplyType(). + /// For example, + /// \code + /// NetworkSimplexSimple<ListDigraph> ns(graph); + /// ns.lowerMap(lower).upperMap(upper).costMap(cost) + /// .supplyMap(sup).run(); + /// \endcode + /// + /// This function can be called more than once. All the given parameters + /// are kept for the next call, unless \ref resetParams() or \ref reset() + /// is used, thus only the modified parameters have to be set again. + /// If the underlying digraph was also modified after the construction + /// of the class (or the last \ref reset() call), then the \ref reset() + /// function must be called. + /// + /// \param pivot_rule The pivot rule that will be used during the + /// algorithm. For more information, see \ref PivotRule. + /// + /// \return \c INFEASIBLE if no feasible flow exists, + /// \n \c OPTIMAL if the problem has optimal solution + /// (i.e. it is feasible and bounded), and the algorithm has found + /// optimal flow and node potentials (primal and dual solutions), + /// \n \c UNBOUNDED if the objective function of the problem is + /// unbounded, i.e. there is a directed cycle having negative total + /// cost and infinite upper bound. + /// + /// \see ProblemType, PivotRule + /// \see resetParams(), reset() + ProblemType run() { +#if DEBUG_LVL>0 + std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ; +#endif + if (!init()) return INFEASIBLE; +#if DEBUG_LVL>0 + std::cout << "Init done, starting iterations\n"; +#endif + + return start(); + } + + /// \brief Reset all the parameters that have been given before. + /// + /// This function resets all the paramaters that have been given + /// before using functions \ref lowerMap(), \ref upperMap(), + /// \ref costMap(), \ref supplyMap(), \ref stSupply(), \ref supplyType(). + /// + /// It is useful for multiple \ref run() calls. Basically, all the given + /// parameters are kept for the next \ref run() call, unless + /// \ref resetParams() or \ref reset() is used. + /// If the underlying digraph was also modified after the construction + /// of the class or the last \ref reset() call, then the \ref reset() + /// function must be used, otherwise \ref resetParams() is sufficient. + /// + /// For example, + /// \code + /// NetworkSimplexSimple<ListDigraph> ns(graph); + /// + /// // First run + /// ns.lowerMap(lower).upperMap(upper).costMap(cost) + /// .supplyMap(sup).run(); + /// + /// // Run again with modified cost map (resetParams() is not called, + /// // so only the cost map have to be set again) + /// cost[e] += 100; + /// ns.costMap(cost).run(); + /// + /// // Run again from scratch using resetParams() + /// // (the lower bounds will be set to zero on all arcs) + /// ns.resetParams(); + /// ns.upperMap(capacity).costMap(cost) + /// .supplyMap(sup).run(); + /// \endcode + /// + /// \return <tt>(*this)</tt> + /// + /// \see reset(), run() + NetworkSimplexSimple& resetParams() { + for (int i = 0; i != _node_num; ++i) { + _supply[i] = 0; + } + for (ArcsType i = 0; i != _arc_num; ++i) { + _cost[i] = 1; + } + _stype = GEQ; + return *this; + } + + + /// \brief Reset the internal data structures and all the parameters + /// that have been given before. + /// + /// This function resets the internal data structures and all the + /// paramaters that have been given before using functions \ref lowerMap(), + /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(), + /// \ref supplyType(). + /// + /// It is useful for multiple \ref run() calls. Basically, all the given + /// parameters are kept for the next \ref run() call, unless + /// \ref resetParams() or \ref reset() is used. + /// If the underlying digraph was also modified after the construction + /// of the class or the last \ref reset() call, then the \ref reset() + /// function must be used, otherwise \ref resetParams() is sufficient. + /// + /// See \ref resetParams() for examples. + /// + /// \return <tt>(*this)</tt> + /// + /// \see resetParams(), run() + NetworkSimplexSimple& reset() { + // Resize vectors + _node_num = _init_nb_nodes; + _arc_num = _init_nb_arcs; + int all_node_num = _node_num + 1; + ArcsType max_arc_num = _arc_num + 2 * _node_num; + + _source.resize(max_arc_num); + _target.resize(max_arc_num); + + _cost.resize(max_arc_num); + _supply.resize(all_node_num); + _flow.resize(max_arc_num); + _pi.resize(all_node_num); + + _parent.resize(all_node_num); + _pred.resize(all_node_num); + _forward.resize(all_node_num); + _thread.resize(all_node_num); + _rev_thread.resize(all_node_num); + _succ_num.resize(all_node_num); + _last_succ.resize(all_node_num); + _state.resize(max_arc_num); + + + //_arc_mixing=false; + if (_arc_mixing && _node_num > 1) { + // Store the arcs in a mixed order + //ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10)); + const ArcsType k = std::max(ArcsType(_arc_num / _node_num), ArcsType(3)); + mixingCoeff = k; + subsequence_length = _arc_num / mixingCoeff + 1; + num_big_subsequences = _arc_num % mixingCoeff; + num_total_big_subsequence_numbers = subsequence_length * num_big_subsequences; + +#pragma omp parallel for schedule(static) + for (Arc a = 0; a <= _graph.maxArcId(); a++) { // --a <=> _graph.next(a) , -1 == INVALID + ArcsType i = sequence(_graph.maxArcId()-a); + _source[i] = _node_id(_graph.source(a)); + _target[i] = _node_id(_graph.target(a)); + } + } else { + // Store the arcs in the original order + ArcsType i = 0; + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a), ++i) { + _source[i] = _node_id(_graph.source(a)); + _target[i] = _node_id(_graph.target(a)); + //_arc_id[a] = i; + } + } + + // Reset parameters + resetParams(); + return *this; + } + + /// @} + + /// \name Query Functions + /// The results of the algorithm can be obtained using these + /// functions.\n + /// The \ref run() function must be called before using them. + + /// @{ + + /// \brief Return the total cost of the found flow. + /// + /// This function returns the total cost of the found flow. + /// Its complexity is O(e). + /// + /// \note The return type of the function can be specified as a + /// template parameter. For example, + /// \code + /// ns.totalCost<double>(); + /// \endcode + /// It is useful if the total cost cannot be stored in the \c Cost + /// type of the algorithm, which is the default return type of the + /// function. + /// + /// \pre \ref run() must be called before using this function. + /*template <typename Number> + Number totalCost() const { + Number c = 0; + for (ArcIt a(_graph); a != INVALID; ++a) { + int i = getArcID(a); + c += Number(_flow[i]) * Number(_cost[i]); + } + return c; + }*/ + + template <typename Number> + Number totalCost() const { + Number c = 0; + +#ifdef SPARSE_FLOW + #ifdef HASHMAP + typename std::unordered_map<size_t, Value>::const_iterator it; + #else + typename std::map<size_t, Value>::const_iterator it; + #endif + for (it = _flow.data.begin(); it!=_flow.data.end(); ++it) + c += Number(it->second) * Number(_cost[it->first]); + return c; +#else + for (ArcsType i = 0; i<_flow.size(); i++) + c += _flow[i] * Number(_cost[i]); + return c; +#endif + } + +#ifndef DOXYGEN + Cost totalCost() const { + return totalCost<Cost>(); + } +#endif + + /// \brief Return the flow on the given arc. + /// + /// This function returns the flow on the given arc. + /// + /// \pre \ref run() must be called before using this function. + Value flow(const Arc& a) const { + return _flow[getArcID(a)]; + } + + /// \brief Return the flow map (the primal solution). + /// + /// This function copies the flow value on each arc into the given + /// map. The \c Value type of the algorithm must be convertible to + /// the \c Value type of the map. + /// + /// \pre \ref run() must be called before using this function. + template <typename FlowMap> + void flowMap(FlowMap &map) const { + Arc a; _graph.first(a); + for (; a != INVALID; _graph.next(a)) { + map.set(a, _flow[getArcID(a)]); + } + } + + /// \brief Return the potential (dual value) of the given node. + /// + /// This function returns the potential (dual value) of the + /// given node. + /// + /// \pre \ref run() must be called before using this function. + Cost potential(const Node& n) const { + return _pi[_node_id(n)]; + } + + /// \brief Return the potential map (the dual solution). + /// + /// This function copies the potential (dual value) of each node + /// into the given map. + /// The \c Cost type of the algorithm must be convertible to the + /// \c Value type of the map. + /// + /// \pre \ref run() must be called before using this function. + template <typename PotentialMap> + void potentialMap(PotentialMap &map) const { + Node n; _graph.first(n); + for (; n != INVALID; _graph.next(n)) { + map.set(n, _pi[_node_id(n)]); + } + } + + /// @} + + private: + + // Initialize internal data structures + bool init() { + if (_node_num == 0) return false; + + // Check the sum of supply values + _sum_supply = 0; + for (int i = 0; i != _node_num; ++i) { + _sum_supply += _supply[i]; + } + /*if (!((_stype == GEQ && _sum_supply <= 0) || + (_stype == LEQ && _sum_supply >= 0))) return false;*/ + + + // Initialize artifical cost + Cost ART_COST; + if (std::numeric_limits<Cost>::is_exact) { + ART_COST = std::numeric_limits<Cost>::max() / 2 + 1; + } else { + ART_COST = 0; + for (ArcsType i = 0; i != _arc_num; ++i) { + if (_cost[i] > ART_COST) ART_COST = _cost[i]; + } + ART_COST = (ART_COST + 1) * _node_num; + } + + // Initialize arc maps + for (ArcsType i = 0; i != _arc_num; ++i) { +#ifndef SPARSE_FLOW + _flow[i] = 0; //by default, the sparse matrix is empty +#endif + _state[i] = STATE_LOWER; + } +#ifdef SPARSE_FLOW + _flow = SparseValueVector<Value>(); +#endif + + // Set data for the artificial root node + _root = _node_num; + _parent[_root] = -1; + _pred[_root] = -1; + _thread[_root] = 0; + _rev_thread[0] = _root; + _succ_num[_root] = _node_num + 1; + _last_succ[_root] = _root - 1; + _supply[_root] = -_sum_supply; + _pi[_root] = 0; + + // Add artificial arcs and initialize the spanning tree data structure + if (_sum_supply == 0) { + // EQ supply constraints + _search_arc_num = _arc_num; + _all_arc_num = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _pred[u] = e; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + _state[e] = STATE_TREE; + if (_supply[u] >= 0) { + _forward[u] = true; + _pi[u] = 0; + _source[e] = u; + _target[e] = _root; + _flow[e] = _supply[u]; + _cost[e] = 0; + } else { + _forward[u] = false; + _pi[u] = ART_COST; + _source[e] = _root; + _target[e] = u; + _flow[e] = -_supply[u]; + _cost[e] = ART_COST; + } + } + } else if (_sum_supply > 0) { + // LEQ supply constraints + _search_arc_num = _arc_num + _node_num; + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + if (_supply[u] >= 0) { + _forward[u] = true; + _pi[u] = 0; + _pred[u] = e; + _source[e] = u; + _target[e] = _root; + _flow[e] = _supply[u]; + _cost[e] = 0; + _state[e] = STATE_TREE; + } else { + _forward[u] = false; + _pi[u] = ART_COST; + _pred[u] = f; + _source[f] = _root; + _target[f] = u; + _flow[f] = -_supply[u]; + _cost[f] = ART_COST; + _state[f] = STATE_TREE; + _source[e] = u; + _target[e] = _root; + //_flow[e] = 0; //by default, the sparse matrix is empty + _cost[e] = 0; + _state[e] = STATE_LOWER; + ++f; + } + } + _all_arc_num = f; + } else { + // GEQ supply constraints + _search_arc_num = _arc_num + _node_num; + ArcsType f = _arc_num + _node_num; + for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) { + _parent[u] = _root; + _thread[u] = u + 1; + _rev_thread[u + 1] = u; + _succ_num[u] = 1; + _last_succ[u] = u; + if (_supply[u] <= 0) { + _forward[u] = false; + _pi[u] = 0; + _pred[u] = e; + _source[e] = _root; + _target[e] = u; + _flow[e] = -_supply[u]; + _cost[e] = 0; + _state[e] = STATE_TREE; + } else { + _forward[u] = true; + _pi[u] = -ART_COST; + _pred[u] = f; + _source[f] = u; + _target[f] = _root; + _flow[f] = _supply[u]; + _state[f] = STATE_TREE; + _cost[f] = ART_COST; + _source[e] = _root; + _target[e] = u; + //_flow[e] = 0; //by default, the sparse matrix is empty + _cost[e] = 0; + _state[e] = STATE_LOWER; + ++f; + } + } + _all_arc_num = f; + } + + return true; + } + + // Find the join node + void findJoinNode() { + int u = _source[in_arc]; + int v = _target[in_arc]; + while (u != v) { + if (_succ_num[u] < _succ_num[v]) { + u = _parent[u]; + } else { + v = _parent[v]; + } + } + join = u; + } + + // Find the leaving arc of the cycle and returns true if the + // leaving arc is not the same as the entering arc + bool findLeavingArc() { + // Initialize first and second nodes according to the direction + // of the cycle + if (_state[in_arc] == STATE_LOWER) { + first = _source[in_arc]; + second = _target[in_arc]; + } else { + first = _target[in_arc]; + second = _source[in_arc]; + } + delta = INF; + char result = 0; + Value d; + ArcsType e; + + // Search the cycle along the path form the first node to the root + for (int u = first; u != join; u = _parent[u]) { + e = _pred[u]; + d = _forward[u] ? _flow[e] : INF; + if (d < delta) { + delta = d; + u_out = u; + result = 1; + } + } + // Search the cycle along the path form the second node to the root + for (int u = second; u != join; u = _parent[u]) { + e = _pred[u]; + d = _forward[u] ? INF : _flow[e]; + if (d <= delta) { + delta = d; + u_out = u; + result = 2; + } + } + + if (result == 1) { + u_in = first; + v_in = second; + } else { + u_in = second; + v_in = first; + } + return result != 0; + } + + // Change _flow and _state vectors + void changeFlow(bool change) { + // Augment along the cycle + if (delta > 0) { + Value val = _state[in_arc] * delta; + _flow[in_arc] += val; + for (int u = _source[in_arc]; u != join; u = _parent[u]) { + _flow[_pred[u]] += _forward[u] ? -val : val; + } + for (int u = _target[in_arc]; u != join; u = _parent[u]) { + _flow[_pred[u]] += _forward[u] ? val : -val; + } + } + // Update the state of the entering and leaving arcs + if (change) { + _state[in_arc] = STATE_TREE; + _state[_pred[u_out]] = + (_flow[_pred[u_out]] == 0) ? STATE_LOWER : STATE_UPPER; + } else { + _state[in_arc] = -_state[in_arc]; + } + } + + // Update the tree structure + void updateTreeStructure() { + int old_rev_thread = _rev_thread[u_out]; + int old_succ_num = _succ_num[u_out]; + int old_last_succ = _last_succ[u_out]; + v_out = _parent[u_out]; + + // Check if u_in and u_out coincide + if (u_in == u_out) { + // Update _parent, _pred, _pred_dir + _parent[u_in] = v_in; + _pred[u_in] = in_arc; + _forward[u_in] = (u_in == _source[in_arc]); + + // Update _thread and _rev_thread + if (_thread[v_in] != u_out) { + ArcsType after = _thread[old_last_succ]; + _thread[old_rev_thread] = after; + _rev_thread[after] = old_rev_thread; + after = _thread[v_in]; + _thread[v_in] = u_out; + _rev_thread[u_out] = v_in; + _thread[old_last_succ] = after; + _rev_thread[after] = old_last_succ; + } + } else { + // Handle the case when old_rev_thread equals to v_in + // (it also means that join and v_out coincide) + int thread_continue = old_rev_thread == v_in ? + _thread[old_last_succ] : _thread[v_in]; + + // Update _thread and _parent along the stem nodes (i.e. the nodes + // between u_in and u_out, whose parent have to be changed) + int stem = u_in; // the current stem node + int par_stem = v_in; // the new parent of stem + int next_stem; // the next stem node + int last = _last_succ[u_in]; // the last successor of stem + int before, after = _thread[last]; + _thread[v_in] = u_in; + _dirty_revs.clear(); + _dirty_revs.push_back(v_in); + while (stem != u_out) { + // Insert the next stem node into the thread list + next_stem = _parent[stem]; + _thread[last] = next_stem; + _dirty_revs.push_back(last); + + // Remove the subtree of stem from the thread list + before = _rev_thread[stem]; + _thread[before] = after; + _rev_thread[after] = before; + + // Change the parent node and shift stem nodes + _parent[stem] = par_stem; + par_stem = stem; + stem = next_stem; + + // Update last and after + last = _last_succ[stem] == _last_succ[par_stem] ? + _rev_thread[par_stem] : _last_succ[stem]; + after = _thread[last]; + } + _parent[u_out] = par_stem; + _thread[last] = thread_continue; + _rev_thread[thread_continue] = last; + _last_succ[u_out] = last; + + // Remove the subtree of u_out from the thread list except for + // the case when old_rev_thread equals to v_in + if (old_rev_thread != v_in) { + _thread[old_rev_thread] = after; + _rev_thread[after] = old_rev_thread; + } + + // Update _rev_thread using the new _thread values + for (int i = 0; i != int(_dirty_revs.size()); ++i) { + int u = _dirty_revs[i]; + _rev_thread[_thread[u]] = u; + } + + // Update _pred, _pred_dir, _last_succ and _succ_num for the + // stem nodes from u_out to u_in + int tmp_sc = 0, tmp_ls = _last_succ[u_out]; + for (int u = u_out, p = _parent[u]; u != u_in; u = p, p = _parent[u]) { + _pred[u] = _pred[p]; + _forward[u] = !_forward[p]; + tmp_sc += _succ_num[u] - _succ_num[p]; + _succ_num[u] = tmp_sc; + _last_succ[p] = tmp_ls; + } + _pred[u_in] = in_arc; + _forward[u_in] = (u_in == _source[in_arc]); + _succ_num[u_in] = old_succ_num; + } + + // Update _last_succ from v_in towards the root + int up_limit_out = _last_succ[join] == v_in ? join : -1; + int last_succ_out = _last_succ[u_out]; + for (int u = v_in; u != -1 && _last_succ[u] == v_in; u = _parent[u]) { + _last_succ[u] = last_succ_out; + } + + // Update _last_succ from v_out towards the root + if (join != old_rev_thread && v_in != old_rev_thread) { + for (int u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ; + u = _parent[u]) { + _last_succ[u] = old_rev_thread; + } + } else if (last_succ_out != old_last_succ) { + for (int u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ; + u = _parent[u]) { + _last_succ[u] = last_succ_out; + } + } + + // Update _succ_num from v_in to join + for (int u = v_in; u != join; u = _parent[u]) { + _succ_num[u] += old_succ_num; + } + // Update _succ_num from v_out to join + for (int u = v_out; u != join; u = _parent[u]) { + _succ_num[u] -= old_succ_num; + } + } + + void updatePotential() { + Cost sigma = _pi[v_in] - _pi[u_in] - + ((_forward[u_in])?_cost[in_arc]:(-_cost[in_arc])); + int end = _thread[_last_succ[u_in]]; + for (int u = u_in; u != end; u = _thread[u]) { + _pi[u] += sigma; + } + } + + + // Heuristic initial pivots + bool initialPivots() { + Value curr, total = 0; + std::vector<Node> supply_nodes, demand_nodes; + Node u; _graph.first(u); + for (; u != INVALIDNODE; _graph.next(u)) { + curr = _supply[_node_id(u)]; + if (curr > 0) { + total += curr; + supply_nodes.push_back(u); + } else if (curr < 0) { + demand_nodes.push_back(u); + } + } + if (_sum_supply > 0) total -= _sum_supply; + if (total <= 0) return true; + + ArcVector arc_vector; + if (_sum_supply >= 0) { + if (supply_nodes.size() == 1 && demand_nodes.size() == 1) { + // Perform a reverse graph search from the sink to the source + //typename GR::template NodeMap<bool> reached(_graph, false); + BoolVector reached(_node_num, false); + Node s = supply_nodes[0], t = demand_nodes[0]; + std::vector<Node> stack; + reached[t] = true; + stack.push_back(t); + while (!stack.empty()) { + Node u, v = stack.back(); + stack.pop_back(); + if (v == s) break; + Arc a; _graph.firstIn(a, v); + for (; a != INVALID; _graph.nextIn(a)) { + if (reached[u = _graph.source(a)]) continue; + ArcsType j = getArcID(a); + arc_vector.push_back(j); + reached[u] = true; + stack.push_back(u); + } + } + } else { + arc_vector.resize(demand_nodes.size()); + // Find the min. cost incomming arc for each demand node +#pragma omp parallel for + for (int i = 0; i < demand_nodes.size(); ++i) { + Node v = demand_nodes[i]; + Cost min_cost = std::numeric_limits<Cost>::max(); + Arc min_arc = INVALID; + Arc a; _graph.firstIn(a, v); + for (; a != INVALID; _graph.nextIn(a)) { + Cost c = _cost[getArcID(a)]; + if (c < min_cost) { + min_cost = c; + min_arc = a; + } + } + arc_vector[i] = getArcID(min_arc); + } + arc_vector.erase(std::remove(arc_vector.begin(), arc_vector.end(), INVALID), arc_vector.end()); + } + } else { + arc_vector.resize(supply_nodes.size()); + // Find the min. cost outgoing arc for each supply node +#pragma omp parallel for + for (int i = 0; i < int(supply_nodes.size()); ++i) { + Node u = supply_nodes[i]; + Cost min_cost = std::numeric_limits<Cost>::max(); + Arc min_arc = INVALID; + Arc a; _graph.firstOut(a, u); + for (; a != INVALID; _graph.nextOut(a)) { + Cost c = _cost[getArcID(a)]; + if (c < min_cost) { + min_cost = c; + min_arc = a; + } + } + arc_vector[i] = getArcID(min_arc); + } + arc_vector.erase(std::remove(arc_vector.begin(), arc_vector.end(), INVALID), arc_vector.end()); + } + + // Perform heuristic initial pivots + for (ArcsType i = 0; i != ArcsType(arc_vector.size()); ++i) { + in_arc = arc_vector[i]; + if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] - + _pi[_target[in_arc]]) >= 0) continue; + findJoinNode(); + bool change = findLeavingArc(); + if (delta >= MAX) return false; + changeFlow(change); + if (change) { + updateTreeStructure(); + updatePotential(); + } + } + return true; + } + + // Execute the algorithm + ProblemType start() { + return start<BlockSearchPivotRule>(); + } + + template <typename PivotRuleImpl> + ProblemType start() { + PivotRuleImpl pivot(*this); + ProblemType retVal = OPTIMAL; + + // Perform heuristic initial pivots + if (!initialPivots()) return UNBOUNDED; + + size_t iter_number = 0; + // Execute the Network Simplex algorithm + while (pivot.findEnteringArc()) { + if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) { +#if DEBUG_LVL>0 + if(iter_number>MAX_DEBUG_ITER) + break; + if(iter_number%1000==0||iter_number%1000==1){ + Cost curCost=totalCost(); + Value sumFlow=0; + Cost a; + a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); + a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + } + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; + std::cout << _cost[in_arc] << "\n"; + std::cout << _pi[_source[in_arc]] << "\n"; + std::cout << _pi[_target[in_arc]] << "\n"; + std::cout << a << "\n"; + } +#endif + + findJoinNode(); + bool change = findLeavingArc(); + if (delta >= MAX) return UNBOUNDED; + changeFlow(change); + if (change) { + updateTreeStructure(); + updatePotential(); + } + +#if DEBUG_LVL>0 + else{ + std::cout << "No change\n"; + } +#endif + +#if DEBUG_LVL>1 + std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n"; +#endif + + + } else { + char errMess[1000]; + sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number ); + std::cerr << errMess; + retVal = MAX_ITER_REACHED; + break; + } + + } + + + +#if DEBUG_LVL>0 + Cost curCost=totalCost(); + Value sumFlow=0; + Cost a; + a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]); + a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]); + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + } + + std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n"; + + std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n"; + std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n"; + +#endif + + + +#if DEBUG_LVL>1 + sumFlow=0; + for (int i=0; i<_flow.size(); i++) { + sumFlow+=_state[i]*_flow[i]; + if (_state[i]==STATE_TREE) { + std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n"; + } + } + std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n"; +#endif + + + + //Check feasibility + if(retVal == OPTIMAL){ + for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) { + if (_flow[e] != 0){ + if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126 + return INFEASIBLE; + else + _flow[e]=0; + } + } + } + + // Shift potentials to meet the requirements of the GEQ/LEQ type + // optimality conditions + if (_sum_supply == 0) { + if (_stype == GEQ) { + Cost max_pot = -std::numeric_limits<Cost>::max(); + for (ArcsType i = 0; i != _node_num; ++i) { + if (_pi[i] > max_pot) max_pot = _pi[i]; + } + if (max_pot > 0) { + for (ArcsType i = 0; i != _node_num; ++i) + _pi[i] -= max_pot; + } + } else { + Cost min_pot = std::numeric_limits<Cost>::max(); + for (ArcsType i = 0; i != _node_num; ++i) { + if (_pi[i] < min_pot) min_pot = _pi[i]; + } + if (min_pot < 0) { + for (ArcsType i = 0; i != _node_num; ++i) + _pi[i] -= min_pot; + } + } + } + + return retVal; + } + + }; //class NetworkSimplexSimple + + ///@} + +} //namespace lemon_omp diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py new file mode 100644 index 0000000..8b4d0c3 --- /dev/null +++ b/ot/lp/solver_1d.py @@ -0,0 +1,367 @@ +# -*- coding: utf-8 -*- +""" +Exact solvers for the 1D Wasserstein distance using cvxopt +""" + +# Author: Remi Flamary <remi.flamary@unice.fr> +# Author: Nicolas Courty <ncourty@irisa.fr> +# +# License: MIT License + +import numpy as np +import warnings + +from .emd_wrap import emd_1d_sorted +from ..backend import get_backend +from ..utils import list_to_array + + +def quantile_function(qs, cws, xs): + r""" Computes the quantile function of an empirical distribution + + Parameters + ---------- + qs: array-like, shape (n,) + Quantiles at which the quantile function is evaluated + cws: array-like, shape (m, ...) + cumulative weights of the 1D empirical distribution, if batched, must be similar to xs + xs: array-like, shape (n, ...) + locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions + + Returns + ------- + q: array-like, shape (..., n) + The quantiles of the distribution + """ + nx = get_backend(qs, cws) + n = xs.shape[0] + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + cws = cws.T.contiguous() + qs = qs.T.contiguous() + else: + cws = cws.T + qs = qs.T + idx = nx.searchsorted(cws, qs).T + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + + +def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True): + r""" + Computes the 1 dimensional OT loss [15] between two (batched) empirical + distributions + + .. math: + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq + + It is formally the p-Wasserstein distance raised to the power p. + We do so in a vectorized way by first building the individual quantile functions then integrating them. + + This function should be preferred to `emd_1d` whenever the backend is + different to numpy, and when gradients over + either sample positions or weights are required. + + Parameters + ---------- + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1 + require_sort: bool, optional + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True + + Returns + ------- + cost: float/array-like, shape (...) + the batched EMD + + References + ---------- + .. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport. + + """ + + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + u_cumweights = nx.cumsum(u_weights, 0) + v_cumweights = nx.cumsum(v_weights, 0) + + qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0) + u_quantiles = quantile_function(qs, u_cumweights, u_values) + v_quantiles = quantile_function(qs, v_cumweights, v_values) + qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)]) + delta = qs[1:, ...] - qs[:-1, ...] + diff_quantiles = nx.abs(u_quantiles - v_quantiles) + + if p == 1: + return nx.sum(delta * nx.abs(diff_quantiles), axis=0) + return nx.sum(delta * nx.power(diff_quantiles, p), axis=0) + + +def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the OT matrix + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost. + Otherwise returns only the optimal transportation matrix. + + Returns + ------- + gamma: (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log: dict + If input log is True, a dictionary containing the cost + + + Examples + -------- + + Simple example with obvious solution. The function emd_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd_1d(x_a, x_b, a, b) + array([[0. , 0.5], + [0.5, 0. ]]) + >>> ot.emd_1d(x_a, x_b) + array([[0. , 0.5], + [0.5, 0. ]]) + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd : EMD for multidimensional distributions + ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the + transportation matrix) + """ + a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + nx = get_backend(x_a, x_b) + + assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + + # if empty array given then use uniform distributions + if a is None or a.ndim == 0 or len(a) == 0: + a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] + if b is None or b.ndim == 0 or len(b) == 0: + b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] + + # ensure that same mass + np.testing.assert_almost_equal( + nx.to_numpy(nx.sum(a, axis=0)), + nx.to_numpy(nx.sum(b, axis=0)), + err_msg='a and b vector must have the same sum' + ) + b = b * nx.sum(a) / nx.sum(b) + + x_a_1d = nx.reshape(x_a, (-1,)) + x_b_1d = nx.reshape(x_b, (-1,)) + perm_a = nx.argsort(x_a_1d) + perm_b = nx.argsort(x_b_1d) + + G_sorted, indices, cost = emd_1d_sorted( + nx.to_numpy(a[perm_a]).astype(np.float64), + nx.to_numpy(b[perm_b]).astype(np.float64), + nx.to_numpy(x_a_1d[perm_a]).astype(np.float64), + nx.to_numpy(x_b_1d[perm_b]).astype(np.float64), + metric=metric, p=p + ) + + G = nx.coo_matrix( + G_sorted, + perm_a[indices[:, 0]], + perm_b[indices[:, 1]], + shape=(a.shape[0], b.shape[0]), + type_as=x_a + ) + if dense: + G = nx.todense(G) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") + if log: + log = {'cost': nx.from_numpy(cost, type_as=x_a)} + return G, log + return G + + +def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the loss + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Only used if log is set to True. Due to implementation details, + this function runs faster when dense is set to False. + log: boolean, optional (default=False) + If True, returns a dictionary containing the transportation matrix. + Otherwise returns only the loss. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + log: dict + If input log is True, a dictionary containing the Optimal transportation + matrix for the given parameters + + + Examples + -------- + + Simple example with obvious solution. The function emd2_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd2_1d(x_a, x_b, a, b) + 0.5 + >>> ot.emd2_1d(x_a, x_b) + 0.5 + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd2 : EMD for multidimensional distributions + ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix + instead of the cost) + """ + # If we do not return G (log==False), then we should not to cast it to dense + # (useless overhead) + G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, + dense=dense and log, log=True) + cost = log_emd['cost'] + if log: + log_emd = {'G': G} + return cost, log_emd + return cost diff --git a/ot/optim.py b/ot/optim.py index b9ca891..bd8ca26 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -12,34 +12,36 @@ import numpy as np from scipy.optimize.linesearch import scalar_search_armijo from .lp import emd from .bregman import sinkhorn +from ot.utils import list_to_array +from .backend import get_backend # The corresponding scipy function does not work for matrices def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99): - """ + r""" Armijo linesearch function that works with matrices - find an approximate minimum of f(xk+alpha*pk) that satifies the + Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the armijo conditions. Parameters ---------- f : callable loss function - xk : ndarray + xk : array-like initial position - pk : ndarray + pk : array-like descent direction - gfk : ndarray - gradient of f at xk + gfk : array-like + gradient of `f` at :math:`x_k` old_fval : float - loss value at xk + loss value at :math:`x_k` args : tuple, optional - arguments given to f + arguments given to `f` c1 : float, optional - c1 const in armijo rule (>0) + :math:`c_1` const in armijo rule (>0) alpha0 : float, optional initial step (>0) @@ -53,7 +55,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, loss value at step alpha """ - xk = np.atleast_1d(xk) + + xk, pk, gfk = list_to_array(xk, pk, gfk) + nx = get_backend(xk, pk) + + if len(xk.shape) == 0: + xk = nx.reshape(xk, (-1,)) + fc = [0] def phi(alpha1): @@ -65,10 +73,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval, else: phi0 = old_fval - derphi0 = np.sum(pk * gfk) # Quickfix for matrices + derphi0 = nx.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0) + # scalar_search_armijo can return alpha > 1 + if alpha is not None: + alpha = min(1, alpha) return alpha, fc[0], phi1 @@ -76,55 +87,64 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): """ Solve the linesearch in the FW iterations + Parameters ---------- cost : method Cost in the FW for the linesearch - G : ndarray, shape(ns,nt) + G : array-like, shape(ns,nt) The transport map at a given iteration of the FW - deltaG : ndarray (ns,nt) + deltaG : array-like (ns,nt) Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration - Mi : ndarray (ns,nt) + Mi : array-like (ns,nt) Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost - f_val : float - Value of the cost at G + f_val : float + Value of the cost at `G` armijo : bool, optional - If True the steps of the line-search is found via an armijo research. Else closed form is used. - If there is convergence issues use False. - C1 : ndarray (ns,ns), optional + If True the steps of the line-search is found via an armijo research. Else closed form is used. + If there is convergence issues use False. + C1 : array-like (ns,ns), optional Structure matrix in the source domain. Only used and necessary when armijo=False - C2 : ndarray (nt,nt), optional + C2 : array-like (nt,nt), optional Structure matrix in the target domain. Only used and necessary when armijo=False reg : float, optional - Regularization parameter. Only used and necessary when armijo=False - Gc : ndarray (ns,nt) + Regularization parameter. Only used and necessary when armijo=False + Gc : array-like (ns,nt) Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False - constC : ndarray (ns,nt) - Constant for the gromov cost. See [24]. Only used and necessary when armijo=False - M : ndarray (ns,nt), optional + constC : array-like (ns,nt) + Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False + M : array-like (ns,nt), optional Cost matrix between the features. Only used and necessary when armijo=False + Returns ------- alpha : float - The optimal step size of the FW + The optimal step size of the FW fc : int - nb of function call. Useless here - f_val : float - The value of the cost for the next iteration + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + + + .. _references-solve-linesearch: References ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain - and Courty Nicolas + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ if armijo: alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) else: # requires symetric matrices - dot1 = np.dot(C1, deltaG) - dot12 = dot1.dot(C2) - a = -2 * reg * np.sum(dot12 * deltaG) - b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG)) + G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(G, deltaG, C1, C2, constC) + else: + nx = get_backend(G, deltaG, C1, C2, constC, M) + + dot = nx.dot(nx.dot(C1, deltaG), C2) + a = -2 * reg * nx.sum(dot * deltaG) + b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG)) c = cost(G) alpha = solve_1d_linesearch_quad(a, b, c) @@ -136,48 +156,49 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): - """ + r""" Solve the general regularized OT problem with conditional gradient The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot f(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - - M is the (ns,nt) metric cost matrix - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>` Parameters ---------- - a : ndarray, shape (ns,) + a : array-like, shape (ns,) samples weights in the source domain - b : ndarray, shape (nt,) + b : array-like, shape (nt,) samples in the target domain - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) loss matrix reg : float Regularization term >0 - G0 : ndarray, shape (ns,nt), optional + G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations numItermaxEmd : int, optional Max number of iterations for emd stopThr : float, optional - Stop threshol on the relative variation (>0) + Stop threshold on the relative variation (>0) stopThr2 : float, optional - Stop threshol on the absolute variation (>0) + Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -193,6 +214,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, log dictionary return only if log==True in parameters + .. _references-cg: References ---------- @@ -204,6 +226,11 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, ot.bregman.sinkhorn : Entropic regularized optimal transport """ + a, b, M, G0 = list_to_array(a, b, M, G0) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(a, b) + else: + nx = get_backend(a, b, M) loop = 1 @@ -211,12 +238,12 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, log = {'loss': []} if G0 is None: - G = np.outer(a, b) + G = nx.outer(a, b) else: G = G0 def cost(G): - return np.sum(M * G) + reg * f(G) + return nx.sum(M * G) + reg * f(G) f_val = cost(G) if log: @@ -237,15 +264,17 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, # problem linearization Mi = M + reg * df(G) # set M positive - Mi += Mi.min() + Mi += nx.min(Mi) # solve linear program - Gc = emd(a, b, Mi, numItermax=numItermaxEmd) + Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) deltaG = Gc - G # line search alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) + if alpha is None: + alpha = 0.0 G = G + alpha * deltaG @@ -268,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: + log.update(logemd) return G, log else: return G @@ -275,51 +305,52 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False): - """ + r""" Solve the general regularized OT problem with the generalized conditional gradient The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} &= \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} &= \mathbf{b} - \gamma\geq 0 + \gamma &\geq 0 where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_ + The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>` Parameters ---------- - a : ndarray, shape (ns,) + a : array-like, shape (ns,) samples weights in the source domain - b : ndarrayv (nt,) + b : array-like, (nt,) samples in the target domain - M : ndarray, shape (ns, nt) + M : array-like, shape (ns, nt) loss matrix reg1 : float Entropic Regularization term >0 reg2 : float Second Regularization term >0 - G0 : ndarray, shape (ns, nt), optional + G0 : array-like, shape (ns, nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations numInnerItermax : int, optional Max number of iterations of Sinkhorn stopThr : float, optional - Stop threshol on the relative variation (>0) + Stop threshold on the relative variation (>0) stopThr2 : float, optional - Stop threshol on the absolute variation (>0) + Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -332,9 +363,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, log : dict log dictionary return only if log==True in parameters + + .. _references-gcg: References ---------- + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. See Also @@ -342,6 +377,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, ot.optim.cg : conditional gradient """ + a, b, M, G0 = list_to_array(a, b, M, G0) + nx = get_backend(a, b, M) loop = 1 @@ -349,12 +386,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, log = {'loss': []} if G0 is None: - G = np.outer(a, b) + G = nx.outer(a, b) else: G = G0 def cost(G): - return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G) + return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G) f_val = cost(G) if log: @@ -382,7 +419,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, deltaG = Gc - G # line search - dcost = Mi + reg1 * (1 + np.log(G)) # ?? + dcost = Mi + reg1 * (1 + nx.log(G)) # ?? alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val) G = G + alpha * deltaG @@ -413,10 +450,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, def solve_1d_linesearch_quad(a, b, c): - """ - For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem: + r""" + For any convex or non-convex 1d quadratic function `f`, solve the following problem: + .. math:: - \argmin f(x)=a*x^{2}+b*x+c + + \mathop{\arg \min}_{0 \leq x \leq 1} \quad f(x) = ax^{2} + bx + c Parameters ---------- diff --git a/ot/partial.py b/ot/partial.py index eb707d8..b7093e4 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -20,13 +20,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, The function considers the following problem: .. math:: - \gamma = \arg\min_\gamma <\gamma,(M-\lambda)>_F + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, (\mathbf{M} - \lambda) \rangle_F - s.t. - \gamma\geq 0 \\ - \gamma 1 \leq a\\ - \gamma^T 1 \leq b\\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} + .. math:: + s.t. \ \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \gamma &\geq 0 + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} or equivalently (see Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. @@ -34,33 +37,32 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, metrics. Foundations of Computational Mathematics, 18(1), 1-44.) .. math:: - \gamma = \arg\min_\gamma <\gamma,M>_F + \sqrt(\lambda/2) - (\|\gamma 1 - a\|_1 + \|\gamma^T 1 - b\|_1) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \sqrt{\frac{\lambda}{2} (\|\gamma \mathbf{1} - \mathbf{a}\|_1 + \|\gamma^T \mathbf{1} - \mathbf{b}\|_1)} - s.t. - \gamma\geq 0 \\ + s.t. \ \gamma \geq 0 where : - - M is the metric cost matrix - - a and b are source and target unbalanced distributions - - :math:`\lambda` is the lagragian cost. Tuning its value allows attaining - a given mass to be transported m + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`\lambda` is the lagrangian cost. Tuning its value allows attaining + a given mass to be transported `m` - The formulation of the problem has been proposed in [28]_ + The formulation of the problem has been proposed in :ref:`[28] <references-partial-wasserstein-lagrange>` Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) - Unnormalized histograms of dimension dim_b + Unnormalized histograms of dimension `dim_b` M : np.ndarray (dim_a, dim_b) cost matrix for the quadratic cost reg_m : float, optional - Lagragian cost + Lagrangian cost nb_dummies : int, optional, default:1 number of reservoir points to be added (to avoid numerical instabilities, increase its value if an error is raised) @@ -69,6 +71,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, **kwargs : dict parameters can be directly passed to the emd solver + .. warning:: When dealing with a large number of points, the EMD solver may face some instabilities, especially when the mass associated to the dummy @@ -77,7 +80,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, Returns ------- - gamma : (dim_a x dim_b) ndarray + gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -97,9 +100,10 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, array([[0.1, 0. ], [0. , 0. ]]) + + .. _references-partial-wasserstein-lagrange: References ---------- - .. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in optimal transport and Monge-Ampere obstacle problems. Annals of mathematics, 673-730. @@ -162,27 +166,30 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): The function considers the following problem: .. math:: - \gamma = \arg\min_\gamma <\gamma,M>_F + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - s.t. - \gamma\geq 0 \\ - \gamma 1 \leq a\\ - \gamma^T 1 \leq b\\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} + .. math:: + s.t. \ \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \gamma &\geq 0 + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - M is the metric cost matrix - - a and b are source and target unbalanced distributions - - m is the amount of mass to be transported + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - `m` is the amount of mass to be transported Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) - Unnormalized histograms of dimension dim_b + Unnormalized histograms of dimension `dim_b` M : np.ndarray (dim_a, dim_b) cost matrix for the quadratic cost m : float, optional @@ -205,7 +212,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): Returns ------- - :math:`gamma` : (dim_a x dim_b) ndarray + gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -230,9 +237,9 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): .. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in optimal transport and Monge-Ampere obstacle problems. Annals of mathematics, 673-730. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. See Also -------- @@ -254,7 +261,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) M_extended = np.zeros((len(a_extended), len(b_extended))) - M_extended[-1, -1] = np.max(M) * 1e5 + M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5 M_extended[:len(a), :len(b)] = M gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, @@ -278,27 +285,30 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): The function considers the following problem: .. math:: - \gamma = \arg\min_\gamma <\gamma,M>_F + \gamma = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + .. math:: + s.t. \ \gamma \mathbf{1} &\leq \mathbf{a} - s.t. - \gamma\geq 0 \\ - \gamma 1 \leq a\\ - \gamma^T 1 \leq b\\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \gamma &\geq 0 + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - M is the metric cost matrix - - a and b are source and target unbalanced distributions - - m is the amount of mass to be transported + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - `m` is the amount of mass to be transported Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) - Unnormalized histograms of dimension dim_b + Unnormalized histograms of dimension `dim_b` M : np.ndarray (dim_a, dim_b) cost matrix for the quadratic cost m : float, optional @@ -321,8 +331,8 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): Returns ------- - :math:`gamma` : (dim_a x dim_b) ndarray - Optimal transportation matrix for the given parameters + GW: float + partial GW discrepancy log : dict log dictionary returned only if `log` is `True` @@ -344,14 +354,13 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): .. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in optimal transport and Monge-Ampere obstacle problems. Annals of mathematics, 673-730. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True, **kwargs) - log_w['T'] = partial_gw if log: @@ -361,8 +370,8 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): def gwgrad_partial(C1, C2, T): - """Compute the GW gradient. Note: we can not use the trick in [12]_ as - the marginals may not sum to 1. + """Compute the GW gradient. Note: we can not use the trick in :ref:`[12] <references-gwgrad-partial>` + as the marginals may not sum to 1. Parameters ---------- @@ -380,6 +389,8 @@ def gwgrad_partial(C1, C2, T): numpy.array of shape (n_p+nb_dummies, n_u) gradient + + .. _references-gwgrad-partial: References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, @@ -426,22 +437,25 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, The function considers the following problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - s.t. \gamma 1 \leq a \\ - \gamma^T 1 \leq b \\ - \gamma\geq 0 \\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\ + .. math:: + s.t. \ \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \gamma &\geq 0 + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma) - =\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are the sample weights - - m is the amount of mass to be transported + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in [29]_ + The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein>` Parameters @@ -455,7 +469,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, q : ndarray, shape (nt,) Distribution in the target space m : float, optional - Amount of mass to be transported (default: min (|p|_1, |q|_1)) + Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) nb_dummies : int, optional Number of dummy points to add (avoid instabilities in the EMD solver) G0 : ndarray, shape (ns, nt), optional @@ -477,7 +491,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, Returns ------- - gamma : (dim_a x dim_b) ndarray + gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -501,14 +515,16 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, >>> np.round(partial_gromov_wasserstein(C1, C2, a, b, m=0.25),2) array([[0. , 0. , 0. , 0. ], [0. , 0. , 0. , 0. ], - [0. , 0. , 0. , 0. ], - [0. , 0. , 0. , 0.25]]) + [0. , 0. , 0.25, 0. ], + [0. , 0. , 0. , 0. ]]) + + .. _references-partial-gromov-wasserstein: References ---------- - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ @@ -530,20 +546,18 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, cpt = 0 err = 1 - eps = 1e-20 + if log: log = {'err': []} while (err > tol and cpt < numItermax): - Gprev = G0 + Gprev = np.copy(G0) M = gwgrad_partial(C1, C2, G0) - M[M < eps] = np.quantile(M, thres) - M_emd = np.zeros(dim_G_extended) M_emd[:len(p), :len(q)] = M - M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5 + M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2 M_emd = np.asarray(M_emd, dtype=np.float64) Gc, logemd = emd(p_extended, q_extended, M_emd, log=True, **kwargs) @@ -565,6 +579,22 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, print('{:5d}|{:8e}|{:8e}'.format(cpt, err, gwloss_partial(C1, C2, G0))) + deltaG = G0 - Gprev + a = gwloss_partial(C1, C2, deltaG) + b = 2 * np.sum(M * deltaG) + if b > 0: # due to numerical precision + gamma = 0 + cpt = numItermax + elif a > 0: + gamma = min(1, np.divide(-b, 2.0 * a)) + else: + if (a + b) < 0: + gamma = 1 + else: + gamma = 0 + cpt = numItermax + + G0 = Gprev + gamma * deltaG cpt += 1 if log: @@ -584,22 +614,25 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, The function considers the following problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + GW = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F - s.t. \gamma 1 \leq a \\ - \gamma^T 1 \leq b \\ - \gamma\geq 0 \\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\ + .. math:: + s.t. \ \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \gamma &\geq 0 + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are the sample weights - - m is the amount of mass to be transported + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in [29]_ + The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein2>` Parameters @@ -613,7 +646,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, q : ndarray, shape (nt,) Distribution in the target space m : float, optional - Amount of mass to be transported (default: min (|p|_1, |q|_1)) + Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) nb_dummies : int, optional Number of dummy points to add (avoid instabilities in the EMD solver) G0 : ndarray, shape (ns, nt), optional @@ -642,7 +675,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, Returns ------- - partial_gw_dist : (dim_a x dim_b) ndarray + partial_gw_dist : float partial GW discrepancy log : dict log dictionary returned only if `log` is `True` @@ -663,11 +696,13 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, >>> np.round(partial_gromov_wasserstein2(C1, C2, a, b, m=0.25),2) 0.0 + + .. _references-partial-gromov-wasserstein2: References ---------- - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ @@ -693,30 +728,29 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, The function considers the following problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 \leq a \\ - \gamma^T 1 \leq b \\ - \gamma\geq 0 \\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\ + s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\ + \gamma^T \mathbf{1} &\leq \mathbf{b} \\ + \gamma &\geq 0 \\ + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\ where : - - M is the metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are the sample weights - - m is the amount of mass to be transported + - :math:`\mathbf{M}` is the metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in [3]_ (prop. 5) + The formulation of the problem has been proposed in :ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5) Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) - Unnormalized histograms of dimension dim_b + Unnormalized histograms of dimension `dim_b` M : np.ndarray (dim_a, dim_b) cost matrix reg : float @@ -735,7 +769,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, Returns ------- - gamma : (dim_a x dim_b) ndarray + gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` @@ -751,6 +785,8 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, array([[0.06, 0.02], [0.01, 0. ]]) + + .. _references-entropic-partial-wasserstein: References ---------- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. @@ -825,32 +861,34 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, numItermax=1000, tol=1e-7, log=False, verbose=False): r""" - Returns the partial Gromov-Wasserstein transport between (C1,p) and (C2,q) + Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \arg\min_{\gamma} \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_{\gamma} \quad \sum_{i,j,k,l} + L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot + \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) - s.t. - \gamma\geq 0 \\ - \gamma 1 \leq a\\ - \gamma^T 1 \leq b\\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} + .. math:: + s.t. \ \gamma &\geq 0 + + \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - C1 is the metric cost matrix in the source space - - C2 is the metric cost matrix in the target space - - p and q are the sample weights - - L : quadratic loss function - - :math:`\Omega` is the entropic regularization term - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - m is the amount of mass to be transported + - :math:`\mathbf{C_1}` is the metric cost matrix in the source space + - :math:`\mathbf{C_2}` is the metric cost matrix in the target space + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights + - `L`: quadratic loss function + - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - `m` is the amount of mass to be transported - The formulation of the GW problem has been proposed in [12]_ and the - partial GW in [29]_. + The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein>` Parameters ---------- @@ -865,7 +903,7 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, reg: float entropic regularization parameter m : float, optional - Amount of mass to be transported (default: min (|p|_1, |q|_1)) + Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix numItermax : int, optional @@ -887,12 +925,12 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, >>> y = np.array([3,2,98,199]).reshape((-1,1)) >>> C1 = sp.spatial.distance.cdist(x, x) >>> C2 = sp.spatial.distance.cdist(y, y) - >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b,50), 2) + >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50), 2) array([[0.12, 0.13, 0. , 0. ], [0.13, 0.12, 0. , 0. ], [0. , 0. , 0.25, 0. ], [0. , 0. , 0. , 0.25]]) - >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50, m=0.25), 2) + >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50,0.25), 2) array([[0.02, 0.03, 0. , 0.03], [0.03, 0.03, 0. , 0.03], [0. , 0. , 0.03, 0. ], @@ -900,19 +938,22 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, Returns ------- - :math: `gamma` : (dim_a x dim_b) ndarray + :math: `gamma` : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` + + .. _references-entropic-partial-gromov-wassertein: References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. See Also -------- @@ -964,33 +1005,33 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, numItermax=1000, tol=1e-7, log=False, verbose=False): r""" - Returns the partial Gromov-Wasserstein discrepancy between (C1,p) and - (C2,q) + Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \arg\min_{\gamma} \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + reg\cdot\Omega(\gamma) + GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot + \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) - s.t. - \gamma\geq 0 \\ - \gamma 1 \leq a\\ - \gamma^T 1 \leq b\\ - 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} + .. math:: + s.t. \ \gamma &\geq 0 + + \gamma \mathbf{1} &\leq \mathbf{a} + + \gamma^T \mathbf{1} &\leq \mathbf{b} + + \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - - C1 is the metric cost matrix in the source space - - C2 is the metric cost matrix in the target space - - p and q are the sample weights - - L : quadratic loss function - - :math:`\Omega` is the entropic regularization term - :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - m is the amount of mass to be transported + - :math:`\mathbf{C_1}` is the metric cost matrix in the source space + - :math:`\mathbf{C_2}` is the metric cost matrix in the target space + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights + - `L` : quadratic loss function + - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - `m` is the amount of mass to be transported - The formulation of the GW problem has been proposed in [12]_ and the - partial GW in [29]_. + The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein2>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein2>` Parameters @@ -1006,7 +1047,7 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, reg: float entropic regularization parameter m : float, optional - Amount of mass to be transported (default: min (|p|_1, |q|_1)) + Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix numItermax : int, optional @@ -1039,14 +1080,17 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, >>> np.round(entropic_partial_gromov_wasserstein2(C1, C2, a, b,50), 2) 1.87 + + .. _references-entropic-partial-gromov-wassertein2: References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ partial_gw, log_gw = entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, @@ -18,10 +18,10 @@ from matplotlib import gridspec def plot1D_mat(a, b, M, title=''): - """ Plot matrix M with the source and target 1D distribution + """ Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution - Creates a subplot with the source distribution a on the left and - target distribution b on the tot. The matrix M is shown in between. + Creates a subplot with the source distribution :math:`\mathbf{a}` on the left and + target distribution :math:`\mathbf{b}` on the top. The matrix :math:`\mathbf{M}` is shown in between. Parameters @@ -61,10 +61,10 @@ def plot1D_mat(a, b, M, title=''): def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): - """ Plot matrix M in 2D with lines using alpha values + """ Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values Plot lines between source and target 2D samples with a color - proportional to the value of the matrix G between samples. + proportional to the value of the matrix :math:`\mathbf{G}` between samples. Parameters diff --git a/ot/regpath.py b/ot/regpath.py new file mode 100644 index 0000000..269937a --- /dev/null +++ b/ot/regpath.py @@ -0,0 +1,827 @@ +# -*- coding: utf-8 -*- +""" +Regularization path OT solvers +""" + +# Author: Haoran Wu <haoran.wu@univ-ubs.fr> +# License: MIT License + +import numpy as np +import scipy.sparse as sp + + +def recast_ot_as_lasso(a, b, C): + r"""This function recasts the l2-penalized UOT problem as a Lasso problem + + Recall the l2-penalized UOT problem defined in [Chapel et al., 2021] + .. math:: + UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2 + + \lambda \|T^T 1_n - b\|_2^2 + s.t. + T \geq 0 + where : + - C is the (dim_a, dim_b) metric cost matrix + - :math:`\lambda` is the l2-regularization coefficient + - a and b are source and target distributions + - T is the transport plan to optimize + + The problem above can be reformulated to a non-negative penalized + linear regression problem, particularly Lasso + .. math:: + UOT2 = \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + s.t. + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] + - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, + see [Chapel et al., 2021] for the design of H. The matrix product H t + computes both the source marginal and the target marginal. + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + Returns + ------- + H : np.ndarray (dim_a+dim_b, dim_a*dim_b) + Auxiliary matrix constituted by 0 and 1 + y : np.ndarray (ns + nt, ) + Concatenation of histogram a and histogram b + c : np.ndarray (ns * nt, ) + Flattened array of cost matrix + Examples + -------- + >>> import ot + >>> a = np.array([0.2, 0.3, 0.5]) + >>> b = np.array([0.1, 0.9]) + >>> C = np.array([[16., 25.], [28., 16.], [40., 36.]]) + >>> H, y, c = ot.regpath.recast_ot_as_lasso(a, b, C) + >>> H.toarray() + array([[1., 1., 0., 0., 0., 0.], + [0., 0., 1., 1., 0., 0.], + [0., 0., 0., 0., 1., 1.], + [1., 0., 1., 0., 1., 0.], + [0., 1., 0., 1., 0., 1.]]) + >>> y + array([0.2, 0.3, 0.5, 0.1, 0.9]) + >>> c + array([16., 25., 28., 16., 40., 36.]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + dim_a = np.shape(a)[0] + dim_b = np.shape(b)[0] + y = np.concatenate((a, b)) + c = C.flatten() + jHa = np.arange(dim_a * dim_b) + iHa = np.repeat(np.arange(dim_a), dim_b) + jHb = np.arange(dim_a * dim_b) + iHb = np.tile(np.arange(dim_b), dim_a) + dim_a + j = np.concatenate((jHa, jHb)) + i = np.concatenate((iHa, iHb)) + H = sp.csc_matrix((np.ones(dim_a * dim_b * 2), (i, j)), + shape=(dim_a + dim_b, dim_a * dim_b)) + return H, y, c + + +def recast_semi_relaxed_as_lasso(a, b, C): + r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem + + .. math:: + semi-relaxed UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2 + s.t. + T^T 1_n = b + t \geq 0 + where : + - C is the (dim_a, dim_b) metric cost matrix + - :math:`\lambda` is the l2-regularization coefficient + - a and b are source and target distributions + - T is the transport plan to optimize + + The problem above can be reformulated as follows + .. math:: + semi-relaxed UOT2 = \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + s.t. + H_c t = b + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - H_r is a (dim_a, dim_a * dim_b) metric matrix, + which computes the sum along the rows of transport plan T + - H_c is a (dim_b, dim_a * dim_b) metric matrix, + which computes the sum along the columns of transport plan T + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + Returns + ------- + Hr : np.ndarray (dim_a, dim_a * dim_b) + Auxiliary matrix constituted by 0 and 1, which computes + the sum along the rows of transport plan T + Hc : np.ndarray (dim_b, dim_a * dim_b) + Auxiliary matrix constituted by 0 and 1, which computes + the sum along the columns of transport plan T + c : np.ndarray (ns * nt, ) + Flattened array of cost matrix + Examples + -------- + >>> import ot + >>> a = np.array([0.2, 0.3, 0.5]) + >>> b = np.array([0.1, 0.9]) + >>> C = np.array([[16., 25.], [28., 16.], [40., 36.]]) + >>> Hr,Hc,c = ot.regpath.recast_semi_relaxed_as_lasso(a, b, C) + >>> Hr.toarray() + array([[1., 1., 0., 0., 0., 0.], + [0., 0., 1., 1., 0., 0.], + [0., 0., 0., 0., 1., 1.]]) + >>> Hc.toarray() + array([[1., 0., 1., 0., 1., 0.], + [0., 1., 0., 1., 0., 1.]]) + >>> c + array([16., 25., 28., 16., 40., 36.]) + """ + + dim_a = np.shape(a)[0] + dim_b = np.shape(b)[0] + + c = C.flatten() + jHr = np.arange(dim_a * dim_b) + iHr = np.repeat(np.arange(dim_a), dim_b) + jHc = np.arange(dim_a * dim_b) + iHc = np.tile(np.arange(dim_b), dim_a) + + Hr = sp.csc_matrix((np.ones(dim_a * dim_b), (iHr, jHr)), + shape=(dim_a, dim_a * dim_b)) + Hc = sp.csc_matrix((np.ones(dim_a * dim_b), (iHc, jHc)), + shape=(dim_b, dim_a * dim_b)) + + return Hr, Hc, c + + +def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma): + r""" This function computes the next value of gamma if a variable + will be added in next iteration of the regularization path + + We look for the largest value of gamma such that + the gradient of an inactive variable vanishes + .. math:: + \max_{i \in \bar{A}} \frac{h_i^T(H_A \phi - y)}{h_i^T H_A \delta - c_i} + where : + - A is the current active set + - h_i is the ith column of auxiliary matrix H + - H_A is the sub-matrix constructed by the columns of H + whose indices belong to the active set A + - c_i is the ith element of cost vector c + - y is the concatenation of source and target distribution + - :math:`\phi` is the intercept of the solutions in current iteration + - :math:`\delta` is the slope of the solutions in current iteration + Parameters + ---------- + phi : np.ndarray (|A|, ) + Intercept of the solutions in current iteration (t is piecewise linear) + delta : np.ndarray (|A|, ) + Slope of the solutions in current iteration (t is piecewise linear) + HtH : np.ndarray (dim_a * dim_b, dim_a * dim_b) + Matrix product of H^T H + Hty : np.ndarray (dim_a + dim_b, ) + Matrix product of H^T y + c: np.ndarray (dim_a * dim_b, ) + Flattened array of cost matrix C + active_index : list + Indices of active variables + current_gamma : float + Value of regularization coefficient at the start of current iteration + Returns + ------- + next_gamma : float + Value of gamma if a variable is added to active set in next iteration + next_active_index : int + Index of variable to be activated + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + M = (HtH[:, active_index].dot(phi) - Hty) / \ + (HtH[:, active_index].dot(delta) - c + 1e-16) + M[active_index] = 0 + M[M > (current_gamma - 1e-10 * current_gamma)] = 0 + return np.max(M), np.argmax(M) + + +def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra, + c, active_index, current_gamma): + r""" This function computes the next value of gamma when a variable is + active in the regularization path of semi-relaxed UOT. + + By taking the Lagrangian form of the problem, we obtain a similar update + as the two-sided relaxed UOT + .. math:: + \max_{i \in \bar{A}} \frac{h_{r i}^T(H_{r A} \phi - a) + h_{c i}^T + \phi_u}{h_{r i}^T H_{r A} \delta + h_{c i} \delta_u - c_i} + where : + - A is the current active set + - h_{r i} is the ith column of the matrix H_r + - h_{c i} is the ith column of the matrix H_c + - H_{r A} is the sub-matrix constructed by the columns of H_r + whose indices belong to the active set A + - c_i is the ith element of cost vector c + - y is the concatenation of source and target distribution + - :math:`\phi` is the intercept of the solutions in current iteration + - :math:`\delta` is the slope of the solutions in current iteration + - :math:`\phi_u` is the intercept of Lagrange parameter in current + iteration + - :math:`\delta_u` is the slope of Lagrange parameter in current iteration + Parameters + ---------- + phi : np.ndarray (|A|, ) + Intercept of the solutions in current iteration (t is piecewise linear) + delta : np.ndarray (|A|, ) + Slope of the solutions in current iteration (t is piecewise linear) + phi_u : np.ndarray (dim_b, ) + Intercept of the Lagrange parameter in current iteration (also linear) + delta_u : np.ndarray (dim_b, ) + Slope of the Lagrange parameter in current iteration (also linear) + HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) + Matrix product of H_r^T H_r + Hc : np.ndarray (dim_b, dim_a * dim_b) + Matrix that computes the sum along the columns of transport plan T + Hra : np.ndarray (dim_a * dim_b, ) + Matrix product of H_r^T a + c: np.ndarray (dim_a * dim_b, ) + Flattened array of cost matrix C + active_index : list + Indices of active variables + current_gamma : float + Value of regularization coefficient at the start of current iteration + Returns + ------- + next_gamma : float + Value of gamma if a variable is added to active set in next iteration + next_active_index : int + Index of variable to be activated + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \ + (HrHr[:, active_index].dot(delta) - c + Hc.T.dot(delta_u) + 1e-16) + M[active_index] = 0 + M[M > (current_gamma - 1e-10 * current_gamma)] = 0 + return np.max(M), np.argmax(M) + + +def compute_next_removal(phi, delta, current_gamma): + r""" This function computes the next value of gamma if a variable + is removed in next iteration of regularization path + + We look for the largest value of gamma such that + an element of current solution vanishes + .. math:: + \max_{j \in A} \frac{\phi_j}{\delta_j} + where : + - A is the current active set + - phi_j is the jth element of the intercept of current solution + - delta_j is the jth elemnt of the slope of current solution + Parameters + ---------- + phi : np.ndarray (|A|, ) + Intercept of the solutions in current iteration (t is piecewise linear) + delta : np.ndarray (|A|, ) + Slope of the solutions in current iteration (t is piecewise linear) + current_gamma : float + Value of regularization coefficient at the start of current iteration + Returns + ------- + next_removal_gamma : float + Value of gamma if a variable is removed in next iteration + next_removal_index : int + Index of the variable to remove in next iteration + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + r_candidate = phi / (delta - 1e-16) + r_candidate[r_candidate >= (1 - 1e-8) * current_gamma] = 0 + return np.max(r_candidate), np.argmax(r_candidate) + + +def complement_schur(M_current, b, d, id_pop): + r""" This function computes the inverse of matrix in regularization path + using Schur complement + + Two cases may arise: Firstly one variable is added to the active set + .. math:: + M_{k+1}^{-1} = + \begin{bmatrix} + M_{k}^{-1} + s^{-1} M_{k}^{-1} b b^T M_{k}^{-1} & -s^{-1} \\ + - s^{-1} b^T M_{k}^{-1} & s^{-1} + \end{bmatrix} + where : + - :math:`M_k^{-1}` is the inverse of matrix in previous iteration and + :math:`M_k` is the upper left block matrix in Schur formulation + - b is the upper right block matrix in Schur formulation. In our case, + b is reduced to a column vector and b^T is the lower left block matrix + - s is the Schur complement, given by + :math:`s = d - b^T M_{k}^{-1} b` in our case + + Secondly, one variable is removed from the active set + .. math:: + M_{k+1}^{-1} = M^{-1}_{A_k \backslash q} - + \frac{r_{-q,q} r^{T}_{-q,q}}{r_{q,q}} + where : + - q is the index of column and row to delete + - :math:`M^{-1}_{A_k \backslash q}` is the previous inverse matrix + without qth column and qth row + - r_{-q,q} is the qth column of :math:`M^{-1}_{k}` without the qth element + - r_{q, q} is the element of qth column and qth row in :math:`M^{-1}_{k}` + Parameters + ---------- + M_current : np.ndarray (|A|-1, |A|-1) + Inverse matrix in previous iteration + b : np.ndarray (|A|-1, ) + Upper right matrix in Schur complement, a column vector in our case + d : float + Lower right matrix in Schur complement, a scalar in our case + id_pop + Index of the variable to be removed, equal to -1 + if none of the variables is deleted in current iteration + Returns + ------- + M : np.ndarray (|A|, |A|) + Inverse matrix needed in current iteration + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + if b is None: + b = M_current[id_pop, :] + b = np.delete(b, id_pop) + M_del = np.delete(M_current, id_pop, 0) + a = M_del[:, id_pop] + M_del = np.delete(M_del, id_pop, 1) + M = M_del - np.outer(a, b) / M_current[id_pop, id_pop] + else: + n = b.shape[0] + 1 + if np.shape(b)[0] == 0: + M = np.array([[0.5]]) + else: + X = M_current.dot(b) + s = d - b.T.dot(X) + M = np.zeros((n, n)) + M[:-1, :-1] = M_current + X.dot(X.T) / s + X_ravel = X.ravel() + M[-1, :-1] = -X_ravel / s + M[:-1, -1] = -X_ravel / s + M[-1, -1] = 1 / s + return M + + +def construct_augmented_H(active_index, m, Hc, HrHr): + r""" This function construct an augmented matrix for the first iteration of + semi-relaxed regularization path + + .. math:: + Augmented_H = + \begin{bmatrix} + 0 & H_{c A} \\ + H_{c A}^T & H_{r A}^T H_{r A} + \end{bmatrix} + where : + - H_{r A} is the sub-matrix constructed by the columns of H_r + whose indices belong to the active set A + - H_{c A} is the sub-matrix constructed by the columns of H_c + whose indices belong to the active set A + Parameters + ---------- + active_index : list + Indices of active variables + m : int + Length of the target distribution + Hc : np.ndarray (dim_b, dim_a * dim_b) + Matrix that computes the sum along the columns of transport plan T + HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) + Matrix product of H_r^T H_r + Returns + ------- + H_augmented : np.ndarray (dim_b + |A|, dim_b + |A|) + Augmented matrix for the first iteration of the semi-relaxed + regularization path + """ + Hc_sub = Hc[:, active_index].toarray() + HrHr_sub = HrHr[:, active_index] + HrHr_sub = HrHr_sub[active_index, :].toarray() + H_augmented = np.block([[np.zeros((m, m)), Hc_sub], [Hc_sub.T, HrHr_sub]]) + return H_augmented + + +def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, + itmax=50000): + r"""This function gives the regularization path of l2-penalized UOT problem + + The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: + \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + s.t. + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] + - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, + see [Chapel et al., 2021] for the design of H. The matrix product Ht + computes both the source marginal and the target marginal. + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + reg: float + l2-regularization coefficient + itmax: int + Maximum number of iteration + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Flattened vector of optimal transport matrix + t_list : list + List of solutions in regularization path + gamma_list : list + List of regularization coefficient in regularization path + Examples + -------- + >>> import ot + >>> import numpy as np + >>> n = 3 + >>> xs = np.array([1., 2., 3.]).reshape((n, 1)) + >>> xt = np.array([5., 6., 7.]).reshape((n, 1)) + >>> C = ot.dist(xs, xt) + >>> C /= C.max() + >>> a = np.array([0.2, 0.5, 0.3]) + >>> b = np.array([0.2, 0.5, 0.3]) + >>> t, _, _ = ot.regpath.fully_relaxed_path(a, b, C, 1e-4) + >>> t + array([1.99958333e-01, 0.00000000e+00, 0.00000000e+00, 3.88888889e-05, + 4.99938889e-01, 0.00000000e+00, 0.00000000e+00, 3.88888889e-05, + 2.99958333e-01]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + n = np.shape(a)[0] + m = np.shape(b)[0] + H, y, c = recast_ot_as_lasso(a, b, C) + HtH = H.T.dot(H) + Hty = H.T.dot(y) + n_iter = 1 + + # initialization + M0 = Hty / c + gamma_list = [np.max(M0)] + active_index = [np.argmax(M0)] + t_list = [np.zeros((n * m,))] + H_inv = np.array([[]]) + add_col = np.array([]) + id_pop = -1 + + while n_iter < itmax and gamma_list[-1] > reg: + H_inv = complement_schur(H_inv, add_col, 2., id_pop) + current_gamma = gamma_list[-1] + + # compute the intercept and slope of solutions in current iteration + # t = phi - gamma * delta + phi = H_inv.dot(Hty[active_index]) + delta = H_inv.dot(c[active_index]) + gamma, ik = ot_next_gamma(phi, delta, HtH, Hty, c, active_index, + current_gamma) + + # compute the next lambda when removing a point from the active set + alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma) + + # if the positivity constraint is violated, we remove id_pop + # from active set, otherwise we add ik to active set + if alt_gamma > gamma: + gamma = alt_gamma + else: + id_pop = -1 + + # compute the solution of current segment + tA = phi - gamma * delta + sol = np.zeros((n * m, )) + sol[active_index] = tA + + if id_pop != -1: + active_index.pop(id_pop) + add_col = None + else: + active_index.append(ik) + add_col = HtH[active_index[:-1], ik].toarray() + + gamma_list.append(gamma) + t_list.append(sol) + n_iter += 1 + + if itmax <= n_iter: + print('maximum iteration has been reached !') + + # correct the last solution and gamma + if len(t_list) > 1: + t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) * + (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2])) + t_list[-1] = t_final + gamma_list[-1] = reg + else: + gamma_list[-1] = reg + print('Regularization path does not exist !') + + return t_list[-1], t_list, gamma_list + + +def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, + itmax=50000): + r"""This function gives the regularization path of semi-relaxed + l2-UOT problem + + The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: + \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + s.t. + H_c t = b + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - H_r is a (dim_a, dim_a * dim_b) metric matrix, + which computes the sum along the rows of transport plan T + - H_c is a (dim_b, dim_a * dim_b) metric matrix, + which computes the sum along the columns of transport plan T + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + reg: float (optional) + l2-regularization coefficient + itmax: int (optional) + Maximum number of iteration + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Flattened vector of optimal transport matrix + t_list : list + List of solutions in regularization path + gamma_list : list + List of regularization coefficient in regularization path + Examples + -------- + >>> import ot + >>> import numpy as np + >>> n = 3 + >>> xs = np.array([1., 2., 3.]).reshape((n, 1)) + >>> xt = np.array([5., 6., 7.]).reshape((n, 1)) + >>> C = ot.dist(xs, xt) + >>> C /= C.max() + >>> a = np.array([0.2, 0.5, 0.3]) + >>> b = np.array([0.2, 0.5, 0.3]) + >>> t, _, _ = ot.regpath.semi_relaxed_path(a, b, C, 1e-4) + >>> t + array([1.99980556e-01, 0.00000000e+00, 0.00000000e+00, 1.94444444e-05, + 4.99980556e-01, 0.00000000e+00, 0.00000000e+00, 1.94444444e-05, + 3.00000000e-01]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + n = np.shape(a)[0] + m = np.shape(b)[0] + Hr, Hc, c = recast_semi_relaxed_as_lasso(a, b, C) + Hra = Hr.T.dot(a) + HrHr = Hr.T.dot(Hr) + n_iter = 1 + active_index = [] + + # initialization + for j in range(np.shape(C)[1]): + i = np.argmin(C[:, j]) + active_index.append(i * m + j) + gamma_list = [] + t_list = [] + current_gamma = np.Inf + augmented_H0 = construct_augmented_H(active_index, m, Hc, HrHr) + add_col = np.array([]) + id_pop = -1 + + while n_iter < itmax and current_gamma > reg: + if n_iter == 1: + H_inv = np.linalg.inv(augmented_H0) + else: + H_inv = complement_schur(H_inv, add_col, 1., id_pop + m) + # compute the intercept and slope of solutions in current iteration + augmented_phi = H_inv.dot(np.concatenate((b, Hra[active_index]))) + augmented_delta = H_inv[:, m:].dot(c[active_index]) + phi = augmented_phi[m:] + delta = augmented_delta[m:] + phi_u = augmented_phi[0:m] + delta_u = augmented_delta[0:m] + gamma, ik = semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, + HrHr, Hc, Hra, c, active_index, + current_gamma) + + # compute the next lambda when removing a point from the active set + alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma) + + # if the positivity constraint is violated, we remove id_pop + # from active set, otherwise we add ik to active set + if alt_gamma > gamma: + gamma = alt_gamma + else: + id_pop = -1 + + # compute the solution of current segment + tA = phi - gamma * delta + sol = np.zeros((n * m, )) + sol[active_index] = tA + if id_pop != -1: + active_index.pop(id_pop) + add_col = None + else: + active_index.append(ik) + add_col = np.concatenate((Hc.toarray()[:, ik], + HrHr.toarray()[active_index[:-1], ik])) + add_col = add_col[:, np.newaxis] + + gamma_list.append(gamma) + t_list.append(sol) + current_gamma = gamma + n_iter += 1 + + if itmax <= n_iter: + print('maximum iteration has been reached !') + + # correct the last solution and gamma + if len(t_list) > 1: + t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) * + (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2])) + t_list[-1] = t_final + gamma_list[-1] = reg + else: + gamma_list[-1] = reg + print('Regularization path does not exist !') + + return t_list[-1], t_list, gamma_list + + +def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, + semi_relaxed=False, itmax=50000): + r"""This function combines both the semi-relaxed and the fully-relaxed + regularization paths of l2-UOT problem + + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + reg: float (optional) + l2-regularization coefficient + semi_relaxed : bool (optional) + Give the semi-relaxed path if true + itmax: int (optional) + Maximum number of iteration + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Flattened vector of optimal transport matrix + t_list : list + List of solutions in regularization path + gamma_list : list + List of regularization coefficient in regularization path + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + if semi_relaxed: + t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg, + itmax=itmax) + else: + t, t_list, gamma_list = fully_relaxed_path(a, b, C, reg=reg, + itmax=itmax) + return t, t_list, gamma_list + + +def compute_transport_plan(gamma, gamma_list, Pi_list): + r""" Given the regularization path, this function computes the transport + plan for any value of gamma by the piecewise linearity of the path + + .. math:: + t(\gamma) = \phi(\gamma) - \gamma \delta(\gamma) + where : + - :math:`\gamma` is the regularization coefficient + - :math:`\phi(\gamma)` is the corresponding intercept + - :math:`\delta(\gamma)` is the corresponding slope + - t is a (dim_a * dim_b, ) vector (flattened version of transport matrix) + Parameters + ---------- + gamma : float + Regularization coefficient + gamma_list : list + List of regularization coefficients in regularization path + Pi_list : list + List of solutions in regularization path + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Transport vector corresponding to the given value of gamma + Examples + -------- + >>> import ot + >>> import numpy as np + >>> n = 3 + >>> xs = np.array([1., 2., 3.]).reshape((n, 1)) + >>> xt = np.array([5., 6., 7.]).reshape((n, 1)) + >>> C = ot.dist(xs, xt) + >>> C /= C.max() + >>> a = np.array([0.2, 0.5, 0.3]) + >>> b = np.array([0.2, 0.5, 0.3]) + >>> t, pi_list, g_list = ot.regpath.regularization_path(a, b, C, reg=1e-4) + >>> gamma = 1 + >>> t2 = ot.regpath.compute_transport_plan(gamma, g_list, pi_list) + >>> t2 + array([0. , 0. , 0. , 0.19722222, 0.05555556, + 0. , 0. , 0.24722222, 0. ]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + if gamma >= gamma_list[0]: + Pi = Pi_list[0] + elif gamma <= gamma_list[-1]: + Pi = Pi_list[-1] + else: + idx = np.where(gamma <= np.array(gamma_list))[0][-1] + gamma_k0 = gamma_list[idx] + gamma_k1 = gamma_list[idx + 1] + pi_k0 = Pi_list[idx] + pi_k1 = Pi_list[idx + 1] + Pi = pi_k0 + (pi_k1 - pi_k0) * (gamma - gamma_k0) \ + / (gamma_k1 - gamma_k0) + return Pi diff --git a/ot/sliced.py b/ot/sliced.py new file mode 100644 index 0000000..cf2d3be --- /dev/null +++ b/ot/sliced.py @@ -0,0 +1,258 @@ +""" +Sliced OT Distances + +""" + +# Author: Adrien Corenflos <adrien.corenflos@aalto.fi> +# Nicolas Courty <ncourty@irisa.fr> +# Rémi Flamary <remi.flamary@polytechnique.edu> +# +# License: MIT License + + +import numpy as np +from .backend import get_backend, NumpyBackend +from .utils import list_to_array + + +def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): + r""" + Generates n_projections samples from the uniform on the unit sphere of dimension :math:`d-1`: :math:`\mathcal{U}(\mathcal{S}^{d-1})` + + Parameters + ---------- + d : int + dimension of the space + n_projections : int + number of samples requested + seed: int or RandomState, optional + Seed used for numpy random number generator + backend: + Backend to ue for random generation + + Returns + ------- + out: ndarray, shape (d, n_projections) + The uniform unit vectors on the sphere + + Examples + -------- + >>> n_projections = 100 + >>> d = 5 + >>> projs = get_random_projections(d, n_projections) + >>> np.allclose(np.sum(np.square(projs), 0), 1.) # doctest: +NORMALIZE_WHITESPACE + True + + """ + + if backend is None: + nx = NumpyBackend() + else: + nx = backend + + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + projections = seed.randn(d, n_projections) + else: + if seed is not None: + nx.seed(seed) + projections = nx.randn(d, n_projections, type_as=type_as) + + projections = projections / nx.sqrt(nx.sum(projections**2, 0, keepdims=True)) + return projections + + +def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, + projections=None, seed=None, log=False): + r""" + Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance + + .. math:: + \mathcal{SWD}_p(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}\left(\mathcal{W}_p^p(\theta_\# \mu, \theta_\# \nu)\right)^{\frac{1}{p}} + + + where : + + - :math:`\theta_\# \mu` stands for the pushforwards of the projection :math:`X \in \mathbb{R}^d \mapsto \langle \theta, X \rangle` + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional = + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Sliced Wasserstein Cost + log : dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 20 + >>> reg = 0.1 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + + .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + """ + from .lp import wasserstein_1d + + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + + if a is None: + a = nx.full(n, 1 / n, type_as=X_s) + if b is None: + b = nx.full(m, 1 / m, type_as=X_s) + + d = X_s.shape[1] + + if projections is None: + projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) + + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) + + projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p) + + res = (nx.sum(projected_emd) / n_projections) ** (1.0 / p) + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res + + +def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, + projections=None, seed=None, log=False): + r""" + Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance + + .. math:: + \mathcal{Max-SWD}_p(\mu, \nu) = \underset{\theta _in + \mathcal{U}(\mathbb{S}^{d-1})}{\max} [\mathcal{W}_p^p(\theta_\# + \mu, \theta_\# \nu)]^{\frac{1}{p}} + + where : + + - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle` + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional = + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Sliced Wasserstein Cost + log : dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 20 + >>> reg = 0.1 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + + .. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). Max-sliced wasserstein distance and its use for gans. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). + """ + from .lp import wasserstein_1d + + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + + if a is None: + a = nx.full(n, 1 / n, type_as=X_s) + if b is None: + b = nx.full(m, 1 / m, type_as=X_s) + + d = X_s.shape[1] + + if projections is None: + projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) + + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) + + projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p) + + res = nx.max(projected_emd) ** (1.0 / p) + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res diff --git a/ot/smooth.py b/ot/smooth.py index 81f6a3e..6855005 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -47,15 +47,24 @@ from scipy.optimize import minimize def projection_simplex(V, z=1, axis=None): - """ Projection of x onto the simplex, scaled by z + r""" Projection of :math:`\mathbf{V}` onto the simplex, scaled by `z` - P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2 + .. math:: + P\left(\mathbf{V}, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z}} \quad \|\mathbf{y} - \mathbf{V}\|^2 + + Parameters + ---------- + V: ndarray, rank 2 z: float or array - If array, len(z) must be compatible with V + If array, len(z) must be compatible with :math:`\mathbf{V}` axis: None or int - - axis=None: project V by P(V.ravel(); z) - - axis=1: project each V[i] by P(V[i]; z[i]) - - axis=0: project each V[:, j] by P(V[:, j]; z[j]) + - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), z)` + - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, z_i)` + - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, z_j)` + + Returns + ------- + projection: ndarray, shape :math:`\mathbf{V}`.shape """ if axis == 1: n_features = V.shape[1] @@ -77,12 +86,12 @@ def projection_simplex(V, z=1, axis=None): class Regularization(object): - """Base class for Regularization objects + r"""Base class for Regularization objects Notes ----- - This class is not intended for direct use but as aparent for true - regularizatiojn implementation. + This class is not intended for direct use but as apparent for true + regularization implementation. """ def __init__(self, gamma=1.0): @@ -98,40 +107,48 @@ class Regularization(object): self.gamma = gamma def delta_Omega(X): - """ - Compute delta_Omega(X[:, j]) for each X[:, j]. - delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y). + r""" + Compute :math:`\delta_\Omega(\mathbf{X}_{:, j})` for each :math:`\mathbf{X}_{:, j}`. + + .. math:: + \delta_\Omega(\mathbf{x}) = \sup_{\mathbf{y} >= 0} \ + \mathbf{y}^T \mathbf{x} - \Omega(\mathbf{y}) Parameters ---------- - X: array, shape = len(a) x len(b) + X: array, shape = (len(a), len(b)) Input array. Returns ------- - v: array, len(b) - Values: v[j] = delta_Omega(X[:, j]) - G: array, len(a) x len(b) - Gradients: G[:, j] = nabla delta_Omega(X[:, j]) + v: array, (len(b), ) + Values: :math:`\mathbf{v}_j = \delta_\Omega(\mathbf{X}_{:, j})` + G: array, (len(a), len(b)) + Gradients: :math:`\mathbf{G}_{:, j} = \nabla \delta_\Omega(\mathbf{X}_{:, j})` """ raise NotImplementedError def max_Omega(X, b): - """ - Compute max_Omega_j(X[:, j]) for each X[:, j]. - max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j]. + r""" + Compute :math:`\mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})` for each :math:`\mathbf{X}_{:, j}`. + + .. math:: + \mathrm{max}_{\Omega, j}(\mathbf{x}) = + \sup_{\substack{\mathbf{y} >= 0 \ \sum_i \mathbf{y}_i = 1}} + \mathbf{y}^T \mathbf{x} - \frac{1}{\mathbf{b}_j} \Omega(\mathbf{b}_j \mathbf{y}) Parameters ---------- - X: array, shape = len(a) x len(b) + X: array, shape = (len(a), len(b)) Input array. + b: array, shape = (len(b), ) Returns ------- - v: array, len(b) - Values: v[j] = max_Omega_j(X[:, j]) - G: array, len(a) x len(b) - Gradients: G[:, j] = nabla max_Omega_j(X[:, j]) + v: array, (len(b), ) + Values: :math:`\mathbf{v}_j = \mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})` + G: array, (len(a), len(b)) + Gradients: :math:`\mathbf{G}_{:, j} = \nabla \mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})` """ raise NotImplementedError @@ -192,7 +209,7 @@ class SquaredL2(Regularization): def dual_obj_grad(alpha, beta, a, b, C, regul): - """ + r""" Compute objective value and gradients of dual objective. Parameters @@ -203,19 +220,19 @@ def dual_obj_grad(alpha, beta, a, b, C, regul): a: array, shape = len(a) b: array, shape = len(b) Input histograms (should be non-negative and sum to 1). - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a delta_Omega(X) method. + Should implement a `delta_Omega(X)` method. Returns ------- obj: float Objective value (higher is better). grad_alpha: array, shape = len(a) - Gradient w.r.t. alpha. + Gradient w.r.t. `alpha`. grad_beta: array, shape = len(b) - Gradient w.r.t. beta. + Gradient w.r.t. `beta`. """ obj = np.dot(alpha, a) + np.dot(beta, b) grad_alpha = a.copy() @@ -242,13 +259,13 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, Parameters ---------- - a: array, shape = len(a) - b: array, shape = len(b) + a: array, shape = (len(a), ) + b: array, shape = (len(b), ) Input histograms (should be non-negative and sum to 1). - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a delta_Omega(X) method. + Should implement a `delta_Omega(X)` method. method: str Solver to be used (passed to `scipy.optimize.minimize`). tol: float @@ -258,8 +275,8 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, Returns ------- - alpha: array, shape = len(a) - beta: array, shape = len(b) + alpha: array, shape = (len(a), ) + beta: array, shape = (len(b), ) Dual potentials. """ @@ -302,10 +319,10 @@ def semi_dual_obj_grad(alpha, a, b, C, regul): a: array, shape = len(a) b: array, shape = len(b) Input histograms (should be non-negative and sum to 1). - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a max_Omega(X) method. + Should implement a `max_Omega(X)` method. Returns ------- @@ -337,13 +354,13 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, Parameters ---------- - a: array, shape = len(a) - b: array, shape = len(b) + a: array, shape = (len(a), ) + b: array, shape = (len(b), ) Input histograms (should be non-negative and sum to 1). - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a max_Omega(X) method. + Should implement a `max_Omega(X)` method. method: str Solver to be used (passed to `scipy.optimize.minimize`). tol: float @@ -353,7 +370,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, Returns ------- - alpha: array, shape = len(a) + alpha: array, shape = (len(a), ) Semi-dual potentials. """ @@ -371,7 +388,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500, def get_plan_from_dual(alpha, beta, C, regul): - """ + r""" Retrieve optimal transportation plan from optimal dual potentials. Parameters @@ -379,14 +396,14 @@ def get_plan_from_dual(alpha, beta, C, regul): alpha: array, shape = len(a) beta: array, shape = len(b) Optimal dual potentials. - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a delta_Omega(X) method. + Should implement a `delta_Omega(X)` method. Returns ------- - T: array, shape = len(a) x len(b) + T: array, shape = (len(a), len(b)) Optimal transportation plan. """ X = alpha[:, np.newaxis] + beta - C @@ -394,7 +411,7 @@ def get_plan_from_dual(alpha, beta, C, regul): def get_plan_from_semi_dual(alpha, b, C, regul): - """ + r""" Retrieve optimal transportation plan from optimal semi-dual potentials. Parameters @@ -403,14 +420,14 @@ def get_plan_from_semi_dual(alpha, b, C, regul): Optimal semi-dual potentials. b: array, shape = len(b) Second input histogram (should be non-negative and sum to 1). - C: array, shape = len(a) x len(b) + C: array, shape = (len(a), len(b)) Ground cost matrix. regul: Regularization object - Should implement a delta_Omega(X) method. + Should implement a `delta_Omega(X)` method. Returns ------- - T: array, shape = len(a) x len(b) + T: array, shape = (len(a), len(b)) Optimal transportation plan. """ X = alpha[:, np.newaxis] - C @@ -422,19 +439,21 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, r""" Solve the regularized OT problem in the dual and return the OT matrix - The function solves the smooth relaxed dual formulation (7) in [17]_ : + The function solves the smooth relaxed dual formulation (7) in + :ref:`[17] <references-smooth-ot-dual>`: .. math:: - \max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j) + \max_{\alpha,\beta}\quad \mathbf{a}^T\alpha + \mathbf{b}^T\beta - + \sum_j \delta_\Omega \left(\alpha+\beta_j-\mathbf{m}_j \right) where : - - :math:`\mathbf{m}_j` is the jth column of the cost matrix + - :math:`\mathbf{m}_j` is the j-th column of the cost matrix - :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega` - (See [17]_ Proposition 1). + (See :ref:`[17] <references-smooth-ot-dual>` Proposition 1). The optimization algorithm is using gradient decent (L-BFGS by default). @@ -444,21 +463,25 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, samples weights in the source domain b : np.ndarray (nt,) or np.ndarray (nt,nbb) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : np.ndarray (ns,nt) loss matrix reg : float Regularization term >0 reg_type : str - Regularization type, can be the following (default ='l2'): - - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_) - - 'l2' : Squared Euclidean regularization + Regularization type, can be the following (default ='l2'): + + - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn + :ref:`[2] <references-smooth-ot-dual>`) + + - 'l2' : Squared Euclidean regularization method : str Solver to use for scipy.optimize.minimize numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -467,15 +490,15 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters + .. _references-smooth-ot-dual: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). @@ -514,21 +537,23 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= r""" Solve the regularized OT problem in the semi-dual and return the OT matrix - The function solves the smooth relaxed dual formulation (10) in [17]_ : + The function solves the smooth relaxed dual formulation (10) in + :ref:`[17] <references-smooth-ot-semi-dual>`: .. math:: - \max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b) + \max_{\alpha}\quad \mathbf{a}^T\alpha- \mathrm{OT}_\Omega^*(\alpha, \mathbf{b}) where : .. math:: - OT_\Omega^*(\alpha,b)=\sum_j b_j + \mathrm{OT}_\Omega^*(\alpha,b)=\sum_j \mathbf{b}_j - - :math:`\mathbf{m}_j` is the jth column of the cost matrix - - :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17] - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{m}_j` is the j-th column of the cost matrix + - :math:`\mathrm{OT}_\Omega^*(\alpha,b)` is defined in Eq. (9) in + :ref:`[17] <references-smooth-ot-semi-dual>` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - The OT matrix can is reconstructed using [17]_ Proposition 2. + The OT matrix can is reconstructed using :ref:`[17] <references-smooth-ot-semi-dual>` Proposition 2. The optimization algorithm is using gradient decent (L-BFGS by default). @@ -538,21 +563,25 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= samples weights in the source domain b : np.ndarray (nt,) or np.ndarray (nt,nbb) samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) + and fixed:math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : np.ndarray (ns,nt) loss matrix reg : float Regularization term >0 reg_type : str - Regularization type, can be the following (default ='l2'): - - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_) - - 'l2' : Squared Euclidean regularization + Regularization type, can be the following (default ='l2'): + + - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn + :ref:`[2] <references-smooth-ot-semi-dual>`) + + - 'l2' : Squared Euclidean regularization method : str Solver to use for scipy.optimize.minimize numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -561,15 +590,15 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= Returns ------- - gamma : (ns x nt) ndarray + gamma : (ns, nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters + .. _references-smooth-ot-semi-dual: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). diff --git a/ot/stochastic.py b/ot/stochastic.py index 13ed9cc..693675f 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -18,22 +18,25 @@ import numpy as np def coordinate_grad_semi_dual(b, M, reg, beta, i): r''' - Compute the coordinate gradient update for regularized discrete distributions for (i, :) + Compute the coordinate gradient update for regularized discrete distributions for :math:`(i, :)` The function computes the gradient of the semi dual problem: .. math:: - \max_v \sum_i (\sum_j v_j * b_j - reg * log(\sum_j exp((v_j - M_{i,j})/reg) * b_j)) * a_i + \max_\mathbf{v} \ \sum_i \mathbf{a}_i \left[ \sum_j \mathbf{v}_j \mathbf{b}_j - \mathrm{reg} + \cdot \log \left( \sum_j \mathbf{b}_j + \exp \left( \frac{\mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}} + \right) \right) \right] Where : - - M is the (ns,nt) metric cost matrix - - v is a dual variable in R^J + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{v}` is a dual variable in :math:`\mathbb{R}^{nt}` - reg is the regularization term - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is the ASGD & SAG algorithms - as proposed in [18]_ [alg.1 & alg.2] + as proposed in :ref:`[18] <references-coordinate-grad-semi-dual>` [alg.1 & alg.2] Parameters @@ -47,7 +50,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): v : ndarray, shape (nt,) Dual variable. i : int - Picked number i. + Picked number `i`. Returns ------- @@ -74,12 +77,10 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + .. _references-coordinate-grad-semi-dual: References ---------- - [Genevay et al., 2016] : - Stochastic Optimization for Large-scale Optimal Transport, - Advances in Neural Information Processing Systems (2016), - arXiv preprint arxiv:1605.08527. + .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). ''' r = M[i, :] - beta exp_beta = np.exp(-r / reg) * b @@ -88,29 +89,29 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i): def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None): - r''' - Compute the SAG algorithm to solve the regularized discrete measures - optimal transport max problem + r""" + Compute the SAG algorithm to solve the regularized discrete measures optimal transport max problem The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1 = b + \gamma^T \mathbf{1} = \mathbf{b} \gamma \geq 0 Where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is the SAG algorithm - as proposed in [18]_ [alg.1] + as proposed in :ref:`[18] <references-sag-entropic-transport>` [alg.1] Parameters @@ -131,7 +132,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None): Returns ------- - v : ndarray, shape (nt,) + v : ndarray, shape (`nt`,) Dual variable. Examples @@ -154,14 +155,12 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None): [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + .. _references-sag-entropic-transport: References ---------- - - [Genevay et al., 2016] : - Stochastic Optimization for Large-scale Optimal Transport, - Advances in Neural Information Processing Systems (2016), - arXiv preprint arxiv:1605.08527. - ''' + .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). + """ if lr is None: lr = 1. / max(a / reg) @@ -187,22 +186,23 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None): The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} \gamma \geq 0 Where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is the ASGD algorithm - as proposed in [18]_ [alg.2] + as proposed in :ref:`[18] <references-averaged-sgd-entropic-transport>` [alg.2] Parameters @@ -220,7 +220,7 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None): Returns ------- - ave_v : ndarray, shape (nt,) + ave_v : ndarray, shape (`nt`,) dual variable Examples @@ -243,13 +243,11 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None): [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + .. _references-averaged-sgd-entropic-transport: References ---------- - - [Genevay et al., 2016] : - Stochastic Optimization for Large-scale Optimal Transport, - Advances in Neural Information Processing Systems (2016), - arXiv preprint arxiv:1605.08527. + .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). ''' if lr is None: @@ -271,20 +269,21 @@ def c_transform_entropic(b, M, reg, beta): r''' The goal is to recover u from the c-transform. - The function computes the c_transform of a dual variable from the other + The function computes the c-transform of a dual variable from the other dual variable: .. math:: - u = v^{c,reg} = -reg \sum_j exp((v - M)/reg) b_j + \mathbf{u} = \mathbf{v}^{c,reg} = - \mathrm{reg} \sum_j \mathbf{b}_j + \exp\left( \frac{\mathbf{v} - \mathbf{M}}{\mathrm{reg}} \right) Where : - - M is the (ns,nt) metric cost matrix - - u, v are dual variables in R^IxR^J + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{u}`, :math:`\mathbf{v}` are dual variables in :math:`\mathbb{R}^{ns} \times \mathbb{R}^{nt}` - reg is the regularization term It is used to recover an optimal u from optimal v solving the semi dual - problem, see Proposition 2.1 of [18]_ + problem, see Proposition 2.1 of :ref:`[18] <references-c-transform-entropic>` Parameters @@ -300,7 +299,7 @@ def c_transform_entropic(b, M, reg, beta): Returns ------- - u : ndarray, shape (ns,) + u : ndarray, shape (`ns`,) Dual variable. Examples @@ -323,13 +322,11 @@ def c_transform_entropic(b, M, reg, beta): [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + .. _references-c-transform-entropic: References ---------- - - [Genevay et al., 2016] : - Stochastic Optimization for Large-scale Optimal Transport, - Advances in Neural Information Processing Systems (2016), - arXiv preprint arxiv:1605.08527. + .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). ''' n_source = np.shape(M)[0] @@ -345,27 +342,28 @@ def c_transform_entropic(b, M, reg, beta): def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, log=False): r''' - Compute the transportation matrix to solve the regularized discrete - measures optimal transport max problem + Compute the transportation matrix to solve the regularized discrete measures optimal transport max problem The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} \gamma \geq 0 Where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + The algorithm used for solving the problem is the SAG or ASGD algorithms - as proposed in [18]_ + as proposed in :ref:`[18] <references-solve-semi-dual-entropic>` Parameters @@ -419,13 +417,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01], [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]]) + + .. _references-solve-semi-dual-entropic: References ---------- - - [Genevay et al., 2016] : - Stochastic Optimization for Large-scale Optimal Transport, - Advances in Neural Information Processing Systems (2016), - arXiv preprint arxiv:1605.08527. + .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016). ''' if method.lower() == "sag": @@ -459,26 +455,30 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, r''' Computes the partial gradient of the dual optimal transport problem. - For each (i,j) in a batch of coordinates, the partial gradients are : + For each :math:`(i,j)` in a batch of coordinates, the partial gradients are : .. math:: - \partial_{u_i} F = u_i * b_s/l_{v} - \sum_{j \in B_v} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j + \partial_{\mathbf{u}_i} F = \frac{b_s}{l_v} \mathbf{u}_i - + \sum_{j \in B_v} \mathbf{a}_i \mathbf{b}_j + \exp\left( \frac{\mathbf{u}_i + \mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}} \right) - \partial_{v_j} F = v_j * b_s/l_{u} - \sum_{i \in B_u} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j + \partial_{\mathbf{v}_j} F = \frac{b_s}{l_u} \mathbf{v}_j - + \sum_{i \in B_u} \mathbf{a}_i \mathbf{b}_j + \exp\left( \frac{\mathbf{u}_i + \mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}} \right) Where : - - M is the (ns,nt) metric cost matrix - - u, v are dual variables in R^ixR^J + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\mathbf{u}`, :math:`\mathbf{v}` are dual variables in :math:`\mathbb{R}^{ns} \times \mathbb{R}^{nt}` - reg is the regularization term - :math:`B_u` and :math:`B_v` are lists of index - - :math:`b_s` is the size of the batchs :math:`B_u` and :math:`B_v` - - :math:`l_u` and :math:`l_v` are the lenghts of :math:`B_u` and :math:`B_v` - - a and b are source and target weights (sum to 1) + - :math:`b_s` is the size of the batches :math:`B_u` and :math:`B_v` + - :math:`l_u` and :math:`l_v` are the lengths of :math:`B_u` and :math:`B_v` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the dual problem is the SGD algorithm - as proposed in [19]_ [alg.1] + as proposed in :ref:`[19] <references-batch-grad-dual>` [alg.1] Parameters @@ -504,7 +504,7 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, Returns ------- - grad : ndarray, shape (ns,) + grad : ndarray, shape (`ns`,) partial grad F Examples @@ -533,12 +533,11 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, [5.06266486e-02, 2.16230494e-03, 2.26215141e-03, 6.81514609e-04], [6.06713990e-02, 3.98139808e-02, 5.46829338e-02, 8.62371424e-06]]) + + .. _references-batch-grad-dual: References ---------- - - [Seguy et al., 2018] : - International Conference on Learning Representation (2018), - arXiv preprint arxiv:1711.02283. + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) ''' G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] - M[batch_alpha, :][:, batch_beta]) / reg) * @@ -555,25 +554,25 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha, def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): r''' - Compute the sgd algorithm to solve the regularized discrete measures - optimal transport dual problem + Compute the sgd algorithm to solve the regularized discrete measures optimal transport dual problem The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} \gamma \geq 0 Where : - - M is the (ns,nt) metric cost matrix + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- @@ -632,9 +631,7 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): References ---------- - [Seguy et al., 2018] : - International Conference on Learning Representation (2018), - arXiv preprint arxiv:1711.02283. + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) ''' n_source = np.shape(M)[0] @@ -657,25 +654,25 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr): def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, log=False): r''' - Compute the transportation matrix to solve the regularized discrete measures - optimal transport dual problem + Compute the transportation matrix to solve the regularized discrete measures optimal transport dual problem The function solves the following optimization problem: .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) - s.t. \gamma 1 = a + s.t. \ \gamma \mathbf{1} = \mathbf{a} - \gamma^T 1= b + \gamma^T \mathbf{1} = \mathbf{b} \gamma \geq 0 Where : - - M is the (ns,nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) Parameters ---------- @@ -736,10 +733,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, References ---------- - - [Seguy et al., 2018] : - International Conference on Learning Representation (2018), - arXiv preprint arxiv:1711.02283. + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) ''' opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size, diff --git a/ot/unbalanced.py b/ot/unbalanced.py index e37f10c..15e180b 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -23,29 +23,31 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + W = \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. - \gamma\geq 0 + \gamma \geq 0 + where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization - term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target unbalanced distributions + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced>` Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) - One or multiple unnormalized histograms of dimension dim_b - If many, compute all the OT distances (a, b_i) + One or multiple unnormalized histograms of dimension `dim_b`. + If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : np.ndarray (dim_a, dim_b) loss matrix reg : float @@ -58,7 +60,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -68,14 +70,14 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Returns ------- if n_hists == 1: - gamma : (dim_a x dim_b) ndarray + - gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters - log : dict + - log : dict log dictionary returned only if `log` is `True` else: - ot_distance : (n_hists,) ndarray - the OT distance between `a` and each of the histograms `b_i` - log : dict + - ot_distance : (n_hists,) ndarray + the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` + - log : dict log dictionary returned only if `log` is `True` Examples @@ -90,9 +92,9 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, [0.18807035, 0.51122823]]) + .. _references-sinkhorn-unbalanced: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 @@ -111,11 +113,11 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, See Also -------- - ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10] + ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced>` ot.unbalanced.sinkhorn_stabilized_unbalanced: - Unbalanced Stabilized sinkhorn [9][10] + Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced>` ot.unbalanced.sinkhorn_reg_scaling_unbalanced: - Unbalanced Sinkhorn with epslilon scaling [9][10] + Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced>` """ @@ -151,29 +153,30 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma\geq 0 where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target unbalanced distributions + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced2>` Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) - One or multiple unnormalized histograms of dimension dim_b - If many, compute all the OT distances (a, b_i) + One or multiple unnormalized histograms of dimension `dim_b`. + If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : np.ndarray (dim_a, dim_b) loss matrix reg : float @@ -186,7 +189,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -196,7 +199,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Returns ------- ot_distance : (n_hists,) ndarray - the OT distance between `a` and each of the histograms `b_i` + the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` log : dict log dictionary returned only if `log` is `True` @@ -211,10 +214,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', array([0.31912866]) - + .. _references-sinkhorn-unbalanced2: References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 @@ -232,9 +234,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', See Also -------- - ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10] - ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10] - ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10] + ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced2>` + ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced2>` + ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>` """ b = np.asarray(b, dtype=np.float64) @@ -270,26 +272,29 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, The function solves the following optimization problem: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. - \gamma\geq 0 + \gamma \geq 0 + where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target unbalanced distributions + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence - The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-knopp-unbalanced>` Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) - One or multiple unnormalized histograms of dimension dim_b + One or multiple unnormalized histograms of dimension `dim_b` If many, compute all the OT distances (a, b_i) M : np.ndarray (dim_a, dim_b) loss matrix @@ -300,7 +305,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -310,15 +315,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Returns ------- if n_hists == 1: - gamma : (dim_a x dim_b) ndarray + - gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters - log : dict + - log : dict log dictionary returned only if `log` is `True` else: - ot_distance : (n_hists,) ndarray - the OT distance between `a` and each of the histograms `b_i` - log : dict + - ot_distance : (n_hists,) ndarray + the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` + - log : dict log dictionary returned only if `log` is `True` + Examples -------- @@ -330,9 +336,10 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, array([[0.51122823, 0.18807035], [0.18807035, 0.51122823]]) + + .. _references-sinkhorn-knopp-unbalanced: References ---------- - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. @@ -445,32 +452,34 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 problem and return the loss The function solves the following optimization problem using log-domain - stabilization as proposed in [10]: + stabilization as proposed in :ref:`[10] <references-sinkhorn-stabilized-unbalanced>`: .. math:: - W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b) + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. - \gamma\geq 0 + \gamma \geq 0 + where : - - M is the (dim_a, dim_b) metric cost matrix - - :math:`\Omega` is the entropic regularization - term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target unbalanced distributions + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_ + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-stabilized-unbalanced>` Parameters ---------- a : np.ndarray (dim_a,) - Unnormalized histogram of dimension dim_a + Unnormalized histogram of dimension `dim_a` b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) - One or multiple unnormalized histograms of dimension dim_b - If many, compute all the OT distances (a, b_i) + One or multiple unnormalized histograms of dimension `dim_b`. + If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : np.ndarray (dim_a, dim_b) loss matrix reg : float @@ -482,7 +491,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (>0) + Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional @@ -492,14 +501,14 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Returns ------- if n_hists == 1: - gamma : (dim_a x dim_b) ndarray + - gamma : (dim_a, dim_b) ndarray Optimal transportation matrix for the given parameters - log : dict + - log : dict log dictionary returned only if `log` is `True` else: - ot_distance : (n_hists,) ndarray - the OT distance between `a` and each of the histograms `b_i` - log : dict + - ot_distance : (n_hists,) ndarray + the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` + - log : dict log dictionary returned only if `log` is `True` Examples -------- @@ -512,9 +521,10 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 array([[0.51122823, 0.18807035], [0.18807035, 0.51122823]]) + + .. _references-sinkhorn-stabilized-unbalanced: References ---------- - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. @@ -654,29 +664,27 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, numItermax=1000, stopThr=1e-6, verbose=False, log=False): - r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization. + r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i) where : - - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized - Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - - :math:`\mathbf{a}_i` are training distributions in the columns of - matrix :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT + - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - reg_mis the marginal relaxation hyperparameter - The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + + The algorithm used for solving the problem is the generalized + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-stabilized>` Parameters ---------- A : np.ndarray (dim, n_hists) - `n_hists` training distributions a_i of dimension dim + `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` M : np.ndarray (dim, dim) ground metric matrix for OT. reg : float @@ -691,7 +699,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -706,9 +714,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, log dictionary return only if log==True in parameters + .. _references-barycenter-unbalanced-stabilized: References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. @@ -806,29 +814,27 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False): - r"""Compute the entropic unbalanced wasserstein barycenter of A. + r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`. - The function solves the following optimization problem with a + The function solves the following optimization problem with :math:`\mathbf{a}` .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i) where : - - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized - Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT + - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-sinkhorn>` Parameters ---------- A : np.ndarray (dim, n_hists) - `n_hists` training distributions a_i of dimension dim + `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` M : np.ndarray (dim, dim) ground metric matrix for OT. reg : float @@ -841,7 +847,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -856,9 +862,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, log dictionary return only if log==True in parameters + .. _references-barycenter-unbalanced-sinkhorn: References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. @@ -936,29 +942,27 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): - r"""Compute the entropic unbalanced wasserstein barycenter of A. + r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`. - The function solves the following optimization problem with a + The function solves the following optimization problem with :math:`\mathbf{a}` .. math:: - \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i) + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i) where : - - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized - Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` - - reg and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT + - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` + - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - reg_mis the marginal relaxation hyperparameter + The algorithm used for solving the problem is the generalized - Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_ + Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced>` Parameters ---------- A : np.ndarray (dim, n_hists) - `n_hists` training distributions a_i of dimension dim + `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` M : np.ndarray (dim, dim) ground metric matrix for OT. reg : float @@ -971,7 +975,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, numItermax : int, optional Max number of iterations stopThr : float, optional - Stop threshol on error (> 0) + Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional @@ -986,9 +990,9 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, log dictionary return only if log==True in parameters + .. _references-barycenter-unbalanced: References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. diff --git a/ot/utils.py b/ot/utils.py index f9911a1..c878563 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -7,7 +7,6 @@ Various useful functions # # License: MIT License -import multiprocessing from functools import reduce import time @@ -15,67 +14,127 @@ import numpy as np from scipy.spatial.distance import cdist import sys import warnings -try: - from inspect import signature -except ImportError: - from .externals.funcsigs import signature +from inspect import signature +from .backend import get_backend __time_tic_toc = time.time() def tic(): - """ Python implementation of Matlab tic() function """ + r""" Python implementation of Matlab tic() function """ global __time_tic_toc __time_tic_toc = time.time() def toc(message='Elapsed time : {} s'): - """ Python implementation of Matlab toc() function """ + r""" Python implementation of Matlab toc() function """ t = time.time() print(message.format(t - __time_tic_toc)) return t - __time_tic_toc def toq(): - """ Python implementation of Julia toc() function """ + r""" Python implementation of Julia toc() function """ t = time.time() return t - __time_tic_toc def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): - """Compute kernel matrix""" + r"""Compute kernel matrix""" + + nx = get_backend(x1, x2) + if method.lower() in ['gaussian', 'gauss', 'rbf']: - K = np.exp(-dist(x1, x2) / (2 * sigma**2)) + K = nx.exp(-dist(x1, x2) / (2 * sigma**2)) return K def laplacian(x): - """Compute Laplacian matrix""" + r"""Compute Laplacian matrix""" L = np.diag(np.sum(x, axis=0)) - x return L -def unif(n): - """ return a uniform histogram of length n (simplex) +def list_to_array(*lst): + r""" Convert a list if in numpy format """ + if len(lst) > 1: + return [np.array(a) if isinstance(a, list) else a for a in lst] + else: + return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + + +def proj_simplex(v, z=1): + r"""Compute the closest point (orthogonal projection) on the + generalized `(n-1)`-simplex of a vector :math:`\mathbf{v}` wrt. to the Euclidean + distance, thus solving: + + .. math:: + \mathcal{P}(w) \in \mathop{\arg \min}_\gamma \| \gamma - \mathbf{v} \|_2 + + s.t. \ \gamma^T \mathbf{1} = z + + \gamma \geq 0 + + If :math:`\mathbf{v}` is a 2d array, compute all the projections wrt. axis 0 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Parameters ---------- + v : {array-like}, shape (n, d) + z : int, optional + 'size' of the simplex (each vectors sum to z, 1 by default) + + Returns + ------- + h : ndarray, shape (`n`, `d`) + Array of projections on the simplex + """ + nx = get_backend(v) + n = v.shape[0] + if v.ndim == 1: + d1 = 1 + v = v[:, None] + else: + d1 = 0 + d = v.shape[1] + + # sort u in ascending order + u = nx.sort(v, axis=0) + # take the descending order + u = nx.flip(u, 0) + cssv = nx.cumsum(u, axis=0) - z + ind = nx.arange(n, type_as=v)[:, None] + 1 + cond = u - cssv / ind > 0 + rho = nx.sum(cond, 0) + theta = cssv[rho - 1, nx.arange(d)] / rho + w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v)) + if d1: + return w[:, 0] + else: + return w + + +def unif(n): + r""" + Return a uniform histogram of length `n` (simplex). + Parameters + ---------- n : int number of bins in the histogram Returns ------- - h : np.array (n,) - histogram of length n such that h_i=1/n for all i - - + h : np.array (`n`,) + histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}` """ return np.ones((n,)) / n def clean_zeros(a, b, M): - """ Remove all components with zeros weights in a and b + r""" Remove all components with zeros weights in :math:`\mathbf{a}` and :math:`\mathbf{b}` """ M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd) a2 = a[a > 0] @@ -84,55 +143,71 @@ def clean_zeros(a, b, M): def euclidean_distances(X, Y, squared=False): - """ - Considering the rows of X (and Y=X) as vectors, compute the + r""" + Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the distance matrix between each pair of vectors. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + Parameters ---------- - X : {array-like}, shape (n_samples_1, n_features) - Y : {array-like}, shape (n_samples_2, n_features) + X : array-like, shape (n_samples_1, n_features) + Y : array-like, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. + Returns ------- - distances : {array}, shape (n_samples_1, n_samples_2) + distances : array-like, shape (`n_samples_1`, `n_samples_2`) """ - XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] - YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] - distances = np.dot(X, Y.T) - distances *= -2 - distances += XX - distances += YY - np.maximum(distances, 0, out=distances) + + nx = get_backend(X, Y) + + a2 = nx.einsum('ij,ij->i', X, X) + b2 = nx.einsum('ij,ij->i', Y, Y) + + c = -2 * nx.dot(X, Y.T) + c += a2[:, None] + c += b2[None, :] + + c = nx.maximum(c, 0) + + if not squared: + c = nx.sqrt(c) + if X is Y: - # Ensure that distances between vectors and themselves are set to 0.0. - # This may not be the case due to floating point rounding errors. - distances.flat[::distances.shape[0] + 1] = 0.0 - return distances if squared else np.sqrt(distances, out=distances) + c = c * (1 - nx.eye(X.shape[0], type_as=c)) + + return c + +def dist(x1, x2=None, metric='sqeuclidean', p=2): + r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` -def dist(x1, x2=None, metric='sqeuclidean'): - """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Parameters ---------- - x1 : ndarray, shape (n1,d) - matrix with n1 samples of size d - x2 : array, shape (n2,d), optional - matrix with n2 samples of size d (if None then x2=x1) + x1 : array-like, shape (n1,d) + matrix with `n1` samples of size `d` + x2 : array-like, shape (n2,d), optional + matrix with `n2` samples of size `d` (if None then :math:`\mathbf{x_2} = \mathbf{x_1}`) metric : str | callable, optional - Name of the metric to be computed (full list in the doc of scipy), If a string, - the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', - 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', - 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also + accepts from the scipy.spatial.distance.cdist function : 'braycurtis', + 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', + 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. Returns ------- - M : np.array (n1,n2) + M : array-like, shape (`n1`, `n2`) distance matrix computed with given metric """ @@ -140,11 +215,17 @@ def dist(x1, x2=None, metric='sqeuclidean'): x2 = x1 if metric == "sqeuclidean": return euclidean_distances(x1, x2, squared=True) - return cdist(x1, x2, metric=metric) + elif metric == "euclidean": + return euclidean_distances(x1, x2, squared=False) + else: + if not get_backend(x1, x2).__name__ == 'numpy': + raise NotImplementedError() + else: + return cdist(x1, x2, metric=metric, p=p) def dist0(n, method='lin_square'): - """Compute standard cost matrices of size (n, n) for OT problems + r"""Compute standard cost matrices of size (`n`, `n`) for OT problems Parameters ---------- @@ -153,11 +234,11 @@ def dist0(n, method='lin_square'): method : str, optional Type of loss matrix chosen from: - * 'lin_square' : linear sampling between 0 and n-1, quadratic loss + * 'lin_square' : linear sampling between 0 and `n-1`, quadratic loss Returns ------- - M : ndarray, shape (n1,n2) + M : ndarray, shape (`n1`, `n2`) Distance matrix computed with given metric. """ res = 0 @@ -168,7 +249,7 @@ def dist0(n, method='lin_square'): def cost_normalization(C, norm=None): - """ Apply normalization to the loss matrix + r""" Apply normalization to the loss matrix Parameters ---------- @@ -180,7 +261,7 @@ def cost_normalization(C, norm=None): Returns ------- - C : ndarray, shape (n1, n2) + C : ndarray, shape (`n1`, `n2`) The input cost matrix normalized according to given norm. """ @@ -202,23 +283,23 @@ def cost_normalization(C, norm=None): def dots(*args): - """ dots function for multiple matrix multiply """ + r""" dots function for multiple matrix multiply """ return reduce(np.dot, args) def label_normalization(y, start=0): - """ Transform labels to start at a given value + r""" Transform labels to start at a given value Parameters ---------- y : array-like, shape (n, ) The vector of labels to be normalized. start : int - Desired value for the smallest label in y (default=0) + Desired value for the smallest label in :math:`\mathbf{y}` (default=0) Returns ------- - y : array-like, shape (n1, ) + y : array-like, shape (`n1`, ) The input vector of labels normalized according to given start value. """ @@ -228,42 +309,15 @@ def label_normalization(y, start=0): return y -def fun(f, q_in, q_out): - """ Utility function for parmap with no serializing problems """ - while True: - i, x = q_in.get() - if i is None: - break - q_out.put((i, f(x))) - - -def parmap(f, X, nprocs=multiprocessing.cpu_count()): - """ paralell map for multiprocessing (only map on windows)""" - - if not sys.platform.endswith('win32'): - - q_in = multiprocessing.Queue(1) - q_out = multiprocessing.Queue() - - proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out)) - for _ in range(nprocs)] - for p in proc: - p.daemon = True - p.start() - - sent = [q_in.put((i, x)) for i, x in enumerate(X)] - [q_in.put((None, None)) for _ in range(nprocs)] - res = [q_out.get() for _ in range(len(sent))] - - [p.join() for p in proc] - - return [x for i, x in sorted(res)] - else: - return list(map(f, X)) +def parmap(f, X, nprocs="default"): + r""" parallel map for multiprocessing. + The function has been deprecated and only performs a regular map. + """ + return list(map(f, X)) def check_params(**kwargs): - """check_params: check whether some parameters are missing + r"""check_params: check whether some parameters are missing """ missing_params = [] @@ -284,14 +338,14 @@ def check_params(**kwargs): def check_random_state(seed): - """Turn seed into a np.random.RandomState instance + r"""Turn `seed` into a np.random.RandomState instance Parameters ---------- seed : None | int | instance of RandomState - If seed is None, return the RandomState singleton used by np.random. - If seed is an int, return a new RandomState instance seeded with seed. - If seed is already a RandomState instance, return it. + If `seed` is None, return the RandomState singleton used by np.random. + If `seed` is an int, return a new RandomState instance seeded with `seed`. + If `seed` is already a RandomState instance, return it. Otherwise raise ValueError. """ if seed is None or seed is np.random: @@ -305,18 +359,21 @@ def check_random_state(seed): class deprecated(object): - """Decorator to mark a function or class as deprecated. + r"""Decorator to mark a function or class as deprecated. deprecated class from scikit-learn package https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py Issue a warning when the function is called/the class is instantiated and adds a warning to the docstring. The optional extra argument will be appended to the deprecation message - and the docstring. Note: to use this with the default value for extra, put - in an empty of parentheses: - >>> from ot.deprecation import deprecated # doctest: +SKIP - >>> @deprecated() # doctest: +SKIP - ... def some_function(): pass # doctest: +SKIP + and the docstring. + + .. note:: + To use this with the default value for extra, use empty parentheses: + + >>> from ot.deprecation import deprecated # doctest: +SKIP + >>> @deprecated() # doctest: +SKIP + ... def some_function(): pass # doctest: +SKIP Parameters ---------- @@ -331,7 +388,7 @@ class deprecated(object): self.extra = extra def __call__(self, obj): - """Call method + r"""Call method Parameters ---------- obj : object @@ -362,7 +419,7 @@ class deprecated(object): return cls def _decorate_fun(self, fun): - """Decorate function fun""" + r"""Decorate function fun""" msg = "Function %s is deprecated" % fun.__name__ if self.extra: @@ -388,7 +445,7 @@ class deprecated(object): def _is_deprecated(func): - """Helper to check if func is wraped by our deprecated decorator""" + r"""Helper to check if func is wraped by our deprecated decorator""" if sys.version_info < (3, 5): raise NotImplementedError("This is only available for python3.5 " "or above") @@ -402,7 +459,7 @@ def _is_deprecated(func): class BaseEstimator(object): - """Base class for most objects in POT + r"""Base class for most objects in POT Code adapted from sklearn BaseEstimator class @@ -415,7 +472,7 @@ class BaseEstimator(object): @classmethod def _get_param_names(cls): - """Get parameter names for the estimator""" + r"""Get parameter names for the estimator""" # fetch the constructor or the original constructor before # deprecation wrapping if any @@ -442,7 +499,7 @@ class BaseEstimator(object): return sorted([p.name for p in parameters]) def get_params(self, deep=True): - """Get parameters for this estimator. + r"""Get parameters for this estimator. Parameters ---------- @@ -479,7 +536,7 @@ class BaseEstimator(object): return out def set_params(self, **params): - """Set the parameters of this estimator. + r"""Set the parameters of this estimator. The method works on simple estimators as well as on nested objects (such as pipelines). The latter have parameters of the form @@ -519,7 +576,7 @@ class BaseEstimator(object): class UndefinedParameter(Exception): - """ + r""" Aim at raising an Exception when a undefined parameter is called """ |