summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py24
-rw-r--r--ot/backend.py1502
-rw-r--r--ot/bregman.py2811
-rw-r--r--ot/da.py507
-rw-r--r--ot/datasets.py12
-rw-r--r--ot/dr.py156
-rw-r--r--ot/gpu/__init__.py12
-rw-r--r--ot/gpu/bregman.py12
-rw-r--r--ot/gpu/da.py2
-rw-r--r--ot/gromov.py1312
-rw-r--r--ot/helpers/__init__.py3
-rw-r--r--ot/helpers/openmp_helpers.py85
-rw-r--r--ot/helpers/pre_build_helpers.py87
-rw-r--r--ot/lp/EMD.h5
-rw-r--r--ot/lp/EMD_wrapper.cpp124
-rw-r--r--ot/lp/__init__.py597
-rw-r--r--ot/lp/cvx.py3
-rw-r--r--ot/lp/emd_wrap.pyx32
-rw-r--r--ot/lp/full_bipartitegraph.h27
-rw-r--r--ot/lp/full_bipartitegraph_omp.h234
-rw-r--r--ot/lp/network_simplex_simple.h212
-rw-r--r--ot/lp/network_simplex_simple_omp.h1699
-rw-r--r--ot/lp/solver_1d.py367
-rw-r--r--ot/optim.py189
-rwxr-xr-xot/partial.py352
-rw-r--r--ot/plot.py10
-rw-r--r--ot/regpath.py827
-rw-r--r--ot/sliced.py258
-rw-r--r--ot/smooth.py183
-rw-r--r--ot/stochastic.py192
-rw-r--r--ot/unbalanced.py220
-rw-r--r--ot/utils.py269
32 files changed, 9780 insertions, 2545 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 0e6e2e2..b6dc2b4 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -5,7 +5,8 @@
:py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim`
:py:mod:`ot.utils`, :py:mod:`ot.datasets`,
:py:mod:`ot.gromov`, :py:mod:`ot.smooth`
- :py:mod:`ot.stochastic`
+ :py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath`
+ , :py:mod:`ot.unbalanced`.
The following sub-modules are not imported due to additional dependencies:
@@ -33,21 +34,30 @@ from . import smooth
from . import stochastic
from . import unbalanced
from . import partial
+from . import backend
+from . import regpath
# OT functions
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
from .bregman import sinkhorn, sinkhorn2, barycenter
-from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
+from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
+ sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
+from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
+from .gromov import (gromov_wasserstein, gromov_wasserstein2,
+ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
# utils functions
from .utils import dist, unif, tic, toc, toq
-__version__ = "0.7.0"
+__version__ = "0.8.0"
-__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets',
- 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d',
+__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
+ 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
+ 'emd2_1d', 'wasserstein_1d', 'backend',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
- 'sinkhorn_unbalanced2']
+ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
+ 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
+ 'max_sliced_wasserstein_distance',
+ 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
diff --git a/ot/backend.py b/ot/backend.py
new file mode 100644
index 0000000..a044f84
--- /dev/null
+++ b/ot/backend.py
@@ -0,0 +1,1502 @@
+# -*- coding: utf-8 -*-
+"""
+Multi-lib backend for POT
+
+The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
+or Jax, POT code should work nonetheless.
+To achieve that, POT provides backend classes which implements functions in their respective backend
+imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
+
+Examples
+--------
+
+>>> from ot.utils import list_to_array
+>>> from ot.backend import get_backend
+>>> def f(a, b): # the function does not know which backend to use
+... a, b = list_to_array(a, b) # if a list in given, make it an array
+... nx = get_backend(a, b) # infer the backend from the arguments
+... c = nx.dot(a, b) # now use the backend to do any calculation
+... return c
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import scipy.special as scipy
+from scipy.sparse import issparse, coo_matrix, csr_matrix
+
+try:
+ import torch
+ torch_type = torch.Tensor
+except ImportError:
+ torch = False
+ torch_type = float
+
+try:
+ import jax
+ import jax.numpy as jnp
+ import jax.scipy.special as jscipy
+ jax_type = jax.numpy.ndarray
+except ImportError:
+ jax = False
+ jax_type = float
+
+str_type_error = "All array should be from the same type/backend. Current types are : {}"
+
+
+def get_backend_list():
+ """Returns the list of available backends"""
+ lst = [NumpyBackend(), ]
+
+ if torch:
+ lst.append(TorchBackend())
+
+ if jax:
+ lst.append(JaxBackend())
+
+ return lst
+
+
+def get_backend(*args):
+ """Returns the proper backend for a list of input arrays
+
+ Also raises TypeError if all arrays are not from the same backend
+ """
+ # check that some arrays given
+ if not len(args) > 0:
+ raise ValueError(" The function takes at least one parameter")
+ # check all same type
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
+
+ if isinstance(args[0], np.ndarray):
+ return NumpyBackend()
+ elif isinstance(args[0], torch_type):
+ return TorchBackend()
+ elif isinstance(args[0], jax_type):
+ return JaxBackend()
+ else:
+ raise ValueError("Unknown type of non implemented backend.")
+
+
+def to_numpy(*args):
+ """Returns numpy arrays from any compatible backend"""
+
+ if len(args) == 1:
+ return get_backend(args[0]).to_numpy(args[0])
+ else:
+ return [get_backend(a).to_numpy(a) for a in args]
+
+
+class Backend():
+ """
+ Backend abstract class.
+ Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`
+
+ - The `__name__` class attribute refers to the name of the backend.
+ - The `__type__` class attribute refers to the data structure used by the backend.
+ """
+
+ __name__ = None
+ __type__ = None
+ __type_list__ = None
+
+ rng_ = None
+
+ def __str__(self):
+ return self.__name__
+
+ # convert to numpy
+ def to_numpy(self, a):
+ """Returns the numpy version of a tensor"""
+ raise NotImplementedError()
+
+ # convert from numpy
+ def from_numpy(self, a, type_as=None):
+ """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
+ raise NotImplementedError()
+
+ def set_gradients(self, val, inputs, grads):
+ """Define the gradients for the value val wrt the inputs """
+ raise NotImplementedError()
+
+ def zeros(self, shape, type_as=None):
+ r"""
+ Creates a tensor full of zeros.
+
+ This function follows the api from :any:`numpy.zeros`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html
+ """
+ raise NotImplementedError()
+
+ def ones(self, shape, type_as=None):
+ r"""
+ Creates a tensor full of ones.
+
+ This function follows the api from :any:`numpy.ones`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html
+ """
+ raise NotImplementedError()
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ r"""
+ Returns evenly spaced values within a given interval.
+
+ This function follows the api from :any:`numpy.arange`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html
+ """
+ raise NotImplementedError()
+
+ def full(self, shape, fill_value, type_as=None):
+ r"""
+ Creates a tensor with given shape, filled with given value.
+
+ This function follows the api from :any:`numpy.full`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.full.html
+ """
+ raise NotImplementedError()
+
+ def eye(self, N, M=None, type_as=None):
+ r"""
+ Creates the identity matrix of given size.
+
+ This function follows the api from :any:`numpy.eye`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html
+ """
+ raise NotImplementedError()
+
+ def sum(self, a, axis=None, keepdims=False):
+ r"""
+ Sums tensor elements over given dimensions.
+
+ This function follows the api from :any:`numpy.sum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html
+ """
+ raise NotImplementedError()
+
+ def cumsum(self, a, axis=None):
+ r"""
+ Returns the cumulative sum of tensor elements over given dimensions.
+
+ This function follows the api from :any:`numpy.cumsum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
+ """
+ raise NotImplementedError()
+
+ def max(self, a, axis=None, keepdims=False):
+ r"""
+ Returns the maximum of an array or maximum along given dimensions.
+
+ This function follows the api from :any:`numpy.amax`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html
+ """
+ raise NotImplementedError()
+
+ def min(self, a, axis=None, keepdims=False):
+ r"""
+ Returns the maximum of an array or maximum along given dimensions.
+
+ This function follows the api from :any:`numpy.amin`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html
+ """
+ raise NotImplementedError()
+
+ def maximum(self, a, b):
+ r"""
+ Returns element-wise maximum of array elements.
+
+ This function follows the api from :any:`numpy.maximum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html
+ """
+ raise NotImplementedError()
+
+ def minimum(self, a, b):
+ r"""
+ Returns element-wise minimum of array elements.
+
+ This function follows the api from :any:`numpy.minimum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html
+ """
+ raise NotImplementedError()
+
+ def dot(self, a, b):
+ r"""
+ Returns the dot product of two tensors.
+
+ This function follows the api from :any:`numpy.dot`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
+ """
+ raise NotImplementedError()
+
+ def abs(self, a):
+ r"""
+ Computes the absolute value element-wise.
+
+ This function follows the api from :any:`numpy.absolute`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html
+ """
+ raise NotImplementedError()
+
+ def exp(self, a):
+ r"""
+ Computes the exponential value element-wise.
+
+ This function follows the api from :any:`numpy.exp`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html
+ """
+ raise NotImplementedError()
+
+ def log(self, a):
+ r"""
+ Computes the natural logarithm, element-wise.
+
+ This function follows the api from :any:`numpy.log`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.log.html
+ """
+ raise NotImplementedError()
+
+ def sqrt(self, a):
+ r"""
+ Returns the non-ngeative square root of a tensor, element-wise.
+
+ This function follows the api from :any:`numpy.sqrt`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html
+ """
+ raise NotImplementedError()
+
+ def power(self, a, exponents):
+ r"""
+ First tensor elements raised to powers from second tensor, element-wise.
+
+ This function follows the api from :any:`numpy.power`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.power.html
+ """
+ raise NotImplementedError()
+
+ def norm(self, a):
+ r"""
+ Computes the matrix frobenius norm.
+
+ This function follows the api from :any:`numpy.linalg.norm`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
+ """
+ raise NotImplementedError()
+
+ def any(self, a):
+ r"""
+ Tests whether any tensor element along given dimensions evaluates to True.
+
+ This function follows the api from :any:`numpy.any`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.any.html
+ """
+ raise NotImplementedError()
+
+ def isnan(self, a):
+ r"""
+ Tests element-wise for NaN and returns result as a boolean tensor.
+
+ This function follows the api from :any:`numpy.isnan`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html
+ """
+ raise NotImplementedError()
+
+ def isinf(self, a):
+ r"""
+ Tests element-wise for positive or negative infinity and returns result as a boolean tensor.
+
+ This function follows the api from :any:`numpy.isinf`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html
+ """
+ raise NotImplementedError()
+
+ def einsum(self, subscripts, *operands):
+ r"""
+ Evaluates the Einstein summation convention on the operands.
+
+ This function follows the api from :any:`numpy.einsum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
+ """
+ raise NotImplementedError()
+
+ def sort(self, a, axis=-1):
+ r"""
+ Returns a sorted copy of a tensor.
+
+ This function follows the api from :any:`numpy.sort`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html
+ """
+ raise NotImplementedError()
+
+ def argsort(self, a, axis=None):
+ r"""
+ Returns the indices that would sort a tensor.
+
+ This function follows the api from :any:`numpy.argsort`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html
+ """
+ raise NotImplementedError()
+
+ def searchsorted(self, a, v, side='left'):
+ r"""
+ Finds indices where elements should be inserted to maintain order in given tensor.
+
+ This function follows the api from :any:`numpy.searchsorted`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html
+ """
+ raise NotImplementedError()
+
+ def flip(self, a, axis=None):
+ r"""
+ Reverses the order of elements in a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.flip`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html
+ """
+ raise NotImplementedError()
+
+ def clip(self, a, a_min, a_max):
+ """
+ Limits the values in a tensor.
+
+ This function follows the api from :any:`numpy.clip`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html
+ """
+ raise NotImplementedError()
+
+ def repeat(self, a, repeats, axis=None):
+ r"""
+ Repeats elements of a tensor.
+
+ This function follows the api from :any:`numpy.repeat`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html
+ """
+ raise NotImplementedError()
+
+ def take_along_axis(self, arr, indices, axis):
+ r"""
+ Gathers elements of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.take_along_axis`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html
+ """
+ raise NotImplementedError()
+
+ def concatenate(self, arrays, axis=0):
+ r"""
+ Joins a sequence of tensors along an existing dimension.
+
+ This function follows the api from :any:`numpy.concatenate`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html
+ """
+ raise NotImplementedError()
+
+ def zero_pad(self, a, pad_width):
+ r"""
+ Pads a tensor.
+
+ This function follows the api from :any:`numpy.pad`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
+ """
+ raise NotImplementedError()
+
+ def argmax(self, a, axis=None):
+ r"""
+ Returns the indices of the maximum values of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.argmax`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html
+ """
+ raise NotImplementedError()
+
+ def mean(self, a, axis=None):
+ r"""
+ Computes the arithmetic mean of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.mean`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html
+ """
+ raise NotImplementedError()
+
+ def std(self, a, axis=None):
+ r"""
+ Computes the standard deviation of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.std`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.std.html
+ """
+ raise NotImplementedError()
+
+ def linspace(self, start, stop, num):
+ r"""
+ Returns a specified number of evenly spaced values over a given interval.
+
+ This function follows the api from :any:`numpy.linspace`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html
+ """
+ raise NotImplementedError()
+
+ def meshgrid(self, a, b):
+ r"""
+ Returns coordinate matrices from coordinate vectors (Numpy convention).
+
+ This function follows the api from :any:`numpy.meshgrid`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html
+ """
+ raise NotImplementedError()
+
+ def diag(self, a, k=0):
+ r"""
+ Extracts or constructs a diagonal tensor.
+
+ This function follows the api from :any:`numpy.diag`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html
+ """
+ raise NotImplementedError()
+
+ def unique(self, a):
+ r"""
+ Finds unique elements of given tensor.
+
+ This function follows the api from :any:`numpy.unique`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html
+ """
+ raise NotImplementedError()
+
+ def logsumexp(self, a, axis=None):
+ r"""
+ Computes the log of the sum of exponentials of input elements.
+
+ This function follows the api from :any:`scipy.special.logsumexp`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
+ """
+ raise NotImplementedError()
+
+ def stack(self, arrays, axis=0):
+ r"""
+ Joins a sequence of tensors along a new dimension.
+
+ This function follows the api from :any:`numpy.stack`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html
+ """
+ raise NotImplementedError()
+
+ def outer(self, a, b):
+ r"""
+ Computes the outer product between two vectors.
+
+ This function follows the api from :any:`numpy.outer`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html
+ """
+ raise NotImplementedError()
+
+ def reshape(self, a, shape):
+ r"""
+ Gives a new shape to a tensor without changing its data.
+
+ This function follows the api from :any:`numpy.reshape`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html
+ """
+ raise NotImplementedError()
+
+ def seed(self, seed=None):
+ r"""
+ Sets the seed for the random generator.
+
+ This function follows the api from :any:`numpy.random.seed`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html
+ """
+ raise NotImplementedError()
+
+ def rand(self, *size, type_as=None):
+ r"""
+ Generate uniform random numbers.
+
+ This function follows the api from :any:`numpy.random.rand`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
+ """
+ raise NotImplementedError()
+
+ def randn(self, *size, type_as=None):
+ r"""
+ Generate normal Gaussian random numbers.
+
+ This function follows the api from :any:`numpy.random.rand`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
+ """
+ raise NotImplementedError()
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ r"""
+ Creates a sparse tensor in COOrdinate format.
+
+ This function follows the api from :any:`scipy.sparse.coo_matrix`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
+ """
+ raise NotImplementedError()
+
+ def issparse(self, a):
+ r"""
+ Checks whether or not the input tensor is a sparse tensor.
+
+ This function follows the api from :any:`scipy.sparse.issparse`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html
+ """
+ raise NotImplementedError()
+
+ def tocsr(self, a):
+ r"""
+ Converts this matrix to Compressed Sparse Row format.
+
+ This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html
+ """
+ raise NotImplementedError()
+
+ def eliminate_zeros(self, a, threshold=0.):
+ r"""
+ Removes entries smaller than the given threshold from the sparse tensor.
+
+ This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros`
+
+ See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html
+ """
+ raise NotImplementedError()
+
+ def todense(self, a):
+ r"""
+ Converts a sparse tensor to a dense tensor.
+
+ This function follows the api from :any:`scipy.sparse.csr_matrix.toarray`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html
+ """
+ raise NotImplementedError()
+
+ def where(self, condition, x, y):
+ r"""
+ Returns elements chosen from x or y depending on condition.
+
+ This function follows the api from :any:`numpy.where`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.where.html
+ """
+ raise NotImplementedError()
+
+ def copy(self, a):
+ r"""
+ Returns a copy of the given tensor.
+
+ This function follows the api from :any:`numpy.copy`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html
+ """
+ raise NotImplementedError()
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ r"""
+ Returns True if two arrays are element-wise equal within a tolerance.
+
+ This function follows the api from :any:`numpy.allclose`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html
+ """
+ raise NotImplementedError()
+
+ def dtype_device(self, a):
+ r"""
+ Returns the dtype and the device of the given tensor.
+ """
+ raise NotImplementedError()
+
+ def assert_same_dtype_device(self, a, b):
+ r"""
+ Checks whether or not the two given inputs have the same dtype as well as the same device
+ """
+ raise NotImplementedError()
+
+
+class NumpyBackend(Backend):
+ """
+ NumPy implementation of the backend
+
+ - `__name__` is "numpy"
+ - `__type__` is np.ndarray
+ """
+
+ __name__ = 'numpy'
+ __type__ = np.ndarray
+ __type_list__ = [np.array(1, dtype=np.float32),
+ np.array(1, dtype=np.float64)]
+
+ rng_ = np.random.RandomState()
+
+ def to_numpy(self, a):
+ return a
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return a
+ elif isinstance(a, float):
+ return a
+ else:
+ return a.astype(type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # No gradients for numpy
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return np.zeros(shape)
+ else:
+ return np.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return np.ones(shape)
+ else:
+ return np.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return np.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return np.full(shape, fill_value)
+ else:
+ return np.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return np.eye(N, M)
+ else:
+ return np.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return np.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return np.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return np.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return np.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return np.maximum(a, b)
+
+ def minimum(self, a, b):
+ return np.minimum(a, b)
+
+ def dot(self, a, b):
+ return np.dot(a, b)
+
+ def abs(self, a):
+ return np.abs(a)
+
+ def exp(self, a):
+ return np.exp(a)
+
+ def log(self, a):
+ return np.log(a)
+
+ def sqrt(self, a):
+ return np.sqrt(a)
+
+ def power(self, a, exponents):
+ return np.power(a, exponents)
+
+ def norm(self, a):
+ return np.sqrt(np.sum(np.square(a)))
+
+ def any(self, a):
+ return np.any(a)
+
+ def isnan(self, a):
+ return np.isnan(a)
+
+ def isinf(self, a):
+ return np.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return np.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return np.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return np.argsort(a, axis)
+
+ def searchsorted(self, a, v, side='left'):
+ if a.ndim == 1:
+ return np.searchsorted(a, v, side)
+ else:
+ # this is a not very efficient way to make numpy
+ # searchsorted work on 2d arrays
+ ret = np.empty(v.shape, dtype=int)
+ for i in range(a.shape[0]):
+ ret[i, :] = np.searchsorted(a[i, :], v[i, :], side)
+ return ret
+
+ def flip(self, a, axis=None):
+ return np.flip(a, axis)
+
+ def outer(self, a, b):
+ return np.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return np.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return np.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return np.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return np.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return np.pad(a, pad_width)
+
+ def argmax(self, a, axis=None):
+ return np.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return np.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return np.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return np.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return np.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return np.diag(a, k)
+
+ def unique(self, a):
+ return np.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ return scipy.logsumexp(a, axis=axis)
+
+ def stack(self, arrays, axis=0):
+ return np.stack(arrays, axis)
+
+ def reshape(self, a, shape):
+ return np.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if seed is not None:
+ self.rng_.seed(seed)
+
+ def rand(self, *size, type_as=None):
+ return self.rng_.rand(*size)
+
+ def randn(self, *size, type_as=None):
+ return self.rng_.randn(*size)
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ if type_as is None:
+ return coo_matrix((data, (rows, cols)), shape=shape)
+ else:
+ return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype)
+
+ def issparse(self, a):
+ return issparse(a)
+
+ def tocsr(self, a):
+ if self.issparse(a):
+ return a.tocsr()
+ else:
+ return csr_matrix(a)
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if threshold > 0:
+ if self.issparse(a):
+ a.data[self.abs(a.data) <= threshold] = 0
+ else:
+ a[self.abs(a) <= threshold] = 0
+ if self.issparse(a):
+ a.eliminate_zeros()
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return a.toarray()
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return np.where(condition, x, y)
+
+ def copy(self, a):
+ return a.copy()
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+ def dtype_device(self, a):
+ if hasattr(a, "dtype"):
+ return a.dtype, "cpu"
+ else:
+ return type(a), "cpu"
+
+ def assert_same_dtype_device(self, a, b):
+ # numpy has implicit type conversion so we automatically validate the test
+ pass
+
+
+class JaxBackend(Backend):
+ """
+ JAX implementation of the backend
+
+ - `__name__` is "jax"
+ - `__type__` is jax.numpy.ndarray
+ """
+
+ __name__ = 'jax'
+ __type__ = jax_type
+ __type_list__ = None
+
+ rng_ = None
+
+ def __init__(self):
+ self.rng_ = jax.random.PRNGKey(42)
+
+ for d in jax.devices():
+ self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d),
+ jax.device_put(jnp.array(1, dtype=jnp.float64), d)]
+
+ def to_numpy(self, a):
+ return np.array(a)
+
+ def _change_device(self, a, type_as):
+ return jax.device_put(a, type_as.device_buffer.device())
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return jnp.array(a)
+ else:
+ return self._change_device(jnp.array(a).astype(type_as.dtype), type_as)
+
+ def set_gradients(self, val, inputs, grads):
+ from jax.flatten_util import ravel_pytree
+ val, = jax.lax.stop_gradient((val,))
+
+ ravelled_inputs, _ = ravel_pytree(inputs)
+ ravelled_grads, _ = ravel_pytree(grads)
+
+ aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
+ aux = aux - jax.lax.stop_gradient(aux)
+
+ val, = jax.tree_map(lambda z: z + aux, (val,))
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.zeros(shape)
+ else:
+ return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as)
+
+ def ones(self, shape, type_as=None):
+ if type_as is None:
+ return jnp.ones(shape)
+ else:
+ return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return jnp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if type_as is None:
+ return jnp.full(shape, fill_value)
+ else:
+ return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return jnp.eye(N, M)
+ else:
+ return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return jnp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return jnp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return jnp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return jnp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return jnp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return jnp.minimum(a, b)
+
+ def dot(self, a, b):
+ return jnp.dot(a, b)
+
+ def abs(self, a):
+ return jnp.abs(a)
+
+ def exp(self, a):
+ return jnp.exp(a)
+
+ def log(self, a):
+ return jnp.log(a)
+
+ def sqrt(self, a):
+ return jnp.sqrt(a)
+
+ def power(self, a, exponents):
+ return jnp.power(a, exponents)
+
+ def norm(self, a):
+ return jnp.sqrt(jnp.sum(jnp.square(a)))
+
+ def any(self, a):
+ return jnp.any(a)
+
+ def isnan(self, a):
+ return jnp.isnan(a)
+
+ def isinf(self, a):
+ return jnp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return jnp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return jnp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return jnp.argsort(a, axis)
+
+ def searchsorted(self, a, v, side='left'):
+ if a.ndim == 1:
+ return jnp.searchsorted(a, v, side)
+ else:
+ # this is a not very efficient way to make jax numpy
+ # searchsorted work on 2d arrays
+ return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])])
+
+ def flip(self, a, axis=None):
+ return jnp.flip(a, axis)
+
+ def outer(self, a, b):
+ return jnp.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return jnp.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return jnp.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return jnp.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return jnp.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return jnp.pad(a, pad_width)
+
+ def argmax(self, a, axis=None):
+ return jnp.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return jnp.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return jnp.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return jnp.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return jnp.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return jnp.diag(a, k)
+
+ def unique(self, a):
+ return jnp.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ return jscipy.logsumexp(a, axis=axis)
+
+ def stack(self, arrays, axis=0):
+ return jnp.stack(arrays, axis)
+
+ def reshape(self, a, shape):
+ return jnp.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if seed is not None:
+ self.rng_ = jax.random.PRNGKey(seed)
+
+ def rand(self, *size, type_as=None):
+ self.rng_, subkey = jax.random.split(self.rng_)
+ if type_as is not None:
+ return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype)
+ else:
+ return jax.random.uniform(subkey, shape=size)
+
+ def randn(self, *size, type_as=None):
+ self.rng_, subkey = jax.random.split(self.rng_)
+ if type_as is not None:
+ return jax.random.normal(subkey, shape=size, dtype=type_as.dtype)
+ else:
+ return jax.random.normal(subkey, shape=size)
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ # Currently, JAX does not support sparse matrices
+ data = self.to_numpy(data)
+ rows = self.to_numpy(rows)
+ cols = self.to_numpy(cols)
+ nx = NumpyBackend()
+ coo_matrix = nx.coo_matrix(data, rows, cols, shape=shape, type_as=type_as)
+ matrix = nx.todense(coo_matrix)
+ return self.from_numpy(matrix)
+
+ def issparse(self, a):
+ # Currently, JAX does not support sparse matrices
+ return False
+
+ def tocsr(self, a):
+ # Currently, JAX does not support sparse matrices
+ return a
+
+ def eliminate_zeros(self, a, threshold=0.):
+ # Currently, JAX does not support sparse matrices
+ if threshold > 0:
+ return self.where(
+ self.abs(a) <= threshold,
+ self.zeros((1,), type_as=a),
+ a
+ )
+ return a
+
+ def todense(self, a):
+ # Currently, JAX does not support sparse matrices
+ return a
+
+ def where(self, condition, x, y):
+ return jnp.where(condition, x, y)
+
+ def copy(self, a):
+ # No need to copy, JAX arrays are immutable
+ return a
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+ def dtype_device(self, a):
+ return a.dtype, a.device_buffer.device()
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ assert a_dtype == b_dtype, "Dtype discrepancy"
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
+
+class TorchBackend(Backend):
+ """
+ PyTorch implementation of the backend
+
+ - `__name__` is "torch"
+ - `__type__` is torch.Tensor
+ """
+
+ __name__ = 'torch'
+ __type__ = torch_type
+ __type_list__ = None
+
+ rng_ = None
+
+ def __init__(self):
+
+ self.rng_ = torch.Generator()
+ self.rng_.seed()
+
+ self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
+ torch.tensor(1, dtype=torch.float64)]
+
+ if torch.cuda.is_available():
+ self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
+ self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
+
+ from torch.autograd import Function
+
+ # define a function that takes inputs val and grads
+ # ad returns a val tensor with proper gradients
+ class ValFunction(Function):
+
+ @staticmethod
+ def forward(ctx, val, grads, *inputs):
+ ctx.grads = grads
+ return val
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # the gradients are grad
+ return (None, None) + ctx.grads
+
+ self.ValFunction = ValFunction
+
+ def to_numpy(self, a):
+ return a.cpu().detach().numpy()
+
+ def from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
+ if type_as is None:
+ return torch.from_numpy(a)
+ else:
+ return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)
+
+ def set_gradients(self, val, inputs, grads):
+
+ Func = self.ValFunction()
+
+ res = Func.apply(val, grads, *inputs)
+
+ return res
+
+ def zeros(self, shape, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
+ if type_as is None:
+ return torch.zeros(shape)
+ else:
+ return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def ones(self, shape, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
+ if type_as is None:
+ return torch.ones(shape)
+ else:
+ return torch.ones(shape, dtype=type_as.dtype, device=type_as.device)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ if type_as is None:
+ return torch.arange(start, stop, step)
+ else:
+ return torch.arange(start, stop, step, device=type_as.device)
+
+ def full(self, shape, fill_value, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
+ if type_as is None:
+ return torch.full(shape, fill_value)
+ else:
+ return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device)
+
+ def eye(self, N, M=None, type_as=None):
+ if M is None:
+ M = N
+ if type_as is None:
+ return torch.eye(N, m=M)
+ else:
+ return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device)
+
+ def sum(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.sum(a)
+ else:
+ return torch.sum(a, axis, keepdim=keepdims)
+
+ def cumsum(self, a, axis=None):
+ if axis is None:
+ return torch.cumsum(a.flatten(), 0)
+ else:
+ return torch.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.max(a)
+ else:
+ return torch.max(a, axis, keepdim=keepdims)[0]
+
+ def min(self, a, axis=None, keepdims=False):
+ if axis is None:
+ return torch.min(a)
+ else:
+ return torch.min(a, axis, keepdim=keepdims)[0]
+
+ def maximum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ if hasattr(torch, "maximum"):
+ return torch.maximum(a, b)
+ else:
+ return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
+
+ def minimum(self, a, b):
+ if isinstance(a, int) or isinstance(a, float):
+ a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
+ if isinstance(b, int) or isinstance(b, float):
+ b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
+ if hasattr(torch, "minimum"):
+ return torch.minimum(a, b)
+ else:
+ return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
+
+ def dot(self, a, b):
+ return torch.matmul(a, b)
+
+ def abs(self, a):
+ return torch.abs(a)
+
+ def exp(self, a):
+ return torch.exp(a)
+
+ def log(self, a):
+ return torch.log(a)
+
+ def sqrt(self, a):
+ return torch.sqrt(a)
+
+ def power(self, a, exponents):
+ return torch.pow(a, exponents)
+
+ def norm(self, a):
+ return torch.sqrt(torch.sum(torch.square(a)))
+
+ def any(self, a):
+ return torch.any(a)
+
+ def isnan(self, a):
+ return torch.isnan(a)
+
+ def isinf(self, a):
+ return torch.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return torch.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ sorted0, indices = torch.sort(a, dim=axis)
+ return sorted0
+
+ def argsort(self, a, axis=-1):
+ sorted, indices = torch.sort(a, dim=axis)
+ return indices
+
+ def searchsorted(self, a, v, side='left'):
+ right = (side != 'left')
+ return torch.searchsorted(a, v, right=right)
+
+ def flip(self, a, axis=None):
+ if axis is None:
+ return torch.flip(a, tuple(i for i in range(len(a.shape))))
+ if isinstance(axis, int):
+ return torch.flip(a, (axis,))
+ else:
+ return torch.flip(a, dims=axis)
+
+ def outer(self, a, b):
+ return torch.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return torch.clamp(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return torch.repeat_interleave(a, repeats, dim=axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return torch.gather(arr, axis, indices)
+
+ def concatenate(self, arrays, axis=0):
+ return torch.cat(arrays, dim=axis)
+
+ def zero_pad(self, a, pad_width):
+ from torch.nn.functional import pad
+ # pad_width is an array of ndim tuples indicating how many 0 before and after
+ # we need to add. We first need to make it compliant with torch syntax, that
+ # starts with the last dim, then second last, etc.
+ how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl)
+ return pad(a, how_pad)
+
+ def argmax(self, a, axis=None):
+ return torch.argmax(a, dim=axis)
+
+ def mean(self, a, axis=None):
+ if axis is not None:
+ return torch.mean(a, dim=axis)
+ else:
+ return torch.mean(a)
+
+ def std(self, a, axis=None):
+ if axis is not None:
+ return torch.std(a, dim=axis, unbiased=False)
+ else:
+ return torch.std(a, unbiased=False)
+
+ def linspace(self, start, stop, num):
+ return torch.linspace(start, stop, num, dtype=torch.float64)
+
+ def meshgrid(self, a, b):
+ X, Y = torch.meshgrid(a, b)
+ return X.T, Y.T
+
+ def diag(self, a, k=0):
+ return torch.diag(a, diagonal=k)
+
+ def unique(self, a):
+ return torch.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ if axis is not None:
+ return torch.logsumexp(a, dim=axis)
+ else:
+ return torch.logsumexp(a, dim=tuple(range(len(a.shape))))
+
+ def stack(self, arrays, axis=0):
+ return torch.stack(arrays, dim=axis)
+
+ def reshape(self, a, shape):
+ return torch.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if isinstance(seed, int):
+ self.rng_.manual_seed(seed)
+ elif isinstance(seed, torch.Generator):
+ self.rng_ = seed
+ else:
+ raise ValueError("Non compatible seed : {}".format(seed))
+
+ def rand(self, *size, type_as=None):
+ if type_as is not None:
+ return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device)
+ else:
+ return torch.rand(size=size, generator=self.rng_)
+
+ def randn(self, *size, type_as=None):
+ if type_as is not None:
+ return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
+ else:
+ return torch.randn(size=size, generator=self.rng_)
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ if type_as is None:
+ return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape)
+ else:
+ return torch.sparse_coo_tensor(
+ torch.stack([rows, cols]), data, size=shape,
+ dtype=type_as.dtype, device=type_as.device
+ )
+
+ def issparse(self, a):
+ return getattr(a, "is_sparse", False) or getattr(a, "is_sparse_csr", False)
+
+ def tocsr(self, a):
+ # Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support
+ return self.todense(a)
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if self.issparse(a):
+ if threshold > 0:
+ mask = self.abs(a) <= threshold
+ mask = ~mask
+ mask = mask.nonzero()
+ else:
+ mask = a._values().nonzero()
+ nv = a._values().index_select(0, mask.view(-1))
+ ni = a._indices().index_select(1, mask.view(-1))
+ return self.coo_matrix(nv, ni[0], ni[1], shape=a.shape, type_as=a)
+ else:
+ if threshold > 0:
+ a[self.abs(a) <= threshold] = 0
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return a.to_dense()
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return torch.where(condition, x, y)
+
+ def copy(self, a):
+ return torch.clone(a)
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+ def dtype_device(self, a):
+ return a.dtype, a.device
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ assert a_dtype == b_dtype, "Dtype discrepancy"
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
diff --git a/ot/bregman.py b/ot/bregman.py
index f1f8437..cce52e2 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -7,70 +7,104 @@ Bregman projections solvers for entropic regularized OT
# Nicolas Courty <ncourty@irisa.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-# Hicham Janati <hicham.janati@inria.fr>
+# Hicham Janati <hicham.janati100@gmail.com>
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
# Alexander Tong <alexander.tong@yale.edu>
# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
+# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
-import numpy as np
import warnings
-from .utils import unif, dist
+
+import numpy as np
from scipy.optimize import fmin_l_bfgs_b
+from ot.utils import unif, dist, list_to_array
+from .backend import get_backend
+
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=True,
+ **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- s.t. \gamma 1 = a
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma^T 1= b
+ \gamma &\geq 0
- \gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn>`
+
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ method used for the solver either 'sinkhorn','sinkhorn_log',
+ 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see
+ those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
-
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -86,102 +120,152 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
-
+ .. _references-sinkhorn:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
+ A., & Peyré, G. (2019, April). Interpolating between optimal transport
+ and MMD using Sinkhorn divergences. In The 22nd International Conference
+ on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn>`
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
+ :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
+ ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling
+ :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
"""
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
**kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
+ **kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log)
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn,
+ **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
return sinkhorn_epsilon_scaling(a, b, M, reg,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn,
+ **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the loss
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- s.t. \gamma 1 = a
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma^T 1= b
+ \gamma &\geq 0
- \gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn2>`
+
+
+ **Choosing a Sinkhorn solver**
+
+ By default and when using a regularization parameter that is not too small
+ the default sinkhorn solver should be enough. If you need to use a small
+ regularization to get sharper OT matrices, you should use the
+ :py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical
+ errors. This last solver can be very slow in practice and might not even
+ converge to a reasonable OT matrix in a finite time. This is why
+ :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ of the regularization (and using warm start) sometimes leads to better
+ solutions. Note that the greedy version of the sinkhorn
+ :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim a providing a
+ fast approximation of the Sinkhorn problem. For use of GPU and gradient
+ computation with small number of iterations we strongly recommend the
+ :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
+ numerical problems.
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
- 'sinkhorn_epsilon_scaling', see those function for specific parameters
+ method used for the solver either 'sinkhorn','sinkhorn_log',
+ 'sinkhorn_stabilized', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- W : (n_hists) ndarray or float
+ W : (n_hists) float/array-like
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
+
Examples
--------
@@ -190,99 +274,142 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn2(a, b, M, 1)
- array([0.26894142])
-
+ 0.26894142136999516
+ .. _references-sinkhorn2:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
- [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation
+ algorithms for optimal transport via Sinkhorn iteration,
+ Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
+ Trouvé, A., & Peyré, G. (2019, April).
+ Interpolating between optimal transport and MMD using Sinkhorn
+ divergences. In The 22nd International Conference on Artificial
+ Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
- ot.bregman.greenkhorn : Greenkhorn [21]
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
-
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn2>`
+ ot.bregman.greenkhorn : Greenkhorn :ref:`[21] <references-sinkhorn2>`
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
+ :ref:`[9] <references-sinkhorn2>` :ref:`[10] <references-sinkhorn2>`
"""
- b = np.asarray(b, dtype=np.float64)
+
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
+
if len(b.shape) < 2:
- b = b[:, None]
- if method.lower() == 'sinkhorn':
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- elif method.lower() == 'sinkhorn_stabilized':
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- elif method.lower() == 'sinkhorn_epsilon_scaling':
- return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ if method.lower() == 'sinkhorn':
+ res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+ if log:
+ return nx.sum(M * res[0]), res[1]
+ else:
+ return nx.sum(M * res)
+
else:
- raise ValueError("Unknown method '%s'." % method)
+
+ if method.lower() == 'sinkhorn':
+ return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
-def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
+ verbose=False, log=False, warn=True,
+ **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp
+ matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-knopp>`
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -299,10 +426,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
[0.13447071, 0.36552929]])
+ .. _references-sinkhorn-knopp:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
See Also
@@ -312,18 +442,18 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
# init data
dim_a = len(a)
- dim_b = len(b)
+ dim_b = b.shape[0]
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -336,66 +466,64 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
- # print(reg)
-
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
-
- # print(np.min(K))
- tmp2 = np.empty(b.shape, dtype=M.dtype)
+ K = nx.exp(M / (-reg))
Kp = (1 / a).reshape(-1, 1) * K
- cpt = 0
+
err = 1
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
uprev = u
vprev = v
+ KtransposeU = nx.dot(K.T, u)
+ v = b / KtransposeU
+ u = 1. / nx.dot(Kp, v)
- KtransposeU = np.dot(K.T, u)
- v = np.divide(b, KtransposeU)
- u = 1. / np.dot(Kp, v)
-
- if (np.any(KtransposeU == 0)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(KtransposeU == 0)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
+ warnings.warn('Warning: numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
- if cpt % 10 == 0:
+ if ii % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
- np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
+ tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
- np.einsum('i,ij,j->j', u, K, v, out=tmp2)
- err = np.linalg.norm(tmp2 - b) # violation of marginal
+ tmp2 = nx.einsum('i,ij,j->j', u, K, v)
+ err = nx.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt + 1
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
+ log['niter'] = ii
log['u'] = u
log['v'] = v
if n_hists: # return only loss
- res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
@@ -409,58 +537,259 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
- log=False):
+def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
+ log=False, warn=True, **kwargs):
r"""
- Solve the entropic regularization optimal transport problem and return the OT matrix
+ Solve the entropic regularization optimal transport problem in log space
+ and return the OT matrix
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
+
+ \gamma^T \mathbf{1} &= \mathbf{b}
+
+ \gamma &\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm :ref:`[2] <references-sinkhorn-log>` with the
+ implementation from :ref:`[34] <references-sinkhorn-log>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (dim_a,)
+ samples weights in the source domain
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
+ samples in the target domain, compute sinkhorn with multiple targets
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
+
+ Returns
+ -------
+ gamma : array-like, shape (dim_a, dim_b)
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
+
+
+ .. _references-sinkhorn-log:
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
+
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
+ Trouvé, A., & Peyré, G. (2019, April). Interpolating between
+ optimal transport and MMD using Sinkhorn divergences. In The
+ 22nd International Conference on Artificial Intelligence and
+ Statistics (pp. 2681-2690). PMLR.
+
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+
+ if len(a) == 0:
+ a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
+ if len(b) == 0:
+ b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
+
+ # init data
+ dim_a = len(a)
+ dim_b = b.shape[0]
+
+ if len(b.shape) > 1:
+ n_hists = b.shape[1]
+ else:
+ n_hists = 0
- The algorithm used is based on the paper
+ if n_hists: # we do not want to use tensors sor we do a loop
- Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
- by Jason Altschuler, Jonathan Weed, Philippe Rigollet
- appeared at NIPS 2017
+ lst_loss = []
+ lst_u = []
+ lst_v = []
- which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
+ for k in range(n_hists):
+ res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+
+ if log:
+ lst_loss.append(nx.sum(M * res[0]))
+ lst_u.append(res[1]['log_u'])
+ lst_v.append(res[1]['log_v'])
+ else:
+ lst_loss.append(nx.sum(M * res))
+ res = nx.stack(lst_loss)
+ if log:
+ log = {'log_u': nx.stack(lst_u, 1),
+ 'log_v': nx.stack(lst_v, 1), }
+ log['u'] = nx.exp(log['log_u'])
+ log['v'] = nx.exp(log['log_v'])
+ return res, log
+ else:
+ return res
+
+ else:
+
+ if log:
+ log = {'err': []}
+
+ Mr = - M / reg
+
+ # we assume that no distances are null except those of the diagonal of
+ # distances
+
+ u = nx.zeros(dim_a, type_as=M)
+ v = nx.zeros(dim_b, type_as=M)
+
+ def get_logT(u, v):
+ if n_hists:
+ return Mr[:, :, None] + u + v
+ else:
+ return Mr + u[:, None] + v[None, :]
+
+ loga = nx.log(a)
+ logb = nx.log(b)
+
+ err = 1
+ for ii in range(numItermax):
+
+ v = logb - nx.logsumexp(Mr + u[:, None], 0)
+ u = loga - nx.logsumexp(Mr + v[None, :], 1)
+
+ if ii % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+
+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
+ tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0)
+ err = nx.norm(tmp2 - b) # violation of marginal
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+
+ if log:
+ log['niter'] = ii
+ log['log_u'] = u
+ log['log_v'] = v
+ log['u'] = nx.exp(u)
+ log['v'] = nx.exp(v)
+
+ return nx.exp(get_logT(u, v)), log
+
+ else:
+ return nx.exp(get_logT(u, v))
+
+
+def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
+ log=False, warn=True):
+ r"""
+ Solve the entropic regularization optimal transport problem and return the OT matrix
+
+ The algorithm used is based on the paper :ref:`[22] <references-greenkhorn>`
+ which is a stochastic version of the Sinkhorn-Knopp
+ algorithm :ref:`[2] <references-greenkhorn>`
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
-
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -477,11 +806,18 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
[0.13447071, 0.36552929]])
+ .. _references-greenkhorn:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
+
+ .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time
+ approximation algorithms for optimal transport via Sinkhorn
+ iteration, Advances in Neural Information Processing
+ Systems (NIPS) 31, 2017
See Also
@@ -491,68 +827,70 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+ if nx.__name__ == "jax":
+ raise TypeError("JAX arrays have been received. Greenkhorn is not "
+ "compatible with JAX")
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
dim_a = a.shape[0]
dim_b = b.shape[0]
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty_like(M)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
- u = np.full(dim_a, 1. / dim_a)
- v = np.full(dim_b, 1. / dim_b)
- G = u[:, np.newaxis] * K * v[np.newaxis, :]
+ u = nx.full((dim_a,), 1. / dim_a, type_as=K)
+ v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ G = u[:, None] * K * v[None, :]
- viol = G.sum(1) - a
- viol_2 = G.sum(0) - b
+ viol = nx.sum(G, axis=1) - a
+ viol_2 = nx.sum(G, axis=0) - b
stopThr_val = 1
-
if log:
log = dict()
log['u'] = u
log['v'] = v
- for i in range(numItermax):
- i_1 = np.argmax(np.abs(viol))
- i_2 = np.argmax(np.abs(viol_2))
- m_viol_1 = np.abs(viol[i_1])
- m_viol_2 = np.abs(viol_2[i_2])
- stopThr_val = np.maximum(m_viol_1, m_viol_2)
+ for ii in range(numItermax):
+ i_1 = nx.argmax(nx.abs(viol))
+ i_2 = nx.argmax(nx.abs(viol_2))
+ m_viol_1 = nx.abs(viol[i_1])
+ m_viol_2 = nx.abs(viol_2[i_2])
+ stopThr_val = nx.maximum(m_viol_1, m_viol_2)
if m_viol_1 > m_viol_2:
old_u = u[i_1]
- u[i_1] = a[i_1] / (K[i_1, :].dot(v))
- G[i_1, :] = u[i_1] * K[i_1, :] * v
-
- viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1]
- viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v)
+ new_u = a[i_1] / (K[i_1, :].dot(v))
+ G[i_1, :] = new_u * K[i_1, :] * v
+ viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1]
+ viol_2 += (K[i_1, :].T * (new_u - old_u) * v)
+ u[i_1] = new_u
else:
old_v = v[i_2]
- v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
- G[:, i_2] = u * K[:, i_2] * v[i_2]
+ new_v = b[i_2] / (K[:, i_2].T.dot(u))
+ G[:, i_2] = u * K[:, i_2] * new_v
# aviol = (G@one_m - a)
# aviol_2 = (G.T@one_n - b)
- viol += (-old_v + v[i_2]) * K[:, i_2] * u
- viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
-
- # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
+ viol += (-old_v + new_v) * K[:, i_2] * u
+ viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2]
+ v[i_2] = new_v
if stopThr_val <= stopThr:
break
else:
- print('Warning: Algorithm did not converge')
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
+ log["n_iter"] = ii
log['u'] = u
log['v'] = v
@@ -564,58 +902,66 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20,
- log=False, **kwargs):
+ log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization OT problem with log stabilization
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
- scaling algorithm as proposed in [2]_ but with the log stabilization
- proposed in [10]_ an defined in [9]_ (Algo 3.1) .
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-stabilized>`
+ but with the log stabilization
+ proposed in :ref:`[10] <references-sinkhorn-stabilized>` an defined in
+ :ref:`[9] <references-sinkhorn-stabilized>` (Algo 3.1) .
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,)
+ b : array-like, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
- thershold for max value in u or v for log scaling
- warmstart : tible of vectors
- if given then sarting values for alpha an beta log scalings
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}`
+ for log scaling
+ warmstart : table of vectors
+ if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -632,14 +978,21 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
[0.13447071, 0.36552929]])
+ .. _references-sinkhorn-stabilized:
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
See Also
@@ -649,19 +1002,19 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# test if multiple target
if len(b.shape) > 1:
n_hists = b.shape[1]
- a = a[:, np.newaxis]
+ a = a[:, None]
else:
n_hists = 0
@@ -669,123 +1022,123 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
dim_a = len(a)
dim_b = len(b)
- cpt = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
+ alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
+ u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M)
+ u /= dim_a
+ v /= dim_b
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((dim_a, 1))
+ return nx.exp(-(M - alpha.reshape((dim_a, 1))
- beta.reshape((1, dim_b))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
- / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b))))
-
- # print(np.min(K))
+ return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
+ / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b))))
K = get_K(alpha, beta)
transp = K
- loop = 1
- cpt = 0
err = 1
- while loop:
+ for ii in range(numItermax):
uprev = u
vprev = v
# sinkhorn update
- v = b / (np.dot(K.T, u) + 1e-16)
- u = a / (np.dot(K, v) + 1e-16)
+ v = b / (nx.dot(K.T, u))
+ u = a / (nx.dot(K, v))
# remove numerical problems and store them in K
- if np.abs(u).max() > tau or np.abs(v).max() > tau:
+ if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
- alpha, beta = alpha + reg * \
- np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
+ alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
else:
- alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
+ alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
- u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
K = get_K(alpha, beta)
- if cpt % print_period == 0:
+ if ii % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
- err_u = abs(u - uprev).max()
- err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
- err_v = abs(v - vprev).max()
- err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err_u = nx.max(nx.abs(u - uprev))
+ err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0)
+ err_v = nx.max(nx.abs(v - vprev))
+ err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0)
err = 0.5 * (err_u + err_v)
else:
transp = get_Gamma(alpha, beta, u, v)
- err = np.linalg.norm((np.sum(transp, axis=0) - b))
+ err = nx.norm(nx.sum(transp, axis=0) - b)
if log:
log['err'].append(err)
if verbose:
- if cpt % (print_period * 20) == 0:
+ if ii % (print_period * 20) == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
if err <= stopThr:
- loop = False
-
- if cpt >= numItermax:
- loop = False
+ break
- if np.any(np.isnan(u)) or np.any(np.isnan(v)):
+ if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)):
# we have reached the machine precision
# come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
-
- cpt = cpt + 1
-
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
if n_hists:
alpha = alpha[:, None]
beta = beta[:, None]
- logu = alpha / reg + np.log(u)
- logv = beta / reg + np.log(v)
+ logu = alpha / reg + nx.log(u)
+ logv = beta / reg + nx.log(v)
+ log["n_iter"] = ii
log['logu'] = logu
log['logv'] = logv
- log['alpha'] = alpha + reg * np.log(u)
- log['beta'] = beta + reg * np.log(v)
+ log['alpha'] = alpha + reg * nx.log(u)
+ log['beta'] = beta + reg * nx.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
if n_hists:
- res = np.zeros((n_hists))
- for i in range(n_hists):
- res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ res = nx.stack([
+ nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ for i in range(n_hists)
+ ])
return res, log
else:
return get_Gamma(alpha, beta, u, v), log
else:
if n_hists:
- res = np.zeros((n_hists))
- for i in range(n_hists):
- res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ res = nx.stack([
+ nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ for i in range(n_hists)
+ ])
return res
else:
return get_Gamma(alpha, beta, u, v)
@@ -794,70 +1147,73 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
numInnerItermax=100, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=10,
- log=False, **kwargs):
+ log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
-
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
- where :
+ \gamma &\geq 0
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ where :
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
- scaling algorithm as proposed in [2]_ but with the log stabilization
- proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2
-
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-epsilon-scaling>`
+ but with the log stabilization
+ proposed in :ref:`[10] <references-sinkhorn-epsilon-scaling>` and the log scaling
+ proposed in :ref:`[9] <references-sinkhorn-epsilon-scaling>` algorithm 3.2
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,)
+ b : array-like, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
- thershold for max value in u or v for log scaling
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}`
+ for log scaling
warmstart : tuple of vectors
- if given then sarting values for alpha an beta log scalings
+ if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
- Max number of iterationsin the inner slog stabilized sinkhorn
+ Max number of iterations in the inner slog stabilized sinkhorn
epsilon0 : int, optional
first epsilon regularization value (then exponential decrease to reg)
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
-
Examples
--------
-
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
@@ -866,29 +1222,32 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
-
+ .. _references-sinkhorn-epsilon-scaling:
References
----------
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
-
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# init data
dim_a = len(a)
@@ -898,14 +1257,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
numItermin = 35
numItermax = max(numItermin, numItermax) # ensure that last velue is exact
- cpt = 0
+ ii = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
+ alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
@@ -913,12 +1272,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
def get_reg(n): # exponential decreasing
return (epsilon0 - reg) * np.exp(-n) + reg
- loop = 1
- cpt = 0
err = 1
- while loop:
+ for ii in range(numItermax):
- regi = get_reg(cpt)
+ regi = get_reg(ii)
G, logi = sinkhorn_stabilized(a, b, M, regi,
numItermax=numInnerItermax, stopThr=1e-9,
@@ -928,33 +1285,31 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
alpha = logi['alpha']
beta = logi['beta']
- if cpt >= numItermax:
- loop = False
-
- if cpt % (print_period) == 0: # spsion nearly converged
+ if ii % (print_period) == 0: # spsion nearly converged
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
- err = np.linalg.norm(
- (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2
+ err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2
if log:
log['err'].append(err)
if verbose:
- if cpt % (print_period * 10) == 0:
- print(
- '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- if err <= stopThr and cpt > numItermin:
- loop = False
+ if ii % (print_period * 10) == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
- cpt = cpt + 1
- # print('err=',err,' cpt=',cpt)
+ if err <= stopThr and ii > numItermin:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
log['alpha'] = alpha
log['beta'] = beta
log['warmstart'] = (log['alpha'], log['beta'])
+ log['niter'] = ii
return G, log
else:
return G
@@ -962,76 +1317,94 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
def geometricBar(weights, alldistribT):
"""return the weighted geometric mean of distributions"""
+ weights, alldistribT = list_to_array(weights, alldistribT)
+ nx = get_backend(weights, alldistribT)
assert (len(weights) == alldistribT.shape[1])
- return np.exp(np.dot(np.log(alldistribT), weights.T))
+ return nx.exp(nx.dot(nx.log(alldistribT), weights.T))
def geometricMean(alldistribT):
"""return the geometric mean of distributions"""
- return np.exp(np.mean(np.log(alldistribT), axis=1))
+ alldistribT = list_to_array(alldistribT)
+ nx = get_backend(alldistribT)
+ return nx.exp(nx.mean(nx.log(alldistribT), axis=1))
def projR(gamma, p):
"""return the KL projection on the row constrints """
- return np.multiply(gamma.T, p / np.maximum(np.sum(gamma, axis=1), 1e-10)).T
+ gamma, p = list_to_array(gamma, p)
+ nx = get_backend(gamma, p)
+ return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T
def projC(gamma, q):
"""return the KL projection on the column constrints """
- return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
+ gamma, q = list_to_array(gamma, q)
+ nx = get_backend(gamma, q)
+ return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10)
def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
- stopThr=1e-4, verbose=False, log=False, **kwargs):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`.
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling
+ algorithm as proposed in :ref:`[3] <references-barycenter>`
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
method : str (optional)
- method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized'
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log'
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter:
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
@@ -1039,232 +1412,327 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
return barycenter_sinkhorn(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return barycenter_stabilized(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn, **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _barycenter_sinkhorn_log(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in :ref:`[3]<references-barycenter-sinkhorn>`.
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter-sinkhorn:
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
if weights is None:
- weights = np.ones(A.shape[1]) / A.shape[1]
+ weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1]
else:
assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
- # M = M/np.median(M) # suggested by G. Peyre
- K = np.exp(-M / reg)
+ K = nx.exp(-M / reg)
- cpt = 0
err = 1
- UKv = np.dot(K, np.divide(A.T, np.sum(K, axis=0)).T)
+ UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T)
+
u = (geometricMean(UKv) / UKv.T).T
- while (err > stopThr and cpt < numItermax):
- cpt = cpt + 1
- UKv = u * np.dot(K, np.divide(A, np.dot(K, u)))
+ for ii in range(numItermax):
+
+ UKv = u * nx.dot(K, A / nx.dot(K, u))
u = (u.T * geometricBar(weights, UKv)).T / UKv
- if cpt % 10 == 1:
- err = np.sum(np.std(UKv, axis=1))
+ if ii % 10 == 1:
+ err = nx.sum(nx.std(UKv, axis=1))
# log and verbose print
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
return geometricBar(weights, UKv), log
else:
return geometricBar(weights, UKv)
+def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the entropic wasserstein barycenter in log-domain
+ """
+
+ A, M = list_to_array(A, M)
+ dim, n_hists = A.shape
+
+ nx = get_backend(A, M)
+
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ if weights is None:
+ weights = nx.ones(n_hists, type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ M = - M / reg
+ logA = nx.log(A + 1e-15)
+ log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros(dim, type_as=A)
+ for k in range(n_hists):
+ f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1)
+ log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0)
+ log_bar = log_bar + weights[k] * log_KU[:, k]
+
+ if ii % 10 == 1:
+ err = nx.exp(G + log_KU).std(axis=1).sum()
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if err < stopThr:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+
+ G = log_bar[:, None] - log_KU
+
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
- with stabilization.
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization.
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling
+ algorithm as proposed in :ref:`[3] <references-barycenter-stabilized>`
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
tau : float
- thershold for max value in u or v for log scaling
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}`
+ for log scaling
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter-stabilized:
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones((n_hists,), type_as=M) / n_hists
else:
assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
- u = np.ones((dim, n_hists)) / dim
- v = np.ones((dim, n_hists)) / dim
+ u = nx.ones((dim, n_hists), type_as=M) / dim
+ v = nx.ones((dim, n_hists), type_as=M) / dim
- # print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
- cpt = 0
err = 1.
- alpha = np.zeros(dim)
- beta = np.zeros(dim)
- q = np.ones(dim) / dim
- while (err > stopThr and cpt < numItermax):
+ alpha = nx.zeros((dim,), type_as=M)
+ beta = nx.zeros((dim,), type_as=M)
+ q = nx.ones((dim,), type_as=M) / dim
+ for ii in range(numItermax):
qprev = q
- Kv = K.dot(v)
- u = A / (Kv + 1e-16)
- Ktu = K.T.dot(u)
+ Kv = nx.dot(K, v)
+ u = A / Kv
+ Ktu = nx.dot(K.T, u)
q = geometricBar(weights, Ktu)
Q = q[:, None]
- v = Q / (Ktu + 1e-16)
+ v = Q / Ktu
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha += reg * nx.log(nx.max(u, 1))
+ beta += reg * nx.log(nx.max(v, 1))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(tuple(v.shape), type_as=v)
+ Kv = nx.dot(K, v)
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration %s' % cpt)
+ warnings.warn('Numerical errors at iteration %s' % ii)
q = qprev
break
- if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ if (ii % 10 == 0 and not absorbing) or ii == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(u * Kv - A).max()
+ err = nx.max(nx.abs(u * Kv - A))
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 50 == 0:
+ if ii % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
- cpt += 1
- if err > stopThr:
- warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
- "Try a larger entropy `reg`" +
- "Or a larger absorption threshold `tau`.")
+ else:
+ if warn:
+ warnings.warn("Stabilized Sinkhorn did not converge." +
+ "Try a larger entropy `reg`" +
+ "Or a larger absorption threshold `tau`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
log['logu'] = np.log(u + 1e-16)
log['logv'] = np.log(v + 1e-16)
return q, log
@@ -1272,157 +1740,717 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
return q
-def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
- stopThr=1e-9, stabThr=1e-30, verbose=False,
- log=False):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
- where A is a collection of 2D images.
+def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs):
+ r"""Compute the debiased Sinkhorn barycenter of distributions A
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
- - reg is the regularization strength scalar value
+ - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence
+ (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
+ The algorithm used for solving the problem is the debiased Sinkhorn
+ algorithm as proposed in :ref:`[37] <references-barycenter-debiased>`
Parameters
----------
- A : ndarray, shape (n_hists, width, height)
- n distributions (2D images) of size width x height
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim`
+ M : array-like, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ method : str (optional)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
+
+
+ Returns
+ -------
+ a : (dim,) array-like
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-barycenter-debiased:
+ References
+ ----------
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
+ Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
+
+ if method.lower() == 'sinkhorn':
+ return _barycenter_debiased(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _barycenter_debiased_log(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the debiased sinkhorn barycenter of distributions A.
+ """
+
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
+ if weights is None:
+ weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1]
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ K = nx.exp(-M / reg)
+
+ err = 1
+
+ UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T)
+
+ u = (geometricMean(UKv) / UKv.T).T
+ c = nx.ones(A.shape[0], type_as=A)
+ bar = nx.ones(A.shape[0], type_as=A)
+
+ for ii in range(numItermax):
+ bold = bar
+ UKv = nx.dot(K, A / nx.dot(K, u))
+ bar = c * geometricBar(weights, UKv)
+ u = bar[:, None] / UKv
+ c = (c * bar / nx.dot(K, c)) ** 0.5
+
+ if ii % 10 == 9:
+ err = abs(bar - bold).max() / max(bar.max(), 1.)
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ # debiased Sinkhorn does not converge monotonically
+ # guarantee a few iterations are done before stopping
+ if err < stopThr and ii > 20:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return bar, log
+ else:
+ return bar
+
+
+def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False,
+ warn=True):
+ r"""Compute the debiased sinkhorn barycenter in log domain.
+ """
+
+ A, M = list_to_array(A, M)
+ dim, n_hists = A.shape
+
+ nx = get_backend(A, M)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ if weights is None:
+ weights = nx.ones(n_hists, type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ M = - M / reg
+ logA = nx.log(A + 1e-15)
+ log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
+ c = nx.zeros(dim, type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros(dim, type_as=A)
+ for k in range(n_hists):
+ f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1)
+ log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0)
+ log_bar += weights[k] * log_KU[:, k]
+ log_bar += c
+ if ii % 10 == 1:
+ err = nx.exp(G + log_KU).std(axis=1).sum()
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if err < stopThr and ii > 20:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+
+ G = log_bar[:, None] - log_KU
+ for _ in range(10):
+ c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0))
+
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
+def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False,
+ warn=True, **kwargs):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
+ where :math:`\mathbf{A}` is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions
+ of matrix :math:`\mathbf{A}`
+ - `reg` is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm
+ as proposed in :ref:`[21] <references-convolutional-barycenter-2d>`
+
+ Parameters
+ ----------
+ A : array-like, shape (n_hists, width, height)
+ `n` distributions (2D images) of size `width` x `height`
reg : float
Regularization term >0
- weights : ndarray, shape (n_hists,)
+ weights : array-like, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
+ method : string, optional
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- a : ndarray, shape (width, height)
+ a : array-like, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-convolutional-barycenter-2d:
References
----------
- .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
- Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
- ACM Transactions on Graphics (TOG), 34(4), 66
+ .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher,
+ A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances:
+ Efficient optimal transportation on geometric domains. ACM Transactions
+ on Graphics (TOG), 34(4), 66
+
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
+ International Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
+
+ if method.lower() == 'sinkhorn':
+ return _convolutional_barycenter2d(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _convolutional_barycenter2d_log(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-9, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images.
"""
+ A = list_to_array(A)
+
+ nx = get_backend(A)
+
if weights is None:
- weights = np.ones(A.shape[0]) / A.shape[0]
+ weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0]
else:
assert (len(weights) == A.shape[0])
if log:
log = {'err': []}
- b = np.zeros_like(A[0, :, :])
- U = np.ones_like(A)
- KV = np.ones_like(A)
-
- cpt = 0
+ bar = nx.ones(A.shape[1:], type_as=A)
+ bar /= bar.sum()
+ U = nx.ones(A.shape, type_as=A)
+ V = nx.ones(A.shape, type_as=A)
err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
- t = np.linspace(0, 1, A.shape[1])
- [Y, X] = np.meshgrid(t, t)
- xi1 = np.exp(-(X - Y) ** 2 / reg)
+ t = nx.linspace(0, 1, A.shape[1])
+ [Y, X] = nx.meshgrid(t, t)
+ K1 = nx.exp(-(X - Y) ** 2 / reg)
+
+ t = nx.linspace(0, 1, A.shape[2])
+ [Y, X] = nx.meshgrid(t, t)
+ K2 = nx.exp(-(X - Y) ** 2 / reg)
+
+ def convol_imgs(imgs):
+ kx = nx.einsum("...ij,kjl->kil", K1, imgs)
+ kxy = nx.einsum("...ij,klj->kli", K2, kx)
+ return kxy
+
+ KU = convol_imgs(U)
+ for ii in range(numItermax):
+ V = bar[None] / KU
+ KV = convol_imgs(V)
+ U = A / KV
+ KU = convol_imgs(U)
+ bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+ if ii % 10 == 9:
+ err = (V * KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+
+ else:
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ log['U'] = U
+ return bar, log
+ else:
+ return bar
+
+
+def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-4, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images in log-domain.
+ """
+
+ A = list_to_array(A)
+
+ nx = get_backend(A)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ n_hists, width, height = A.shape
+
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == n_hists)
+
+ if log:
+ log = {'err': []}
+
+ err = 1
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ M1 = - (X - Y) ** 2 / reg
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ M2 = - (X - Y) ** 2 / reg
+
+ def convol_img(log_img):
+ log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
+ log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
+ return log_img
+
+ logA = nx.log(A + stabThr)
+ log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros((width, height), type_as=A)
+ for k in range(n_hists):
+ f = logA[k] - convol_img(G[k])
+ log_KU[k] = convol_img(f)
+ log_bar = log_bar + weights[k] * log_KU[k]
+
+ if ii % 10 == 9:
+ err = nx.exp(G + log_KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ G = log_bar[None, :, :] - log_KU
+
+ else:
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
+def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn",
+ numItermax=10000, stopThr=1e-3,
+ verbose=False, log=False, warn=True,
+ **kwargs):
+ r"""Compute the debiased sinkhorn barycenter of distributions :math:`\mathbf{A}`
+ where :math:`\mathbf{A}` is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.barycenter_debiased`)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two
+ dimensions of matrix :math:`\mathbf{A}`
+ - `reg` is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the debiased Sinkhorn scaling
+ algorithm as proposed in :ref:`[37] <references-convolutional-barycenter2d-debiased>`
+
+ Parameters
+ ----------
+ A : array-like, shape (n_hists, width, height)
+ `n` distributions (2D images) of size `width` x `height`
+ reg : float
+ Regularization term >0
+ weights : array-like, shape (n_hists,)
+ Weights of each image on the simplex (barycentric coodinates)
+ method : string, optional
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ stabThr : float, optional
+ Stabilization threshold to avoid numerical precision issue
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
+
+
+ Returns
+ -------
+ a : array-like, shape (width, height)
+ 2D Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-convolutional-barycenter2d-debiased:
+ References
+ ----------
+
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
+ Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
- t = np.linspace(0, 1, A.shape[2])
- [Y, X] = np.meshgrid(t, t)
- xi2 = np.exp(-(X - Y) ** 2 / reg)
+ if method.lower() == 'sinkhorn':
+ return _convolutional_barycenter2d_debiased(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
- def K(x):
- return np.dot(np.dot(xi1, x), xi2)
- while (err > stopThr and cpt < numItermax):
+def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-3, stabThr=1e-15, verbose=False,
+ log=False, warn=True):
+ r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions.
+ """
- bold = b
- cpt = cpt + 1
+ A = list_to_array(A)
+ n_hists, width, height = A.shape
- b = np.zeros_like(A[0, :, :])
- for r in range(A.shape[0]):
- KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
- b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
- b = np.exp(b)
- for r in range(A.shape[0]):
- U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
+ nx = get_backend(A)
- if cpt % 10 == 1:
- err = np.sum(np.abs(bold - b))
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == n_hists)
+
+ if log:
+ log = {'err': []}
+
+ bar = nx.ones((width, height), type_as=A)
+ bar /= width * height
+ U = nx.ones(A.shape, type_as=A)
+ V = nx.ones(A.shape, type_as=A)
+ c = nx.ones(A.shape[1:], type_as=A)
+ err = 1
+
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ K1 = nx.exp(-(X - Y) ** 2 / reg)
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ K2 = nx.exp(-(X - Y) ** 2 / reg)
+
+ def convol_imgs(imgs):
+ kx = nx.einsum("...ij,kjl->kil", K1, imgs)
+ kxy = nx.einsum("...ij,klj->kli", K2, kx)
+ return kxy
+
+ KU = convol_imgs(U)
+ for ii in range(numItermax):
+ V = bar[None] / KU
+ KV = convol_imgs(V)
+ U = A / KV
+ KU = convol_imgs(U)
+ bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+
+ for _ in range(10):
+ c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5
+
+ if ii % 10 == 9:
+ err = (V * KU).std(axis=0).sum()
# log and verbose print
if log:
log['err'].append(err)
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
+ # debiased Sinkhorn does not converge monotonically
+ # guarantee a few iterations are done before stopping
+ if err < stopThr and ii > 20:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
log['U'] = U
- return b, log
+ return bar, log
+ else:
+ return bar
+
+
+def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-3, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the debiased barycenter of 2D images in log-domain.
+ """
+
+ A = list_to_array(A)
+ n_hists, width, height = A.shape
+ nx = get_backend(A)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[0])
+
+ if log:
+ log = {'err': []}
+
+ err = 1
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ M1 = - (X - Y) ** 2 / reg
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ M2 = - (X - Y) ** 2 / reg
+
+ def convol_img(log_img):
+ log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
+ log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
+ return log_img
+
+ logA = nx.log(A + stabThr)
+ log_bar, c = nx.zeros((2, width, height), type_as=A)
+ log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros((width, height), type_as=A)
+ for k in range(n_hists):
+ f = logA[k] - convol_img(G[k])
+ log_KU[k] = convol_img(f)
+ log_bar = log_bar + weights[k] * log_KU[k]
+ log_bar += c
+ for _ in range(10):
+ c = 0.5 * (c + log_bar - convol_img(c))
+
+ if ii % 10 == 9:
+ err = nx.exp(G + log_KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr and ii > 20:
+ break
+ G = log_bar[None, :, :] - log_KU
+
+ else:
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
else:
- return b
+ return nx.exp(log_bar)
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
- stopThr=1e-3, verbose=False, log=False):
+ stopThr=1e-3, verbose=False, log=False, warn=True):
r"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
The function solve the following optimization problem:
.. math::
- \mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h})
+
+ \mathbf{h} = \mathop{\arg \min}_\mathbf{h} \quad
+ (1 - \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a}, \mathbf{Dh}) +
+ \alpha W_{\mathbf{M_0}, \mathrm{reg}_0}(\mathbf{h}_0, \mathbf{h})
where :
- - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
- - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)`
+ - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`,
+ its expected shape is `(dim_a, n_atoms)`
- :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
- :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
- - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting
- - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization
- - :math:`\\alpha`weight data fitting and regularization
+ - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the
+ cost matrix (`dim_a`, `dim_a`) for OT data fitting
+ - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization
+ term and the cost matrix (`dim_prior`, `n_atoms`) regularization
+ - :math:`\alpha` weight data fitting and regularization
- The optimization problem is solved suing the algorithm described in [4]
+ The optimization problem is solved following the algorithm described
+ in :ref:`[4] <references-unmix>`
Parameters
----------
- a : ndarray, shape (dim_a)
+ a : array-like, shape (dim_a)
observed distribution (histogram, sums to 1)
- D : ndarray, shape (dim_a, n_atoms)
+ D : array-like, shape (dim_a, n_atoms)
dictionary matrix
- M : ndarray, shape (dim_a, dim_a)
+ M : array-like, shape (dim_a, dim_a)
loss matrix
- M0 : ndarray, shape (n_atoms, dim_prior)
+ M0 : array-like, shape (n_atoms, dim_prior)
loss matrix
- h0 : ndarray, shape (n_atoms,)
+ h0 : array-like, shape (n_atoms,)
prior on the estimated unmixing h
reg : float
Regularization term >0 (Wasserstein data fitting)
@@ -1433,105 +2461,125 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
-
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- h : ndarray, shape (n_atoms,)
+ h : array-like, shape (n_atoms,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-unmix:
References
----------
- .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016.
-
+ .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti,
+ Supervised planetary unmixing with optimal transport, Whorkshop
+ on Hyperspectral Image and Signal Processing :
+ Evolution in Remote Sensing (WHISPERS), 2016.
"""
+ a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0)
+
+ nx = get_backend(a, D, M, M0, h0)
+
# M = M/np.median(M)
- K = np.exp(-M / reg)
+ K = nx.exp(-M / reg)
# M0 = M0/np.median(M0)
- K0 = np.exp(-M0 / reg0)
+ K0 = nx.exp(-M0 / reg0)
old = h0
err = 1
- cpt = 0
# log = {'niter':0, 'all_err':[]}
if log:
log = {'err': []}
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
K = projC(K, a)
K0 = projC(K0, h0)
- new = np.sum(K0, axis=1)
+ new = nx.sum(K0, axis=1)
# we recombine the current selection from dictionnary
- inv_new = np.dot(D, new)
- other = np.sum(K, axis=1)
+ inv_new = nx.dot(D, new)
+ other = nx.sum(K, axis=1)
# geometric interpolation
- delta = np.exp(alpha * np.log(other) + (1 - alpha) * np.log(inv_new))
+ delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new))
K = projR(K, delta)
- K0 = np.dot(np.diag(np.dot(D.T, delta / inv_new)), K0)
+ K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0)
- err = np.linalg.norm(np.sum(K0, axis=1) - old)
+ err = nx.norm(nx.sum(K0, axis=1) - old)
old = new
if log:
log['err'].append(err)
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- cpt = cpt + 1
-
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ else:
+ if warn:
+ warnings.warn("Unmixing algorithm did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
- return np.sum(K0, axis=1), log
+ log['niter'] = ii
+ return nx.sum(K0, axis=1), log
else:
- return np.sum(K0, axis=1)
+ return nx.sum(K0, axis=1)
def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
- stopThr=1e-6, verbose=False, log=False, **kwargs):
- r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
+ stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):
+ r'''Joint OT and proportion estimation for multi-source target shift as
+ proposed in :ref:`[27] <references-jcpot-barycenter>`
The function solves the following optimization problem:
.. math::
- \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k
+ \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \quad \sum_{k=1}^{K} \lambda_k
W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a})
s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
where :
- - :math:`\lambda_k` is the weight of k-th source domain
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes
- - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C
+ - :math:`\lambda_k` is the weight of `k`-th source domain
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain
+ defined as in [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape
+ is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source
+ domain and `C` is the number of classes
+ - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C`
- :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
- - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)`
+ - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in
+ [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape is :math:`(n_k, C)`
- The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
+ The problem consist in solving a Wasserstein barycenter problem to estimate
+ the proportions :math:`\mathbf{h}` in the target domain.
The algorithm used for solving the problem is the Iterative Bregman projections algorithm
- with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution.
+ with two sets of marginal constraints related to the unknown vector
+ :math:`\mathbf{h}` and uniform target distribution.
Parameters
----------
- Xs : list of K np.ndarray(nsk,d)
+ Xs : list of K array-like(nsk,d)
features of all source domains' samples
- Ys : list of K np.ndarray(nsk,)
+ Ys : list of K array-like(nsk,)
labels of all source domains' samples
- Xt : np.ndarray (nt,d)
+ Xt : array-like (nt,d)
samples in the target domain
reg : float
Regularization term > 0
@@ -1541,28 +2589,37 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Max number of iterations
stopThr : float, optional
Stop threshold on relative change in the barycenter (>0)
- log : bool, optional
- record log if True
verbose : bool, optional (default=False)
Controls the verbosity of the optimization algorithm
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- h : (C,) ndarray
+ h : (C,) array-like
proportion estimation in the target domain
log : dict
log dictionary return only if log==True in parameters
+ .. _references-jcpot-barycenter:
References
----------
.. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
- "Optimal transport for multi-source domain adaptation under target shift",
- International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
-
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
'''
- nbclasses = len(np.unique(Ys[0]))
+
+ Xs = list_to_array(*Xs)
+ Ys = list_to_array(*Ys)
+ Xt = list_to_array(Xt)
+
+ nx = get_backend(*Xs, *Ys, Xt)
+
+ nbclasses = len(nx.unique(Ys[0]))
nbdomains = len(Xs)
# log dictionary
@@ -1579,19 +2636,19 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
dom = {}
nsk = Xs[d].shape[0] # get number of elements for this domain
dom['nbelem'] = nsk
- classes = np.unique(Ys[d]) # get number of classes for this domain
+ classes = nx.unique(Ys[d]) # get number of classes for this domain
# format classes to start from 0 for convenience
- if np.min(classes) != 0:
- Ys[d] = Ys[d] - np.min(classes)
- classes = np.unique(Ys[d])
+ if nx.min(classes) != 0:
+ Ys[d] -= nx.min(classes)
+ classes = nx.unique(Ys[d])
# build the corresponding D_1 and D_2 matrices
- Dtmp1 = np.zeros((nbclasses, nsk))
- Dtmp2 = np.zeros((nbclasses, nsk))
+ Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
+ Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
for c in classes:
- nbelemperclass = np.sum(Ys[d] == c)
+ nbelemperclass = nx.sum(Ys[d] == c)
if nbelemperclass != 0:
Dtmp1[int(c), Ys[d] == c] = 1.
Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
@@ -1602,51 +2659,54 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Mtmp = dist(Xs[d], Xt, metric=metric)
M.append(Mtmp)
- Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)
- np.divide(Mtmp, -reg, out=Ktmp)
- np.exp(Ktmp, out=Ktmp)
+ Ktmp = nx.exp(-Mtmp / reg)
K.append(Ktmp)
# uniform target distribution
- a = unif(np.shape(Xt)[0])
+ a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0])
- cpt = 0 # iterations count
err = 1
- old_bary = np.ones((nbclasses))
+ old_bary = nx.ones((nbclasses,), type_as=Xs[0])
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
- bary = np.zeros((nbclasses))
+ bary = nx.zeros((nbclasses,), type_as=Xs[0])
# update coupling matrices for marginal constraints w.r.t. uniform target distribution
for d in range(nbdomains):
K[d] = projC(K[d], a)
- other = np.sum(K[d], axis=1)
- bary = bary + np.log(np.dot(D1[d], other)) / nbdomains
+ other = nx.sum(K[d], axis=1)
+ bary += nx.log(nx.dot(D1[d], other)) / nbdomains
- bary = np.exp(bary)
+ bary = nx.exp(bary)
# update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
for d in range(nbdomains):
- new = np.dot(D2[d].T, bary)
+ new = nx.dot(D2[d].T, bary)
K[d] = projR(K[d], new)
- err = np.linalg.norm(bary - old_bary)
- cpt = cpt + 1
+ err = nx.norm(bary - old_bary)
+
old_bary = bary
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- bary = bary / np.sum(bary)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Algorithm did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ bary = bary / nx.sum(bary)
if log:
- log['niter'] = cpt
+ log['niter'] = ii
log['M'] = M
log['D1'] = D1
log['D2'] = D2
@@ -1657,8 +2717,8 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9, verbose=False,
- log=False, **kwargs):
+ numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
+ log=False, warn=True, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -1666,45 +2726,56 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return
+ the dual potentials only (to save memory). If False, calculate full
+ cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batches used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
+ gamma : array-like, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1715,9 +2786,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
>>> n_samples_a = 2
>>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
- >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
- >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
+ >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
+ >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
array([[4.99977301e-01, 2.26989344e-05],
[2.26989344e-05, 4.99977301e-01]])
@@ -1725,30 +2796,115 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = nx.from_numpy(unif(nt), type_as=X_s)
+
+ if isLazy:
+ if log:
+ dict_log = {"err": []}
- M = dist(X_s, X_t, metric=metric)
+ log_a, log_b = nx.log(a), nx.log(b)
+ f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+
+ if isinstance(batchSize, int):
+ bs, bt = batchSize, batchSize
+ elif isinstance(batchSize, tuple) and len(batchSize) == 2:
+ bs, bt = batchSize[0], batchSize[1]
+ else:
+ raise ValueError("Batch size must be in integer or a tuple of two integers")
+
+ range_s, range_t = range(0, ns, bs), range(0, nt, bt)
+
+ lse_f = nx.zeros((ns,), type_as=a)
+ lse_g = nx.zeros((nt,), type_as=a)
+
+ X_s_np = nx.to_numpy(X_s)
+ X_t_np = nx.to_numpy(X_t)
+
+ for i_ot in range(numIterMax):
+
+ lse_f_cols = []
+ for i in range_s:
+ M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ lse_f_cols.append(
+ nx.logsumexp(g[None, :] - M / reg, axis=1)
+ )
+ lse_f = nx.concatenate(lse_f_cols, axis=0)
+ f = log_a - lse_f
+
+ lse_g_cols = []
+ for j in range_t:
+ M = dist(X_s_np, X_t_np[j:j + bt, :], metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ lse_g_cols.append(
+ nx.logsumexp(f[:, None] - M / reg, axis=0)
+ )
+ lse_g = nx.concatenate(lse_g_cols, axis=0)
+ g = log_b - lse_g
+
+ if (i_ot + 1) % 10 == 0:
+ m1_cols = []
+ for i in range_s:
+ M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ m1_cols.append(
+ nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1)
+ )
+ m1 = nx.concatenate(m1_cols, axis=0)
+ err = nx.sum(nx.abs(m1 - a))
+ if log:
+ dict_log["err"].append(err)
+
+ if verbose and (i_ot + 1) % 100 == 0:
+ print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
+
+ if err <= stopThr:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ dict_log["u"] = f
+ dict_log["v"] = g
+ return (f, g, dict_log)
+ else:
+ return (f, g)
- if log:
- pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
- return pi, log
else:
- pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
- return pi
+ M = dist(X_s, X_t, metric=metric)
+ if log:
+ pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=True, **kwargs)
+ return pi, log
+ else:
+ pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=False, **kwargs)
+ return pi
-def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, **kwargs):
+def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9, isLazy=False,
+ batchSize=100, verbose=False, log=False, warn=True, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -1757,46 +2913,57 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
+ isLazy: boolean, optional
+ If True, then only calculate the cost matrix by block and return
+ the dual potentials only (to save memory). If False, calculate
+ full cost matrix and return outputs of sinkhorn function.
+ batchSize: int or tuple of 2 int, optional
+ Size of the batches used to compute the sinkhorn update without memory overhead.
+ When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
- Regularized optimal transportation matrix for the given parameters
+ W : (n_hists) array-like or float
+ Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1806,41 +2973,94 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
>>> n_samples_a = 2
>>> n_samples_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
- >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
- >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
- array([4.53978687e-05])
+ >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
+ >>> b = np.full((n_samples_b, 3), 1/n_samples_b)
+ >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False)
+ array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05])
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling
+ Algorithms for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
'''
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+
+ ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(np.shape(X_s)[0])
+ a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
- b = unif(np.shape(X_t)[0])
+ b = nx.from_numpy(unif(nt), type_as=X_s)
- M = dist(X_s, X_t, metric=metric)
+ if isLazy:
+ if log:
+ f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
+ numIterMax=numIterMax,
+ stopThr=stopThr,
+ isLazy=isLazy,
+ batchSize=batchSize,
+ verbose=verbose, log=log,
+ warn=warn)
+ else:
+ f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
+ numIterMax=numIterMax, stopThr=stopThr,
+ isLazy=isLazy, batchSize=batchSize,
+ verbose=verbose, log=log,
+ warn=warn)
+
+ bs = batchSize if isinstance(batchSize, int) else batchSize[0]
+ range_s = range(0, ns, bs)
+
+ loss = 0
+
+ X_s_np = nx.to_numpy(X_s)
+ X_t_np = nx.to_numpy(X_t)
+
+ for i in range_s:
+ M_block = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M_block = nx.from_numpy(M_block, type_as=a)
+ pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
+ loss += nx.sum(M_block * pi_block)
+
+ if log:
+ return loss, dict_log
+ else:
+ return loss
- if log:
- sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss, log
else:
- sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
- return sinkhorn_loss
+ M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+
+ if log:
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
+ return sinkhorn_loss, log
+ else:
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
+ return sinkhorn_loss
-def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, **kwargs):
+def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9,
+ verbose=False, log=False, warn=True,
+ **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data
@@ -1849,64 +3069,72 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
.. math::
- W &= \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ W &= \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- W_a &= \min_{\gamma_a} <\gamma_a,M_a>_F + reg\cdot\Omega(\gamma_a)
+ W_a &= \min_{\gamma_a} \quad \langle \gamma_a, \mathbf{M_a} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma_a)
- W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b)
+ W_b &= \min_{\gamma_b} \quad \langle \gamma_b, \mathbf{M_b} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma_b)
- S &= W - 1/2 * (W_a + W_b)
+ S &= W - \frac{W_a + W_b}{2}
.. math::
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
- \gamma_a 1 = a
+ \gamma_a \mathbf{1} &= \mathbf{a}
- \gamma_a^T 1= a
+ \gamma_a^T \mathbf{1} &= \mathbf{a}
- \gamma_a\geq 0
+ \gamma_a &\geq 0
- \gamma_b 1 = b
+ \gamma_b \mathbf{1} &= \mathbf{b}
- \gamma_b^T 1= b
+ \gamma_b^T \mathbf{1} &= \mathbf{b}
- \gamma_b\geq 0
+ \gamma_b &\geq 0
where :
- - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b))
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`)
+ is the (`n_samples_a`, `n_samples_b`) metric cost matrix
+ (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`))
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
- Regularized optimal transportation matrix for the given parameters
+ W : (1,) array-like
+ Optimal transportation symmetrized loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1915,27 +3143,36 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
>>> n_samples_a = 2
>>> n_samples_b = 4
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
- >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
>>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS
- array([1.499...])
+ 1.499887176049052
References
----------
- .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
+ .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative
+ Models with Sinkhorn Divergences, Proceedings of the Twenty-First
+ International Conference on Artficial Intelligence and Statistics,
+ (AISTATS) 21, 2018
'''
if log:
- sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
- sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
- sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
log = {}
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
@@ -1948,99 +3185,119 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
return max(0, sinkhorn_div), log
else:
- sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
+
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
+
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
+
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ return max(0, sinkhorn_div)
- sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
- sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
+def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
+ restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09,
+ verbose=False, log=False):
+ r"""
+ Screening Sinkhorn Algorithm for Regularized Optimal Transport
- sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
- return max(0, sinkhorn_div)
+ The function solves an approximate dual of Sinkhorn divergence :ref:`[2]
+ <references-screenkhorn>` which is written as the following optimization problem:
+
+ .. math::
+ (\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \quad
+ \mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} -
+ \langle \kappa \mathbf{u}, \mathbf{a} \rangle -
+ \langle \frac{1}{\kappa} \mathbf{v}, \mathbf{b} \rangle
-def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True,
- maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
- r""""
- Screening Sinkhorn Algorithm for Regularized Optimal Transport
+ where:
- The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem:
+ .. math::
- ..math::
- (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b>
+ \mathbf{B}(\mathbf{u}, \mathbf{v}) = \mathrm{diag}(e^\mathbf{u}) \mathbf{K} \mathrm{diag}(e^\mathbf{v}) \text{, with } \mathbf{K} = e^{-\mathbf{M} / \mathrm{reg}} \text{ and}
- where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and
+ .. math::
- s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns}
+ s.t. \ e^{u_i} &\geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\}
- e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt}
+ e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\}
- The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26]
+ The parameters `kappa` and `epsilon` are determined w.r.t the couple number
+ budget of points (`ns_budget`, `nt_budget`), see Equation (5)
+ in :ref:`[26] <references-screenkhorn>`
Parameters
----------
- a : `numpy.ndarray`, shape=(ns,)
+ a: array-like, shape=(ns,)
samples weights in the source domain
-
- b : `numpy.ndarray`, shape=(nt,)
+ b: array-like, shape=(nt,)
samples weights in the target domain
-
- M : `numpy.ndarray`, shape=(ns, nt)
+ M: array-like, shape=(ns, nt)
Cost matrix
-
- reg : `float`
+ reg: `float`
Level of the entropy regularisation
-
- ns_budget : `int`, deafult=None
- Number budget of points to be keeped in the source domain
- If it is None then 50% of the source sample points will be keeped
-
- nt_budget : `int`, deafult=None
- Number budget of points to be keeped in the target domain
- If it is None then 50% of the target sample points will be keeped
-
- uniform : `bool`, default=False
- If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt
-
+ ns_budget: `int`, default=None
+ Number budget of points to be kept in the source domain.
+ If it is None then 50% of the source sample points will be kept
+ nt_budget: `int`, default=None
+ Number budget of points to be kept in the target domain.
+ If it is None then 50% of the target sample points will be kept
+ uniform: `bool`, default=False
+ If `True`, the source and target distribution are supposed to be uniform,
+ i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt`
restricted : `bool`, default=True
If `True`, a warm-start initialization for the L-BFGS-B solver
using a restricted Sinkhorn algorithm with at most 5 iterations
-
- maxiter : `int`, default=10000
+ maxiter: `int`, default=10000
Maximum number of iterations in LBFGS solver
+ maxfun: `int`, default=10000
+ Maximum number of function evaluations in LBFGS solver
+ pgtol: `float`, default=1e-09
+ Final objective function accuracy in LBFGS solver
+ verbose: `bool`, default=False
+ If `True`, display informations about the cardinals of the active sets
+ and the parameters kappa and epsilon
- maxfun : `int`, default=10000
- Maximum number of function evaluations in LBFGS solver
- pgtol : `float`, default=1e-09
- Final objective function accuracy in LBFGS solver
+ .. admonition:: Dependency
- verbose : `bool`, default=False
- If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa
- and epsilon
+ To gain more efficiency, :py:func:`ot.bregman.screenkhorn` needs to call the "Bottleneck"
+ package (https://pypi.org/project/Bottleneck/) in the screening pre-processing step.
- Dependency
- ----------
- To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/)
- in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears:
- "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"
+ If Bottleneck isn't installed, the following error message appears:
+
+ "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"
Returns
-------
- gamma : `numpy.ndarray`, shape=(ns, nt)
+ gamma : array-like, shape=(ns, nt)
Screened optimal transportation matrix for the given parameters
log : `dict`, default=False
Log dictionary return only if log==True in parameters
+ .. _references-screenkhorn:
References
-----------
- .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport,
+ Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
+ .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019).
+ Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
"""
# check if bottleneck module exists
@@ -2048,12 +3305,17 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
import bottleneck
except ImportError:
warnings.warn(
- "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.")
+ "Bottleneck module is not installed. Install it from"
+ " https://pypi.org/project/Bottleneck/ for better performance.")
bottleneck = np
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+ if nx.__name__ == "jax":
+ raise TypeError("JAX arrays have been received but screenkhorn is not "
+ "compatible with JAX.")
+
ns, nt = M.shape
# by default, we keep only 50% of the sample data points
@@ -2063,9 +3325,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
nt_budget = int(np.floor(0.5 * nt))
# calculate the Gibbs kernel
- K = np.empty_like(M)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
def projection(u, epsilon):
u[u <= epsilon] = epsilon
@@ -2077,8 +3337,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
if ns_budget == ns and nt_budget == nt:
# full number of budget points (ns, nt) = (ns_budget, nt_budget)
- Isel = np.ones(ns, dtype=bool)
- Jsel = np.ones(nt, dtype=bool)
+ Isel = nx.from_numpy(np.ones(ns, dtype=bool))
+ Jsel = nx.from_numpy(np.ones(nt, dtype=bool))
epsilon = 0.0
kappa = 1.0
@@ -2094,57 +3354,63 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
K_IJc = []
K_IcJ = []
- vec_eps_IJc = np.zeros(nt)
- vec_eps_IcJ = np.zeros(ns)
+ vec_eps_IJc = nx.zeros((nt,), type_as=M)
+ vec_eps_IcJ = nx.zeros((ns,), type_as=M)
else:
# sum of rows and columns of K
- K_sum_cols = K.sum(axis=1)
- K_sum_rows = K.sum(axis=0)
+ K_sum_cols = nx.sum(K, axis=1)
+ K_sum_rows = nx.sum(K, axis=0)
if uniform:
if ns / ns_budget < 4:
- aK_sort = np.sort(K_sum_cols)
+ aK_sort = nx.sort(K_sum_cols)
epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
else:
- aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1]
+ aK_sort = nx.from_numpy(
+ bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1],
+ type_as=M
+ )
epsilon_u_square = a[0] / aK_sort
if nt / nt_budget < 4:
- bK_sort = np.sort(K_sum_rows)
+ bK_sort = nx.sort(K_sum_rows)
epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
else:
- bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1]
+ bK_sort = nx.from_numpy(
+ bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1],
+ type_as=M
+ )
epsilon_v_square = b[0] / bK_sort
else:
aK = a / K_sum_cols
bK = b / K_sum_rows
- aK_sort = np.sort(aK)[::-1]
+ aK_sort = nx.flip(nx.sort(aK), axis=0)
epsilon_u_square = aK_sort[ns_budget - 1]
- bK_sort = np.sort(bK)[::-1]
+ bK_sort = nx.flip(nx.sort(bK), axis=0)
epsilon_v_square = bK_sort[nt_budget - 1]
# active sets I and J (see Lemma 1 in [26])
Isel = a >= epsilon_u_square * K_sum_cols
Jsel = b >= epsilon_v_square * K_sum_rows
- if sum(Isel) != ns_budget:
+ if nx.sum(Isel) != ns_budget:
if uniform:
aK = a / K_sum_cols
- aK_sort = np.sort(aK)[::-1]
- epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean()
+ aK_sort = nx.flip(nx.sort(aK), axis=0)
+ epsilon_u_square = nx.mean(aK_sort[ns_budget - 1:ns_budget + 1])
Isel = a >= epsilon_u_square * K_sum_cols
- ns_budget = sum(Isel)
+ ns_budget = nx.sum(Isel)
- if sum(Jsel) != nt_budget:
+ if nx.sum(Jsel) != nt_budget:
if uniform:
bK = b / K_sum_rows
- bK_sort = np.sort(bK)[::-1]
- epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean()
+ bK_sort = nx.flip(nx.sort(bK), axis=0)
+ epsilon_v_square = nx.mean(bK_sort[nt_budget - 1:nt_budget + 1])
Jsel = b >= epsilon_v_square * K_sum_rows
- nt_budget = sum(Jsel)
+ nt_budget = nx.sum(Jsel)
epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4)
kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2)
@@ -2152,7 +3418,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
if verbose:
print("epsilon = %s\n" % epsilon)
print("kappa = %s\n" % kappa)
- print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel)))
+ print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n'
+ % (sum(Isel), sum(Jsel)))
# Ic, Jc: complementary of the active sets I and J
Ic = ~Isel
@@ -2162,18 +3429,18 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
K_IcJ = K[np.ix_(Ic, Jsel)]
K_IJc = K[np.ix_(Isel, Jc)]
- K_min = K_IJ.min()
+ K_min = nx.min(K_IJ)
if K_min == 0:
- K_min = np.finfo(float).tiny
+ K_min = float(np.finfo(float).tiny)
# a_I, b_J, a_Ic, b_Jc
a_I = a[Isel]
b_J = b[Jsel]
if not uniform:
- a_I_min = a_I.min()
- a_I_max = a_I.max()
- b_J_max = b_J.max()
- b_J_min = b_J.min()
+ a_I_min = nx.min(a_I)
+ a_I_max = nx.max(a_I)
+ b_J_max = nx.max(b_J)
+ b_J_min = nx.min(b_J)
else:
a_I_min = a_I[0]
a_I_max = a_I[0]
@@ -2182,33 +3449,37 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
# box constraints in L-BFGS-B (see Proposition 1 in [26])
bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
- ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
+ ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
bounds_v = [(
- max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
- epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
+ max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
+ epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
# pre-calculated constants for the objective
- vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)
- vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0)
+ vec_eps_IJc = epsilon * kappa * nx.sum(
+ K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :],
+ axis=1
+ )
+ vec_eps_IcJ = (epsilon / kappa) * nx.sum(
+ nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ,
+ axis=0
+ )
# initialisation
- u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa)
- v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa)
+ u0 = nx.full((ns_budget,), 1. / ns_budget + epsilon / kappa, type_as=M)
+ v0 = nx.full((nt_budget,), 1. / nt_budget + epsilon * kappa, type_as=M)
# pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26])
if restricted:
if ns_budget != ns or nt_budget != nt:
- cst_u = kappa * epsilon * K_IJc.sum(axis=1)
- cst_v = epsilon * K_IcJ.sum(axis=0) / kappa
+ cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1)
+ cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa
- cpt = 1
- while cpt < 5: # 5 iterations
- K_IJ_v = np.dot(K_IJ.T, u0) + cst_v
+ for _ in range(5): # 5 iterations
+ K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v
v0 = b_J / (kappa * K_IJ_v)
- KIJ_u = np.dot(K_IJ, v0) + cst_u
+ KIJ_u = nx.dot(K_IJ, v0) + cst_u
u0 = (kappa * a_I) / KIJ_u
- cpt += 1
u0 = projection(u0, epsilon / kappa)
v0 = projection(v0, epsilon * kappa)
@@ -2219,15 +3490,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
def restricted_sinkhorn(usc, vsc, max_iter=5):
"""
- Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26])
+ Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B)
"""
- cpt = 1
- while cpt < max_iter:
- K_IJ_v = np.dot(K_IJ.T, usc) + cst_v
+ for _ in range(max_iter):
+ K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v
vsc = b_J / (kappa * K_IJ_v)
- KIJ_u = np.dot(K_IJ, vsc) + cst_u
+ KIJ_u = nx.dot(K_IJ, vsc) + cst_u
usc = (kappa * a_I) / KIJ_u
- cpt += 1
usc = projection(usc, epsilon / kappa)
vsc = projection(vsc, epsilon * kappa)
@@ -2235,17 +3504,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
return usc, vsc
def screened_obj(usc, vsc):
- part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J,
- np.log(vsc))
- part_IJc = np.dot(usc, vec_eps_IJc)
- part_IcJ = np.dot(vec_eps_IcJ, vsc)
+ part_IJ = (
+ nx.dot(nx.dot(usc, K_IJ), vsc)
+ - kappa * nx.dot(a_I, nx.log(usc))
+ - (1. / kappa) * nx.dot(b_J, nx.log(vsc))
+ )
+ part_IJc = nx.dot(usc, vec_eps_IJc)
+ part_IcJ = nx.dot(vec_eps_IcJ, vsc)
psi_epsilon = part_IJ + part_IJc + part_IcJ
return psi_epsilon
def screened_grad(usc, vsc):
# gradients of Psi_(kappa,epsilon) w.r.t u and v
- grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
- grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc
+ grad_u = nx.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
+ grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc
return grad_u, grad_v
def bfgspost(theta):
@@ -2255,20 +3527,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
f = screened_obj(u, v)
# gradient
g_u, g_v = screened_grad(u, v)
- g = np.hstack([g_u, g_v])
- return f, g
+ g = nx.concatenate([g_u, g_v], axis=0)
+ return nx.to_numpy(f), nx.to_numpy(g)
# ----------------------------------------------------------------------------------------------------------------#
# Step 2: L-BFGS-B solver #
# ----------------------------------------------------------------------------------------------------------------#
u0, v0 = restricted_sinkhorn(u0, v0)
- theta0 = np.hstack([u0, v0])
+ theta0 = nx.concatenate([u0, v0], axis=0)
bounds = bounds_u + bounds_v # constraint bounds
def obj(theta):
- return bfgspost(theta)
+ return bfgspost(nx.from_numpy(theta, type_as=M))
theta, _, _ = fmin_l_bfgs_b(func=obj,
x0=theta0,
@@ -2276,12 +3548,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
maxfun=maxfun,
pgtol=pgtol,
maxiter=maxiter)
+ theta = nx.from_numpy(theta, type_as=M)
usc = theta[:ns_budget]
vsc = theta[ns_budget:]
- usc_full = np.full(ns, epsilon / kappa)
- vsc_full = np.full(nt, epsilon * kappa)
+ usc_full = nx.full((ns,), epsilon / kappa, type_as=M)
+ vsc_full = nx.full((nt,), epsilon * kappa, type_as=M)
usc_full[Isel] = usc
vsc_full[Jsel] = vsc
@@ -2293,7 +3566,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
log['Jsel'] = Jsel
gamma = usc_full[:, None] * K * vsc_full[None, :]
- gamma = gamma / gamma.sum()
+ gamma = gamma / nx.sum(gamma)
if log:
return gamma, log
diff --git a/ot/da.py b/ot/da.py
index b881a8b..4fd97df 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -26,34 +26,36 @@ from .optim import gcg
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
log=False):
- """
+ r"""
Solve the entropic regularization optimal transport problem with nonconvex
group lasso regularization
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)
- + \eta \Omega_g(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot \Omega_e(\gamma) + \eta \ \Omega_g(\gamma)
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
+
+ \gamma^T \mathbf{1} = \mathbf{b}
- s.t. \gamma 1 = a
+ \gamma \geq 0
- \gamma^T 1= b
- \gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e
(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\Omega_g` is the group lasso regularization term
:math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1`
- where :math:`\mathcal{I}_c` are the index of samples from class c
+ where :math:`\mathcal{I}_c` are the index of samples from class `c`
in the source domain.
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the generalized conditional
- gradient as proposed in [5]_ [7]_
+ gradient as proposed in :ref:`[5, 7] <references-sinkhorn-lpl1-mm>`.
Parameters
@@ -84,19 +86,20 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-sinkhorn-lpl1-mm:
References
----------
-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE
Transactions on Pattern Analysis and Machine Intelligence ,
vol.PP, no.99, pp.1-1
+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence
and applications. arXiv preprint arXiv:1510.06567.
@@ -137,34 +140,36 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
log=False):
- """
+ r"""
Solve the entropic regularization optimal transport problem with group
lasso regularization
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+
- \eta \Omega_g(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot \Omega_e(\gamma) + \eta \ \Omega_g(\gamma)
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
+
+ \gamma^T \mathbf{1} = \mathbf{b}
- s.t. \gamma 1 = a
+ \gamma \geq 0
- \gamma^T 1= b
- \gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega_e` is the entropic regularization term
:math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\Omega_g` is the group lasso regulaization term
:math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2`
where :math:`\mathcal{I}_c` are the index of samples from class
- c in the source domain.
- - a and b are source and target weights (sum to 1)
+ `c` in the source domain.
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the generalised conditional
- gradient as proposed in [5]_ [7]_
+ gradient as proposed in :ref:`[5, 7] <references-sinkhorn-l1l2-gl>`.
Parameters
@@ -195,18 +200,19 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-sinkhorn-l1l2-gl:
References
----------
-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE Transactions
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence and
applications. arXiv preprint arXiv:1510.06567.
@@ -245,38 +251,40 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
verbose2=False, numItermax=100, numInnerItermax=10,
stopInnerThr=1e-6, stopThr=1e-5, log=False,
**kwargs):
- """Joint OT and linear mapping estimation as proposed in [8]
+ r"""Joint OT and linear mapping estimation as proposed in
+ :ref:`[8] <references-joint-OT-mapping-linear>`.
The function solves the following optimization problem:
.. math::
- \min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F +
- \mu<\gamma,M>_F + \eta \|L -I\|^2_F
+ \min_{\gamma,L}\quad \|L(\mathbf{X_s}) - n_s\gamma \mathbf{X_t} \|^2_F +
+ \mu \langle \gamma, \mathbf{M} \rangle_F + \eta \|L - \mathbf{I}\|^2_F
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
- \gamma\geq 0
where :
- - M is the (ns,nt) squared euclidean cost matrix between samples in
- Xs and Xt (scaled by ns)
- - :math:`L` is a dxd linear operator that approximates the barycentric
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in
+ :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`)
+ - :math:`L` is a :math:`d\times d` linear operator that approximates the barycentric
mapping
- - :math:`I` is the identity matrix (neutral linear mapping)
- - a and b are uniform source and target weights
+ - :math:`\mathbf{I}` is the identity matrix (neutral linear mapping)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights
The problem consist in solving jointly an optimal transport matrix
:math:`\gamma` and a linear mapping that fits the barycentric mapping
- :math:`n_s\gamma X_t`.
+ :math:`n_s\gamma \mathbf{X_t}`.
One can also estimate a mapping with constant bias (see supplementary
- material of [8]) using the bias optional argument.
+ material of :ref:`[8] <references-joint-OT-mapping-linear>`) using the bias optional argument.
The algorithm used for solving the problem is the block coordinate
- descent that alternates between updates of G (using conditionnal gradient)
- and the update of L using a classical least square solver.
+ descent that alternates between updates of :math:`\mathbf{G}` (using conditionnal gradient)
+ and the update of :math:`\mathbf{L}` using a classical least square solver.
Parameters
@@ -307,17 +315,17 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
- L : (d x d) ndarray
- Linear mapping matrix (d+1 x d if bias)
+ L : (d, d) ndarray
+ Linear mapping matrix ((:math:`d+1`, `d`) if bias)
log : dict
log dictionary return only if log==True in parameters
+ .. _references-joint-OT-mapping-linear:
References
----------
-
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
"Mapping estimation for discrete optimal transport",
Neural Information Processing Systems (NIPS), 2016.
@@ -434,37 +442,41 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
numItermax=100, numInnerItermax=10,
stopInnerThr=1e-6, stopThr=1e-5, log=False,
**kwargs):
- """Joint OT and nonlinear mapping estimation with kernels as proposed in [8]
+ r"""Joint OT and nonlinear mapping estimation with kernels as proposed in
+ :ref:`[8] <references-joint-OT-mapping-kernel>`.
The function solves the following optimization problem:
.. math::
- \min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) -
- n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H}
+ \min_{\gamma, L\in\mathcal{H}}\quad \|L(\mathbf{X_s}) -
+ n_s\gamma \mathbf{X_t}\|^2_F + \mu \langle \gamma, \mathbf{M} \rangle_F +
+ \eta \|L\|^2_\mathcal{H}
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- s.t. \gamma 1 = a
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
- \gamma^T 1= b
- \gamma\geq 0
where :
- - M is the (ns,nt) squared euclidean cost matrix between samples in
- Xs and Xt (scaled by ns)
- - :math:`L` is a ns x d linear operator on a kernel matrix that
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) squared euclidean cost matrix between samples in
+ :math:`\mathbf{X_s}` and :math:`\mathbf{X_t}` (scaled by :math:`n_s`)
+ - :math:`L` is a :math:`n_s \times d` linear operator on a kernel matrix that
approximates the barycentric mapping
- - a and b are uniform source and target weights
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are uniform source and target weights
The problem consist in solving jointly an optimal transport matrix
:math:`\gamma` and the nonlinear mapping that fits the barycentric mapping
- :math:`n_s\gamma X_t`.
+ :math:`n_s\gamma \mathbf{X_t}`.
One can also estimate a mapping with constant bias (see supplementary
- material of [8]) using the bias optional argument.
+ material of :ref:`[8] <references-joint-OT-mapping-kernel>`) using the bias optional argument.
The algorithm used for solving the problem is the block coordinate
- descent that alternates between updates of G (using conditionnal gradient)
- and the update of L using a classical kernel least square solver.
+ descent that alternates between updates of :math:`\mathbf{G}` (using conditionnal gradient)
+ and the update of :math:`\mathbf{L}` using a classical kernel least square solver.
Parameters
@@ -478,7 +490,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
eta : float, optional
Regularization term for the linear mapping L (>0)
kerneltype : str,optional
- kernel used by calling function ot.utils.kernel (gaussian by default)
+ kernel used by calling function :py:func:`ot.utils.kernel` (gaussian by default)
sigma : float, optional
Gaussian kernel bandwidth.
bias : bool,optional
@@ -501,17 +513,17 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
- L : (ns x d) ndarray
- Nonlinear mapping matrix (ns+1 x d if bias)
+ L : (ns, d) ndarray
+ Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias)
log : dict
log dictionary return only if log==True in parameters
+ .. _references-joint-OT-mapping-kernel:
References
----------
-
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
"Mapping estimation for discrete optimal transport",
Neural Information Processing Systems (NIPS), 2016.
@@ -645,26 +657,27 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
wt=None, bias=True, log=False):
- """ return OT linear operator between samples
+ r"""Return OT linear operator between samples.
The function estimates the optimal linear operator that aligns the two
empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)`
- and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark
- 2.29 in [15].
+ form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
+ and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
+ :ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in
+ :ref:`[15] <references-OT-mapping-linear>`.
The linear operator from source to target :math:`M`
.. math::
- M(x)=Ax+b
+ M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
where :
.. math::
- A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
\Sigma_s^{-1/2}
- .. math::
- b=\mu_t-A\mu_s
+
+ \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
Parameters
----------
@@ -673,35 +686,35 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
xt : np.ndarray (nt,d)
samples in the target domain
reg : float,optional
- regularization added to the diagonals of convariances (>0)
+ regularization added to the diagonals of covariances (>0)
ws : np.ndarray (ns,1), optional
weights for the source samples
wt : np.ndarray (ns,1), optional
weights for the target samples
bias: boolean, optional
- estimate bias b else b=0 (default:True)
+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
log : bool, optional
record log if True
Returns
-------
- A : (d x d) ndarray
+ A : (d, d) ndarray
Linear operator
- b : (1 x d) ndarray
+ b : (1, d) ndarray
bias
log : dict
log dictionary return only if log==True in parameters
+ .. _references-OT-mapping-linear:
References
----------
-
.. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
distributions", Journal of Optimization Theory and Applications
Vol 43, 1984
- .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
Transport", 2018.
@@ -754,24 +767,34 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
r"""Solve the optimal transport problem (OT) with Laplacian regularization
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + eta\Omega_\alpha(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \eta \cdot \Omega_\alpha(\gamma)
- s.t.\ \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
- \gamma\geq 0
+ \gamma \geq 0
where:
- - a and b are source and target weights (sum to 1)
- - xs and xt are source and target samples
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+ - :math:`\mathbf{x_s}` and :math:`\mathbf{x_t}` are source and target samples
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega_\alpha` is the Laplacian regularization term
- :math:`\Omega_\alpha = (1-\alpha)/n_s^2\sum_{i,j}S^s_{i,j}\|T(\mathbf{x}^s_i)-T(\mathbf{x}^s_j)\|^2+\alpha/n_t^2\sum_{i,j}S^t_{i,j}^'\|T(\mathbf{x}^t_i)-T(\mathbf{x}^t_j)\|^2`
- with :math:`S^s_{i,j}, S^t_{i,j}` denoting source and target similarity matrices and :math:`T(\cdot)` being a barycentric mapping
- The algorithm used for solving the problem is the conditional gradient algorithm as proposed in [5].
+ .. math::
+ \Omega_\alpha = \frac{1 - \alpha}{n_s^2} \sum_{i,j}
+ \mathbf{S^s}_{i,j} \|T(\mathbf{x}^s_i) - T(\mathbf{x}^s_j) \|^2 +
+ \frac{\alpha}{n_t^2} \sum_{i,j}
+ \mathbf{S^t}_{i,j} \|T(\mathbf{x}^t_i) - T(\mathbf{x}^t_j) \|^2
+
+
+ with :math:`\mathbf{S^s}_{i,j}, \mathbf{S^t}_{i,j}` denoting source and target similarity
+ matrices and :math:`T(\cdot)` being a barycentric mapping.
+
+ The algorithm used for solving the problem is the conditional gradient algorithm as proposed in
+ :ref:`[5] <references-emd-laplace>`.
Parameters
----------
@@ -811,22 +834,23 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-emd-laplace:
References
----------
-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE
- Transactions on Pattern Analysis and Machine Intelligence ,
+ Transactions on Pattern Analysis and Machine Intelligence,
vol.PP, no.99, pp.1-1
+
.. [30] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
"Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching,"
- in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+ in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
See Also
--------
@@ -882,7 +906,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
def distribution_estimation_uniform(X):
- """estimates a uniform distribution from an array of samples X
+ """estimates a uniform distribution from an array of samples :math:`\mathbf{X}`
Parameters
----------
@@ -892,7 +916,7 @@ def distribution_estimation_uniform(X):
Returns
-------
mu : array-like, shape (n_samples,)
- The uniform distribution estimated from X
+ The uniform distribution estimated from :math:`\mathbf{X}`
"""
return unif(X.shape[0])
@@ -902,32 +926,32 @@ class BaseTransport(BaseEstimator):
"""Base class for OTDA objects
- Notes
- -----
- All estimators should specify all the parameters that can be set
- at the class level in their ``__init__`` as explicit keyword
- arguments (no ``*args`` or ``**kwargs``).
+ .. note::
+ All estimators should specify all the parameters that can be set
+ at the class level in their ``__init__`` as explicit keyword
+ arguments (no ``*args`` or ``**kwargs``).
- the fit method should:
+ The fit method should:
- estimate a cost matrix and store it in a `cost_` attribute
- - estimate a coupling matrix and store it in a `coupling_`
- attribute
+ - estimate a coupling matrix and store it in a `coupling_` attribute
- estimate distributions from source and target data and store them in
- mu_s and mu_t attributes
- - store Xs and Xt in attributes to be used later on in transform and
- inverse_transform methods
+ `mu_s` and `mu_t` attributes
+ - store `Xs` and `Xt` in attributes to be used later on in `transform` and
+ `inverse_transform` methods
+
+ `transform` method should always get as input a `Xs` parameter
+
+ `inverse_transform` method should always get as input a `Xt` parameter
- transform method should always get as input a Xs parameter
- inverse_transform method should always get as input a Xt parameter
+ `transform_labels` method should always get as input a `ys` parameter
- transform_labels method should always get as input a ys parameter
- inverse_transform_labels method should always get as input a yt parameter
+ `inverse_transform_labels` method should always get as input a `yt` parameter
"""
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -938,8 +962,8 @@ class BaseTransport(BaseEstimator):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -987,8 +1011,8 @@ class BaseTransport(BaseEstimator):
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt) and transports source samples Xs onto target
- ones Xt
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
+ and transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -999,8 +1023,8 @@ class BaseTransport(BaseEstimator):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1014,7 +1038,7 @@ class BaseTransport(BaseEstimator):
return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt)
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples Xs onto target ones Xt
+ """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -1025,8 +1049,8 @@ class BaseTransport(BaseEstimator):
Xt : array-like, shape (n_target_samples, n_features)
The target input samples.
yt : array-like, shape (n_target_samples,)
- The class labels for target. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels for target. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1081,7 +1105,8 @@ class BaseTransport(BaseEstimator):
return transp_Xs
def transform_labels(self, ys=None):
- """Propagate source labels ys to obtain estimated target labels as in [27]
+ """Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in
+ :ref:`[27] <references-basetransport-transform-labels>`.
Parameters
----------
@@ -1093,9 +1118,10 @@ class BaseTransport(BaseEstimator):
transp_ys : array-like, shape (n_target_samples, nb_classes)
Estimated soft target labels.
+
+ .. _references-basetransport-transform-labels:
References
----------
-
.. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
"Optimal transport for multi-source domain adaptation under target shift",
International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
@@ -1111,7 +1137,7 @@ class BaseTransport(BaseEstimator):
D1 = np.zeros((n, len(ysTemp)))
# perform label propagation
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True)
# set nans to 0
transp[~ np.isfinite(transp)] = 0
@@ -1126,7 +1152,7 @@ class BaseTransport(BaseEstimator):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples Xt onto source samples Xs
+ """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
Parameters
----------
@@ -1137,8 +1163,8 @@ class BaseTransport(BaseEstimator):
Xt : array-like, shape (n_target_samples, n_features)
The target input samples.
yt : array-like, shape (n_target_samples,)
- The target class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The target class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1192,7 +1218,8 @@ class BaseTransport(BaseEstimator):
return transp_Xt
def inverse_transform_labels(self, yt=None):
- """Propagate target labels yt to obtain estimated source labels ys
+ """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
+ :math:`\mathbf{y_s}`
Parameters
----------
@@ -1228,39 +1255,41 @@ class BaseTransport(BaseEstimator):
class LinearTransport(BaseTransport):
- """ OT linear operator between empirical distributions
+ r""" OT linear operator between empirical distributions
The function estimates the optimal linear operator that aligns the two
empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)`
- and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in
- remark 2.29 in [15].
+ form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
+ and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
+ :ref:`[14] <references-lineartransport>` and discussed in remark 2.29 in
+ :ref:`[15] <references-lineartransport>`.
The linear operator from source to target :math:`M`
.. math::
- M(x)=Ax+b
+ M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
where :
.. math::
- A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2}
+ \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
\Sigma_s^{-1/2}
- .. math::
- b=\mu_t-A\mu_s
+
+ \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
Parameters
----------
reg : float,optional
- regularization added to the daigonals of convariances (>0)
+ regularization added to the daigonals of covariances (>0)
bias: boolean, optional
- estimate bias b else b=0 (default:True)
+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
log : bool, optional
record log if True
+
+ .. _references-lineartransport:
References
----------
-
.. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
distributions", Journal of Optimization Theory and Applications
Vol 43, 1984
@@ -1279,7 +1308,7 @@ class LinearTransport(BaseTransport):
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1290,8 +1319,8 @@ class LinearTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1325,7 +1354,7 @@ class LinearTransport(BaseTransport):
return self
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples Xs onto target ones Xt
+ """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -1336,8 +1365,8 @@ class LinearTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1358,7 +1387,7 @@ class LinearTransport(BaseTransport):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples Xt onto target samples Xs
+ """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
Parameters
----------
@@ -1369,8 +1398,8 @@ class LinearTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1392,7 +1421,7 @@ class LinearTransport(BaseTransport):
class SinkhornTransport(BaseTransport):
- """Domain Adapatation OT method based on Sinkhorn Algorithm
+ """Domain Adaptation OT method based on Sinkhorn Algorithm
Parameters
----------
@@ -1400,7 +1429,7 @@ class SinkhornTransport(BaseTransport):
Entropic regularization parameter
max_iter : int, float, optional (default=1000)
The minimum number of iteration before stopping the optimization
- algorithm if no it has not converged
+ algorithm if it has not converged
tol : float, optional (default=10e-9)
The precision required to stop the optimization algorithm.
verbose : bool, optional (default=False)
@@ -1417,8 +1446,8 @@ class SinkhornTransport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
- limit_max: float, optional (defaul=np.infty)
+ "ferradans" which uses the method proposed in :ref:`[6] <references-sinkhorntransport>`.
+ limit_max: float, optional (default=np.infty)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an cost defined
by this variable
@@ -1428,16 +1457,20 @@ class SinkhornTransport(BaseTransport):
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
log_ : dictionary
- The dictionary of log, empty dic if parameter log is not True
+ The dictionary of log, empty dict if parameter log is not True
+
+ .. _references-sinkhorntransport:
References
----------
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE Transactions
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems (NIPS)
26, 2013
+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport. SIAM Journal on Imaging
Sciences, 7(3), 1853-1882.
@@ -1461,7 +1494,7 @@ class SinkhornTransport(BaseTransport):
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1472,8 +1505,8 @@ class SinkhornTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1504,7 +1537,7 @@ class SinkhornTransport(BaseTransport):
class EMDTransport(BaseTransport):
- """Domain Adapatation OT method based on Earth Mover's Distance
+ """Domain Adaptation OT method based on Earth Mover's Distance
Parameters
----------
@@ -1520,7 +1553,7 @@ class EMDTransport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
+ "ferradans" which uses the method proposed in :ref:`[6] <references-emdtransport>`.
limit_max: float, optional (default=10)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an infinite cost
@@ -1534,14 +1567,16 @@ class EMDTransport(BaseTransport):
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
+
+ .. _references-emdtransport:
References
----------
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
- "Optimal Transport for Domain Adaptation," in IEEE Transactions
- on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ "Optimal Transport for Domain Adaptation," in IEEE Transactions
+ on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
- Regularized discrete optimal transport. SIAM Journal on Imaging
- Sciences, 7(3), 1853-1882.
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, metric="sqeuclidean", norm=None, log=False,
@@ -1558,7 +1593,7 @@ class EMDTransport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1569,8 +1604,8 @@ class EMDTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1597,8 +1632,7 @@ class EMDTransport(BaseTransport):
class SinkhornLpl1Transport(BaseTransport):
-
- """Domain Adapatation OT method based on sinkhorn algorithm +
+ r"""Domain Adaptation OT method based on sinkhorn algorithm +
LpL1 class regularization.
Parameters
@@ -1609,7 +1643,7 @@ class SinkhornLpl1Transport(BaseTransport):
Class regularization parameter
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
- algorithm if no it has not converged
+ algorithm if it has not converged
max_inner_iter : int, float, optional (default=200)
The number of iteration in the inner loop
log : bool, optional (default=False)
@@ -1628,8 +1662,8 @@ class SinkhornLpl1Transport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
- limit_max: float, optional (defaul=np.infty)
+ "ferradans" which uses the method proposed in :ref:`[6] <references-sinkhornlpl1transport>`.
+ limit_max: float, optional (default=np.infty)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit a cost defined by
limit_max.
@@ -1639,16 +1673,19 @@ class SinkhornLpl1Transport(BaseTransport):
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
+
+ .. _references-sinkhornlpl1transport:
References
----------
-
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE
Transactions on Pattern Analysis and Machine Intelligence ,
vol.PP, no.99, pp.1-1
+
.. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence
and applications. arXiv preprint arXiv:1510.06567.
+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport. SIAM Journal on Imaging
Sciences, 7(3), 1853-1882.
@@ -1675,7 +1712,7 @@ class SinkhornLpl1Transport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1686,8 +1723,8 @@ class SinkhornLpl1Transport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1719,13 +1756,14 @@ class SinkhornLpl1Transport(BaseTransport):
class EMDLaplaceTransport(BaseTransport):
- """Domain Adapatation OT method based on Earth Mover's Distance with Laplacian regularization
+ """Domain Adaptation OT method based on Earth Mover's Distance with Laplacian regularization
Parameters
----------
reg_type : string optional (default='pos')
Type of the regularization term: 'pos' and 'disp' for
- regularization term defined in [2] and [6], respectively.
+ regularization term defined in :ref:`[2] <references-emdlaplacetransport>` and
+ :ref:`[6] <references-emdlaplacetransport>`, respectively.
reg_lap : float, optional (default=1)
Laplacian regularization parameter
reg_src : float, optional (default=0.5)
@@ -1756,24 +1794,27 @@ class EMDLaplaceTransport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
+ "ferradans" which uses the method proposed in :ref:`[6] <references-emdlaplacetransport>`.
Attributes
----------
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
+
+ .. _references-emdlaplacetransport:
References
----------
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE Transactions
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
.. [2] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
"Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching,"
- in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+ in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
- Regularized discrete optimal transport. SIAM Journal on Imaging
- Sciences, 7(3), 1853-1882.
+ Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean",
@@ -1799,7 +1840,7 @@ class EMDLaplaceTransport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1810,8 +1851,8 @@ class EMDLaplaceTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1840,8 +1881,8 @@ class EMDLaplaceTransport(BaseTransport):
class SinkhornL1l2Transport(BaseTransport):
- """Domain Adapatation OT method based on sinkhorn algorithm +
- l1l2 class regularization.
+ """Domain Adaptation OT method based on sinkhorn algorithm +
+ L1L2 class regularization.
Parameters
----------
@@ -1851,7 +1892,7 @@ class SinkhornL1l2Transport(BaseTransport):
Class regularization parameter
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
- algorithm if no it has not converged
+ algorithm if it has not converged
max_inner_iter : int, float, optional (default=200)
The number of iteration in the inner loop
tol : float, optional (default=10e-9)
@@ -1870,7 +1911,7 @@ class SinkhornL1l2Transport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
+ "ferradans" which uses the method proposed in :ref:`[6] <references-sinkhornl1l2transport>`.
limit_max: float, optional (default=10)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an infinite cost
@@ -1881,18 +1922,21 @@ class SinkhornL1l2Transport(BaseTransport):
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
log_ : dictionary
- The dictionary of log, empty dic if parameter log is not True
+ The dictionary of log, empty dict if parameter log is not True
+
+ .. _references-sinkhornl1l2transport:
References
----------
-
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE
Transactions on Pattern Analysis and Machine Intelligence ,
vol.PP, no.99, pp.1-1
+
.. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence
and applications. arXiv preprint arXiv:1510.06567.
+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport. SIAM Journal on Imaging
Sciences, 7(3), 1853-1882.
@@ -1919,7 +1963,7 @@ class SinkhornL1l2Transport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -1930,8 +1974,8 @@ class SinkhornL1l2Transport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -1973,7 +2017,7 @@ class MappingTransport(BaseEstimator):
mu : float, optional (default=1)
Weight for the linear OT loss (>0)
eta : float, optional (default=0.001)
- Regularization term for the linear mapping L (>0)
+ Regularization term for the linear mapping `L` (>0)
bias : bool, optional (default=False)
Estimate linear mapping with constant bias
metric : string, optional (default="sqeuclidean")
@@ -2004,17 +2048,20 @@ class MappingTransport(BaseEstimator):
----------
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
- mapping_ : array-like, shape (n_features (+ 1), n_features)
- (if bias) for kernel == linear
+ mapping_ :
The associated mapping
- array-like, shape (n_source_samples (+ 1), n_features)
- (if bias) for kernel == gaussian
+
+ - array-like, shape (`n_features` (+ 1), `n_features`),
+ (if bias) for kernel == linear
+
+ - array-like, shape (`n_source_samples` (+ 1), `n_features`),
+ (if bias) for kernel == gaussian
log_ : dictionary
- The dictionary of log, empty dic if parameter log is not True
+ The dictionary of log, empty dict if parameter log is not True
+
References
----------
-
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
"Mapping estimation for discrete optimal transport",
Neural Information Processing Systems (NIPS), 2016.
@@ -2042,7 +2089,8 @@ class MappingTransport(BaseEstimator):
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
"""Builds an optimal coupling and estimates the associated mapping
- from source and target sets of samples (Xs, ys) and (Xt, yt)
+ from source and target sets of samples
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -2053,8 +2101,8 @@ class MappingTransport(BaseEstimator):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -2098,7 +2146,7 @@ class MappingTransport(BaseEstimator):
return self
def transform(self, Xs):
- """Transports source samples Xs onto target ones Xt
+ """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -2138,7 +2186,7 @@ class MappingTransport(BaseEstimator):
class UnbalancedSinkhornTransport(BaseTransport):
- """Domain Adapatation unbalanced OT method based on sinkhorn algorithm
+ """Domain Adaptation unbalanced OT method based on sinkhorn algorithm
Parameters
----------
@@ -2151,7 +2199,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
'sinkhorn_epsilon_scaling', see those function for specific parameters
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
- algorithm if no it has not converged
+ algorithm if it has not converged
tol : float, optional (default=10e-9)
Stop threshold on error (inner sinkhorn solver) (>0)
verbose : bool, optional (default=False)
@@ -2168,7 +2216,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
+ "ferradans" which uses the method proposed in :ref:`[6] <references-unbalancedsinkhorntransport>`.
limit_max: float, optional (default=10)
Controls the semi supervised mode. Transport between labeled source
and target samples of different classes will exhibit an infinite cost
@@ -2179,14 +2227,16 @@ class UnbalancedSinkhornTransport(BaseTransport):
coupling_ : array-like, shape (n_source_samples, n_target_samples)
The optimal coupling
log_ : dictionary
- The dictionary of log, empty dic if parameter log is not True
+ The dictionary of log, empty dict if parameter log is not True
+
+ .. _references-unbalancedsinkhorntransport:
References
----------
-
.. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
- Scaling algorithms for unbalanced transport problems. arXiv preprint
- arXiv:1607.05816.
+ Scaling algorithms for unbalanced transport problems. arXiv preprint
+ arXiv:1607.05816.
+
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
Regularized discrete optimal transport. SIAM Journal on Imaging
Sciences, 7(3), 1853-1882.
@@ -2212,7 +2262,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Build a coupling matrix from source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -2223,8 +2273,8 @@ class UnbalancedSinkhornTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -2258,7 +2308,7 @@ class UnbalancedSinkhornTransport(BaseTransport):
class JCPOTTransport(BaseTransport):
- """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm.
+ """Domain Adaptation OT method for multi-source target shift based on Wasserstein barycenter algorithm.
Parameters
----------
@@ -2266,7 +2316,7 @@ class JCPOTTransport(BaseTransport):
Entropic regularization parameter
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
- algorithm if no it has not converged
+ algorithm if it has not converged
tol : float, optional (default=10e-9)
Stop threshold on error (inner sinkhorn solver) (>0)
verbose : bool, optional (default=False)
@@ -2283,7 +2333,7 @@ class JCPOTTransport(BaseTransport):
out_of_sample_map : string, optional (default="ferradans")
The kind of out of sample mapping to apply to transport samples
from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
+ "ferradans" which uses the method proposed in :ref:`[6] <references-jcpottransport>`.
Attributes
----------
@@ -2292,11 +2342,12 @@ class JCPOTTransport(BaseTransport):
proportions_ : array-like, shape (n_classes,)
Estimated class proportions in the target domain
log_ : dictionary
- The dictionary of log, empty dic if parameter log is not True
+ The dictionary of log, empty dict if parameter log is not True
+
+ .. _references-jcpottransport:
References
----------
-
.. [1] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
"Optimal transport for multi-source domain adaptation under target shift",
International Conference on Artificial Intelligence and Statistics (AISTATS),
@@ -2323,7 +2374,7 @@ class JCPOTTransport(BaseTransport):
def fit(self, Xs, ys=None, Xt=None, yt=None):
"""Building coupling matrices from a list of source and target sets of samples
- (Xs, ys) and (Xt, yt)
+ :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})`
Parameters
----------
@@ -2334,8 +2385,8 @@ class JCPOTTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -2368,7 +2419,7 @@ class JCPOTTransport(BaseTransport):
return self
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples Xs onto target ones Xt
+ """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
Parameters
----------
@@ -2379,8 +2430,8 @@ class JCPOTTransport(BaseTransport):
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
+ The class labels. If some target samples are unlabelled, fill the
+ :math:`\mathbf{y_t}`'s elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
class label
@@ -2440,7 +2491,8 @@ class JCPOTTransport(BaseTransport):
return transp_Xs
def transform_labels(self, ys=None):
- """Propagate source labels ys to obtain target labels as in [27]
+ """Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in
+ :ref:`[27] <references-jcpottransport-transform-labels>`
Parameters
----------
@@ -2451,6 +2503,14 @@ class JCPOTTransport(BaseTransport):
-------
yt : array-like, shape (n_target_samples, nb_classes)
Estimated soft target labels.
+
+
+ .. _references-jcpottransport-transform-labels:
+ References
+ ----------
+ .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
"""
# check the necessary inputs parameters are here
@@ -2482,11 +2542,12 @@ class JCPOTTransport(BaseTransport):
return yt.T
def inverse_transform_labels(self, yt=None):
- """Propagate source labels ys to obtain target labels
+ """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels
+ :math:`\mathbf{y_s}`
Parameters
----------
- yt : array-like, shape (n_source_samples,)
+ yt : array-like, shape (n_target_samples,)
The target class labels
Returns
diff --git a/ot/datasets.py b/ot/datasets.py
index b86ef3b..ad6390c 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -13,7 +13,7 @@ from .utils import check_random_state, deprecated
def make_1D_gauss(n, m, s):
- """return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
+ """return a 1D histogram for a gaussian distribution (`n` bins, mean `m` and std `s`)
Parameters
----------
@@ -26,7 +26,7 @@ def make_1D_gauss(n, m, s):
Returns
-------
- h : ndarray (n,)
+ h : ndarray (`n`,)
1D histogram for a gaussian distribution
"""
x = np.arange(n, dtype=np.float64)
@@ -41,7 +41,7 @@ def get_1D_gauss(n, m, sigma):
def make_2D_samples_gauss(n, m, sigma, random_state=None):
- """Return n samples drawn from 2D gaussian N(m,sigma)
+ """Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)`
Parameters
----------
@@ -59,8 +59,8 @@ def make_2D_samples_gauss(n, m, sigma, random_state=None):
Returns
-------
- X : ndarray, shape (n, 2)
- n samples drawn from N(m, sigma).
+ X : ndarray, shape (`n`, 2)
+ n samples drawn from :math:`\mathcal{N}(m, \sigma)`.
"""
generator = check_random_state(random_state)
@@ -102,7 +102,7 @@ def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwa
Returns
-------
X : ndarray, shape (n, d)
- n observation of size d
+ `n` observation of size `d`
y : ndarray, shape (n,)
labels of the samples.
"""
diff --git a/ot/dr.py b/ot/dr.py
index 11d2e10..c2f51f8 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -10,6 +10,7 @@ Dimension reduction with OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Minhui Huang <mhhuang@ucdavis.edu>
#
# License: MIT License
@@ -21,7 +22,7 @@ from pymanopt.solvers import SteepestDescent, TrustRegions
def dist(x1, x2):
- """ Compute squared euclidean distance between samples (autograd)
+ r""" Compute squared euclidean distance between samples (autograd)
"""
x1p2 = np.sum(np.square(x1), 1)
x2p2 = np.sum(np.square(x2), 1)
@@ -29,7 +30,7 @@ def dist(x1, x2):
def sinkhorn(w1, w2, M, reg, k):
- """Sinkhorn algorithm with fixed number of iteration (autograd)
+ r"""Sinkhorn algorithm with fixed number of iteration (autograd)
"""
K = np.exp(-M / reg)
ui = np.ones((M.shape[0],))
@@ -42,14 +43,14 @@ def sinkhorn(w1, w2, M, reg, k):
def split_classes(X, y):
- """split samples in X by classes in y
+ r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`
"""
lstsclass = np.unique(y)
return [X[y == i, :].astype(np.float32) for i in lstsclass]
def fda(X, y, p=2, reg=1e-16):
- """Fisher Discriminant Analysis
+ r"""Fisher Discriminant Analysis
Parameters
----------
@@ -108,20 +109,21 @@ def fda(X, y, p=2, reg=1e-16):
return Popt, proj
-def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
- """
- Wasserstein Discriminant Analysis [11]_
+def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False):
+ r"""
+ Wasserstein Discriminant Analysis :ref:`[11] <references-wda>`
The function solves the following optimization problem:
.. math::
- P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
+ \mathbf{P} = \mathop{\arg \min}_\mathbf{P} \quad
+ \frac{\sum\limits_i W(P \mathbf{X}^i, P \mathbf{X}^i)}{\sum\limits_{i, j \neq i} W(P \mathbf{X}^i, P \mathbf{X}^j)}
where :
- - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
+ - :math:`P` is a linear projection operator in the Stiefel(`p`, `d`) manifold
- :math:`W` is entropic regularized Wasserstein distances
- - :math:`X^i` are samples in the dataset corresponding to class i
+ - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i
Parameters
----------
@@ -138,6 +140,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
else should be a pymanopt.solvers
P0 : ndarray, shape (d, p)
Initial starting point for projection.
+ normalize : bool, optional
+ Normalise the Wasserstaiun distance by the average distance on P0 (default : False)
verbose : int, optional
Print information along iterations.
@@ -148,6 +152,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
proj : callable
Projection function including mean centering.
+
+ .. _references-wda:
References
----------
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
@@ -163,6 +169,18 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
# compute uniform weighs
wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc]
+ # pre-compute reg_c,c'
+ if P0 is not None and normalize:
+ regmean = np.zeros((len(xc), len(xc)))
+ for i, xi in enumerate(xc):
+ xi = np.dot(xi, P0)
+ for j, xj in enumerate(xc[i:]):
+ xj = np.dot(xj, P0)
+ M = dist(xi, xj)
+ regmean[i, j] = np.sum(M) / (len(xi) * len(xj))
+ else:
+ regmean = np.ones((len(xc), len(xc)))
+
def cost(P):
# wda loss
loss_b = 0
@@ -173,7 +191,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
for j, xj in enumerate(xc[i:]):
xj = np.dot(xj, P)
M = dist(xi, xj)
- G = sinkhorn(wc[i], wc[j + i], M, reg, k)
+ G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k)
if j == 0:
loss_w += np.sum(G * M)
else:
@@ -198,3 +216,119 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
return (X - mx.reshape((1, -1))).dot(Popt)
return Popt, proj
+
+
+def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
+ r"""
+ Projection Robust Wasserstein Distance :ref:`[32] <references-projection-robust-wasserstein>`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \max_{U \in St(d, k)} \ \min_{\pi \in \Pi(\mu,\nu)} \quad \sum_{i,j} \pi_{i,j}
+ \|U^T(\mathbf{x}_i - \mathbf{y}_j)\|^2 - \mathrm{reg} \cdot H(\pi)
+
+ - :math:`U` is a linear projection operator in the Stiefel(`d`, `k`) manifold
+ - :math:`H(\pi)` is entropy regularizer
+ - :math:`\mathbf{x}_i`, :math:`\mathbf{y}_j` are samples of measures :math:`\mu` and :math:`\nu` respectively
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, d)
+ Samples from measure :math:`\mu`
+ Y : ndarray, shape (n, d)
+ Samples from measure :math:`\nu`
+ a : ndarray, shape (n, )
+ weights for measure :math:`\mu`
+ b : ndarray, shape (n, )
+ weights for measure :math:`\nu`
+ tau : float
+ stepsize for Riemannian Gradient Descent
+ U0 : ndarray, shape (d, p)
+ Initial starting point for projection.
+ reg : float, optional
+ Regularization term >0 (entropic regularization)
+ k : int
+ Subspace dimension
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : int, optional
+ Print information along iterations.
+
+ Returns
+ -------
+ pi : ndarray, shape (n, n)
+ Optimal transportation matrix for the given parameters
+ U : ndarray, shape (d, k)
+ Projection operator.
+
+
+ .. _references-projection-robust-wasserstein:
+ References
+ ----------
+ .. [32] Huang, M. , Ma S. & Lai L. (2021).
+ A Riemannian Block Coordinate Descent Method for Computing
+ the Projection Robust Wasserstein Distance, ICML.
+ """ # noqa
+
+ # initialization
+ n, d = X.shape
+ m, d = Y.shape
+ a = np.asarray(a, dtype=np.float64)
+ b = np.asarray(b, dtype=np.float64)
+ u = np.ones(n) / n
+ v = np.ones(m) / m
+ ones = np.ones((n, m))
+
+ assert d > k
+
+ if U0 is None:
+ U = np.random.randn(d, k)
+ U, _ = np.linalg.qr(U)
+ else:
+ U = U0
+
+ def Vpi(X, Y, a, b, pi):
+ # Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }.
+ A = X.T.dot(pi).dot(Y)
+ return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T
+
+ err = 1
+ iter = 0
+
+ while err > stopThr and iter < maxiter:
+
+ # Projected cost matrix
+ UUT = U.dot(U.T)
+ M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot(
+ np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T))
+
+ A = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=A)
+ np.exp(A, out=A)
+
+ # Sinkhorn update
+ Ap = (1 / a).reshape(-1, 1) * A
+ AtransposeU = np.dot(A.T, u)
+ v = np.divide(b, AtransposeU)
+ u = 1. / np.dot(Ap, v)
+ pi = u.reshape((-1, 1)) * A * v.reshape((1, -1))
+
+ V = Vpi(X, Y, a, b, pi)
+
+ # Riemannian gradient descent
+ G = 2 / reg * V.dot(U)
+ GTU = G.T.dot(U)
+ xi = G - U.dot(GTU + GTU.T) / 2 # Riemannian gradient
+ U, _ = np.linalg.qr(U + tau * xi) # Retraction by QR decomposition
+
+ grad_norm = np.linalg.norm(xi)
+ err = max(reg * grad_norm, np.linalg.norm(np.sum(pi, 0) - b, 1))
+
+ f_val = np.trace(U.T.dot(V.dot(U)))
+ if verbose:
+ print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val)
+
+ iter = iter + 1
+
+ return pi, U
diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py
index 7478fb9..12db605 100644
--- a/ot/gpu/__init__.py
+++ b/ot/gpu/__init__.py
@@ -7,7 +7,13 @@ The GPU backend in handled by `cupy
<https://cupy.chainer.org/>`_.
.. warning::
- Note that by default the module is not import in :mod:`ot`. In order to
+ This module is now deprecated and will be removed in future releases. POT
+ now privides a backend mechanism that allows for solving prolem on GPU wth
+ the pytorch backend.
+
+
+.. warning::
+ Note that by default the module is not imported in :mod:`ot`. In order to
use it you need to explicitely import :mod:`ot.gpu` .
By default, the functions in this module accept and return numpy arrays
@@ -25,6 +31,8 @@ result of the function with parameter ``to_numpy=False``.
#
# License: MIT License
+import warnings
+
from . import bregman
from . import da
from .bregman import sinkhorn
@@ -34,7 +42,7 @@ from . import utils
from .utils import dist, to_gpu, to_np
-
+warnings.warn('This module is deprecated and will be removed in the next minor release of POT', category=DeprecationWarning)
__all__ = ["utils", "dist", "sinkhorn",
diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py
index 2e2df83..76af00e 100644
--- a/ot/gpu/bregman.py
+++ b/ot/gpu/bregman.py
@@ -15,7 +15,7 @@ from . import utils
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
verbose=False, log=False, to_numpy=True, **kwargs):
- """
+ r"""
Solve the entropic regularization optimal transport on GPU
If the input matrix are in numpy format, they will be uploaded to the
@@ -54,7 +54,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -148,13 +148,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
# we can speed up the process by checking for the error only all
# the 10th iterations
if nbb:
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ err = np.sqrt(
+ np.sum((u - uprev)**2) / np.sum((u)**2)
+ + np.sum((v - vprev)**2) / np.sum((v)**2)
+ )
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
tmp2 = np.sum(u[:, None] * K * v[None, :], 0)
#tmp2=np.einsum('i,ij,j->j', u, K, v)
- err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
+ err = np.linalg.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
index 4a98038..7adb830 100644
--- a/ot/gpu/da.py
+++ b/ot/gpu/da.py
@@ -120,7 +120,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
labels_a2 = cp.asnumpy(labels_a)
classes = npp.unique(labels_a2)
for c in classes:
- idxc, = utils.to_gpu(npp.where(labels_a2 == c))
+ idxc = utils.to_gpu(*npp.where(labels_a2 == c))
indices_labels.append(idxc)
W = np.zeros(M.shape)
diff --git a/ot/gromov.py b/ot/gromov.py
index 4427a96..ea667e4 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -14,63 +14,85 @@ import numpy as np
from .bregman import sinkhorn
-from .utils import dist, UndefinedParameter
+from .utils import dist, UndefinedParameter, list_to_array
from .optim import cg
+from .lp import emd_1d, emd
+from .utils import check_random_state
+from .backend import get_backend
def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
- """Return loss matrices and tensors for Gromov-Wasserstein fast computation
+ r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation
- Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
- function as the loss function of Gromow-Wasserstein discrepancy.
+ Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the
+ selected loss function as the loss function of Gromow-Wasserstein discrepancy.
- The matrices are computed as described in Proposition 1 in [12]
+ The matrices are computed as described in Proposition 1 in :ref:`[12] <references-init-matrix>`
Where :
- * C1 : Metric cost matrix in the source space
- * C2 : Metric cost matrix in the target space
- * T : A coupling between those two spaces
-
- The square-loss function L(a,b)=|a-b|^2 is read as :
- L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
- * f1(a)=(a^2)
- * f2(b)=(b^2)
- * h1(a)=a
- * h2(b)=2*b
-
- The kl-loss function L(a,b)=a*log(a/b)-a+b is read as :
- L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
- * f1(a)=a*log(a)-a
- * f2(b)=b
- * h1(a)=a
- * h2(b)=log(b)
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{T}`: A coupling between those two spaces
+
+ The square-loss function :math:`L(a, b) = |a - b|^2` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a^2
+
+ f_2(b) &= b^2
+
+ h_1(a) &= a
+
+ h_2(b) &= 2b
+
+ The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a \log(a) - a
+
+ f_2(b) &= b
+
+ h_1(a) &= a
+
+ h_2(b) &= \log(b)
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- T : ndarray, shape (ns, nt)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ T : array-like, shape (ns, nt)
Coupling between source and target spaces
- p : ndarray, shape (ns,)
+ p : array-like, shape (ns,)
Returns
-------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+
+ .. _references-init-matrix:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
if loss_fun == 'square_loss':
def f1(a):
@@ -86,7 +108,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
- return a * np.log(a + 1e-15) - a
+ return a * nx.log(a + 1e-15) - a
def f2(b):
return b
@@ -95,12 +117,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
return a
def h2(b):
- return np.log(b + 1e-15)
-
- constC1 = np.dot(np.dot(f1(C1), p.reshape(-1, 1)),
- np.ones(len(q)).reshape(1, -1))
- constC2 = np.dot(np.ones(len(p)).reshape(-1, 1),
- np.dot(q.reshape(1, -1), f2(C2).T))
+ return nx.log(b + 1e-15)
+
+ constC1 = nx.dot(
+ nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
+ nx.ones((1, len(q)), type_as=q)
+ )
+ constC2 = nx.dot(
+ nx.ones((len(p), 1), type_as=p),
+ nx.dot(nx.reshape(q, (1, -1)), f2(C2).T)
+ )
constC = constC1 + constC2
hC1 = h1(C1)
hC2 = h2(C2)
@@ -109,61 +135,70 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
def tensor_product(constC, hC1, hC2, T):
- """Return the tensor for Gromov-Wasserstein fast computation
+ r"""Return the tensor for Gromov-Wasserstein fast computation
- The tensor is computed as described in Proposition 1 Eq. (6) in [12].
+ The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-tensor-product>`
Parameters
----------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
Returns
-------
- tens : ndarray, shape (ns, nt)
- \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
+ tens : array-like, shape (`ns`, `nt`)
+ :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result
+
+ .. _references-tensor-product:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
- A = -np.dot(hC1, T).dot(hC2.T)
+ constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T)
+ nx = get_backend(constC, hC1, hC2, T)
+
+ A = - nx.dot(
+ nx.dot(hC1, T), hC2.T
+ )
tens = constC + A
# tens -= tens.min()
return tens
def gwloss(constC, hC1, hC2, T):
- """Return the Loss for Gromov-Wasserstein
+ r"""Return the Loss for Gromov-Wasserstein
- The loss is computed as described in Proposition 1 Eq. (6) in [12].
+ The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-gwloss>`
Parameters
----------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
- T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ T : array-like, shape (ns, nt)
+ Current value of transport matrix :math:`\mathbf{T}`
Returns
-------
loss : float
Gromov Wasserstein loss
+
+ .. _references-gwloss:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -171,33 +206,38 @@ def gwloss(constC, hC1, hC2, T):
tens = tensor_product(constC, hC1, hC2, T)
- return np.sum(tens * T)
+ tens, T = list_to_array(tens, T)
+ nx = get_backend(tens, T)
+
+ return nx.sum(tens * T)
def gwggrad(constC, hC1, hC2, T):
- """Return the gradient for Gromov-Wasserstein
+ r"""Return the gradient for Gromov-Wasserstein
- The gradient is computed as described in Proposition 2 in [12].
+ The gradient is computed as described in Proposition 2 in :ref:`[12] <references-gwggrad>`
Parameters
----------
- constC : ndarray, shape (ns, nt)
- Constant C matrix in Eq. (6)
- hC1 : ndarray, shape (ns, ns)
- h1(C1) matrix in Eq. (6)
- hC2 : ndarray, shape (nt, nt)
- h2(C) matrix in Eq. (6)
- T : ndarray, shape (ns, nt)
- Current value of transport matrix T
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ T : array-like, shape (ns, nt)
+ Current value of transport matrix :math:`\mathbf{T}`
Returns
-------
- grad : ndarray, shape (ns, nt)
+ grad : array-like, shape (`ns`, `nt`)
Gromov Wasserstein gradient
+
+ .. _references-gwggrad:
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -207,89 +247,109 @@ def gwggrad(constC, hC1, hC2, T):
def update_square_loss(p, lambdas, T, Cs):
- """
- Updates C according to the L2 Loss kernel with the S Ts couplings
- calculated at each iteration
+ r"""
+ Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
+ couplings calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
Masses in the targeted barycenter.
lambdas : list of float
- List of the S spaces' weights.
- T : list of S np.ndarray of shape (ns,N)
- The S Ts couplings calculated at each iteration.
- Cs : list of S ndarray, shape(ns,ns)
+ List of the `S` spaces' weights.
+ T : list of S array-like of shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape(ns,ns)
Metric cost matrices.
Returns
----------
- C : ndarray, shape (nt, nt)
- Updated C matrix.
+ C : array-like, shape (`nt`, `nt`)
+ Updated :math:`\mathbf{C}` matrix.
"""
- tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
- for s in range(len(T))])
- ppt = np.outer(p, p)
+ T = list_to_array(*T)
+ Cs = list_to_array(*Cs)
+ p = list_to_array(p)
+ nx = get_backend(p, *T, *Cs)
- return np.divide(tmpsum, ppt)
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+
+ return tmpsum / ppt
def update_kl_loss(p, lambdas, T, Cs):
- """
- Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration
+ r"""
+ Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
Weights in the targeted barycenter.
- lambdas : list of the S spaces' weights
- T : list of S np.ndarray of shape (ns,N)
- The S Ts couplings calculated at each iteration.
- Cs : list of S ndarray, shape(ns,ns)
+ lambdas : list of float
+ List of the `S` spaces' weights
+ T : list of S array-like of shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape(ns,ns)
Metric cost matrices.
Returns
----------
- C : ndarray, shape (ns,ns)
- updated C matrix
+ C : array-like, shape (`ns`, `ns`)
+ updated :math:`\mathbf{C}` matrix
"""
- tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s])
- for s in range(len(T))])
- ppt = np.outer(p, p)
+ Cs = list_to_array(*Cs)
+ T = list_to_array(*T)
+ p = list_to_array(p)
+ nx = get_backend(p, *T, *Cs)
- return np.exp(np.divide(tmpsum, ppt))
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+ return nx.exp(tmpsum / ppt)
-def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
- """
- Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
+
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- p : ndarray, shape (ns,)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
-
max_iter : int, optional
Max number of iterations
tol : float, optional
@@ -299,22 +359,23 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
log : bool, optional
record log if True
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver
Returns
-------
- T : ndarray, shape (ns, nt)
- Doupling between the two spaces that minimizes:
- \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ T : array-like, shape (`ns`, `nt`)
+ Coupling between the two spaces that minimizes:
+
+ :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
log : dict
Convergence information and loss.
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -323,6 +384,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
mathematics 11.4 (2011): 417-487.
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20 = p, q, C1, C2
+ nx = get_backend(p0, q0, C10, C20)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -336,37 +406,45 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
if log:
res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log['gw_dist'] = gwloss(constC, hC1, hC2, res)
- return res, log
+ log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10)
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
else:
- return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10)
-def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
- """
- Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs):
+ r"""
+ Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ GW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ marices and weights for quadratic loss using the gradients from [38]_.
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
+ C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
- p : ndarray, shape (ns,)
+ p : array-like, shape (ns,)
Distribution in the source space.
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
@@ -379,8 +457,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
log : bool, optional
record log if True
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
Returns
-------
@@ -391,7 +469,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -399,7 +477,20 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
metric approach to object matching. Foundations of computational
mathematics 11.4 (2011): 417-487.
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20 = p, q, C1, C2
+ nx = get_backend(p0, q0, C10, C20)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -410,53 +501,71 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res)
- log_gw['T'] = res
+
+ T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+
+ T0 = nx.from_numpy(T, type_as=C10)
+
+ log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10)
+ log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10)
+ log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10)
+ log_gw['T'] = T0
+
+ gw = log_gw['gw_dist']
+
+ if loss_fun == 'square_loss':
+ gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
+ gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ gw = nx.set_gradients(gw, (p0, q0, C10, C20),
+ (log_gw['u'], log_gw['v'], gC1, gC2))
+
if log:
- return log_gw['gw_dist'], log_gw
+ return gw, log_gw
else:
- return log_gw['gw_dist']
+ return gw
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
- """
- Computes the FGW transport between two graphs see [24]
+ r"""
+ Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
.. math::
- \gamma = arg\min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l}
- L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. \gamma 1 = p
- \gamma^T 1= q
- \gamma\geq 0
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
- - p and q are source and target weights (sum to 1)
- - L is a loss function to account for the misfit between the similarity matrices
- The algorithm used for solving the problem is conditional gradient as discussed in [24]_
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
Parameters
----------
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
Metric cost matrix between features across domains
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix representative of the structure in the source space
- C2 : ndarray, shape (nt, nt)
+ C2 : array-like, shape (nt, nt)
Metric cost matrix representative of the structure in the target space
- p : ndarray, shape (ns,)
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : str, optional
Loss function used for the solver
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
log : bool, optional
record log if True
**kwargs : dict
@@ -464,18 +573,30 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : array-like, shape (`ns`, `nt`)
Optimal transportation matrix for the given parameters.
log : dict
Log dictionary return only if log==True in parameters.
+
+ .. _references-fused-gromov-wasserstein:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas "Optimal Transport for structured data with
application on graphs", International Conference on Machine Learning
(ICML). 2019.
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ nx = get_backend(p0, q0, C10, C20, M0)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -489,69 +610,98 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
if log:
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
- log['fgw_dist'] = log['loss'][::-1][0]
- return res, log
+
+ fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
+
+ log['fgw_dist'] = fgw_dist
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
+
else:
- return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
+ return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10)
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
- """
- Computes the FGW distance between two graphs see [24]
+ r"""
+ Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
.. math::
- \min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l}
- L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
+ \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
- s.t. \gamma 1 = p
- \gamma^T 1= q
- \gamma\geq 0
+ \mathbf{\gamma} &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
- - p and q are source and target weights (sum to 1)
- - L is a loss function to account for the misfit between the similarity matrices
- The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as
+ discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+
+ Note that when using backends, this loss function is differentiable wrt the
+ marices and weights for quadratic loss using the gradients from [38]_.
Parameters
----------
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
Metric cost matrix between features across domains
- C1 : ndarray, shape (ns, ns)
- Metric cost matrix respresentative of the structure in the source space.
- C2 : ndarray, shape (nt, nt)
- Metric cost matrix espresentative of the structure in the target space.
- p : ndarray, shape (ns,)
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space.
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space.
+ p : array-like, shape (ns,)
Distribution in the source space.
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space.
loss_fun : str, optional
Loss function used for the solver.
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research.
- Else closed form is used. If there is convergence issues use False.
+ If True the step of the line-search is found via an armijo research.
+ Else closed form is used. If there are convergence issues use False.
log : bool, optional
Record log if True.
**kwargs : dict
- Parameters can be directly pased to the ot.optim.cg solver.
+ Parameters can be directly passed to the ot.optim.cg solver.
Returns
-------
- gamma : ndarray, shape (ns, nt)
- Optimal transportation matrix for the given parameters.
+ fgw-distance : float
+ Fused gromov wasserstein distance for the given parameters.
log : dict
Log dictionary return only if log==True in parameters.
+
+ .. _references-fused-gromov-wasserstein2:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
+
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
"""
+ p, q = list_to_array(p, q)
+
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ nx = get_backend(p0, q0, C10, C20, M0)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -563,50 +713,462 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
def df(G):
return gwggrad(constC, hC1, hC2, G)
- res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+ T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
+
+ fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10)
+
+ T0 = nx.from_numpy(T, type_as=C10)
+
+ log_fgw['fgw_dist'] = fgw_dist
+ log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10)
+ log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10)
+ log_fgw['T'] = T0
+
+ if loss_fun == 'square_loss':
+ gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
+ gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
+ fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
+ (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
+
if log:
- log['fgw_dist'] = log['loss'][::-1][0]
- log['T'] = res
- return log['fgw_dist'], log
+ return fgw_dist, log_fgw
else:
- return log['fgw_dist']
+ return fgw_dist
-def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False):
+def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
+ nb_samples_p=None, nb_samples_q=None, std=True, random_state=None):
+ r"""
+ Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ with a fixed transport plan :math:`\mathbf{T}`.
+
+ The function gives an unbiased approximation of the following equation:
+
+ .. math::
+
+ GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - `L` : Loss function to account for the misfit between the similarity matrices
+ - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}`
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ T : csr or array-like, shape (ns, nt)
+ Transport plan matrix, either a sparse csr or a dense matrix
+ nb_samples_p : int, optional
+ `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}`
+ nb_samples_q : int, optional
+ `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first
+ std : bool, optional
+ Standard deviation associated with the prediction of the gromov-wasserstein cost
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ : float
+ Gromov-wasserstein cost
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ generator = check_random_state(random_state)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ # It is always better to sample from the biggest distribution first.
+ if len_p < len_q:
+ p, q = q, p
+ len_p, len_q = len_q, len_p
+ C1, C2 = C2, C1
+ T = T.T
+
+ if nb_samples_p is None:
+ if nx.issparse(T):
+ # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced
+ nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p)
+ else:
+ nb_samples_p = len_p
+ else:
+ # The number of sample along the first dimension is without replacement.
+ nb_samples_p = min(nb_samples_p, len_p)
+ if nb_samples_q is None:
+ nb_samples_q = 1
+ if std:
+ nb_samples_q = max(2, nb_samples_q)
+
+ index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+ index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+
+ index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+ index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+
+ for i in range(nb_samples_p):
+ if nx.issparse(T):
+ T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,))
+ T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,))
+ else:
+ T_indexi = T[index_i[i], :]
+ T_indexj = T[index_j[i], :]
+ # For each of the row sampled, the column is sampled.
+ index_k[i] = generator.choice(
+ len_q,
+ size=nb_samples_q,
+ p=T_indexi / nx.sum(T_indexi),
+ replace=True
+ )
+ index_l[i] = generator.choice(
+ len_q,
+ size=nb_samples_q,
+ p=T_indexj / nx.sum(T_indexj),
+ replace=True
+ )
+
+ list_value_sample = nx.stack([
+ loss_fun(
+ C1[np.ix_(index_i, index_j)],
+ C2[np.ix_(index_k[:, n], index_l[:, n])]
+ ) for n in range(nb_samples_q)
+ ], axis=2)
+
+ if std:
+ std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5
+ return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p)
+ else:
+ return nx.mean(list_value_sample)
+
+
+def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe.
+ This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ alpha : float
+ Step of the Frank-Wolfe algorithm, should be between 0 and 1
+ max_iter : int, optional
+ Max number of iterations
+ threshold_plan : float, optional
+ Deleting very small values in the transport plan. If above zero, it violates the marginal constraints.
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
"""
- Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ generator = check_random_state(random_state)
+
+ index = np.zeros(2, dtype=int)
- (C1,p) and (C2,q)
+ # Initialize with default marginal
+ index[0] = generator.choice(len_p, size=1, p=p)
+ index[1] = generator.choice(len_q, size=1, p=q)
+ T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
+
+ best_gw_dist_estimated = np.inf
+ for cpt in range(max_iter):
+ index[0] = generator.choice(len_p, size=1, p=p)
+ T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
+ index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum())
+
+ if alpha == 1:
+ T = nx.tocsr(
+ emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
+ )
+ else:
+ new_T = nx.tocsr(
+ emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
+ )
+ T = (1 - alpha) * T + alpha * new_T
+ # To limit the number of non 0, the values below the threshold are set to 0.
+ T = nx.eliminate_zeros(T, threshold=threshold_plan)
+
+ if cpt % 10 == 0 or cpt == (max_iter - 1):
+ gw_dist_estimated = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, std=False, random_state=generator
+ )
+
+ if gw_dist_estimated < best_gw_dist_estimated:
+ best_gw_dist_estimated = gw_dist_estimated
+ best_T = nx.copy(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated))
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=best_T, random_state=generator
+ )
+ return best_T, log
+ return best_T
+
+
+def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False,
+ random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe.
+ This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver.
The function solves the following optimization problem:
.. math::
- GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
- s.t. T 1 = p
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
- T^T 1= q
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
- T\geq 0
+ \mathbf{T} &\geq 0
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
- - H : entropy
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- p : ndarray, shape (ns,)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ nb_samples_grad : int
+ Number of samples to approximate the gradient
+ epsilon : float
+ Weight of the Kullback-Leibler regularization
+ max_iter : int, optional
+ Max number of iterations
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ generator = check_random_state(random_state)
+
+ # The most natural way to define nb_sample is with a simple integer.
+ if isinstance(nb_samples_grad, int):
+ if nb_samples_grad > len_p:
+ # As the sampling along the first dimension is done without replacement, the rest is reported to the second
+ # dimension.
+ nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad
+ T = nx.outer(p, q)
+ # continue_loop allows to stop the loop if there is several successive small modification of T.
+ continue_loop = 0
+
+ # The gradient of GW is more complex if the two matrices are not symmetric.
+ C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
+
+ for cpt in range(max_iter):
+ index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False)
+ Lik = 0
+ for i, index0_i in enumerate(index0):
+ index1 = generator.choice(len_q,
+ size=nb_samples_grad_q,
+ p=T[index0_i, :] / nx.sum(T[index0_i, :]),
+ replace=False)
+ # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
+ if (not C_are_symmetric) and generator.rand(1) > 0.5:
+ Lik += nx.mean(loss_fun(
+ C1[:, [index0[i]] * nb_samples_grad_q][:, None, :],
+ C2[:, index1][None, :, :]
+ ), axis=2)
+ else:
+ Lik += nx.mean(loss_fun(
+ C1[[index0[i]] * nb_samples_grad_q, :][:, :, None],
+ C2[index1, :][:, None, :]
+ ), axis=0)
+
+ max_Lik = nx.max(Lik)
+ if max_Lik == 0:
+ continue
+ # This division by the max is here to facilitate the choice of epsilon.
+ Lik /= max_Lik
+
+ if epsilon > 0:
+ # Set to infinity all the numbers below exp(-200) to avoid log of 0.
+ log_T = nx.log(nx.clip(T, np.exp(-200), 1))
+ log_T = nx.where(log_T == -200, -np.inf, log_T)
+ Lik = Lik - epsilon * log_T
+
+ try:
+ new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon)
+ except (RuntimeWarning, UserWarning):
+ print("Warning catched in Sinkhorn: Return last stable T")
+ break
+ else:
+ new_T = emd(a=p, b=q, M=Lik)
+
+ change_T = nx.mean((T - new_T) ** 2)
+ if change_T <= 10e-20:
+ continue_loop += 1
+ if continue_loop > 100: # Number max of low modifications of T
+ T = nx.copy(new_T)
+ break
+ else:
+ continue_loop = 0
+
+ if verbose and cpt % 10 == 0:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, change_T))
+ T = nx.copy(new_T)
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, random_state=generator
+ )
+ return T, log
+ return T
+
+
+def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+ - `H`: entropy
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : string
Loss function used for the solver either 'square_loss' or 'kl_loss'
@@ -623,21 +1185,20 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
Returns
-------
- T : ndarray, shape (ns, nt)
+ T : array-like, shape (`ns`, `nt`)
Optimal coupling between the two spaces
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
- C1 = np.asarray(C1, dtype=np.float64)
- C2 = np.asarray(C2, dtype=np.float64)
-
- T = np.outer(p, q) # Initialization
+ T = nx.outer(p, q)
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
@@ -654,12 +1215,12 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
# compute the gradient
tens = gwggrad(constC, hC1, hC2, T)
- T = sinkhorn(p, q, tens, epsilon)
+ T = sinkhorn(p, q, tens, epsilon, method='sinkhorn')
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.linalg.norm(T - Tprev)
+ err = nx.norm(T - Tprev)
if log:
log['err'].append(err)
@@ -681,33 +1242,33 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False):
- """
- Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices
-
- (C1,p) and (C2,q)
+ r"""
+ Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
+ GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})
+ \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
Where :
- - C1 : Metric cost matrix in the source space
- - C2 : Metric cost matrix in the target space
- - p : distribution in the source space
- - q : distribution in the target space
- - L : loss function to account for the misfit between the similarity matrices
- - H : entropy
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+ - `H`: entropy
Parameters
----------
- C1 : ndarray, shape (ns, ns)
+ C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
- C2 : ndarray, shape (nt, nt)
- Metric costfr matrix in the target space
- p : ndarray, shape (ns,)
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
Distribution in the source space
- q : ndarray, shape (nt,)
+ q : array-like, shape (nt,)
Distribution in the target space
loss_fun : str
Loss function used for the solver either 'square_loss' or 'kl_loss'
@@ -729,7 +1290,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
@@ -746,76 +1307,79 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
- """
- Returns the gromov-wasserstein barycenters of S measured similarity matrices
-
- (Cs)_{s=1}^{s=S}
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
+ r"""
+ Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
The function solves the following optimization problem:
.. math::
- C = argmin_{C\in R^{NxN}} \sum_s \lambda_s GW(C,C_s,p,p_s)
+ \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
Where :
- - :math:`C_s` : metric cost matrix
- - :math:`p_s` : distribution
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
Parameters
----------
N : int
Size of the targeted barycenter
- Cs : list of S np.ndarray of shape (ns,ns)
+ Cs : list of S array-like of shape (ns,ns)
Metric cost matrices
- ps : list of S np.ndarray of shape (ns,)
- Sample weights in the S spaces
- p : ndarray, shape(N,)
+ ps : list of S array-like of shape (ns,)
+ Sample weights in the `S` spaces
+ p : array-like, shape(N,)
Weights in the targeted barycenter
lambdas : list of float
- List of the S spaces' weights.
+ List of the `S` spaces' weights.
loss_fun : callable
Tensor-matrix multiplication function based on specific loss function.
update : callable
- function(p,lambdas,T,Cs) that updates C according to a specific Kernel
- with the S Ts couplings calculated at each iteration
+ function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
+ :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
+ calculated at each iteration
epsilon : float
Regularization term >0
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations.
log : bool, optional
Record log if True.
- init_C : bool | ndarray, shape (N, N)
- Random initial value for the C matrix provided by user.
+ init_C : bool | array-like, shape (N, N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
Returns
-------
- C : ndarray, shape (N, N)
+ C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *ps, p)
S = len(Cs)
- Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
- lambdas = np.asarray(lambdas, dtype=np.float64)
-
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
- # XXX use random state
- xalea = np.random.randn(N, 2)
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
else:
C = init_C
@@ -828,7 +1392,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
Cprev = C
T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-5, verbose, log) for s in range(S)]
+ max_iter, 1e-4, verbose, log) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
@@ -838,7 +1402,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.linalg.norm(C - Cprev)
+ err = nx.norm(C - Cprev)
error.append(err)
if log:
@@ -856,72 +1420,78 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None):
- """
- Returns the gromov-wasserstein barycenters of S measured similarity matrices
-
- (Cs)_{s=1}^{s=S}
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
+ r"""
+ Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
- The function solves the following optimization problem with block
- coordinate descent:
+ The function solves the following optimization problem with block coordinate descent:
.. math::
- C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps)
+
+ \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
Where :
- - Cs : metric cost matrix
- - ps : distribution
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
Parameters
----------
N : int
Size of the targeted barycenter
- Cs : list of S np.ndarray of shape (ns, ns)
+ Cs : list of S array-like of shape (ns, ns)
Metric cost matrices
- ps : list of S np.ndarray of shape (ns,)
- Sample weights in the S spaces
- p : ndarray, shape (N,)
+ ps : list of S array-like of shape (ns,)
+ Sample weights in the `S` spaces
+ p : array-like, shape (N,)
Weights in the targeted barycenter
lambdas : list of float
- List of the S spaces' weights
- loss_fun : tensor-matrix multiplication function based on specific loss function
- update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
- with the S Ts couplings calculated at each iteration
+ List of the `S` spaces' weights
+ loss_fun : callable
+ tensor-matrix multiplication function based on specific loss function
+ update : callable
+ function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
+ :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
+ calculated at each iteration
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0).
+ Stop threshold on error (>0).
verbose : bool, optional
Print information along iterations.
log : bool, optional
Record log if True.
- init_C : bool | ndarray, shape(N,N)
- Random initial value for the C matrix provided by user.
+ init_C : bool | array-like, shape(N,N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
Returns
-------
- C : ndarray, shape (N, N)
+ C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
References
----------
- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
"""
- S = len(Cs)
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *ps, p)
- Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
- lambdas = np.asarray(lambdas, dtype=np.float64)
+ S = len(Cs)
# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
- # XXX : should use a random state and not use the global seed
- xalea = np.random.randn(N, 2)
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
else:
C = init_C
@@ -944,7 +1514,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = np.linalg.norm(C - Cprev)
+ err = nx.norm(C - Cprev)
error.append(err)
if log:
@@ -963,21 +1533,21 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
- verbose=False, log=False, init_C=None, init_X=None):
- """Compute the fgw barycenter as presented eq (5) in [24].
+ verbose=False, log=False, init_C=None, init_X=None, random_state=None):
+ r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
Parameters
----------
- N : integer
+ N : int
Desired number of samples of the target barycenter
- Ys: list of ndarray, each element has shape (ns,d)
+ Ys: list of array-like, each element has shape (ns,d)
Features of all samples
- Cs : list of ndarray, each element has shape (ns,ns)
+ Cs : list of array-like, each element has shape (ns,ns)
Structure matrices of all samples
- ps : list of ndarray, each element has shape (ns,)
+ ps : list of array-like, each element has shape (ns,)
Masses of all samples.
lambdas : list of float
- List of the S spaces' weights
+ List of the `S` spaces' weights
alpha : float
Alpha parameter for the fgw distance
fixed_structure : bool
@@ -989,46 +1559,51 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0).
+ Stop threshold on error (>0).
verbose : bool, optional
Print information along iterations.
log : bool, optional
Record log if True.
- init_C : ndarray, shape (N,N), optional
+ init_C : array-like, shape (N,N), optional
Initialization for the barycenters' structure matrix. If not set
a random init is used.
- init_X : ndarray, shape (N,d), optional
+ init_X : array-like, shape (N,d), optional
Initialization for the barycenters' features. If not set a
random init is used.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
Returns
-------
- X : ndarray, shape (N, d)
+ X : array-like, shape (`N`, `d`)
Barycenters' features
- C : ndarray, shape (N, N)
+ C : array-like, shape (`N`, `N`)
Barycenters' structure matrix
- log_: dict
+ log : dict
Only returned when log=True. It contains the keys:
- T : list of (N,ns) transport matrices
- Ms : all distance matrices between the feature of the barycenter and the
- other features dist(X,Ys) shape (N,ns)
+ - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
+ - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
+
+
+ .. _references-fgw-barycenters:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ Ys = list_to_array(*Ys)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *Ys, *ps)
+
S = len(Cs)
d = Ys[0].shape[1] # dimension on the node features
if p is None:
- p = np.ones(N) / N
-
- Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
- Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
-
- lambdas = np.asarray(lambdas, dtype=np.float64)
+ p = nx.ones(N, type_as=Cs[0]) / N
if fixed_structure:
if init_C is None:
@@ -1037,8 +1612,10 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
C = init_C
else:
if init_C is None:
- xalea = np.random.randn(N, 2)
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
+ C = nx.from_numpy(C, type_as=ps[0])
else:
C = init_C
@@ -1049,13 +1626,13 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
X = init_X
else:
if init_X is None:
- X = np.zeros((N, d))
+ X = nx.zeros((N, d), type_as=ps[0])
else:
X = init_X
- T = [np.outer(p, q) for q in ps]
+ T = [nx.outer(p, q) for q in ps]
- Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
cpt = 0
err_feature = 1
@@ -1075,20 +1652,19 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
- Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
if not fixed_structure:
if loss_fun == 'square_loss':
T_temp = [t.T for t in T]
- C = update_sructure_matrix(p, lambdas, T_temp, Cs)
+ C = update_structure_matrix(p, lambdas, T_temp, Cs)
T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
# T is N,ns
- err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
- err_structure = np.linalg.norm(C - Cprev)
-
+ err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
+ err_structure = nx.norm(C - Cprev)
if log:
log_['err_feature'].append(err_feature)
log_['err_structure'].append(err_structure)
@@ -1114,64 +1690,80 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
return X, C
-def update_sructure_matrix(p, lambdas, T, Cs):
- """Updates C according to the L2 Loss kernel with the S Ts couplings.
+def update_structure_matrix(p, lambdas, T, Cs):
+ r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
It is calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
Masses in the targeted barycenter.
lambdas : list of float
- List of the S spaces' weights.
- T : list of S ndarray of shape (ns, N)
- The S Ts couplings calculated at each iteration.
- Cs : list of S ndarray, shape (ns, ns)
- Metric cost matrices.
+ List of the `S` spaces' weights.
+ T : list of S array-like of shape (ns, N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape (ns, ns)
+ Metric cost matrices.
Returns
-------
- C : ndarray, shape (nt, nt)
- Updated C matrix.
+ C : array-like, shape (`nt`, `nt`)
+ Updated :math:`\mathbf{C}` matrix.
"""
- tmpsum = sum([lambdas[s] * np.dot(T[s].T, Cs[s]).dot(T[s]) for s in range(len(T))])
- ppt = np.outer(p, p)
+ p = list_to_array(p)
+ T = list_to_array(*T)
+ Cs = list_to_array(*Cs)
+ nx = get_backend(*Cs, *T, p)
- return np.divide(tmpsum, ppt)
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+ return tmpsum / ppt
def update_feature_matrix(lambdas, Ys, Ts, p):
- """Updates the feature with respect to the S Ts couplings.
+ r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
- in [24] calculated at each iteration
+ in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
Parameters
----------
- p : ndarray, shape (N,)
+ p : array-like, shape (N,)
masses in the targeted barycenter
lambdas : list of float
- List of the S spaces' weights
- Ts : list of S np.ndarray(ns,N)
- the S Ts couplings calculated at each iteration
- Ys : list of S ndarray, shape(d,ns)
+ List of the `S` spaces' weights
+ Ts : list of S array-like, shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
+ Ys : list of S array-like, shape (d,ns)
The features.
Returns
-------
- X : ndarray, shape (d, N)
+ X : array-like, shape (`d`, `N`)
+
+ .. _references-update-feature-matrix:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
- and Courty Nicolas
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
- p = np.array(1. / p).reshape(-1,)
-
- tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))])
-
+ p = list_to_array(p)
+ Ts = list_to_array(*Ts)
+ Ys = list_to_array(*Ys)
+ nx = get_backend(*Ys, *Ts, p)
+
+ p = 1. / p
+ tmpsum = sum([
+ lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
+ for s in range(len(Ts))
+ ])
return tmpsum
diff --git a/ot/helpers/__init__.py b/ot/helpers/__init__.py
new file mode 100644
index 0000000..b948671
--- /dev/null
+++ b/ot/helpers/__init__.py
@@ -0,0 +1,3 @@
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
diff --git a/ot/helpers/openmp_helpers.py b/ot/helpers/openmp_helpers.py
new file mode 100644
index 0000000..a6ad38b
--- /dev/null
+++ b/ot/helpers/openmp_helpers.py
@@ -0,0 +1,85 @@
+"""Helpers for OpenMP support during the build."""
+
+# This code is adapted for a large part from the astropy openmp helpers, which
+# can be found at: https://github.com/astropy/extension-helpers/blob/master/extension_helpers/_openmp_helpers.py # noqa
+
+
+import os
+import sys
+import textwrap
+import subprocess
+
+from distutils.errors import CompileError, LinkError
+
+from pre_build_helpers import compile_test_program
+
+
+def get_openmp_flag(compiler):
+ """Get openmp flags for a given compiler"""
+
+ if hasattr(compiler, 'compiler'):
+ compiler = compiler.compiler[0]
+ else:
+ compiler = compiler.__class__.__name__
+
+ if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler):
+ omp_flag = ['/Qopenmp']
+ elif sys.platform == "win32":
+ omp_flag = ['/openmp']
+ elif sys.platform in ("darwin", "linux") and "icc" in compiler:
+ omp_flag = ['-qopenmp']
+ elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''):
+ omp_flag = []
+ else:
+ # Default flag for GCC and clang:
+ omp_flag = ['-fopenmp']
+ if sys.platform.startswith("darwin"):
+ omp_flag += ["-Xpreprocessor", "-lomp"]
+ return omp_flag
+
+
+def check_openmp_support():
+ """Check whether OpenMP test code can be compiled and run"""
+
+ code = textwrap.dedent(
+ """\
+ #include <omp.h>
+ #include <stdio.h>
+ int main(void) {
+ #pragma omp parallel
+ printf("nthreads=%d\\n", omp_get_num_threads());
+ return 0;
+ }
+ """)
+
+ extra_preargs = os.getenv('LDFLAGS', None)
+ if extra_preargs is not None:
+ extra_preargs = extra_preargs.strip().split(" ")
+ extra_preargs = [
+ flag for flag in extra_preargs
+ if flag.startswith(('-L', '-Wl,-rpath', '-l'))]
+
+ extra_postargs = get_openmp_flag
+
+ try:
+ output, compile_flags = compile_test_program(
+ code,
+ extra_preargs=extra_preargs,
+ extra_postargs=extra_postargs
+ )
+
+ if output and 'nthreads=' in output[0]:
+ nthreads = int(output[0].strip().split('=')[1])
+ openmp_supported = len(output) == nthreads
+ elif "PYTHON_CROSSENV" in os.environ:
+ # Since we can't run the test program when cross-compiling
+ # assume that openmp is supported if the program can be
+ # compiled.
+ openmp_supported = True
+ else:
+ openmp_supported = False
+
+ except (CompileError, LinkError, subprocess.CalledProcessError):
+ openmp_supported = False
+ compile_flags = []
+ return openmp_supported, compile_flags
diff --git a/ot/helpers/pre_build_helpers.py b/ot/helpers/pre_build_helpers.py
new file mode 100644
index 0000000..93ecd6a
--- /dev/null
+++ b/ot/helpers/pre_build_helpers.py
@@ -0,0 +1,87 @@
+"""Helpers to check build environment before actual build of POT"""
+
+import os
+import sys
+import glob
+import tempfile
+import setuptools # noqa
+import subprocess
+
+from distutils.dist import Distribution
+from distutils.sysconfig import customize_compiler
+from numpy.distutils.ccompiler import new_compiler
+from numpy.distutils.command.config_compiler import config_cc
+
+
+def _get_compiler():
+ """Get a compiler equivalent to the one that will be used to build POT
+ Handles compiler specified as follows:
+ - python setup.py build_ext --compiler=<compiler>
+ - CC=<compiler> python setup.py build_ext
+ """
+ dist = Distribution({'script_name': os.path.basename(sys.argv[0]),
+ 'script_args': sys.argv[1:],
+ 'cmdclass': {'config_cc': config_cc}})
+
+ cmd_opts = dist.command_options.get('build_ext')
+ if cmd_opts is not None and 'compiler' in cmd_opts:
+ compiler = cmd_opts['compiler'][1]
+ else:
+ compiler = None
+
+ ccompiler = new_compiler(compiler=compiler)
+ customize_compiler(ccompiler)
+
+ return ccompiler
+
+
+def compile_test_program(code, extra_preargs=[], extra_postargs=[]):
+ """Check that some C code can be compiled and run"""
+ ccompiler = _get_compiler()
+
+ # extra_(pre/post)args can be a callable to make it possible to get its
+ # value from the compiler
+ if callable(extra_preargs):
+ extra_preargs = extra_preargs(ccompiler)
+ if callable(extra_postargs):
+ extra_postargs = extra_postargs(ccompiler)
+
+ start_dir = os.path.abspath('.')
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ try:
+ os.chdir(tmp_dir)
+
+ # Write test program
+ with open('test_program.c', 'w') as f:
+ f.write(code)
+
+ os.mkdir('objects')
+
+ # Compile, test program
+ ccompiler.compile(['test_program.c'], output_dir='objects',
+ extra_postargs=extra_postargs)
+
+ # Link test program
+ objects = glob.glob(
+ os.path.join('objects', '*' + ccompiler.obj_extension))
+ ccompiler.link_executable(objects, 'test_program',
+ extra_preargs=extra_preargs,
+ extra_postargs=extra_postargs)
+
+ if "PYTHON_CROSSENV" not in os.environ:
+ # Run test program if not cross compiling
+ # will raise a CalledProcessError if return code was non-zero
+ output = subprocess.check_output('./test_program')
+ output = output.decode(
+ sys.stdout.encoding or 'utf-8').splitlines()
+ else:
+ # Return an empty output if we are cross compiling
+ # as we cannot run the test_program
+ output = []
+ except Exception:
+ raise
+ finally:
+ os.chdir(start_dir)
+
+ return output, extra_postargs
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index c0fe7a3..8a1f9ac 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -18,19 +18,18 @@
#include <iostream>
#include <vector>
-#include "network_simplex_simple.h"
-using namespace lemon;
typedef unsigned int node_id_type;
enum ProblemType {
INFEASIBLE,
OPTIMAL,
UNBOUNDED,
- MAX_ITER_REACHED
+ MAX_ITER_REACHED
};
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
+int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index bc873ed..2bdc172 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -12,16 +12,22 @@
*
*/
+
+#include "network_simplex_simple.h"
+#include "network_simplex_simple_omp.h"
#include "EMD.h"
+#include <cstdint>
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, int maxIter) {
- // beware M and C anre strored in row major C style!!!
- int n, m, i, cur;
+ // beware M and C are stored in row major C style!!!
+
+ using namespace lemon;
+ int n, m, cur;
typedef FullBipartiteDigraph Digraph;
- DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
+ DIGRAPH_TYPEDEFS(Digraph);
// Get the number of non zero coordinates for r and c
n=0;
@@ -48,7 +54,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
std::vector<int> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
// Set supply and demand, don't account for 0 values (faster)
@@ -76,10 +82,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
net.supplyMap(&weights1[0], n, &weights2[0], m);
// Set the cost of each edge
+ int64_t idarc = 0;
for (int i=0; i<n; i++) {
for (int j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
- net.setCost(di.arcFromId(i*m+j), val);
+ net.setCost(di.arcFromId(idarc), val);
+ ++idarc;
}
}
@@ -87,12 +95,13 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm
int ret=net.run();
+ int i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
for (; a != INVALID; di.next(a)) {
- int i = di.source(a);
- int j = di.target(a);
+ i = di.source(a);
+ j = di.target(a);
double flow = net.flow(a);
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
*(G+indI[i]*n2+indJ[j-n]) = flow;
@@ -106,3 +115,104 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
return ret;
}
+
+
+
+
+
+
+int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
+ double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
+ // beware M and C are stored in row major C style!!!
+
+ using namespace lemon_omp;
+ int n, m, cur;
+
+ typedef FullBipartiteDigraph Digraph;
+ DIGRAPH_TYPEDEFS(Digraph);
+
+ // Get the number of non zero coordinates for r and c
+ n=0;
+ for (int i=0; i<n1; i++) {
+ double val=*(X+i);
+ if (val>0) {
+ n++;
+ }else if(val<0){
+ return INFEASIBLE;
+ }
+ }
+ m=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ m++;
+ }else if(val<0){
+ return INFEASIBLE;
+ }
+ }
+
+ // Define the graph
+
+ std::vector<int> indI(n), indJ(m);
+ std::vector<double> weights1(n), weights2(m);
+ Digraph di(n, m);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
+
+ // Set supply and demand, don't account for 0 values (faster)
+
+ cur=0;
+ for (int i=0; i<n1; i++) {
+ double val=*(X+i);
+ if (val>0) {
+ weights1[ cur ] = val;
+ indI[cur++]=i;
+ }
+ }
+
+ // Demand is actually negative supply...
+
+ cur=0;
+ for (int i=0; i<n2; i++) {
+ double val=*(Y+i);
+ if (val>0) {
+ weights2[ cur ] = -val;
+ indJ[cur++]=i;
+ }
+ }
+
+
+ net.supplyMap(&weights1[0], n, &weights2[0], m);
+
+ // Set the cost of each edge
+ int64_t idarc = 0;
+ for (int i=0; i<n; i++) {
+ for (int j=0; j<m; j++) {
+ double val=*(D+indI[i]*n2+indJ[j]);
+ net.setCost(di.arcFromId(idarc), val);
+ ++idarc;
+ }
+ }
+
+
+ // Solve the problem with the network simplex algorithm
+
+ int ret=net.run();
+ int i, j;
+ if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
+ *cost = 0;
+ Arc a; di.first(a);
+ for (; a != INVALID; di.next(a)) {
+ i = di.source(a);
+ j = di.target(a);
+ double flow = net.flow(a);
+ *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
+ *(G+indI[i]*n2+indJ[j-n]) = flow;
+ *(alpha + indI[i]) = -net.potential(i);
+ *(beta + indJ[j-n]) = net.potential(j);
+ }
+
+ }
+
+
+ return ret;
+}
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 514a607..5da897d 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -8,25 +8,50 @@ Solvers for the original linear program OT problem
#
# License: MIT License
+import os
import multiprocessing
import sys
import numpy as np
-from scipy.sparse import coo_matrix
+import warnings
from . import cvx
from .cvx import barycenter
+
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from ..utils import dist
+from .solver_1d import emd_1d, emd2_1d, wasserstein_1d
+
+from ..utils import dist, list_to_array
from ..utils import parmap
+from ..backend import get_backend
-__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
+__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
+def check_number_threads(numThreads):
+ """Checks whether or not the requested number of threads has a valid value.
+
+ Parameters
+ ----------
+ numThreads : int or str
+ The requested number of threads, should either be a strictly positive integer or "max" or None
+
+ Returns
+ -------
+ numThreads : int
+ Corrected number of threads
+ """
+ if (numThreads is None) or (isinstance(numThreads, str) and numThreads.lower() == 'max'):
+ return -1
+ if (not isinstance(numThreads, int)) or numThreads < 1:
+ raise ValueError('numThreads should either be "max" or a strictly positive integer')
+ return numThreads
+
+
def center_ot_dual(alpha0, beta0, a=None, b=None):
- r"""Center dual OT potentials w.r.t. theirs weights
+ r"""Center dual OT potentials w.r.t. their weights
The main idea of this function is to find unique dual potentials
that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having
@@ -37,7 +62,7 @@ def center_ot_dual(alpha0, beta0, a=None, b=None):
is the following:
.. math::
- \alpha^T a= \beta^T b
+ \alpha^T \mathbf{a} = \beta^T \mathbf{b}
in addition to the OT problem constraints.
@@ -45,11 +70,11 @@ def center_ot_dual(alpha0, beta0, a=None, b=None):
a constant from both :math:`\alpha_0` and :math:`\beta_0`.
.. math::
- c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta}
+ c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}}
- \alpha=\alpha_0+c
+ \alpha &= \alpha_0 + c
- \beta=\beta0+c
+ \beta &= \beta_0 + c
Parameters
----------
@@ -92,35 +117,35 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
The feasible values are computed efficiently but rather coarsely.
.. warning::
- This function is necessary because the C++ solver in emd_c
- discards all samples in the distributions with
- zeros weights. This means that while the primal variable (transport
+ This function is necessary because the C++ solver in `emd_c`
+ discards all samples in the distributions with
+ zeros weights. This means that while the primal variable (transport
matrix) is exact, the solver only returns feasible dual potentials
- on the samples with weights different from zero.
+ on the samples with weights different from zero.
First we compute the constraints violations:
.. math::
- V=\alpha+\beta^T-M
+ \mathbf{V} = \alpha + \beta^T - \mathbf{M}
- Next we compute the max amount of violation per row (alpha) and
- columns (beta)
+ Next we compute the max amount of violation per row (:math:`\alpha`) and
+ columns (:math:`beta`)
.. math::
- v^a_i=\max_j V_{i,j}
+ \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j}
- v^b_j=\max_i V_{i,j}
+ \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j}
Finally we update the dual potential with 0 weights if a
constraint is violated
.. math::
- \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0
+ \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0
- \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0
+ \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0
In the end the dual potentials are centered using function
- :ref:`center_ot_dual`.
+ :py:func:`ot.lp.center_ot_dual`.
Note that all those updates do not change the objective value of the
solution but provide dual potentials that do not violate the constraints.
@@ -172,54 +197,62 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
return center_ot_dual(alpha, beta, a, b)
-def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
+def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
r"""Solves the Earth Movers distance problem and returns the OT matrix
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
+
+ \gamma \geq 0
- \gamma\geq 0
where :
- - M is the metric cost matrix
- - a and b are the sample weights
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order
+ numpy.array in float64 format. It will be converted if not in this
+ format
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
- Uses the algorithm proposed in [1]_
+ Uses the algorithm proposed in :ref:`[1] <references-emd>`.
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
+ b : (nt,) array-like, float
Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
+ M : (ns,nt) array-like, float
+ Loss matrix (c-order array in numpy with type float64)
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
log: bool, optional (default=False)
- If True, returns a dictionary containing the cost and dual
- variables. Otherwise returns only the optimal transportation matrix.
+ If True, returns a dictionary containing the cost and dual variables.
+ Otherwise returns only the optimal transportation matrix.
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
:ref:`center_ot_dual`.
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
Returns
-------
- gamma: (ns x nt) numpy.ndarray
- Optimal transportation matrix for the given parameters
- log: dict
- If input log is true, a dictionary containing the cost and dual
- variables and exit status
+ gamma: array-like, shape (ns, nt)
+ Optimal transportation matrix for the given
+ parameters
+ log: dict, optional
+ If input log is true, a dictionary containing the
+ cost and dual variables and exit status
Examples
@@ -232,26 +265,39 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
- >>> ot.emd(a,b,M)
+ >>> ot.emd(a, b, M)
array([[0.5, 0. ],
[0. , 0.5]])
+
+ .. _references-emd:
References
----------
-
- .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
- (2011, December). Displacement interpolation using Lagrangian mass
- transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
- 158). ACM.
+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
+ December). Displacement interpolation using Lagrangian mass transport.
+ In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT"""
+ ot.optim.cg : General regularized OT
+ """
+
+ # convert to numpy if list
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
+
+ # ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M = np.asarray(M, dtype=np.float64, order='C')
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -262,81 +308,91 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True):
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0), err_msg='a and b vector must have the same sum')
+ b = b * a.sum() / b.sum()
+
asel = a != 0
bsel = b != 0
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ numThreads = check_number_threads(numThreads)
+
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
if center_dual:
u, v = center_ot_dual(u, v, a, b)
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
-
+
result_code_string = check_result(result_code)
if log:
log = {}
log['cost'] = cost
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
- return G, log
- return G
+ return nx.from_numpy(G, type_as=M0), log
+ return nx.from_numpy(G, type_as=M0)
-def emd2(a, b, M, processes=multiprocessing.cpu_count(),
+def emd2(a, b, M, processes=1,
numItermax=100000, log=False, return_matrix=False,
- center_dual=True):
+ center_dual=True, numThreads=1):
r"""Solves the Earth Movers distance problem and returns the loss
.. math::
- \min_\gamma <\gamma,M>_F
+ \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
+
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- s.t. \gamma 1 = a
+ \gamma^T \mathbf{1} = \mathbf{b}
- \gamma^T 1= b
+ \gamma \geq 0
- \gamma\geq 0
where :
- - M is the metric cost matrix
- - a and b are the sample weights
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
- .. warning::
- Note that the M matrix needs to be a C-order numpy.array in float64
- format.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
- Uses the algorithm proposed in [1]_
+ Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
Parameters
----------
- a : (ns,) numpy.ndarray, float64
+ a : (ns,) array-like, float64
Source histogram (uniform weight if empty list)
- b : (nt,) numpy.ndarray, float64
+ b : (nt,) array-like, float64
Target histogram (uniform weight if empty list)
- M : (ns,nt) numpy.ndarray, float64
- Loss matrix (c-order array with type float64)
- processes : int, optional (default=nb cpu)
- Nb of processes used for multiple emd computation (not used on windows)
+ M : (ns,nt) array-like, float64
+ Loss matrix (for numpy c-order array with type float64)
+ processes : int, optional (default=1)
+ Nb of processes used for multiple emd computation (deprecated)
numItermax : int, optional (default=100000)
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
log: boolean, optional (default=False)
- If True, returns a dictionary containing the cost and dual
+ If True, returns a dictionary containing dual
variables. Otherwise returns only the optimal transportation cost.
return_matrix: boolean, optional (default=False)
If True, returns the optimal transportation matrix in the log.
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
:ref:`center_ot_dual`.
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
Returns
-------
- gamma: (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- log: dictnp
- If input log is true, a dictionary containing the cost and dual
+ W: float, array-like
+ Optimal transportation loss for the given parameters
+ log: dict
+ If input log is true, a dictionary containing dual
variables and exit status
@@ -354,9 +410,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
>>> ot.emd2(a,b,M)
0.0
+
+ .. _references-emd2:
References
----------
-
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
(2011, December). Displacement interpolation using Lagrangian mass
transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
@@ -365,15 +422,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT
- ot.optim.cg : General regularized OT"""
+ ot.optim.cg : General regularized OT
+ """
+
+ a, b, M = list_to_array(a, b, M)
+
+ a0, b0, M0 = a, b, M
+ nx = get_backend(M0, a0, b0)
+
+ # convert to numpy
+ M = nx.to_numpy(M)
+ a = nx.to_numpy(a)
+ b = nx.to_numpy(b)
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
-
- # problem with pikling Forks
- if sys.platform.endswith('win32'):
- processes = 1
+ M = np.asarray(M, dtype=np.float64, order='C')
# if empty array given then use uniform distributions
if len(a) == 0:
@@ -386,11 +450,13 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
asel = a != 0
+ numThreads = check_number_threads(numThreads)
+
if log or return_matrix:
def f(b):
bsel = b != 0
-
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
if center_dual:
u, v = center_ot_dual(u, v, a, b)
@@ -400,17 +466,20 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
result_code_string = check_result(result_code)
log = {}
+ G = nx.from_numpy(G, type_as=M0)
if return_matrix:
log['G'] = G
- log['u'] = u
- log['v'] = v
+ log['u'] = nx.from_numpy(u, type_as=a0)
+ log['v'] = nx.from_numpy(v, type_as=b0)
log['warning'] = result_code_string
log['result_code'] = result_code
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0, b0, M0), (log['u'], log['v'], G))
return [cost, log]
else:
def f(b):
bsel = b != 0
- G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
+ G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
if center_dual:
u, v = center_ot_dual(u, v, a, b)
@@ -418,6 +487,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
+ G = nx.from_numpy(G, type_as=M0)
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ (a0, b0, M0), (nx.from_numpy(u, type_as=a0),
+ nx.from_numpy(v, type_as=b0), G))
+
check_result(result_code)
return cost
@@ -426,35 +500,53 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
nb = b.shape[1]
if processes > 1:
- res = parmap(f, [b[:, i] for i in range(nb)], processes)
- else:
- res = list(map(f, [b[:, i].copy() for i in range(nb)]))
+ warnings.warn(
+ "The 'processes' parameter has been deprecated. "
+ "Multiprocessing should be done outside of POT."
+ )
+ res = list(map(f, [b[:, i].copy() for i in range(nb)]))
return res
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
- stopThr=1e-7, verbose=False, log=None):
- """
- Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
+ stopThr=1e-7, verbose=False, log=None, numThreads=1):
+ r"""
+ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally:
+
+ .. math::
+ \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
+
+ where :
+
+ - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
+ - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i`
+ - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations
+ - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
+
+ This problem is considered in :ref:`[1] <references-free-support-barycenter>` (Algorithm 2).
+ There are two differences with the following codes:
- The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
- This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
- we do not optimize over the weights
- - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
+ - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
+ :ref:`[1] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ implementation of the fixed-point algorithm of
+ :ref:`[2] <references-free-support-barycenter>` proposed in the continuous setting.
Parameters
----------
- measures_locations : list of (k_i,d) numpy.ndarray
- The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
- measures_weights : list of (k_i,) numpy.ndarray
- Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
+ measures_locations : list of N (k_i,d) numpy.ndarray
+ The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
+ measures_weights : list of N (k_i,) numpy.ndarray
+ Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
+ representing the weights of each discrete input measure
X_init : (k,d) np.ndarray
- Initialization of the support locations (on k atoms) of the barycenter
+ Initialization of the support locations (on `k` atoms) of the barycenter
b : (k,) np.ndarray
Initialization of the weights of the barycenter (non-negatives, sum to 1)
- weights : (k,) np.ndarray
+ weights : (N,) np.ndarray
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
numItermax : int, optional
@@ -465,15 +557,20 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
Print information along iterations
log : bool, optional
record log if True
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
+
Returns
-------
X : (k,d) np.ndarray
Support locations (on k atoms) of the barycenter
+
+ .. _references-free-support-barycenter:
References
----------
-
.. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
.. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
@@ -504,7 +601,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
weights.tolist()):
M_i = dist(X, measure_locations_i)
- T_i = emd(b, measure_weights_i, M_i)
+ T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
displacement_square_norm = np.sum(np.square(T_sum - X))
@@ -523,287 +620,3 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
return X, log_dict
else:
return X
-
-
-def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
- log=False):
- r"""Solves the Earth Movers distance problem between 1d measures and returns
- the OT matrix
-
-
- .. math::
- \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
- where :
-
- - d is the metric
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- metric: str, optional (default='sqeuclidean')
- Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
- Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
- p: float, optional (default=1.0)
- The p-norm to apply for if metric='minkowski'
- dense: boolean, optional (default=True)
- If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
- Otherwise returns a sparse representation using scipy's `coo_matrix`
- format. Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
- are used.
- log: boolean, optional (default=False)
- If True, returns a dictionary containing the cost.
- Otherwise returns only the optimal transportation matrix.
-
- Returns
- -------
- gamma: (ns, nt) ndarray
- Optimal transportation matrix for the given parameters
- log: dict
- If input log is True, a dictionary containing the cost
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function emd_1d accepts lists and
- performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.emd_1d(x_a, x_b, a, b)
- array([[0. , 0.5],
- [0.5, 0. ]])
- >>> ot.emd_1d(x_a, x_b)
- array([[0. , 0.5],
- [0.5, 0. ]])
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd : EMD for multidimensional distributions
- ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
- transportation matrix)
- """
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- x_a = np.asarray(x_a, dtype=np.float64)
- x_b = np.asarray(x_b, dtype=np.float64)
-
- assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
- "emd_1d should only be used with monodimensional data"
- assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
- "emd_1d should only be used with monodimensional data"
-
- # if empty array given then use uniform distributions
- if a.ndim == 0 or len(a) == 0:
- a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
- if b.ndim == 0 or len(b) == 0:
- b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
-
- x_a_1d = x_a.reshape((-1,))
- x_b_1d = x_b.reshape((-1,))
- perm_a = np.argsort(x_a_1d)
- perm_b = np.argsort(x_b_1d)
-
- G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b],
- x_a_1d[perm_a], x_b_1d[perm_b],
- metric=metric, p=p)
- G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
- shape=(a.shape[0], b.shape[0]))
- if dense:
- G = G.toarray()
- if log:
- log = {'cost': cost}
- return G, log
- return G
-
-
-def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
- log=False):
- r"""Solves the Earth Movers distance problem between 1d measures and returns
- the loss
-
-
- .. math::
- \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
- where :
-
- - d is the metric
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- metric: str, optional (default='sqeuclidean')
- Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
- Due to implementation details, this function runs faster when
- `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
- are used.
- p: float, optional (default=1.0)
- The p-norm to apply for if metric='minkowski'
- dense: boolean, optional (default=True)
- If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
- Otherwise returns a sparse representation using scipy's `coo_matrix`
- format. Only used if log is set to True. Due to implementation details,
- this function runs faster when dense is set to False.
- log: boolean, optional (default=False)
- If True, returns a dictionary containing the transportation matrix.
- Otherwise returns only the loss.
-
- Returns
- -------
- loss: float
- Cost associated to the optimal transportation
- log: dict
- If input log is True, a dictionary containing the Optimal transportation
- matrix for the given parameters
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function emd2_1d accepts lists and
- performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.emd2_1d(x_a, x_b, a, b)
- 0.5
- >>> ot.emd2_1d(x_a, x_b)
- 0.5
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd2 : EMD for multidimensional distributions
- ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
- instead of the cost)
- """
- # If we do not return G (log==False), then we should not to cast it to dense
- # (useless overhead)
- G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
- dense=dense and log, log=True)
- cost = log_emd['cost']
- if log:
- log_emd = {'G': G}
- return cost, log_emd
- return cost
-
-
-def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.):
- r"""Solves the p-Wasserstein distance problem between 1d measures and returns
- the distance
-
- .. math::
- \min_\gamma \left( \sum_i \sum_j \gamma_{ij} \|x_a[i] - x_b[j]\|^p \right)^{1/p}
-
- s.t. \gamma 1 = a,
- \gamma^T 1= b,
- \gamma\geq 0
-
- where :
-
- - x_a and x_b are the samples
- - a and b are the sample weights
-
- Uses the algorithm detailed in [1]_
-
- Parameters
- ----------
- x_a : (ns,) or (ns, 1) ndarray, float64
- Source dirac locations (on the real line)
- x_b : (nt,) or (ns, 1) ndarray, float64
- Target dirac locations (on the real line)
- a : (ns,) ndarray, float64, optional
- Source histogram (default is uniform weight)
- b : (nt,) ndarray, float64, optional
- Target histogram (default is uniform weight)
- p: float, optional (default=1.0)
- The order of the p-Wasserstein distance to be computed
-
- Returns
- -------
- dist: float
- p-Wasserstein distance
-
-
- Examples
- --------
-
- Simple example with obvious solution. The function wasserstein_1d accepts
- lists and performs automatic conversion to numpy arrays
-
- >>> import ot
- >>> a=[.5, .5]
- >>> b=[.5, .5]
- >>> x_a = [2., 0.]
- >>> x_b = [0., 3.]
- >>> ot.wasserstein_1d(x_a, x_b, a, b)
- 0.5
- >>> ot.wasserstein_1d(x_a, x_b)
- 0.5
-
- References
- ----------
-
- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
- See Also
- --------
- ot.lp.emd_1d : EMD for 1d distributions
- """
- cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
- dense=False, log=False)
- return np.power(cost_emd, 1. / p)
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 8e763be..869d450 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -27,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):
def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
- """Compute the Wasserstein barycenter of distributions A
+ r"""Compute the Wasserstein barycenter of distributions A
The function solves the following optimization problem [16]:
@@ -76,7 +76,6 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
.. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
-
"""
if weights is None:
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index c167964..42e08f4 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -20,6 +20,7 @@ import warnings
cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
+ int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -38,7 +39,7 @@ def check_result(result_code):
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -97,8 +98,6 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0])
cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0)
- cdef np.ndarray[long, ndim=1, mode="c"] iG=np.zeros(0,dtype=np.int)
- cdef np.ndarray[long, ndim=1, mode="c"] jG=np.zeros(0,dtype=np.int)
if not len(a):
a=np.ones((n1,))/n1
@@ -111,8 +110,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
# calling the function
with nogil:
- result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
-
+ if numThreads == 1:
+ result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
+ else:
+ result_code = EMD_wrap_omp(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, numThreads)
return G, cost, alpha, beta, result_code
@@ -157,22 +158,22 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
cost associated to the optimal transportation
"""
cdef double cost = 0.
- cdef int n = u_weights.shape[0]
- cdef int m = v_weights.shape[0]
+ cdef Py_ssize_t n = u_weights.shape[0]
+ cdef Py_ssize_t m = v_weights.shape[0]
- cdef int i = 0
+ cdef Py_ssize_t i = 0
cdef double w_i = u_weights[0]
- cdef int j = 0
+ cdef Py_ssize_t j = 0
cdef double w_j = v_weights[0]
cdef double m_ij = 0.
cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ),
dtype=np.float64)
- cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
- dtype=np.int)
- cdef int cur_idx = 0
- while i < n and j < m:
+ cdef np.ndarray[long long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
+ dtype=np.int64)
+ cdef Py_ssize_t cur_idx = 0
+ while True:
if metric == 'sqeuclidean':
m_ij = (u[i] - v[j]) * (u[i] - v[j])
elif metric == 'cityblock' or metric == 'euclidean':
@@ -188,6 +189,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
i += 1
+ if i == n:
+ break
w_j -= w_i
w_i = u_weights[i]
else:
@@ -196,7 +199,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
indices[cur_idx, 0] = i
indices[cur_idx, 1] = j
j += 1
+ if j == m:
+ break
w_i -= w_j
w_j = v_weights[j]
cur_idx += 1
+ cur_idx += 1
return G[:cur_idx], indices[:cur_idx], cost
diff --git a/ot/lp/full_bipartitegraph.h b/ot/lp/full_bipartitegraph.h
index 87a1bec..713ccb5 100644
--- a/ot/lp/full_bipartitegraph.h
+++ b/ot/lp/full_bipartitegraph.h
@@ -23,10 +23,10 @@
*
*/
-#ifndef LEMON_FULL_BIPARTITE_GRAPH_H
-#define LEMON_FULL_BIPARTITE_GRAPH_H
+#pragma once
#include "core.h"
+#include <cstdint>
///\ingroup graphs
///\file
@@ -44,16 +44,16 @@ namespace lemon {
//class Node;
typedef int Node;
//class Arc;
- typedef long long Arc;
+ typedef int64_t Arc;
protected:
int _node_num;
- long long _arc_num;
+ int64_t _arc_num;
FullBipartiteDigraphBase() {}
- void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = n1 * n2; _n1=n1; _n2=n2;}
+ void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;}
public:
@@ -65,25 +65,25 @@ namespace lemon {
Arc arc(const Node& s, const Node& t) const {
if (s<_n1 && t>=_n1)
- return Arc(s * _n2 + (t-_n1) );
+ return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) );
else
return Arc(-1);
}
int nodeNum() const { return _node_num; }
- long long arcNum() const { return _arc_num; }
+ int64_t arcNum() const { return _arc_num; }
int maxNodeId() const { return _node_num - 1; }
- long long maxArcId() const { return _arc_num - 1; }
+ int64_t maxArcId() const { return _arc_num - 1; }
Node source(Arc arc) const { return arc / _n2; }
Node target(Arc arc) const { return (arc % _n2) + _n1; }
static int id(Node node) { return node; }
- static long long id(Arc arc) { return arc; }
+ static int64_t id(Arc arc) { return arc; }
static Node nodeFromId(int id) { return Node(id);}
- static Arc arcFromId(int id) { return Arc(id);}
+ static Arc arcFromId(int64_t id) { return Arc(id);}
Arc findArc(Node s, Node t, Arc prev = -1) const {
@@ -136,7 +136,7 @@ namespace lemon {
///
/// \brief A directed full graph class.
///
- /// FullBipartiteDigraph is a simple and fast implmenetation of directed full
+ /// FullBipartiteDigraph is a simple and fast implementation of directed full
/// (complete) graphs. It contains an arc from each node to each node
/// (including a loop for each node), therefore the number of arcs
/// is the square of the number of nodes.
@@ -203,13 +203,10 @@ namespace lemon {
/// \brief Number of nodes.
int nodeNum() const { return Parent::nodeNum(); }
/// \brief Number of arcs.
- long long arcNum() const { return Parent::arcNum(); }
+ int64_t arcNum() const { return Parent::arcNum(); }
};
} //namespace lemon
-
-
-#endif //LEMON_FULL_GRAPH_H
diff --git a/ot/lp/full_bipartitegraph_omp.h b/ot/lp/full_bipartitegraph_omp.h
new file mode 100644
index 0000000..8cbed0b
--- /dev/null
+++ b/ot/lp/full_bipartitegraph_omp.h
@@ -0,0 +1,234 @@
+/* -*- mode: C++; indent-tabs-mode: nil; -*-
+ *
+ * This file has been adapted by Nicolas Bonneel (2013),
+ * from full_graph.h from LEMON, a generic C++ optimization library,
+ * to implement a lightweight fully connected bipartite graph. A previous
+ * version of this file is used as part of the Displacement Interpolation
+ * project,
+ * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/
+ *
+ *
+ **** Original file Copyright Notice :
+ * Copyright (C) 2003-2010
+ * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport
+ * (Egervary Research Group on Combinatorial Optimization, EGRES).
+ *
+ * Permission to use, modify and distribute this software is granted
+ * provided that this copyright notice appears in all copies. For
+ * precise terms see the accompanying LICENSE file.
+ *
+ * This software is provided "AS IS" with no warranty of any kind,
+ * express or implied, and with no claim as to its suitability for any
+ * purpose.
+ *
+ */
+
+#pragma once
+
+#include <cstdint>
+
+///\ingroup graphs
+///\file
+///\brief FullBipartiteDigraph and FullBipartiteGraph classes.
+
+
+namespace lemon_omp {
+
+ ///This \c \#define creates convenient type definitions for the following
+ ///types of \c Digraph: \c Node, \c NodeIt, \c Arc, \c ArcIt, \c InArcIt,
+ ///\c OutArcIt, \c BoolNodeMap, \c IntNodeMap, \c DoubleNodeMap,
+ ///\c BoolArcMap, \c IntArcMap, \c DoubleArcMap.
+ ///
+ ///\note If the graph type is a dependent type, ie. the graph type depend
+ ///on a template parameter, then use \c TEMPLATE_DIGRAPH_TYPEDEFS()
+ ///macro.
+#define DIGRAPH_TYPEDEFS(Digraph) \
+ typedef Digraph::Node Node; \
+ typedef Digraph::Arc Arc; \
+
+
+ ///Create convenience typedefs for the digraph types and iterators
+
+ ///\see DIGRAPH_TYPEDEFS
+ ///
+ ///\note Use this macro, if the graph type is a dependent type,
+ ///ie. the graph type depend on a template parameter.
+#define TEMPLATE_DIGRAPH_TYPEDEFS(Digraph) \
+ typedef typename Digraph::Node Node; \
+ typedef typename Digraph::Arc Arc; \
+
+
+ class FullBipartiteDigraphBase {
+ public:
+
+ typedef FullBipartiteDigraphBase Digraph;
+
+ //class Node;
+ typedef int Node;
+ //class Arc;
+ typedef int64_t Arc;
+
+ protected:
+
+ int _node_num;
+ int64_t _arc_num;
+
+ FullBipartiteDigraphBase() {}
+
+ void construct(int n1, int n2) { _node_num = n1+n2; _arc_num = (int64_t)n1 * (int64_t)n2; _n1=n1; _n2=n2;}
+
+ public:
+
+ int _n1, _n2;
+
+
+ Node operator()(int ix) const { return Node(ix); }
+ static int index(const Node& node) { return node; }
+
+ Arc arc(const Node& s, const Node& t) const {
+ if (s<_n1 && t>=_n1)
+ return Arc((int64_t)s * (int64_t)_n2 + (int64_t)(t-_n1) );
+ else
+ return Arc(-1);
+ }
+
+ int nodeNum() const { return _node_num; }
+ int64_t arcNum() const { return _arc_num; }
+
+ int maxNodeId() const { return _node_num - 1; }
+ int64_t maxArcId() const { return _arc_num - 1; }
+
+ Node source(Arc arc) const { return arc / _n2; }
+ Node target(Arc arc) const { return (arc % _n2) + _n1; }
+
+ static int id(Node node) { return node; }
+ static int64_t id(Arc arc) { return arc; }
+
+ static Node nodeFromId(int id) { return Node(id);}
+ static Arc arcFromId(int64_t id) { return Arc(id);}
+
+
+ Arc findArc(Node s, Node t, Arc prev = -1) const {
+ return prev == -1 ? arc(s, t) : -1;
+ }
+
+ void first(Node& node) const {
+ node = _node_num - 1;
+ }
+
+ static void next(Node& node) {
+ --node;
+ }
+
+ void first(Arc& arc) const {
+ arc = _arc_num - 1;
+ }
+
+ static void next(Arc& arc) {
+ --arc;
+ }
+
+ void firstOut(Arc& arc, const Node& node) const {
+ if (node>=_n1)
+ arc = -1;
+ else
+ arc = (node + 1) * _n2 - 1;
+ }
+
+ void nextOut(Arc& arc) const {
+ if (arc % _n2 == 0) arc = 0;
+ --arc;
+ }
+
+ void firstIn(Arc& arc, const Node& node) const {
+ if (node<_n1)
+ arc = -1;
+ else
+ arc = _arc_num + node - _node_num;
+ }
+
+ void nextIn(Arc& arc) const {
+ arc -= _n2;
+ if (arc < 0) arc = -1;
+ }
+
+ };
+
+ /// \ingroup graphs
+ ///
+ /// \brief A directed full graph class.
+ ///
+ /// FullBipartiteDigraph is a simple and fast implmenetation of directed full
+ /// (complete) graphs. It contains an arc from each node to each node
+ /// (including a loop for each node), therefore the number of arcs
+ /// is the square of the number of nodes.
+ /// This class is completely static and it needs constant memory space.
+ /// Thus you can neither add nor delete nodes or arcs, however
+ /// the structure can be resized using resize().
+ ///
+ /// This type fully conforms to the \ref concepts::Digraph "Digraph concept".
+ /// Most of its member functions and nested classes are documented
+ /// only in the concept class.
+ ///
+ /// This class provides constant time counting for nodes and arcs.
+ ///
+ /// \note FullBipartiteDigraph and FullBipartiteGraph classes are very similar,
+ /// but there are two differences. While this class conforms only
+ /// to the \ref concepts::Digraph "Digraph" concept, FullBipartiteGraph
+ /// conforms to the \ref concepts::Graph "Graph" concept,
+ /// moreover FullBipartiteGraph does not contain a loop for each
+ /// node as this class does.
+ ///
+ /// \sa FullBipartiteGraph
+ class FullBipartiteDigraph : public FullBipartiteDigraphBase {
+ typedef FullBipartiteDigraphBase Parent;
+
+ public:
+
+ /// \brief Default constructor.
+ ///
+ /// Default constructor. The number of nodes and arcs will be zero.
+ FullBipartiteDigraph() { construct(0,0); }
+
+ /// \brief Constructor
+ ///
+ /// Constructor.
+ /// \param n The number of the nodes.
+ FullBipartiteDigraph(int n1, int n2) { construct(n1, n2); }
+
+
+ /// \brief Returns the node with the given index.
+ ///
+ /// Returns the node with the given index. Since this structure is
+ /// completely static, the nodes can be indexed with integers from
+ /// the range <tt>[0..nodeNum()-1]</tt>.
+ /// The index of a node is the same as its ID.
+ /// \sa index()
+ Node operator()(int ix) const { return Parent::operator()(ix); }
+
+ /// \brief Returns the index of the given node.
+ ///
+ /// Returns the index of the given node. Since this structure is
+ /// completely static, the nodes can be indexed with integers from
+ /// the range <tt>[0..nodeNum()-1]</tt>.
+ /// The index of a node is the same as its ID.
+ /// \sa operator()()
+ static int index(const Node& node) { return Parent::index(node); }
+
+ /// \brief Returns the arc connecting the given nodes.
+ ///
+ /// Returns the arc connecting the given nodes.
+ /*Arc arc(Node u, Node v) const {
+ return Parent::arc(u, v);
+ }*/
+
+ /// \brief Number of nodes.
+ int nodeNum() const { return Parent::nodeNum(); }
+ /// \brief Number of arcs.
+ int64_t arcNum() const { return Parent::arcNum(); }
+ };
+
+
+
+
+} //namespace lemon_omp
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 5d93040..3b46b9b 100644
--- a/ot/lp/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
@@ -25,15 +25,17 @@
*
*/
-#ifndef LEMON_NETWORK_SIMPLEX_SIMPLE_H
-#define LEMON_NETWORK_SIMPLEX_SIMPLE_H
+#pragma once
+#undef DEBUG_LVL
#define DEBUG_LVL 0
#if DEBUG_LVL>0
#include <iomanip>
#endif
-
+#undef EPSILON
+#undef _EPSILON
+#undef MAX_DEBUG_ITER
#define EPSILON 2.2204460492503131e-15
#define _EPSILON 1e-8
#define MAX_DEBUG_ITER 100000
@@ -50,6 +52,7 @@
#include <vector>
#include <limits>
#include <algorithm>
+#include <iostream>
#include <cstdio>
#ifdef HASHMAP
#include <hash_map>
@@ -63,6 +66,8 @@
//#include "sparse_array_n.h"
#include "full_bipartitegraph.h"
+#undef INVALIDNODE
+#undef INVALID
#define INVALIDNODE -1
#define INVALID (-1)
@@ -76,16 +81,16 @@ namespace lemon {
class SparseValueVector
{
public:
- SparseValueVector(int n=0)
+ SparseValueVector(size_t n=0)
{
}
- void resize(int n=0){};
- T operator[](const int id) const
+ void resize(size_t n=0){};
+ T operator[](const size_t id) const
{
#ifdef HASHMAP
- typename stdext::hash_map<int,T>::const_iterator it = data.find(id);
+ typename stdext::hash_map<size_t,T>::const_iterator it = data.find(id);
#else
- typename std::map<int,T>::const_iterator it = data.find(id);
+ typename std::map<size_t,T>::const_iterator it = data.find(id);
#endif
if (it==data.end())
return 0;
@@ -93,16 +98,16 @@ namespace lemon {
return it->second;
}
- ProxyObject<T> operator[](const int id)
+ ProxyObject<T> operator[](const size_t id)
{
return ProxyObject<T>( this, id );
}
//private:
#ifdef HASHMAP
- stdext::hash_map<int,T> data;
+ stdext::hash_map<size_t,T> data;
#else
- std::map<int,T> data;
+ std::map<size_t,T> data;
#endif
};
@@ -110,7 +115,7 @@ namespace lemon {
template <typename T>
class ProxyObject {
public:
- ProxyObject( SparseValueVector<T> *v, int idx ){_v=v; _idx=idx;};
+ ProxyObject( SparseValueVector<T> *v, size_t idx ){_v=v; _idx=idx;};
ProxyObject<T> & operator=( const T &v ) {
// If we get here, we know that operator[] was called to perform a write access,
// so we can insert an item in the vector if needed
@@ -123,9 +128,9 @@ namespace lemon {
// If we get here, we know that operator[] was called to perform a read access,
// so we can simply return the existing object
#ifdef HASHMAP
- typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx);
+ typename stdext::hash_map<size_t,T>::iterator it = _v->data.find(_idx);
#else
- typename std::map<int,T>::iterator it = _v->data.find(_idx);
+ typename std::map<size_t,T>::iterator it = _v->data.find(_idx);
#endif
if (it==_v->data.end())
return 0;
@@ -137,9 +142,9 @@ namespace lemon {
{
if (val==0) return;
#ifdef HASHMAP
- typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx);
+ typename stdext::hash_map<size_t,T>::iterator it = _v->data.find(_idx);
#else
- typename std::map<int,T>::iterator it = _v->data.find(_idx);
+ typename std::map<size_t,T>::iterator it = _v->data.find(_idx);
#endif
if (it==_v->data.end())
_v->data[_idx] = val;
@@ -156,9 +161,9 @@ namespace lemon {
{
if (val==0) return;
#ifdef HASHMAP
- typename stdext::hash_map<int,T>::iterator it = _v->data.find(_idx);
+ typename stdext::hash_map<size_t,T>::iterator it = _v->data.find(_idx);
#else
- typename std::map<int,T>::iterator it = _v->data.find(_idx);
+ typename std::map<size_t,T>::iterator it = _v->data.find(_idx);
#endif
if (it==_v->data.end())
_v->data[_idx] = -val;
@@ -173,7 +178,7 @@ namespace lemon {
}
SparseValueVector<T> *_v;
- int _idx;
+ size_t _idx;
};
@@ -204,7 +209,7 @@ namespace lemon {
///
/// \tparam GR The digraph type the algorithm runs on.
/// \tparam V The number type used for flow amounts, capacity bounds
- /// and supply values in the algorithm. By default, it is \c int.
+ /// and supply values in the algorithm. By default, it is \c int64_t.
/// \tparam C The number type used for costs and potentials in the
/// algorithm. By default, it is the same as \c V.
///
@@ -214,7 +219,7 @@ namespace lemon {
/// \note %NetworkSimplexSimple provides five different pivot rule
/// implementations, from which the most efficient one is used
/// by default. For more information, see \ref PivotRule.
- template <typename GR, typename V = int, typename C = V, typename NodesType = unsigned short int>
+ template <typename GR, typename V = int, typename C = V, typename NodesType = unsigned short int, typename ArcsType = int64_t>
class NetworkSimplexSimple
{
public:
@@ -228,7 +233,7 @@ namespace lemon {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, long long nb_arcs,int maxiters) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -288,11 +293,11 @@ namespace lemon {
private:
- int max_iter;
+ size_t max_iter;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
typedef std::vector<int> IntVector;
- typedef std::vector<NodesType> UHalfIntVector;
+ typedef std::vector<ArcsType> ArcVector;
typedef std::vector<Value> ValueVector;
typedef std::vector<Cost> CostVector;
// typedef SparseValueVector<Cost> CostVector;
@@ -315,9 +320,9 @@ namespace lemon {
// Data related to the underlying digraph
const GR &_graph;
int _node_num;
- int _arc_num;
- int _all_arc_num;
- int _search_arc_num;
+ ArcsType _arc_num;
+ ArcsType _all_arc_num;
+ ArcsType _search_arc_num;
// Parameters of the problem
SupplyType _stype;
@@ -325,9 +330,9 @@ namespace lemon {
inline int _node_id(int n) const {return _node_num-n-1;} ;
- //IntArcMap _arc_id;
- UHalfIntVector _source;
- UHalfIntVector _target;
+// IntArcMap _arc_id;
+ IntVector _source; // keep nodes as integers
+ IntVector _target;
bool _arc_mixing;
public:
// Node and arc data
@@ -341,7 +346,7 @@ namespace lemon {
private:
// Data for storing the spanning tree structure
IntVector _parent;
- IntVector _pred;
+ ArcVector _pred;
IntVector _thread;
IntVector _rev_thread;
IntVector _succ_num;
@@ -349,17 +354,17 @@ namespace lemon {
IntVector _dirty_revs;
BoolVector _forward;
StateVector _state;
- int _root;
+ ArcsType _root;
// Temporary data used in the current pivot iteration
- int in_arc, join, u_in, v_in, u_out, v_out;
- int first, second, right, last;
- int stem, par_stem, new_stem;
+ ArcsType in_arc, join, u_in, v_in, u_out, v_out;
+ ArcsType first, second, right, last;
+ ArcsType stem, par_stem, new_stem;
Value delta;
const Value MAX;
- int mixingCoeff;
+ ArcsType mixingCoeff;
public:
@@ -373,27 +378,27 @@ namespace lemon {
private:
// thank you to DVK and MizardX from StackOverflow for this function!
- inline int sequence(int k) const {
- int smallv = (k > num_total_big_subsequence_numbers) & 1;
+ inline ArcsType sequence(ArcsType k) const {
+ ArcsType smallv = (k > num_total_big_subsequence_numbers) & 1;
k -= num_total_big_subsequence_numbers * smallv;
- int subsequence_length2 = subsequence_length- smallv;
- int subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv;
- int subsequence_offset = (k % subsequence_length2) * mixingCoeff;
+ ArcsType subsequence_length2 = subsequence_length- smallv;
+ ArcsType subsequence_num = (k / subsequence_length2) + num_big_subseqiences * smallv;
+ ArcsType subsequence_offset = (k % subsequence_length2) * mixingCoeff;
return subsequence_offset + subsequence_num;
}
- int subsequence_length;
- int num_big_subseqiences;
- int num_total_big_subsequence_numbers;
+ ArcsType subsequence_length;
+ ArcsType num_big_subseqiences;
+ ArcsType num_total_big_subsequence_numbers;
- inline int getArcID(const Arc &arc) const
+ inline ArcsType getArcID(const Arc &arc) const
{
//int n = _arc_num-arc._id-1;
- int n = _arc_num-GR::id(arc)-1;
+ ArcsType n = _arc_num-GR::id(arc)-1;
- //int a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
- //int b = _arc_id[arc];
+ //ArcsType a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
+ //ArcsType b = _arc_id[arc];
if (_arc_mixing)
return sequence(n);
else
@@ -401,16 +406,16 @@ namespace lemon {
}
// finally unused because too slow
- inline int getSource(const int arc) const
+ inline ArcsType getSource(const ArcsType arc) const
{
- //int a = _source[arc];
+ //ArcsType a = _source[arc];
//return a;
- int n = _arc_num-arc-1;
+ ArcsType n = _arc_num-arc-1;
if (_arc_mixing)
n = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
- int b;
+ ArcsType b;
if (n>=0)
b = _node_id(_graph.source(GR::arcFromId( n ) ));
else
@@ -436,17 +441,17 @@ namespace lemon {
private:
// References to the NetworkSimplexSimple class
- const UHalfIntVector &_source;
- const UHalfIntVector &_target;
+ const IntVector &_source;
+ const IntVector &_target;
const CostVector &_cost;
const StateVector &_state;
const CostVector &_pi;
- int &_in_arc;
- int _search_arc_num;
+ ArcsType &_in_arc;
+ ArcsType _search_arc_num;
// Pivot rule data
- int _block_size;
- int _next_arc;
+ ArcsType _block_size;
+ ArcsType _next_arc;
NetworkSimplexSimple &_ns;
public:
@@ -460,17 +465,16 @@ namespace lemon {
{
// The main parameters of the pivot rule
const double BLOCK_SIZE_FACTOR = 1.0;
- const int MIN_BLOCK_SIZE = 10;
+ const ArcsType MIN_BLOCK_SIZE = 10;
- _block_size = std::max( int(BLOCK_SIZE_FACTOR *
- std::sqrt(double(_search_arc_num))),
- MIN_BLOCK_SIZE );
+ _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE);
}
+
// Find next entering arc
bool findEnteringArc() {
Cost c, min = 0;
- int e;
- int cnt = _block_size;
+ ArcsType e;
+ ArcsType cnt = _block_size;
double a;
for (e = _next_arc; e != _search_arc_num; ++e) {
c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
@@ -516,7 +520,7 @@ namespace lemon {
int _init_nb_nodes;
- long long _init_nb_arcs;
+ ArcsType _init_nb_arcs;
/// \name Parameters
/// The parameters of the algorithm can be specified using these
@@ -736,7 +740,7 @@ namespace lemon {
for (int i = 0; i != _node_num; ++i) {
_supply[i] = 0;
}
- for (int i = 0; i != _arc_num; ++i) {
+ for (ArcsType i = 0; i != _arc_num; ++i) {
_cost[i] = 1;
}
_stype = GEQ;
@@ -745,7 +749,7 @@ namespace lemon {
- int divid (int x, int y)
+ int64_t divid (int64_t x, int64_t y)
{
return (x-x%y)/y;
}
@@ -775,7 +779,7 @@ namespace lemon {
_node_num = _init_nb_nodes;
_arc_num = _init_nb_arcs;
int all_node_num = _node_num + 1;
- int max_arc_num = _arc_num + 2 * _node_num;
+ ArcsType max_arc_num = _arc_num + 2 * _node_num;
_source.resize(max_arc_num);
_target.resize(max_arc_num);
@@ -798,13 +802,13 @@ namespace lemon {
//_arc_mixing=false;
if (_arc_mixing) {
// Store the arcs in a mixed order
- int k = std::max(int(std::sqrt(double(_arc_num))), 10);
+ const ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10));
mixingCoeff = k;
subsequence_length = _arc_num / mixingCoeff + 1;
num_big_subseqiences = _arc_num % mixingCoeff;
num_total_big_subsequence_numbers = subsequence_length * num_big_subseqiences;
- int i = 0, j = 0;
+ ArcsType i = 0, j = 0;
Arc a; _graph.first(a);
for (; a != INVALID; _graph.next(a)) {
_source[i] = _node_id(_graph.source(a));
@@ -814,7 +818,7 @@ namespace lemon {
}
} else {
// Store the arcs in the original order
- int i = 0;
+ ArcsType i = 0;
Arc a; _graph.first(a);
for (; a != INVALID; _graph.next(a), ++i) {
_source[i] = _node_id(_graph.source(a));
@@ -856,7 +860,7 @@ namespace lemon {
Number totalCost() const {
Number c = 0;
for (ArcIt a(_graph); a != INVALID; ++a) {
- int i = getArcID(a);
+ int64_t i = getArcID(a);
c += Number(_flow[i]) * Number(_cost[i]);
}
return c;
@@ -867,15 +871,15 @@ namespace lemon {
Number c = 0;
/*#ifdef HASHMAP
- typename stdext::hash_map<int, Value>::const_iterator it;
+ typename stdext::hash_map<int64_t, Value>::const_iterator it;
#else
- typename std::map<int, Value>::const_iterator it;
+ typename std::map<int64_t, Value>::const_iterator it;
#endif
for (it = _flow.data.begin(); it!=_flow.data.end(); ++it)
c += Number(it->second) * Number(_cost[it->first]);
return c;*/
- for (unsigned long i=0; i<_flow.size(); i++)
+ for (ArcsType i=0; i<_flow.size(); i++)
c += _flow[i] * Number(_cost[i]);
return c;
@@ -944,14 +948,14 @@ namespace lemon {
// Initialize internal data structures
bool init() {
if (_node_num == 0) return false;
-
+
// Check the sum of supply values
_sum_supply = 0;
for (int i = 0; i != _node_num; ++i) {
_sum_supply += _supply[i];
}
if ( fabs(_sum_supply) > _EPSILON ) return false;
-
+
_sum_supply = 0;
// Initialize artifical cost
@@ -960,14 +964,14 @@ namespace lemon {
ART_COST = std::numeric_limits<Cost>::max() / 2 + 1;
} else {
ART_COST = 0;
- for (int i = 0; i != _arc_num; ++i) {
+ for (ArcsType i = 0; i != _arc_num; ++i) {
if (_cost[i] > ART_COST) ART_COST = _cost[i];
}
ART_COST = (ART_COST + 1) * _node_num;
}
// Initialize arc maps
- for (int i = 0; i != _arc_num; ++i) {
+ for (ArcsType i = 0; i != _arc_num; ++i) {
//_flow[i] = 0; //by default, the sparse matrix is empty
_state[i] = STATE_LOWER;
}
@@ -988,7 +992,7 @@ namespace lemon {
// EQ supply constraints
_search_arc_num = _arc_num;
_all_arc_num = _arc_num + _node_num;
- for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
_parent[u] = _root;
_pred[u] = e;
_thread[u] = u + 1;
@@ -1016,8 +1020,8 @@ namespace lemon {
else if (_sum_supply > 0) {
// LEQ supply constraints
_search_arc_num = _arc_num + _node_num;
- int f = _arc_num + _node_num;
- for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ ArcsType f = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
_parent[u] = _root;
_thread[u] = u + 1;
_rev_thread[u + 1] = u;
@@ -1054,8 +1058,8 @@ namespace lemon {
else {
// GEQ supply constraints
_search_arc_num = _arc_num + _node_num;
- int f = _arc_num + _node_num;
- for (int u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ ArcsType f = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
_parent[u] = _root;
_thread[u] = u + 1;
_rev_thread[u + 1] = u;
@@ -1120,9 +1124,9 @@ namespace lemon {
second = _source[in_arc];
}
delta = INF;
- int result = 0;
+ char result = 0;
Value d;
- int e;
+ ArcsType e;
// Search the cycle along the path form the first node to the root
for (int u = first; u != join; u = _parent[u]) {
@@ -1239,7 +1243,7 @@ namespace lemon {
// Update _rev_thread using the new _thread values
for (int i = 0; i != int(_dirty_revs.size()); ++i) {
- u = _dirty_revs[i];
+ int u = _dirty_revs[i];
_rev_thread[_thread[u]] = u;
}
@@ -1257,7 +1261,7 @@ namespace lemon {
u = w;
}
_pred[u_in] = in_arc;
- _forward[u_in] = ((unsigned int)u_in == _source[in_arc]);
+ _forward[u_in] = (u_in == _source[in_arc]);
_succ_num[u_in] = old_succ_num;
// Set limits for updating _last_succ form v_in and v_out
@@ -1328,7 +1332,7 @@ namespace lemon {
if (_sum_supply > 0) total -= _sum_supply;
if (total <= 0) return true;
- IntVector arc_vector;
+ ArcVector arc_vector;
if (_sum_supply >= 0) {
if (supply_nodes.size() == 1 && demand_nodes.size() == 1) {
// Perform a reverse graph search from the sink to the source
@@ -1345,7 +1349,7 @@ namespace lemon {
Arc a; _graph.firstIn(a, v);
for (; a != INVALID; _graph.nextIn(a)) {
if (reached[u = _graph.source(a)]) continue;
- int j = getArcID(a);
+ ArcsType j = getArcID(a);
if (INF >= total) {
arc_vector.push_back(j);
reached[u] = true;
@@ -1355,7 +1359,7 @@ namespace lemon {
}
} else {
// Find the min. cost incomming arc for each demand node
- for (int i = 0; i != int(demand_nodes.size()); ++i) {
+ for (int i = 0; i != demand_nodes.size(); ++i) {
Node v = demand_nodes[i];
Cost c, min_cost = std::numeric_limits<Cost>::max();
Arc min_arc = INVALID;
@@ -1393,7 +1397,7 @@ namespace lemon {
}
// Perform heuristic initial pivots
- for (int i = 0; i != int(arc_vector.size()); ++i) {
+ for (ArcsType i = 0; i != arc_vector.size(); ++i) {
in_arc = arc_vector[i];
// l'erreur est probablement ici...
if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -
@@ -1423,7 +1427,7 @@ namespace lemon {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
- int iter_number=0;
+ size_t iter_number=0;
//pivot.setDantzig(true);
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
@@ -1443,7 +1447,7 @@ namespace lemon {
double a;
a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
- for (int i=0; i<_flow.size(); i++) {
+ for (int64_t i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
}
std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
@@ -1482,12 +1486,12 @@ namespace lemon {
double a;
a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
- for (int i=0; i<_flow.size(); i++) {
+ for (int64_t i=0; i<_flow.size(); i++) {
sumFlow+=_state[i]*_flow[i];
}
-
+
std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
-
+
std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
@@ -1505,9 +1509,9 @@ namespace lemon {
#endif
// Check feasibility
if( retVal == OPTIMAL){
- for (int e = _search_arc_num; e != _all_arc_num; ++e) {
+ for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) {
if (_flow[e] != 0){
- if (abs(_flow[e]) > EPSILON)
+ if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126
return INFEASIBLE;
else
_flow[e]=0;
@@ -1521,20 +1525,20 @@ namespace lemon {
if (_sum_supply == 0) {
if (_stype == GEQ) {
Cost max_pot = -std::numeric_limits<Cost>::max();
- for (int i = 0; i != _node_num; ++i) {
+ for (ArcsType i = 0; i != _node_num; ++i) {
if (_pi[i] > max_pot) max_pot = _pi[i];
}
if (max_pot > 0) {
- for (int i = 0; i != _node_num; ++i)
+ for (ArcsType i = 0; i != _node_num; ++i)
_pi[i] -= max_pot;
}
} else {
Cost min_pot = std::numeric_limits<Cost>::max();
- for (int i = 0; i != _node_num; ++i) {
+ for (ArcsType i = 0; i != _node_num; ++i) {
if (_pi[i] < min_pot) min_pot = _pi[i];
}
if (min_pot < 0) {
- for (int i = 0; i != _node_num; ++i)
+ for (ArcsType i = 0; i != _node_num; ++i)
_pi[i] -= min_pot;
}
}
@@ -1548,5 +1552,3 @@ namespace lemon {
///@}
} //namespace lemon
-
-#endif //LEMON_NETWORK_SIMPLEX_H
diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h
new file mode 100644
index 0000000..87e4c05
--- /dev/null
+++ b/ot/lp/network_simplex_simple_omp.h
@@ -0,0 +1,1699 @@
+/* -*- mode: C++; indent-tabs-mode: nil; -*-
+*
+*
+* This file has been adapted by Nicolas Bonneel (2013),
+* from network_simplex.h from LEMON, a generic C++ optimization library,
+* to implement a lightweight network simplex for mass transport, more
+* memory efficient than the original file. A previous version of this file
+* is used as part of the Displacement Interpolation project,
+* Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/
+*
+* Revisions:
+* March 2015: added OpenMP parallelization
+* March 2017: included Antoine Rolet's trick to make it more robust
+* April 2018: IMPORTANT bug fix + uses 64bit integers (slightly slower but less risks of overflows), updated to a newer version of the algo by LEMON, sparse flow by default + minor edits.
+*
+*
+**** Original file Copyright Notice :
+*
+* Copyright (C) 2003-2010
+* Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport
+* (Egervary Research Group on Combinatorial Optimization, EGRES).
+*
+* Permission to use, modify and distribute this software is granted
+* provided that this copyright notice appears in all copies. For
+* precise terms see the accompanying LICENSE file.
+*
+* This software is provided "AS IS" with no warranty of any kind,
+* express or implied, and with no claim as to its suitability for any
+* purpose.
+*
+*/
+
+#pragma once
+#undef DEBUG_LVL
+#define DEBUG_LVL 0
+
+#if DEBUG_LVL>0
+#include <iomanip>
+#endif
+
+#undef EPSILON
+#undef _EPSILON
+#undef MAX_DEBUG_ITER
+#define EPSILON std::numeric_limits<Cost>::epsilon()*10
+#define _EPSILON 1e-8
+#define MAX_DEBUG_ITER 100000
+
+/// \ingroup min_cost_flow_algs
+///
+/// \file
+/// \brief Network Simplex algorithm for finding a minimum cost flow.
+
+// if your compiler has troubles with unorderedmaps, just comment the following line to use a slower std::map instead
+#define HASHMAP // now handled with unorderedmaps instead of stdext::hash_map. Should be better supported.
+
+#define SPARSE_FLOW // a sparse flow vector will be 10-15% slower for small problems but uses less memory and becomes faster for large problems (40k total nodes)
+
+#include <vector>
+#include <limits>
+#include <algorithm>
+#include <iostream>
+#ifdef HASHMAP
+#include <unordered_map>
+#else
+#include <map>
+#endif
+//#include "core.h"
+//#include "lmath.h"
+
+#ifdef OMP
+#include <omp.h>
+#endif
+#include <cmath>
+
+
+//#include "sparse_array_n.h"
+#include "full_bipartitegraph_omp.h"
+
+#undef INVALIDNODE
+#undef INVALID
+#define INVALIDNODE -1
+#define INVALID (-1)
+
+namespace lemon_omp {
+
+ int64_t max_threads = -1;
+
+ template <typename T>
+ class ProxyObject;
+
+ template<typename T>
+ class SparseValueVector
+ {
+ public:
+ SparseValueVector(size_t n = 0) // parameter n for compatibility with standard vectors
+ {
+ }
+ void resize(size_t n = 0) {};
+ T operator[](const size_t id) const
+ {
+#ifdef HASHMAP
+ typename std::unordered_map<size_t, T>::const_iterator it = data.find(id);
+#else
+ typename std::map<size_t, T>::const_iterator it = data.find(id);
+#endif
+ if (it == data.end())
+ return 0;
+ else
+ return it->second;
+ }
+
+ ProxyObject<T> operator[](const size_t id)
+ {
+ return ProxyObject<T>(this, id);
+ }
+
+ //private:
+#ifdef HASHMAP
+ std::unordered_map<size_t, T> data;
+#else
+ std::map<size_t, T> data;
+#endif
+
+ };
+
+ template <typename T>
+ class ProxyObject {
+ public:
+ ProxyObject(SparseValueVector<T> *v, size_t idx) { _v = v; _idx = idx; };
+ ProxyObject<T> & operator=(const T &v) {
+ // If we get here, we know that operator[] was called to perform a write access,
+ // so we can insert an item in the vector if needed
+ if (v != 0)
+ _v->data[_idx] = v;
+ return *this;
+ }
+
+ operator T() {
+ // If we get here, we know that operator[] was called to perform a read access,
+ // so we can simply return the existing object
+#ifdef HASHMAP
+ typename std::unordered_map<size_t, T>::iterator it = _v->data.find(_idx);
+#else
+ typename std::map<size_t, T>::iterator it = _v->data.find(_idx);
+#endif
+ if (it == _v->data.end())
+ return 0;
+ else
+ return it->second;
+ }
+
+ void operator+=(T val)
+ {
+ if (val == 0) return;
+#ifdef HASHMAP
+ typename std::unordered_map<size_t, T>::iterator it = _v->data.find(_idx);
+#else
+ typename std::map<size_t, T>::iterator it = _v->data.find(_idx);
+#endif
+ if (it == _v->data.end())
+ _v->data[_idx] = val;
+ else
+ {
+ T sum = it->second + val;
+ if (sum == 0)
+ _v->data.erase(it);
+ else
+ it->second = sum;
+ }
+ }
+ void operator-=(T val)
+ {
+ if (val == 0) return;
+#ifdef HASHMAP
+ typename std::unordered_map<size_t, T>::iterator it = _v->data.find(_idx);
+#else
+ typename std::map<size_t, T>::iterator it = _v->data.find(_idx);
+#endif
+ if (it == _v->data.end())
+ _v->data[_idx] = -val;
+ else
+ {
+ T sum = it->second - val;
+ if (sum == 0)
+ _v->data.erase(it);
+ else
+ it->second = sum;
+ }
+ }
+
+ SparseValueVector<T> *_v;
+ size_t _idx;
+ };
+
+
+
+ /// \addtogroup min_cost_flow_algs
+ /// @{
+
+ /// \brief Implementation of the primal Network Simplex algorithm
+ /// for finding a \ref min_cost_flow "minimum cost flow".
+ ///
+ /// \ref NetworkSimplexSimple implements the primal Network Simplex algorithm
+ /// for finding a \ref min_cost_flow "minimum cost flow"
+ /// \ref amo93networkflows, \ref dantzig63linearprog,
+ /// \ref kellyoneill91netsimplex.
+ /// This algorithm is a highly efficient specialized version of the
+ /// linear programming simplex method directly for the minimum cost
+ /// flow problem.
+ ///
+ /// In general, %NetworkSimplexSimple is the fastest implementation available
+ /// in LEMON for this problem.
+ /// Moreover, it supports both directions of the supply/demand inequality
+ /// constraints. For more information, see \ref SupplyType.
+ ///
+ /// Most of the parameters of the problem (except for the digraph)
+ /// can be given using separate functions, and the algorithm can be
+ /// executed using the \ref run() function. If some parameters are not
+ /// specified, then default values will be used.
+ ///
+ /// \tparam GR The digraph type the algorithm runs on.
+ /// \tparam V The number type used for flow amounts, capacity bounds
+ /// and supply values in the algorithm. By default, it is \c int.
+ /// \tparam C The number type used for costs and potentials in the
+ /// algorithm. By default, it is the same as \c V.
+ ///
+ /// \warning Both number types must be signed and all input data must
+ /// be integer.
+ ///
+ /// \note %NetworkSimplexSimple provides five different pivot rule
+ /// implementations, from which the most efficient one is used
+ /// by default. For more information, see \ref PivotRule.
+ template <typename GR, typename V = int, typename C = V, typename ArcsType = int64_t>
+ class NetworkSimplexSimple
+ {
+ public:
+
+ /// \brief Constructor.
+ ///
+ /// The constructor of the class.
+ ///
+ /// \param graph The digraph the algorithm runs on.
+ /// \param arc_mixing Indicate if the arcs have to be stored in a
+ /// mixed order in the internal data structure.
+ /// In special cases, it could lead to better overall performance,
+ /// but it is usually slower. Therefore it is disabled by default.
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) :
+ _graph(graph), //_arc_id(graph),
+ _arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
+ MAX(std::numeric_limits<Value>::max()),
+ INF(std::numeric_limits<Value>::has_infinity ?
+ std::numeric_limits<Value>::infinity() : MAX)
+ {
+ // Reset data structures
+ reset();
+ max_iter = maxiters;
+#ifdef OMP
+ if (max_threads < 0) {
+ max_threads = omp_get_max_threads();
+ }
+ if (numThreads > 0 && numThreads<=max_threads){
+ num_threads = numThreads;
+ } else if (numThreads == -1 || numThreads>max_threads) {
+ num_threads = max_threads;
+ } else {
+ num_threads = 1;
+ }
+ omp_set_num_threads(num_threads);
+#else
+ num_threads = 1;
+#endif
+ }
+
+ /// The type of the flow amounts, capacity bounds and supply values
+ typedef V Value;
+ /// The type of the arc costs
+ typedef C Cost;
+
+ public:
+ /// \brief Problem type constants for the \c run() function.
+ ///
+ /// Enum type containing the problem type constants that can be
+ /// returned by the \ref run() function of the algorithm.
+ enum ProblemType {
+ /// The problem has no feasible solution (flow).
+ INFEASIBLE,
+ /// The problem has optimal solution (i.e. it is feasible and
+ /// bounded), and the algorithm has found optimal flow and node
+ /// potentials (primal and dual solutions).
+ OPTIMAL,
+ /// The objective function of the problem is unbounded, i.e.
+ /// there is a directed cycle having negative total cost and
+ /// infinite upper bound.
+ UNBOUNDED,
+ // The maximum number of iteration has been reached
+ MAX_ITER_REACHED
+ };
+
+ /// \brief Constants for selecting the type of the supply constraints.
+ ///
+ /// Enum type containing constants for selecting the supply type,
+ /// i.e. the direction of the inequalities in the supply/demand
+ /// constraints of the \ref min_cost_flow "minimum cost flow problem".
+ ///
+ /// The default supply type is \c GEQ, the \c LEQ type can be
+ /// selected using \ref supplyType().
+ /// The equality form is a special case of both supply types.
+ enum SupplyType {
+ /// This option means that there are <em>"greater or equal"</em>
+ /// supply/demand constraints in the definition of the problem.
+ GEQ,
+ /// This option means that there are <em>"less or equal"</em>
+ /// supply/demand constraints in the definition of the problem.
+ LEQ
+ };
+
+
+
+ private:
+ size_t max_iter;
+ int num_threads;
+ TEMPLATE_DIGRAPH_TYPEDEFS(GR);
+
+ typedef std::vector<int> IntVector;
+ typedef std::vector<ArcsType> ArcVector;
+ typedef std::vector<Value> ValueVector;
+ typedef std::vector<Cost> CostVector;
+ // typedef SparseValueVector<Cost> CostVector;
+ typedef std::vector<char> BoolVector;
+ // Note: vector<char> is used instead of vector<bool> for efficiency reasons
+
+ // State constants for arcs
+ enum ArcState {
+ STATE_UPPER = -1,
+ STATE_TREE = 0,
+ STATE_LOWER = 1
+ };
+
+ typedef std::vector<signed char> StateVector;
+ // Note: vector<signed char> is used instead of vector<ArcState> for
+ // efficiency reasons
+
+ private:
+
+ // Data related to the underlying digraph
+ const GR &_graph;
+ int _node_num;
+ ArcsType _arc_num;
+ ArcsType _all_arc_num;
+ ArcsType _search_arc_num;
+
+ // Parameters of the problem
+ SupplyType _stype;
+ Value _sum_supply;
+
+ inline int _node_id(int n) const { return _node_num - n - 1; };
+
+ //IntArcMap _arc_id;
+ IntVector _source; // keep nodes as integers
+ IntVector _target;
+ bool _arc_mixing;
+
+ // Node and arc data
+ CostVector _cost;
+ ValueVector _supply;
+#ifdef SPARSE_FLOW
+ SparseValueVector<Value> _flow;
+#else
+ ValueVector _flow;
+#endif
+
+ CostVector _pi;
+
+ // Data for storing the spanning tree structure
+ IntVector _parent;
+ ArcVector _pred;
+ IntVector _thread;
+ IntVector _rev_thread;
+ IntVector _succ_num;
+ IntVector _last_succ;
+ IntVector _dirty_revs;
+ BoolVector _forward;
+ StateVector _state;
+ ArcsType _root;
+
+ // Temporary data used in the current pivot iteration
+ ArcsType in_arc, join, u_in, v_in, u_out, v_out;
+ ArcsType first, second, right, last;
+ ArcsType stem, par_stem, new_stem;
+ Value delta;
+
+ const Value MAX;
+
+ ArcsType mixingCoeff;
+
+ public:
+
+ /// \brief Constant for infinite upper bounds (capacities).
+ ///
+ /// Constant for infinite upper bounds (capacities).
+ /// It is \c std::numeric_limits<Value>::infinity() if available,
+ /// \c std::numeric_limits<Value>::max() otherwise.
+ const Value INF;
+
+ private:
+
+ // thank you to DVK and MizardX from StackOverflow for this function!
+ inline ArcsType sequence(ArcsType k) const {
+ ArcsType smallv = (k > num_total_big_subsequence_numbers) & 1;
+
+ k -= num_total_big_subsequence_numbers * smallv;
+ ArcsType subsequence_length2 = subsequence_length - smallv;
+ ArcsType subsequence_num = (k / subsequence_length2) + num_big_subsequences * smallv;
+ ArcsType subsequence_offset = (k % subsequence_length2) * mixingCoeff;
+
+ return subsequence_offset + subsequence_num;
+ }
+ ArcsType subsequence_length;
+ ArcsType num_big_subsequences;
+ ArcsType num_total_big_subsequence_numbers;
+
+ inline ArcsType getArcID(const Arc &arc) const
+ {
+ //int n = _arc_num-arc._id-1;
+ ArcsType n = _arc_num - GR::id(arc) - 1;
+
+ //ArcsType a = mixingCoeff*(n%mixingCoeff) + n/mixingCoeff;
+ //ArcsType b = _arc_id[arc];
+ if (_arc_mixing)
+ return sequence(n);
+ else
+ return n;
+ }
+
+ // finally unused because too slow
+ inline ArcsType getSource(const ArcsType arc) const
+ {
+ //ArcsType a = _source[arc];
+ //return a;
+
+ ArcsType n = _arc_num - arc - 1;
+ if (_arc_mixing)
+ n = mixingCoeff*(n%mixingCoeff) + n / mixingCoeff;
+
+ ArcsType b;
+ if (n >= 0)
+ b = _node_id(_graph.source(GR::arcFromId(n)));
+ else
+ {
+ n = arc + 1 - _arc_num;
+ if (n <= _node_num)
+ b = _node_num;
+ else
+ if (n >= _graph._n1)
+ b = _graph._n1;
+ else
+ b = _graph._n1 - n;
+ }
+
+ return b;
+ }
+
+
+
+ // Implementation of the Block Search pivot rule
+ class BlockSearchPivotRule
+ {
+ private:
+
+ // References to the NetworkSimplexSimple class
+ const IntVector &_source;
+ const IntVector &_target;
+ const CostVector &_cost;
+ const StateVector &_state;
+ const CostVector &_pi;
+ ArcsType &_in_arc;
+ ArcsType _search_arc_num;
+
+ // Pivot rule data
+ ArcsType _block_size;
+ ArcsType _next_arc;
+ NetworkSimplexSimple &_ns;
+
+ public:
+
+ // Constructor
+ BlockSearchPivotRule(NetworkSimplexSimple &ns) :
+ _source(ns._source), _target(ns._target),
+ _cost(ns._cost), _state(ns._state), _pi(ns._pi),
+ _in_arc(ns.in_arc), _search_arc_num(ns._search_arc_num),
+ _next_arc(0), _ns(ns)
+ {
+ // The main parameters of the pivot rule
+ const double BLOCK_SIZE_FACTOR = 1;
+ const ArcsType MIN_BLOCK_SIZE = 10;
+
+ _block_size = std::max(ArcsType(BLOCK_SIZE_FACTOR * std::sqrt(double(_search_arc_num))), MIN_BLOCK_SIZE);
+ }
+
+ // Find next entering arc
+ bool findEnteringArc() {
+ Cost min_val = 0;
+
+ ArcsType N = _ns.num_threads;
+
+ std::vector<Cost> minArray(N, 0);
+ std::vector<ArcsType> arcId(N);
+ ArcsType bs = (ArcsType)ceil(_block_size / (double)N);
+
+ for (ArcsType i = 0; i < _search_arc_num; i += _block_size) {
+
+ ArcsType e;
+ int j;
+#pragma omp parallel
+ {
+#ifdef OMP
+ int t = omp_get_thread_num();
+#else
+ int t = 0;
+#endif
+
+#pragma omp for schedule(static, bs) lastprivate(e)
+ for (j = 0; j < std::min(i + _block_size, _search_arc_num) - i; j++) {
+ e = (_next_arc + i + j); if (e >= _search_arc_num) e -= _search_arc_num;
+ Cost c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
+ if (c < minArray[t]) {
+ minArray[t] = c;
+ arcId[t] = e;
+ }
+ }
+ }
+ for (int j = 0; j < N; j++) {
+ if (minArray[j] < min_val) {
+ min_val = minArray[j];
+ _in_arc = arcId[j];
+ }
+ }
+ Cost a = std::abs(_pi[_source[_in_arc]]) > std::abs(_pi[_target[_in_arc]]) ? std::abs(_pi[_source[_in_arc]]) : std::abs(_pi[_target[_in_arc]]);
+ a = a > std::abs(_cost[_in_arc]) ? a : std::abs(_cost[_in_arc]);
+ if (min_val < -EPSILON*a) {
+ _next_arc = e;
+ return true;
+ }
+ }
+
+ Cost a = fabs(_pi[_source[_in_arc]]) > fabs(_pi[_target[_in_arc]]) ? fabs(_pi[_source[_in_arc]]) : fabs(_pi[_target[_in_arc]]);
+ a = a > fabs(_cost[_in_arc]) ? a : fabs(_cost[_in_arc]);
+ if (min_val >= -EPSILON*a) return false;
+
+ return true;
+ }
+
+
+ // Find next entering arc
+ /*bool findEnteringArc() {
+ Cost min_val = 0;
+ int N = omp_get_max_threads();
+ std::vector<Cost> minArray(N);
+ std::vector<ArcsType> arcId(N);
+
+ ArcsType bs = (ArcsType)ceil(_block_size / (double)N);
+ for (ArcsType i = 0; i < _search_arc_num; i += _block_size) {
+
+ ArcsType maxJ = std::min(i + _block_size, _search_arc_num) - i;
+ ArcsType j;
+#pragma omp parallel
+ {
+ int t = omp_get_thread_num();
+ Cost minV = 0;
+ ArcsType arcStart = _next_arc + i;
+ ArcsType arc = -1;
+#pragma omp for schedule(static, bs)
+ for (j = 0; j < maxJ; j++) {
+ ArcsType e = arcStart + j; if (e >= _search_arc_num) e -= _search_arc_num;
+ Cost c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
+ if (c < minV) {
+ minV = c;
+ arc = e;
+ }
+ }
+
+ minArray[t] = minV;
+ arcId[t] = arc;
+ }
+ for (int j = 0; j < N; j++) {
+ if (minArray[j] < min_val) {
+ min_val = minArray[j];
+ _in_arc = arcId[j];
+ }
+ }
+
+ //FIX by Antoine Rolet to avoid precision issues
+ Cost a = std::max(std::abs(_cost[_in_arc]), std::max(std::abs(_pi[_source[_in_arc]]), std::abs(_pi[_target[_in_arc]])));
+ if (min_val <-std::numeric_limits<Cost>::epsilon()*a) {
+ _next_arc = _next_arc + i + maxJ - 1;
+ if (_next_arc >= _search_arc_num) _next_arc -= _search_arc_num;
+ return true;
+ }
+ }
+
+ if (min_val >= 0) {
+ return false;
+ }
+
+ return true;
+ }*/
+
+
+ /*bool findEnteringArc() {
+ Cost c, min = 0;
+ int cnt = _block_size;
+ int e, min_arc = _next_arc;
+ for (e = _next_arc; e < _search_arc_num; ++e) {
+ c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
+ if (c < min) {
+ min = c;
+ min_arc = e;
+
+ }
+ if (--cnt == 0) {
+ if (min < 0) break;
+ cnt = _block_size;
+
+ }
+
+ }
+ if (min == 0 || cnt > 0) {
+ for (e = 0; e < _next_arc; ++e) {
+ c = _state[e] * (_cost[e] + _pi[_source[e]] - _pi[_target[e]]);
+ if (c < min) {
+ min = c;
+ min_arc = e;
+
+ }
+ if (--cnt == 0) {
+ if (min < 0) break;
+ cnt = _block_size;
+
+ }
+
+ }
+
+ }
+ if (min >= 0) return false;
+ _in_arc = min_arc;
+ _next_arc = e;
+ return true;
+ }*/
+
+
+
+ }; //class BlockSearchPivotRule
+
+
+
+ public:
+
+
+
+ int _init_nb_nodes;
+ ArcsType _init_nb_arcs;
+
+ /// \name Parameters
+ /// The parameters of the algorithm can be specified using these
+ /// functions.
+
+ /// @{
+
+
+ /// \brief Set the costs of the arcs.
+ ///
+ /// This function sets the costs of the arcs.
+ /// If it is not used before calling \ref run(), the costs
+ /// will be set to \c 1 on all arcs.
+ ///
+ /// \param map An arc map storing the costs.
+ /// Its \c Value type must be convertible to the \c Cost type
+ /// of the algorithm.
+ ///
+ /// \return <tt>(*this)</tt>
+ template<typename CostMap>
+ NetworkSimplexSimple& costMap(const CostMap& map) {
+ Arc a; _graph.first(a);
+ for (; a != INVALID; _graph.next(a)) {
+ _cost[getArcID(a)] = map[a];
+ }
+ return *this;
+ }
+
+
+ /// \brief Set the costs of one arc.
+ ///
+ /// This function sets the costs of one arcs.
+ /// Done for memory reasons
+ ///
+ /// \param arc An arc.
+ /// \param arc A cost
+ ///
+ /// \return <tt>(*this)</tt>
+ template<typename Value>
+ NetworkSimplexSimple& setCost(const Arc& arc, const Value cost) {
+ _cost[getArcID(arc)] = cost;
+ return *this;
+ }
+
+
+ /// \brief Set the supply values of the nodes.
+ ///
+ /// This function sets the supply values of the nodes.
+ /// If neither this function nor \ref stSupply() is used before
+ /// calling \ref run(), the supply of each node will be set to zero.
+ ///
+ /// \param map A node map storing the supply values.
+ /// Its \c Value type must be convertible to the \c Value type
+ /// of the algorithm.
+ ///
+ /// \return <tt>(*this)</tt>
+ template<typename SupplyMap>
+ NetworkSimplexSimple& supplyMap(const SupplyMap& map) {
+ Node n; _graph.first(n);
+ for (; n != INVALIDNODE; _graph.next(n)) {
+ _supply[_node_id(n)] = map[n];
+ }
+ return *this;
+ }
+ template<typename SupplyMap>
+ NetworkSimplexSimple& supplyMap(const SupplyMap* map1, int n1, const SupplyMap* map2, int n2) {
+ Node n; _graph.first(n);
+ for (; n != INVALIDNODE; _graph.next(n)) {
+ if (n<n1)
+ _supply[_node_id(n)] = map1[n];
+ else
+ _supply[_node_id(n)] = map2[n - n1];
+ }
+ return *this;
+ }
+ template<typename SupplyMap>
+ NetworkSimplexSimple& supplyMapAll(SupplyMap val1, int n1, SupplyMap val2, int n2) {
+ Node n; _graph.first(n);
+ for (; n != INVALIDNODE; _graph.next(n)) {
+ if (n<n1)
+ _supply[_node_id(n)] = val1;
+ else
+ _supply[_node_id(n)] = val2;
+ }
+ return *this;
+ }
+
+ /// \brief Set single source and target nodes and a supply value.
+ ///
+ /// This function sets a single source node and a single target node
+ /// and the required flow value.
+ /// If neither this function nor \ref supplyMap() is used before
+ /// calling \ref run(), the supply of each node will be set to zero.
+ ///
+ /// Using this function has the same effect as using \ref supplyMap()
+ /// with such a map in which \c k is assigned to \c s, \c -k is
+ /// assigned to \c t and all other nodes have zero supply value.
+ ///
+ /// \param s The source node.
+ /// \param t The target node.
+ /// \param k The required amount of flow from node \c s to node \c t
+ /// (i.e. the supply of \c s and the demand of \c t).
+ ///
+ /// \return <tt>(*this)</tt>
+ NetworkSimplexSimple& stSupply(const Node& s, const Node& t, Value k) {
+ for (int i = 0; i != _node_num; ++i) {
+ _supply[i] = 0;
+ }
+ _supply[_node_id(s)] = k;
+ _supply[_node_id(t)] = -k;
+ return *this;
+ }
+
+ /// \brief Set the type of the supply constraints.
+ ///
+ /// This function sets the type of the supply/demand constraints.
+ /// If it is not used before calling \ref run(), the \ref GEQ supply
+ /// type will be used.
+ ///
+ /// For more information, see \ref SupplyType.
+ ///
+ /// \return <tt>(*this)</tt>
+ NetworkSimplexSimple& supplyType(SupplyType supply_type) {
+ _stype = supply_type;
+ return *this;
+ }
+
+ /// @}
+
+ /// \name Execution Control
+ /// The algorithm can be executed using \ref run().
+
+ /// @{
+
+ /// \brief Run the algorithm.
+ ///
+ /// This function runs the algorithm.
+ /// The paramters can be specified using functions \ref lowerMap(),
+ /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(),
+ /// \ref supplyType().
+ /// For example,
+ /// \code
+ /// NetworkSimplexSimple<ListDigraph> ns(graph);
+ /// ns.lowerMap(lower).upperMap(upper).costMap(cost)
+ /// .supplyMap(sup).run();
+ /// \endcode
+ ///
+ /// This function can be called more than once. All the given parameters
+ /// are kept for the next call, unless \ref resetParams() or \ref reset()
+ /// is used, thus only the modified parameters have to be set again.
+ /// If the underlying digraph was also modified after the construction
+ /// of the class (or the last \ref reset() call), then the \ref reset()
+ /// function must be called.
+ ///
+ /// \param pivot_rule The pivot rule that will be used during the
+ /// algorithm. For more information, see \ref PivotRule.
+ ///
+ /// \return \c INFEASIBLE if no feasible flow exists,
+ /// \n \c OPTIMAL if the problem has optimal solution
+ /// (i.e. it is feasible and bounded), and the algorithm has found
+ /// optimal flow and node potentials (primal and dual solutions),
+ /// \n \c UNBOUNDED if the objective function of the problem is
+ /// unbounded, i.e. there is a directed cycle having negative total
+ /// cost and infinite upper bound.
+ ///
+ /// \see ProblemType, PivotRule
+ /// \see resetParams(), reset()
+ ProblemType run() {
+#if DEBUG_LVL>0
+ std::cout << "OPTIMAL = " << OPTIMAL << "\nINFEASIBLE = " << INFEASIBLE << "\nUNBOUNDED = " << UNBOUNDED << "\nMAX_ITER_REACHED" << MAX_ITER_REACHED << "\n" ;
+#endif
+ if (!init()) return INFEASIBLE;
+#if DEBUG_LVL>0
+ std::cout << "Init done, starting iterations\n";
+#endif
+
+ return start();
+ }
+
+ /// \brief Reset all the parameters that have been given before.
+ ///
+ /// This function resets all the paramaters that have been given
+ /// before using functions \ref lowerMap(), \ref upperMap(),
+ /// \ref costMap(), \ref supplyMap(), \ref stSupply(), \ref supplyType().
+ ///
+ /// It is useful for multiple \ref run() calls. Basically, all the given
+ /// parameters are kept for the next \ref run() call, unless
+ /// \ref resetParams() or \ref reset() is used.
+ /// If the underlying digraph was also modified after the construction
+ /// of the class or the last \ref reset() call, then the \ref reset()
+ /// function must be used, otherwise \ref resetParams() is sufficient.
+ ///
+ /// For example,
+ /// \code
+ /// NetworkSimplexSimple<ListDigraph> ns(graph);
+ ///
+ /// // First run
+ /// ns.lowerMap(lower).upperMap(upper).costMap(cost)
+ /// .supplyMap(sup).run();
+ ///
+ /// // Run again with modified cost map (resetParams() is not called,
+ /// // so only the cost map have to be set again)
+ /// cost[e] += 100;
+ /// ns.costMap(cost).run();
+ ///
+ /// // Run again from scratch using resetParams()
+ /// // (the lower bounds will be set to zero on all arcs)
+ /// ns.resetParams();
+ /// ns.upperMap(capacity).costMap(cost)
+ /// .supplyMap(sup).run();
+ /// \endcode
+ ///
+ /// \return <tt>(*this)</tt>
+ ///
+ /// \see reset(), run()
+ NetworkSimplexSimple& resetParams() {
+ for (int i = 0; i != _node_num; ++i) {
+ _supply[i] = 0;
+ }
+ for (ArcsType i = 0; i != _arc_num; ++i) {
+ _cost[i] = 1;
+ }
+ _stype = GEQ;
+ return *this;
+ }
+
+
+ /// \brief Reset the internal data structures and all the parameters
+ /// that have been given before.
+ ///
+ /// This function resets the internal data structures and all the
+ /// paramaters that have been given before using functions \ref lowerMap(),
+ /// \ref upperMap(), \ref costMap(), \ref supplyMap(), \ref stSupply(),
+ /// \ref supplyType().
+ ///
+ /// It is useful for multiple \ref run() calls. Basically, all the given
+ /// parameters are kept for the next \ref run() call, unless
+ /// \ref resetParams() or \ref reset() is used.
+ /// If the underlying digraph was also modified after the construction
+ /// of the class or the last \ref reset() call, then the \ref reset()
+ /// function must be used, otherwise \ref resetParams() is sufficient.
+ ///
+ /// See \ref resetParams() for examples.
+ ///
+ /// \return <tt>(*this)</tt>
+ ///
+ /// \see resetParams(), run()
+ NetworkSimplexSimple& reset() {
+ // Resize vectors
+ _node_num = _init_nb_nodes;
+ _arc_num = _init_nb_arcs;
+ int all_node_num = _node_num + 1;
+ ArcsType max_arc_num = _arc_num + 2 * _node_num;
+
+ _source.resize(max_arc_num);
+ _target.resize(max_arc_num);
+
+ _cost.resize(max_arc_num);
+ _supply.resize(all_node_num);
+ _flow.resize(max_arc_num);
+ _pi.resize(all_node_num);
+
+ _parent.resize(all_node_num);
+ _pred.resize(all_node_num);
+ _forward.resize(all_node_num);
+ _thread.resize(all_node_num);
+ _rev_thread.resize(all_node_num);
+ _succ_num.resize(all_node_num);
+ _last_succ.resize(all_node_num);
+ _state.resize(max_arc_num);
+
+
+ //_arc_mixing=false;
+ if (_arc_mixing && _node_num > 1) {
+ // Store the arcs in a mixed order
+ //ArcsType k = std::max(ArcsType(std::sqrt(double(_arc_num))), ArcsType(10));
+ const ArcsType k = std::max(ArcsType(_arc_num / _node_num), ArcsType(3));
+ mixingCoeff = k;
+ subsequence_length = _arc_num / mixingCoeff + 1;
+ num_big_subsequences = _arc_num % mixingCoeff;
+ num_total_big_subsequence_numbers = subsequence_length * num_big_subsequences;
+
+#pragma omp parallel for schedule(static)
+ for (Arc a = 0; a <= _graph.maxArcId(); a++) { // --a <=> _graph.next(a) , -1 == INVALID
+ ArcsType i = sequence(_graph.maxArcId()-a);
+ _source[i] = _node_id(_graph.source(a));
+ _target[i] = _node_id(_graph.target(a));
+ }
+ } else {
+ // Store the arcs in the original order
+ ArcsType i = 0;
+ Arc a; _graph.first(a);
+ for (; a != INVALID; _graph.next(a), ++i) {
+ _source[i] = _node_id(_graph.source(a));
+ _target[i] = _node_id(_graph.target(a));
+ //_arc_id[a] = i;
+ }
+ }
+
+ // Reset parameters
+ resetParams();
+ return *this;
+ }
+
+ /// @}
+
+ /// \name Query Functions
+ /// The results of the algorithm can be obtained using these
+ /// functions.\n
+ /// The \ref run() function must be called before using them.
+
+ /// @{
+
+ /// \brief Return the total cost of the found flow.
+ ///
+ /// This function returns the total cost of the found flow.
+ /// Its complexity is O(e).
+ ///
+ /// \note The return type of the function can be specified as a
+ /// template parameter. For example,
+ /// \code
+ /// ns.totalCost<double>();
+ /// \endcode
+ /// It is useful if the total cost cannot be stored in the \c Cost
+ /// type of the algorithm, which is the default return type of the
+ /// function.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ /*template <typename Number>
+ Number totalCost() const {
+ Number c = 0;
+ for (ArcIt a(_graph); a != INVALID; ++a) {
+ int i = getArcID(a);
+ c += Number(_flow[i]) * Number(_cost[i]);
+ }
+ return c;
+ }*/
+
+ template <typename Number>
+ Number totalCost() const {
+ Number c = 0;
+
+#ifdef SPARSE_FLOW
+ #ifdef HASHMAP
+ typename std::unordered_map<size_t, Value>::const_iterator it;
+ #else
+ typename std::map<size_t, Value>::const_iterator it;
+ #endif
+ for (it = _flow.data.begin(); it!=_flow.data.end(); ++it)
+ c += Number(it->second) * Number(_cost[it->first]);
+ return c;
+#else
+ for (ArcsType i = 0; i<_flow.size(); i++)
+ c += _flow[i] * Number(_cost[i]);
+ return c;
+#endif
+ }
+
+#ifndef DOXYGEN
+ Cost totalCost() const {
+ return totalCost<Cost>();
+ }
+#endif
+
+ /// \brief Return the flow on the given arc.
+ ///
+ /// This function returns the flow on the given arc.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ Value flow(const Arc& a) const {
+ return _flow[getArcID(a)];
+ }
+
+ /// \brief Return the flow map (the primal solution).
+ ///
+ /// This function copies the flow value on each arc into the given
+ /// map. The \c Value type of the algorithm must be convertible to
+ /// the \c Value type of the map.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ template <typename FlowMap>
+ void flowMap(FlowMap &map) const {
+ Arc a; _graph.first(a);
+ for (; a != INVALID; _graph.next(a)) {
+ map.set(a, _flow[getArcID(a)]);
+ }
+ }
+
+ /// \brief Return the potential (dual value) of the given node.
+ ///
+ /// This function returns the potential (dual value) of the
+ /// given node.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ Cost potential(const Node& n) const {
+ return _pi[_node_id(n)];
+ }
+
+ /// \brief Return the potential map (the dual solution).
+ ///
+ /// This function copies the potential (dual value) of each node
+ /// into the given map.
+ /// The \c Cost type of the algorithm must be convertible to the
+ /// \c Value type of the map.
+ ///
+ /// \pre \ref run() must be called before using this function.
+ template <typename PotentialMap>
+ void potentialMap(PotentialMap &map) const {
+ Node n; _graph.first(n);
+ for (; n != INVALID; _graph.next(n)) {
+ map.set(n, _pi[_node_id(n)]);
+ }
+ }
+
+ /// @}
+
+ private:
+
+ // Initialize internal data structures
+ bool init() {
+ if (_node_num == 0) return false;
+
+ // Check the sum of supply values
+ _sum_supply = 0;
+ for (int i = 0; i != _node_num; ++i) {
+ _sum_supply += _supply[i];
+ }
+ /*if (!((_stype == GEQ && _sum_supply <= 0) ||
+ (_stype == LEQ && _sum_supply >= 0))) return false;*/
+
+
+ // Initialize artifical cost
+ Cost ART_COST;
+ if (std::numeric_limits<Cost>::is_exact) {
+ ART_COST = std::numeric_limits<Cost>::max() / 2 + 1;
+ } else {
+ ART_COST = 0;
+ for (ArcsType i = 0; i != _arc_num; ++i) {
+ if (_cost[i] > ART_COST) ART_COST = _cost[i];
+ }
+ ART_COST = (ART_COST + 1) * _node_num;
+ }
+
+ // Initialize arc maps
+ for (ArcsType i = 0; i != _arc_num; ++i) {
+#ifndef SPARSE_FLOW
+ _flow[i] = 0; //by default, the sparse matrix is empty
+#endif
+ _state[i] = STATE_LOWER;
+ }
+#ifdef SPARSE_FLOW
+ _flow = SparseValueVector<Value>();
+#endif
+
+ // Set data for the artificial root node
+ _root = _node_num;
+ _parent[_root] = -1;
+ _pred[_root] = -1;
+ _thread[_root] = 0;
+ _rev_thread[0] = _root;
+ _succ_num[_root] = _node_num + 1;
+ _last_succ[_root] = _root - 1;
+ _supply[_root] = -_sum_supply;
+ _pi[_root] = 0;
+
+ // Add artificial arcs and initialize the spanning tree data structure
+ if (_sum_supply == 0) {
+ // EQ supply constraints
+ _search_arc_num = _arc_num;
+ _all_arc_num = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ _parent[u] = _root;
+ _pred[u] = e;
+ _thread[u] = u + 1;
+ _rev_thread[u + 1] = u;
+ _succ_num[u] = 1;
+ _last_succ[u] = u;
+ _state[e] = STATE_TREE;
+ if (_supply[u] >= 0) {
+ _forward[u] = true;
+ _pi[u] = 0;
+ _source[e] = u;
+ _target[e] = _root;
+ _flow[e] = _supply[u];
+ _cost[e] = 0;
+ } else {
+ _forward[u] = false;
+ _pi[u] = ART_COST;
+ _source[e] = _root;
+ _target[e] = u;
+ _flow[e] = -_supply[u];
+ _cost[e] = ART_COST;
+ }
+ }
+ } else if (_sum_supply > 0) {
+ // LEQ supply constraints
+ _search_arc_num = _arc_num + _node_num;
+ ArcsType f = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ _parent[u] = _root;
+ _thread[u] = u + 1;
+ _rev_thread[u + 1] = u;
+ _succ_num[u] = 1;
+ _last_succ[u] = u;
+ if (_supply[u] >= 0) {
+ _forward[u] = true;
+ _pi[u] = 0;
+ _pred[u] = e;
+ _source[e] = u;
+ _target[e] = _root;
+ _flow[e] = _supply[u];
+ _cost[e] = 0;
+ _state[e] = STATE_TREE;
+ } else {
+ _forward[u] = false;
+ _pi[u] = ART_COST;
+ _pred[u] = f;
+ _source[f] = _root;
+ _target[f] = u;
+ _flow[f] = -_supply[u];
+ _cost[f] = ART_COST;
+ _state[f] = STATE_TREE;
+ _source[e] = u;
+ _target[e] = _root;
+ //_flow[e] = 0; //by default, the sparse matrix is empty
+ _cost[e] = 0;
+ _state[e] = STATE_LOWER;
+ ++f;
+ }
+ }
+ _all_arc_num = f;
+ } else {
+ // GEQ supply constraints
+ _search_arc_num = _arc_num + _node_num;
+ ArcsType f = _arc_num + _node_num;
+ for (ArcsType u = 0, e = _arc_num; u != _node_num; ++u, ++e) {
+ _parent[u] = _root;
+ _thread[u] = u + 1;
+ _rev_thread[u + 1] = u;
+ _succ_num[u] = 1;
+ _last_succ[u] = u;
+ if (_supply[u] <= 0) {
+ _forward[u] = false;
+ _pi[u] = 0;
+ _pred[u] = e;
+ _source[e] = _root;
+ _target[e] = u;
+ _flow[e] = -_supply[u];
+ _cost[e] = 0;
+ _state[e] = STATE_TREE;
+ } else {
+ _forward[u] = true;
+ _pi[u] = -ART_COST;
+ _pred[u] = f;
+ _source[f] = u;
+ _target[f] = _root;
+ _flow[f] = _supply[u];
+ _state[f] = STATE_TREE;
+ _cost[f] = ART_COST;
+ _source[e] = _root;
+ _target[e] = u;
+ //_flow[e] = 0; //by default, the sparse matrix is empty
+ _cost[e] = 0;
+ _state[e] = STATE_LOWER;
+ ++f;
+ }
+ }
+ _all_arc_num = f;
+ }
+
+ return true;
+ }
+
+ // Find the join node
+ void findJoinNode() {
+ int u = _source[in_arc];
+ int v = _target[in_arc];
+ while (u != v) {
+ if (_succ_num[u] < _succ_num[v]) {
+ u = _parent[u];
+ } else {
+ v = _parent[v];
+ }
+ }
+ join = u;
+ }
+
+ // Find the leaving arc of the cycle and returns true if the
+ // leaving arc is not the same as the entering arc
+ bool findLeavingArc() {
+ // Initialize first and second nodes according to the direction
+ // of the cycle
+ if (_state[in_arc] == STATE_LOWER) {
+ first = _source[in_arc];
+ second = _target[in_arc];
+ } else {
+ first = _target[in_arc];
+ second = _source[in_arc];
+ }
+ delta = INF;
+ char result = 0;
+ Value d;
+ ArcsType e;
+
+ // Search the cycle along the path form the first node to the root
+ for (int u = first; u != join; u = _parent[u]) {
+ e = _pred[u];
+ d = _forward[u] ? _flow[e] : INF;
+ if (d < delta) {
+ delta = d;
+ u_out = u;
+ result = 1;
+ }
+ }
+ // Search the cycle along the path form the second node to the root
+ for (int u = second; u != join; u = _parent[u]) {
+ e = _pred[u];
+ d = _forward[u] ? INF : _flow[e];
+ if (d <= delta) {
+ delta = d;
+ u_out = u;
+ result = 2;
+ }
+ }
+
+ if (result == 1) {
+ u_in = first;
+ v_in = second;
+ } else {
+ u_in = second;
+ v_in = first;
+ }
+ return result != 0;
+ }
+
+ // Change _flow and _state vectors
+ void changeFlow(bool change) {
+ // Augment along the cycle
+ if (delta > 0) {
+ Value val = _state[in_arc] * delta;
+ _flow[in_arc] += val;
+ for (int u = _source[in_arc]; u != join; u = _parent[u]) {
+ _flow[_pred[u]] += _forward[u] ? -val : val;
+ }
+ for (int u = _target[in_arc]; u != join; u = _parent[u]) {
+ _flow[_pred[u]] += _forward[u] ? val : -val;
+ }
+ }
+ // Update the state of the entering and leaving arcs
+ if (change) {
+ _state[in_arc] = STATE_TREE;
+ _state[_pred[u_out]] =
+ (_flow[_pred[u_out]] == 0) ? STATE_LOWER : STATE_UPPER;
+ } else {
+ _state[in_arc] = -_state[in_arc];
+ }
+ }
+
+ // Update the tree structure
+ void updateTreeStructure() {
+ int old_rev_thread = _rev_thread[u_out];
+ int old_succ_num = _succ_num[u_out];
+ int old_last_succ = _last_succ[u_out];
+ v_out = _parent[u_out];
+
+ // Check if u_in and u_out coincide
+ if (u_in == u_out) {
+ // Update _parent, _pred, _pred_dir
+ _parent[u_in] = v_in;
+ _pred[u_in] = in_arc;
+ _forward[u_in] = (u_in == _source[in_arc]);
+
+ // Update _thread and _rev_thread
+ if (_thread[v_in] != u_out) {
+ ArcsType after = _thread[old_last_succ];
+ _thread[old_rev_thread] = after;
+ _rev_thread[after] = old_rev_thread;
+ after = _thread[v_in];
+ _thread[v_in] = u_out;
+ _rev_thread[u_out] = v_in;
+ _thread[old_last_succ] = after;
+ _rev_thread[after] = old_last_succ;
+ }
+ } else {
+ // Handle the case when old_rev_thread equals to v_in
+ // (it also means that join and v_out coincide)
+ int thread_continue = old_rev_thread == v_in ?
+ _thread[old_last_succ] : _thread[v_in];
+
+ // Update _thread and _parent along the stem nodes (i.e. the nodes
+ // between u_in and u_out, whose parent have to be changed)
+ int stem = u_in; // the current stem node
+ int par_stem = v_in; // the new parent of stem
+ int next_stem; // the next stem node
+ int last = _last_succ[u_in]; // the last successor of stem
+ int before, after = _thread[last];
+ _thread[v_in] = u_in;
+ _dirty_revs.clear();
+ _dirty_revs.push_back(v_in);
+ while (stem != u_out) {
+ // Insert the next stem node into the thread list
+ next_stem = _parent[stem];
+ _thread[last] = next_stem;
+ _dirty_revs.push_back(last);
+
+ // Remove the subtree of stem from the thread list
+ before = _rev_thread[stem];
+ _thread[before] = after;
+ _rev_thread[after] = before;
+
+ // Change the parent node and shift stem nodes
+ _parent[stem] = par_stem;
+ par_stem = stem;
+ stem = next_stem;
+
+ // Update last and after
+ last = _last_succ[stem] == _last_succ[par_stem] ?
+ _rev_thread[par_stem] : _last_succ[stem];
+ after = _thread[last];
+ }
+ _parent[u_out] = par_stem;
+ _thread[last] = thread_continue;
+ _rev_thread[thread_continue] = last;
+ _last_succ[u_out] = last;
+
+ // Remove the subtree of u_out from the thread list except for
+ // the case when old_rev_thread equals to v_in
+ if (old_rev_thread != v_in) {
+ _thread[old_rev_thread] = after;
+ _rev_thread[after] = old_rev_thread;
+ }
+
+ // Update _rev_thread using the new _thread values
+ for (int i = 0; i != int(_dirty_revs.size()); ++i) {
+ int u = _dirty_revs[i];
+ _rev_thread[_thread[u]] = u;
+ }
+
+ // Update _pred, _pred_dir, _last_succ and _succ_num for the
+ // stem nodes from u_out to u_in
+ int tmp_sc = 0, tmp_ls = _last_succ[u_out];
+ for (int u = u_out, p = _parent[u]; u != u_in; u = p, p = _parent[u]) {
+ _pred[u] = _pred[p];
+ _forward[u] = !_forward[p];
+ tmp_sc += _succ_num[u] - _succ_num[p];
+ _succ_num[u] = tmp_sc;
+ _last_succ[p] = tmp_ls;
+ }
+ _pred[u_in] = in_arc;
+ _forward[u_in] = (u_in == _source[in_arc]);
+ _succ_num[u_in] = old_succ_num;
+ }
+
+ // Update _last_succ from v_in towards the root
+ int up_limit_out = _last_succ[join] == v_in ? join : -1;
+ int last_succ_out = _last_succ[u_out];
+ for (int u = v_in; u != -1 && _last_succ[u] == v_in; u = _parent[u]) {
+ _last_succ[u] = last_succ_out;
+ }
+
+ // Update _last_succ from v_out towards the root
+ if (join != old_rev_thread && v_in != old_rev_thread) {
+ for (int u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ;
+ u = _parent[u]) {
+ _last_succ[u] = old_rev_thread;
+ }
+ } else if (last_succ_out != old_last_succ) {
+ for (int u = v_out; u != up_limit_out && _last_succ[u] == old_last_succ;
+ u = _parent[u]) {
+ _last_succ[u] = last_succ_out;
+ }
+ }
+
+ // Update _succ_num from v_in to join
+ for (int u = v_in; u != join; u = _parent[u]) {
+ _succ_num[u] += old_succ_num;
+ }
+ // Update _succ_num from v_out to join
+ for (int u = v_out; u != join; u = _parent[u]) {
+ _succ_num[u] -= old_succ_num;
+ }
+ }
+
+ void updatePotential() {
+ Cost sigma = _pi[v_in] - _pi[u_in] -
+ ((_forward[u_in])?_cost[in_arc]:(-_cost[in_arc]));
+ int end = _thread[_last_succ[u_in]];
+ for (int u = u_in; u != end; u = _thread[u]) {
+ _pi[u] += sigma;
+ }
+ }
+
+
+ // Heuristic initial pivots
+ bool initialPivots() {
+ Value curr, total = 0;
+ std::vector<Node> supply_nodes, demand_nodes;
+ Node u; _graph.first(u);
+ for (; u != INVALIDNODE; _graph.next(u)) {
+ curr = _supply[_node_id(u)];
+ if (curr > 0) {
+ total += curr;
+ supply_nodes.push_back(u);
+ } else if (curr < 0) {
+ demand_nodes.push_back(u);
+ }
+ }
+ if (_sum_supply > 0) total -= _sum_supply;
+ if (total <= 0) return true;
+
+ ArcVector arc_vector;
+ if (_sum_supply >= 0) {
+ if (supply_nodes.size() == 1 && demand_nodes.size() == 1) {
+ // Perform a reverse graph search from the sink to the source
+ //typename GR::template NodeMap<bool> reached(_graph, false);
+ BoolVector reached(_node_num, false);
+ Node s = supply_nodes[0], t = demand_nodes[0];
+ std::vector<Node> stack;
+ reached[t] = true;
+ stack.push_back(t);
+ while (!stack.empty()) {
+ Node u, v = stack.back();
+ stack.pop_back();
+ if (v == s) break;
+ Arc a; _graph.firstIn(a, v);
+ for (; a != INVALID; _graph.nextIn(a)) {
+ if (reached[u = _graph.source(a)]) continue;
+ ArcsType j = getArcID(a);
+ arc_vector.push_back(j);
+ reached[u] = true;
+ stack.push_back(u);
+ }
+ }
+ } else {
+ arc_vector.resize(demand_nodes.size());
+ // Find the min. cost incomming arc for each demand node
+#pragma omp parallel for
+ for (int i = 0; i < demand_nodes.size(); ++i) {
+ Node v = demand_nodes[i];
+ Cost min_cost = std::numeric_limits<Cost>::max();
+ Arc min_arc = INVALID;
+ Arc a; _graph.firstIn(a, v);
+ for (; a != INVALID; _graph.nextIn(a)) {
+ Cost c = _cost[getArcID(a)];
+ if (c < min_cost) {
+ min_cost = c;
+ min_arc = a;
+ }
+ }
+ arc_vector[i] = getArcID(min_arc);
+ }
+ arc_vector.erase(std::remove(arc_vector.begin(), arc_vector.end(), INVALID), arc_vector.end());
+ }
+ } else {
+ arc_vector.resize(supply_nodes.size());
+ // Find the min. cost outgoing arc for each supply node
+#pragma omp parallel for
+ for (int i = 0; i < int(supply_nodes.size()); ++i) {
+ Node u = supply_nodes[i];
+ Cost min_cost = std::numeric_limits<Cost>::max();
+ Arc min_arc = INVALID;
+ Arc a; _graph.firstOut(a, u);
+ for (; a != INVALID; _graph.nextOut(a)) {
+ Cost c = _cost[getArcID(a)];
+ if (c < min_cost) {
+ min_cost = c;
+ min_arc = a;
+ }
+ }
+ arc_vector[i] = getArcID(min_arc);
+ }
+ arc_vector.erase(std::remove(arc_vector.begin(), arc_vector.end(), INVALID), arc_vector.end());
+ }
+
+ // Perform heuristic initial pivots
+ for (ArcsType i = 0; i != ArcsType(arc_vector.size()); ++i) {
+ in_arc = arc_vector[i];
+ if (_state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -
+ _pi[_target[in_arc]]) >= 0) continue;
+ findJoinNode();
+ bool change = findLeavingArc();
+ if (delta >= MAX) return false;
+ changeFlow(change);
+ if (change) {
+ updateTreeStructure();
+ updatePotential();
+ }
+ }
+ return true;
+ }
+
+ // Execute the algorithm
+ ProblemType start() {
+ return start<BlockSearchPivotRule>();
+ }
+
+ template <typename PivotRuleImpl>
+ ProblemType start() {
+ PivotRuleImpl pivot(*this);
+ ProblemType retVal = OPTIMAL;
+
+ // Perform heuristic initial pivots
+ if (!initialPivots()) return UNBOUNDED;
+
+ size_t iter_number = 0;
+ // Execute the Network Simplex algorithm
+ while (pivot.findEnteringArc()) {
+ if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) {
+#if DEBUG_LVL>0
+ if(iter_number>MAX_DEBUG_ITER)
+ break;
+ if(iter_number%1000==0||iter_number%1000==1){
+ Cost curCost=totalCost();
+ Value sumFlow=0;
+ Cost a;
+ a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
+ a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
+ for (int i=0; i<_flow.size(); i++) {
+ sumFlow+=_state[i]*_flow[i];
+ }
+ std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << iter_number << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
+ std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
+ std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
+ std::cout << _cost[in_arc] << "\n";
+ std::cout << _pi[_source[in_arc]] << "\n";
+ std::cout << _pi[_target[in_arc]] << "\n";
+ std::cout << a << "\n";
+ }
+#endif
+
+ findJoinNode();
+ bool change = findLeavingArc();
+ if (delta >= MAX) return UNBOUNDED;
+ changeFlow(change);
+ if (change) {
+ updateTreeStructure();
+ updatePotential();
+ }
+
+#if DEBUG_LVL>0
+ else{
+ std::cout << "No change\n";
+ }
+#endif
+
+#if DEBUG_LVL>1
+ std::cout << "Arc in = (" << _source[in_arc] << ", " << _target[in_arc] << ")\n";
+#endif
+
+
+ } else {
+ char errMess[1000];
+ sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
+ std::cerr << errMess;
+ retVal = MAX_ITER_REACHED;
+ break;
+ }
+
+ }
+
+
+
+#if DEBUG_LVL>0
+ Cost curCost=totalCost();
+ Value sumFlow=0;
+ Cost a;
+ a= (fabs(_pi[_source[in_arc]])>=fabs(_pi[_target[in_arc]])) ? fabs(_pi[_source[in_arc]]) : fabs(_pi[_target[in_arc]]);
+ a=a>=fabs(_cost[in_arc])?a:fabs(_cost[in_arc]);
+ for (int i=0; i<_flow.size(); i++) {
+ sumFlow+=_state[i]*_flow[i];
+ }
+
+ std::cout << "Sum of the flow " << std::setprecision(20) << sumFlow << "\n" << niter << " iterations, current cost=" << curCost << "\nReduced cost=" << _state[in_arc] * (_cost[in_arc] + _pi[_source[in_arc]] -_pi[_target[in_arc]]) << "\nPrecision = "<< -EPSILON*(a) << "\n";
+
+ std::cout << "Arc in = (" << _node_id(_source[in_arc]) << ", " << _node_id(_target[in_arc]) <<")\n";
+ std::cout << "Supplies = (" << _supply[_source[in_arc]] << ", " << _supply[_target[in_arc]] << ")\n";
+
+#endif
+
+
+
+#if DEBUG_LVL>1
+ sumFlow=0;
+ for (int i=0; i<_flow.size(); i++) {
+ sumFlow+=_state[i]*_flow[i];
+ if (_state[i]==STATE_TREE) {
+ std::cout << "Non zero value at (" << _node_num+1-_source[i] << ", " << _node_num+1-_target[i] << ")\n";
+ }
+ }
+ std::cout << "Sum of the flow " << sumFlow << "\n"<< niter <<" iterations, current cost=" << totalCost() << "\n";
+#endif
+
+
+
+ //Check feasibility
+ if(retVal == OPTIMAL){
+ for (ArcsType e = _search_arc_num; e != _all_arc_num; ++e) {
+ if (_flow[e] != 0){
+ if (fabs(_flow[e]) > _EPSILON) // change of the original code following issue #126
+ return INFEASIBLE;
+ else
+ _flow[e]=0;
+ }
+ }
+ }
+
+ // Shift potentials to meet the requirements of the GEQ/LEQ type
+ // optimality conditions
+ if (_sum_supply == 0) {
+ if (_stype == GEQ) {
+ Cost max_pot = -std::numeric_limits<Cost>::max();
+ for (ArcsType i = 0; i != _node_num; ++i) {
+ if (_pi[i] > max_pot) max_pot = _pi[i];
+ }
+ if (max_pot > 0) {
+ for (ArcsType i = 0; i != _node_num; ++i)
+ _pi[i] -= max_pot;
+ }
+ } else {
+ Cost min_pot = std::numeric_limits<Cost>::max();
+ for (ArcsType i = 0; i != _node_num; ++i) {
+ if (_pi[i] < min_pot) min_pot = _pi[i];
+ }
+ if (min_pot < 0) {
+ for (ArcsType i = 0; i != _node_num; ++i)
+ _pi[i] -= min_pot;
+ }
+ }
+ }
+
+ return retVal;
+ }
+
+ }; //class NetworkSimplexSimple
+
+ ///@}
+
+} //namespace lemon_omp
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
new file mode 100644
index 0000000..8b4d0c3
--- /dev/null
+++ b/ot/lp/solver_1d.py
@@ -0,0 +1,367 @@
+# -*- coding: utf-8 -*-
+"""
+Exact solvers for the 1D Wasserstein distance using cvxopt
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Author: Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import warnings
+
+from .emd_wrap import emd_1d_sorted
+from ..backend import get_backend
+from ..utils import list_to_array
+
+
+def quantile_function(qs, cws, xs):
+ r""" Computes the quantile function of an empirical distribution
+
+ Parameters
+ ----------
+ qs: array-like, shape (n,)
+ Quantiles at which the quantile function is evaluated
+ cws: array-like, shape (m, ...)
+ cumulative weights of the 1D empirical distribution, if batched, must be similar to xs
+ xs: array-like, shape (n, ...)
+ locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions
+
+ Returns
+ -------
+ q: array-like, shape (..., n)
+ The quantiles of the distribution
+ """
+ nx = get_backend(qs, cws)
+ n = xs.shape[0]
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ cws = cws.T.contiguous()
+ qs = qs.T.contiguous()
+ else:
+ cws = cws.T
+ qs = qs.T
+ idx = nx.searchsorted(cws, qs).T
+ return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0)
+
+
+def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
+ r"""
+ Computes the 1 dimensional OT loss [15] between two (batched) empirical
+ distributions
+
+ .. math:
+ OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq
+
+ It is formally the p-Wasserstein distance raised to the power p.
+ We do so in a vectorized way by first building the individual quantile functions then integrating them.
+
+ This function should be preferred to `emd_1d` whenever the backend is
+ different to numpy, and when gradients over
+ either sample positions or weights are required.
+
+ Parameters
+ ----------
+ u_values: array-like, shape (n, ...)
+ locations of the first empirical distribution
+ v_values: array-like, shape (m, ...)
+ locations of the second empirical distribution
+ u_weights: array-like, shape (n, ...), optional
+ weights of the first empirical distribution, if None then uniform weights are used
+ v_weights: array-like, shape (m, ...), optional
+ weights of the second empirical distribution, if None then uniform weights are used
+ p: int, optional
+ order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1
+ require_sort: bool, optional
+ sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to
+ the function, default is True
+
+ Returns
+ -------
+ cost: float/array-like, shape (...)
+ the batched EMD
+
+ References
+ ----------
+ .. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport.
+
+ """
+
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ u_cumweights = nx.cumsum(u_weights, 0)
+ v_cumweights = nx.cumsum(v_weights, 0)
+
+ qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0)
+ u_quantiles = quantile_function(qs, u_cumweights, u_values)
+ v_quantiles = quantile_function(qs, v_cumweights, v_values)
+ qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)])
+ delta = qs[1:, ...] - qs[:-1, ...]
+ diff_quantiles = nx.abs(u_quantiles - v_quantiles)
+
+ if p == 1:
+ return nx.sum(delta * nx.abs(diff_quantiles), axis=0)
+ return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)
+
+
+def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
+ log=False):
+ r"""Solves the Earth Movers distance problem between 1d measures and returns
+ the OT matrix
+
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format. Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the cost.
+ Otherwise returns only the optimal transportation matrix.
+
+ Returns
+ -------
+ gamma: (ns, nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log: dict
+ If input log is True, a dictionary containing the cost
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd_1d accepts lists and
+ performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd_1d(x_a, x_b, a, b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
+ >>> ot.emd_1d(x_a, x_b)
+ array([[0. , 0.5],
+ [0.5, 0. ]])
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd : EMD for multidimensional distributions
+ ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
+ transportation matrix)
+ """
+ a, b, x_a, x_b = list_to_array(a, b, x_a, x_b)
+ nx = get_backend(x_a, x_b)
+
+ assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
+ assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
+ "emd_1d should only be used with monodimensional data"
+
+ # if empty array given then use uniform distributions
+ if a is None or a.ndim == 0 or len(a) == 0:
+ a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0]
+ if b is None or b.ndim == 0 or len(b) == 0:
+ b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
+
+ # ensure that same mass
+ np.testing.assert_almost_equal(
+ nx.to_numpy(nx.sum(a, axis=0)),
+ nx.to_numpy(nx.sum(b, axis=0)),
+ err_msg='a and b vector must have the same sum'
+ )
+ b = b * nx.sum(a) / nx.sum(b)
+
+ x_a_1d = nx.reshape(x_a, (-1,))
+ x_b_1d = nx.reshape(x_b, (-1,))
+ perm_a = nx.argsort(x_a_1d)
+ perm_b = nx.argsort(x_b_1d)
+
+ G_sorted, indices, cost = emd_1d_sorted(
+ nx.to_numpy(a[perm_a]).astype(np.float64),
+ nx.to_numpy(b[perm_b]).astype(np.float64),
+ nx.to_numpy(x_a_1d[perm_a]).astype(np.float64),
+ nx.to_numpy(x_b_1d[perm_b]).astype(np.float64),
+ metric=metric, p=p
+ )
+
+ G = nx.coo_matrix(
+ G_sorted,
+ perm_a[indices[:, 0]],
+ perm_b[indices[:, 1]],
+ shape=(a.shape[0], b.shape[0]),
+ type_as=x_a
+ )
+ if dense:
+ G = nx.todense(G)
+ elif str(nx) == "jax":
+ warnings.warn("JAX does not support sparse matrices, converting to dense")
+ if log:
+ log = {'cost': nx.from_numpy(cost, type_as=x_a)}
+ return G, log
+ return G
+
+
+def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
+ log=False):
+ r"""Solves the Earth Movers distance problem between 1d measures and returns
+ the loss
+
+
+ .. math::
+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
+
+ s.t. \gamma 1 = a,
+ \gamma^T 1= b,
+ \gamma\geq 0
+ where :
+
+ - d is the metric
+ - x_a and x_b are the samples
+ - a and b are the sample weights
+
+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
+
+ Uses the algorithm detailed in [1]_
+
+ Parameters
+ ----------
+ x_a : (ns,) or (ns, 1) ndarray, float64
+ Source dirac locations (on the real line)
+ x_b : (nt,) or (ns, 1) ndarray, float64
+ Target dirac locations (on the real line)
+ a : (ns,) ndarray, float64, optional
+ Source histogram (default is uniform weight)
+ b : (nt,) ndarray, float64, optional
+ Target histogram (default is uniform weight)
+ metric: str, optional (default='sqeuclidean')
+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
+ Due to implementation details, this function runs faster when
+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
+ are used.
+ p: float, optional (default=1.0)
+ The p-norm to apply for if metric='minkowski'
+ dense: boolean, optional (default=True)
+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
+ Otherwise returns a sparse representation using scipy's `coo_matrix`
+ format. Only used if log is set to True. Due to implementation details,
+ this function runs faster when dense is set to False.
+ log: boolean, optional (default=False)
+ If True, returns a dictionary containing the transportation matrix.
+ Otherwise returns only the loss.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+ log: dict
+ If input log is True, a dictionary containing the Optimal transportation
+ matrix for the given parameters
+
+
+ Examples
+ --------
+
+ Simple example with obvious solution. The function emd2_1d accepts lists and
+ performs automatic conversion to numpy arrays
+
+ >>> import ot
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> x_a = [2., 0.]
+ >>> x_b = [0., 3.]
+ >>> ot.emd2_1d(x_a, x_b, a, b)
+ 0.5
+ >>> ot.emd2_1d(x_a, x_b)
+ 0.5
+
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+
+ See Also
+ --------
+ ot.lp.emd2 : EMD for multidimensional distributions
+ ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
+ instead of the cost)
+ """
+ # If we do not return G (log==False), then we should not to cast it to dense
+ # (useless overhead)
+ G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
+ dense=dense and log, log=True)
+ cost = log_emd['cost']
+ if log:
+ log_emd = {'G': G}
+ return cost, log_emd
+ return cost
diff --git a/ot/optim.py b/ot/optim.py
index b9ca891..bd8ca26 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -12,34 +12,36 @@ import numpy as np
from scipy.optimize.linesearch import scalar_search_armijo
from .lp import emd
from .bregman import sinkhorn
+from ot.utils import list_to_array
+from .backend import get_backend
# The corresponding scipy function does not work for matrices
def line_search_armijo(f, xk, pk, gfk, old_fval,
args=(), c1=1e-4, alpha0=0.99):
- """
+ r"""
Armijo linesearch function that works with matrices
- find an approximate minimum of f(xk+alpha*pk) that satifies the
+ Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the
armijo conditions.
Parameters
----------
f : callable
loss function
- xk : ndarray
+ xk : array-like
initial position
- pk : ndarray
+ pk : array-like
descent direction
- gfk : ndarray
- gradient of f at xk
+ gfk : array-like
+ gradient of `f` at :math:`x_k`
old_fval : float
- loss value at xk
+ loss value at :math:`x_k`
args : tuple, optional
- arguments given to f
+ arguments given to `f`
c1 : float, optional
- c1 const in armijo rule (>0)
+ :math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
@@ -53,7 +55,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
loss value at step alpha
"""
- xk = np.atleast_1d(xk)
+
+ xk, pk, gfk = list_to_array(xk, pk, gfk)
+ nx = get_backend(xk, pk)
+
+ if len(xk.shape) == 0:
+ xk = nx.reshape(xk, (-1,))
+
fc = [0]
def phi(alpha1):
@@ -65,10 +73,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
else:
phi0 = old_fval
- derphi0 = np.sum(pk * gfk) # Quickfix for matrices
+ derphi0 = nx.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
+ # scalar_search_armijo can return alpha > 1
+ if alpha is not None:
+ alpha = min(1, alpha)
return alpha, fc[0], phi1
@@ -76,55 +87,64 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
"""
Solve the linesearch in the FW iterations
+
Parameters
----------
cost : method
Cost in the FW for the linesearch
- G : ndarray, shape(ns,nt)
+ G : array-like, shape(ns,nt)
The transport map at a given iteration of the FW
- deltaG : ndarray (ns,nt)
+ deltaG : array-like (ns,nt)
Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
- Mi : ndarray (ns,nt)
+ Mi : array-like (ns,nt)
Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
- f_val : float
- Value of the cost at G
+ f_val : float
+ Value of the cost at `G`
armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
- C1 : ndarray (ns,ns), optional
+ If True the steps of the line-search is found via an armijo research. Else closed form is used.
+ If there is convergence issues use False.
+ C1 : array-like (ns,ns), optional
Structure matrix in the source domain. Only used and necessary when armijo=False
- C2 : ndarray (nt,nt), optional
+ C2 : array-like (nt,nt), optional
Structure matrix in the target domain. Only used and necessary when armijo=False
reg : float, optional
- Regularization parameter. Only used and necessary when armijo=False
- Gc : ndarray (ns,nt)
+ Regularization parameter. Only used and necessary when armijo=False
+ Gc : array-like (ns,nt)
Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
- constC : ndarray (ns,nt)
- Constant for the gromov cost. See [24]. Only used and necessary when armijo=False
- M : ndarray (ns,nt), optional
+ constC : array-like (ns,nt)
+ Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
+ M : array-like (ns,nt), optional
Cost matrix between the features. Only used and necessary when armijo=False
+
Returns
-------
alpha : float
- The optimal step size of the FW
+ The optimal step size of the FW
fc : int
- nb of function call. Useless here
- f_val : float
- The value of the cost for the next iteration
+ nb of function call. Useless here
+ f_val : float
+ The value of the cost for the next iteration
+
+
+ .. _references-solve-linesearch:
References
----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
- and Courty Nicolas
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
if armijo:
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
else: # requires symetric matrices
- dot1 = np.dot(C1, deltaG)
- dot12 = dot1.dot(C2)
- a = -2 * reg * np.sum(dot12 * deltaG)
- b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
+ G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(G, deltaG, C1, C2, constC)
+ else:
+ nx = get_backend(G, deltaG, C1, C2, constC, M)
+
+ dot = nx.dot(nx.dot(C1, deltaG), C2)
+ a = -2 * reg * nx.sum(dot * deltaG)
+ b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG))
c = cost(G)
alpha = solve_1d_linesearch_quad(a, b, c)
@@ -136,48 +156,49 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
- """
+ r"""
Solve the general regularized OT problem with conditional gradient
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot f(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
- The algorithm used for solving the problem is conditional gradient as discussed in [1]_
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
Parameters
----------
- a : ndarray, shape (ns,)
+ a : array-like, shape (ns,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : array-like, shape (nt,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
- G0 : ndarray, shape (ns,nt), optional
+ G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numItermaxEmd : int, optional
Max number of iterations for emd
stopThr : float, optional
- Stop threshol on the relative variation (>0)
+ Stop threshold on the relative variation (>0)
stopThr2 : float, optional
- Stop threshol on the absolute variation (>0)
+ Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -193,6 +214,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
log dictionary return only if log==True in parameters
+ .. _references-cg:
References
----------
@@ -204,6 +226,11 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""
+ a, b, M, G0 = list_to_array(a, b, M, G0)
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(a, b)
+ else:
+ nx = get_backend(a, b, M)
loop = 1
@@ -211,12 +238,12 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
log = {'loss': []}
if G0 is None:
- G = np.outer(a, b)
+ G = nx.outer(a, b)
else:
G = G0
def cost(G):
- return np.sum(M * G) + reg * f(G)
+ return nx.sum(M * G) + reg * f(G)
f_val = cost(G)
if log:
@@ -237,15 +264,17 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
# problem linearization
Mi = M + reg * df(G)
# set M positive
- Mi += Mi.min()
+ Mi += nx.min(Mi)
# solve linear program
- Gc = emd(a, b, Mi, numItermax=numItermaxEmd)
+ Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True)
deltaG = Gc - G
# line search
alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
+ if alpha is None:
+ alpha = 0.0
G = G + alpha * deltaG
@@ -268,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
if log:
+ log.update(logemd)
return G, log
else:
return G
@@ -275,51 +305,52 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
- """
+ r"""
Solve the general regularized OT problem with the generalized conditional gradient
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg1\cdot\Omega(\gamma) + reg2\cdot f(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} &= \mathbf{b}
- \gamma\geq 0
+ \gamma &\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`f` is the regularization term ( and df is its gradient)
- - a and b are source and target weights (sum to 1)
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
- The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5,7]_
+ The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
- a : ndarray, shape (ns,)
+ a : array-like, shape (ns,)
samples weights in the source domain
- b : ndarrayv (nt,)
+ b : array-like, (nt,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : array-like, shape (ns, nt)
loss matrix
reg1 : float
Entropic Regularization term >0
reg2 : float
Second Regularization term >0
- G0 : ndarray, shape (ns, nt), optional
+ G0 : array-like, shape (ns, nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
Max number of iterations of Sinkhorn
stopThr : float, optional
- Stop threshol on the relative variation (>0)
+ Stop threshold on the relative variation (>0)
stopThr2 : float, optional
- Stop threshol on the absolute variation (>0)
+ Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -332,9 +363,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-gcg:
References
----------
+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
See Also
@@ -342,6 +377,8 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
ot.optim.cg : conditional gradient
"""
+ a, b, M, G0 = list_to_array(a, b, M, G0)
+ nx = get_backend(a, b, M)
loop = 1
@@ -349,12 +386,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
log = {'loss': []}
if G0 is None:
- G = np.outer(a, b)
+ G = nx.outer(a, b)
else:
G = G0
def cost(G):
- return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G)
+ return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G)
f_val = cost(G)
if log:
@@ -382,7 +419,7 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
deltaG = Gc - G
# line search
- dcost = Mi + reg1 * (1 + np.log(G)) # ??
+ dcost = Mi + reg1 * (1 + nx.log(G)) # ??
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val)
G = G + alpha * deltaG
@@ -413,10 +450,12 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
def solve_1d_linesearch_quad(a, b, c):
- """
- For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem:
+ r"""
+ For any convex or non-convex 1d quadratic function `f`, solve the following problem:
+
.. math::
- \argmin f(x)=a*x^{2}+b*x+c
+
+ \mathop{\arg \min}_{0 \leq x \leq 1} \quad f(x) = ax^{2} + bx + c
Parameters
----------
diff --git a/ot/partial.py b/ot/partial.py
index eb707d8..b7093e4 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -20,13 +20,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
The function considers the following problem:
.. math::
- \gamma = \arg\min_\gamma <\gamma,(M-\lambda)>_F
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, (\mathbf{M} - \lambda) \rangle_F
- s.t.
- \gamma\geq 0 \\
- \gamma 1 \leq a\\
- \gamma^T 1 \leq b\\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+ .. math::
+ s.t. \ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \gamma &\geq 0
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
or equivalently (see Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X.
@@ -34,33 +37,32 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
metrics. Foundations of Computational Mathematics, 18(1), 1-44.)
.. math::
- \gamma = \arg\min_\gamma <\gamma,M>_F + \sqrt(\lambda/2)
- (\|\gamma 1 - a\|_1 + \|\gamma^T 1 - b\|_1)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \sqrt{\frac{\lambda}{2} (\|\gamma \mathbf{1} - \mathbf{a}\|_1 + \|\gamma^T \mathbf{1} - \mathbf{b}\|_1)}
- s.t.
- \gamma\geq 0 \\
+ s.t. \ \gamma \geq 0
where :
- - M is the metric cost matrix
- - a and b are source and target unbalanced distributions
- - :math:`\lambda` is the lagragian cost. Tuning its value allows attaining
- a given mass to be transported m
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
+ - :math:`\lambda` is the lagrangian cost. Tuning its value allows attaining
+ a given mass to be transported `m`
- The formulation of the problem has been proposed in [28]_
+ The formulation of the problem has been proposed in :ref:`[28] <references-partial-wasserstein-lagrange>`
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,)
- Unnormalized histograms of dimension dim_b
+ Unnormalized histograms of dimension `dim_b`
M : np.ndarray (dim_a, dim_b)
cost matrix for the quadratic cost
reg_m : float, optional
- Lagragian cost
+ Lagrangian cost
nb_dummies : int, optional, default:1
number of reservoir points to be added (to avoid numerical
instabilities, increase its value if an error is raised)
@@ -69,6 +71,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
**kwargs : dict
parameters can be directly passed to the emd solver
+
.. warning::
When dealing with a large number of points, the EMD solver may face
some instabilities, especially when the mass associated to the dummy
@@ -77,7 +80,7 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
Returns
-------
- gamma : (dim_a x dim_b) ndarray
+ gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
@@ -97,9 +100,10 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
array([[0.1, 0. ],
[0. , 0. ]])
+
+ .. _references-partial-wasserstein-lagrange:
References
----------
-
.. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in
optimal transport and Monge-Ampere obstacle problems. Annals of
mathematics, 673-730.
@@ -162,27 +166,30 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
The function considers the following problem:
.. math::
- \gamma = \arg\min_\gamma <\gamma,M>_F
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
- s.t.
- \gamma\geq 0 \\
- \gamma 1 \leq a\\
- \gamma^T 1 \leq b\\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+ .. math::
+ s.t. \ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \gamma &\geq 0
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - M is the metric cost matrix
- - a and b are source and target unbalanced distributions
- - m is the amount of mass to be transported
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
+ - `m` is the amount of mass to be transported
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,)
- Unnormalized histograms of dimension dim_b
+ Unnormalized histograms of dimension `dim_b`
M : np.ndarray (dim_a, dim_b)
cost matrix for the quadratic cost
m : float, optional
@@ -205,7 +212,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
Returns
-------
- :math:`gamma` : (dim_a x dim_b) ndarray
+ gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
@@ -230,9 +237,9 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
.. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in
optimal transport and Monge-Ampere obstacle problems. Annals of
mathematics, 673-730.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
See Also
--------
@@ -254,7 +261,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
M_extended = np.zeros((len(a_extended), len(b_extended)))
- M_extended[-1, -1] = np.max(M) * 1e5
+ M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5
M_extended[:len(a), :len(b)] = M
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
@@ -278,27 +285,30 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
The function considers the following problem:
.. math::
- \gamma = \arg\min_\gamma <\gamma,M>_F
+ \gamma = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
+
+ .. math::
+ s.t. \ \gamma \mathbf{1} &\leq \mathbf{a}
- s.t.
- \gamma\geq 0 \\
- \gamma 1 \leq a\\
- \gamma^T 1 \leq b\\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \gamma &\geq 0
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - M is the metric cost matrix
- - a and b are source and target unbalanced distributions
- - m is the amount of mass to be transported
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
+ - `m` is the amount of mass to be transported
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,)
- Unnormalized histograms of dimension dim_b
+ Unnormalized histograms of dimension `dim_b`
M : np.ndarray (dim_a, dim_b)
cost matrix for the quadratic cost
m : float, optional
@@ -321,8 +331,8 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
Returns
-------
- :math:`gamma` : (dim_a x dim_b) ndarray
- Optimal transportation matrix for the given parameters
+ GW: float
+ partial GW discrepancy
log : dict
log dictionary returned only if `log` is `True`
@@ -344,14 +354,13 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
.. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in
optimal transport and Monge-Ampere obstacle problems. Annals of
mathematics, 673-730.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True,
**kwargs)
-
log_w['T'] = partial_gw
if log:
@@ -361,8 +370,8 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
def gwgrad_partial(C1, C2, T):
- """Compute the GW gradient. Note: we can not use the trick in [12]_ as
- the marginals may not sum to 1.
+ """Compute the GW gradient. Note: we can not use the trick in :ref:`[12] <references-gwgrad-partial>`
+ as the marginals may not sum to 1.
Parameters
----------
@@ -380,6 +389,8 @@ def gwgrad_partial(C1, C2, T):
numpy.array of shape (n_p+nb_dummies, n_u)
gradient
+
+ .. _references-gwgrad-partial:
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
@@ -426,22 +437,25 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
The function considers the following problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
- s.t. \gamma 1 \leq a \\
- \gamma^T 1 \leq b \\
- \gamma\geq 0 \\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\
+ .. math::
+ s.t. \ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \gamma &\geq 0
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - M is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)
- =\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are the sample weights
- - m is the amount of mass to be transported
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+ - `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in [29]_
+ The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein>`
Parameters
@@ -455,7 +469,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
q : ndarray, shape (nt,)
Distribution in the target space
m : float, optional
- Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
nb_dummies : int, optional
Number of dummy points to add (avoid instabilities in the EMD solver)
G0 : ndarray, shape (ns, nt), optional
@@ -477,7 +491,7 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
Returns
-------
- gamma : (dim_a x dim_b) ndarray
+ gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
@@ -501,14 +515,16 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
>>> np.round(partial_gromov_wasserstein(C1, C2, a, b, m=0.25),2)
array([[0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. ],
- [0. , 0. , 0. , 0. ],
- [0. , 0. , 0. , 0.25]])
+ [0. , 0. , 0.25, 0. ],
+ [0. , 0. , 0. , 0. ]])
+
+ .. _references-partial-gromov-wasserstein:
References
----------
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
@@ -530,20 +546,18 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
cpt = 0
err = 1
- eps = 1e-20
+
if log:
log = {'err': []}
while (err > tol and cpt < numItermax):
- Gprev = G0
+ Gprev = np.copy(G0)
M = gwgrad_partial(C1, C2, G0)
- M[M < eps] = np.quantile(M, thres)
-
M_emd = np.zeros(dim_G_extended)
M_emd[:len(p), :len(q)] = M
- M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5
+ M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2
M_emd = np.asarray(M_emd, dtype=np.float64)
Gc, logemd = emd(p_extended, q_extended, M_emd, log=True, **kwargs)
@@ -565,6 +579,22 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
print('{:5d}|{:8e}|{:8e}'.format(cpt, err,
gwloss_partial(C1, C2, G0)))
+ deltaG = G0 - Gprev
+ a = gwloss_partial(C1, C2, deltaG)
+ b = 2 * np.sum(M * deltaG)
+ if b > 0: # due to numerical precision
+ gamma = 0
+ cpt = numItermax
+ elif a > 0:
+ gamma = min(1, np.divide(-b, 2.0 * a))
+ else:
+ if (a + b) < 0:
+ gamma = 1
+ else:
+ gamma = 0
+ cpt = numItermax
+
+ G0 = Gprev + gamma * deltaG
cpt += 1
if log:
@@ -584,22 +614,25 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
The function considers the following problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F
+ GW = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
- s.t. \gamma 1 \leq a \\
- \gamma^T 1 \leq b \\
- \gamma\geq 0 \\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\
+ .. math::
+ s.t. \ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \gamma &\geq 0
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - M is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term
- :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are the sample weights
- - m is the amount of mass to be transported
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+ - `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in [29]_
+ The formulation of the problem has been proposed in :ref:`[29] <references-partial-gromov-wasserstein2>`
Parameters
@@ -613,7 +646,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
q : ndarray, shape (nt,)
Distribution in the target space
m : float, optional
- Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
nb_dummies : int, optional
Number of dummy points to add (avoid instabilities in the EMD solver)
G0 : ndarray, shape (ns, nt), optional
@@ -642,7 +675,7 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
Returns
-------
- partial_gw_dist : (dim_a x dim_b) ndarray
+ partial_gw_dist : float
partial GW discrepancy
log : dict
log dictionary returned only if `log` is `True`
@@ -663,11 +696,13 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None,
>>> np.round(partial_gromov_wasserstein2(C1, C2, a, b, m=0.25),2)
0.0
+
+ .. _references-partial-gromov-wasserstein2:
References
----------
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
@@ -693,30 +728,29 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
The function considers the following problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 \leq a \\
- \gamma^T 1 \leq b \\
- \gamma\geq 0 \\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\} \\
+ s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\
+ \gamma^T \mathbf{1} &\leq \mathbf{b} \\
+ \gamma &\geq 0 \\
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\
where :
- - M is the metric cost matrix
- - :math:`\Omega` is the entropic regularization term
- :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are the sample weights
- - m is the amount of mass to be transported
+ - :math:`\mathbf{M}` is the metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+ - `m` is the amount of mass to be transported
- The formulation of the problem has been proposed in [3]_ (prop. 5)
+ The formulation of the problem has been proposed in :ref:`[3] <references-entropic-partial-wasserstein>` (prop. 5)
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,)
- Unnormalized histograms of dimension dim_b
+ Unnormalized histograms of dimension `dim_b`
M : np.ndarray (dim_a, dim_b)
cost matrix
reg : float
@@ -735,7 +769,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
Returns
-------
- gamma : (dim_a x dim_b) ndarray
+ gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
@@ -751,6 +785,8 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
array([[0.06, 0.02],
[0.01, 0. ]])
+
+ .. _references-entropic-partial-wasserstein:
References
----------
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
@@ -825,32 +861,34 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
numItermax=1000, tol=1e-7, log=False,
verbose=False):
r"""
- Returns the partial Gromov-Wasserstein transport between (C1,p) and (C2,q)
+ Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \arg\min_{\gamma} \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})\cdot
- \gamma_{i,j}\cdot\gamma_{k,l} + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_{\gamma} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot
+ \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
- s.t.
- \gamma\geq 0 \\
- \gamma 1 \leq a\\
- \gamma^T 1 \leq b\\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+ .. math::
+ s.t. \ \gamma &\geq 0
+
+ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - C1 is the metric cost matrix in the source space
- - C2 is the metric cost matrix in the target space
- - p and q are the sample weights
- - L : quadratic loss function
- - :math:`\Omega` is the entropic regularization term
- :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - m is the amount of mass to be transported
+ - :math:`\mathbf{C_1}` is the metric cost matrix in the source space
+ - :math:`\mathbf{C_2}` is the metric cost matrix in the target space
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
+ - `L`: quadratic loss function
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - `m` is the amount of mass to be transported
- The formulation of the GW problem has been proposed in [12]_ and the
- partial GW in [29]_.
+ The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein>`
Parameters
----------
@@ -865,7 +903,7 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
reg: float
entropic regularization parameter
m : float, optional
- Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
numItermax : int, optional
@@ -887,12 +925,12 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
>>> y = np.array([3,2,98,199]).reshape((-1,1))
>>> C1 = sp.spatial.distance.cdist(x, x)
>>> C2 = sp.spatial.distance.cdist(y, y)
- >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b,50), 2)
+ >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50), 2)
array([[0.12, 0.13, 0. , 0. ],
[0.13, 0.12, 0. , 0. ],
[0. , 0. , 0.25, 0. ],
[0. , 0. , 0. , 0.25]])
- >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50, m=0.25), 2)
+ >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50,0.25), 2)
array([[0.02, 0.03, 0. , 0.03],
[0.03, 0.03, 0. , 0.03],
[0. , 0. , 0.03, 0. ],
@@ -900,19 +938,22 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None,
Returns
-------
- :math: `gamma` : (dim_a x dim_b) ndarray
+ :math: `gamma` : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
+
+ .. _references-entropic-partial-gromov-wassertein:
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
See Also
--------
@@ -964,33 +1005,33 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
numItermax=1000, tol=1e-7, log=False,
verbose=False):
r"""
- Returns the partial Gromov-Wasserstein discrepancy between (C1,p) and
- (C2,q)
+ Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
The function solves the following optimization problem:
.. math::
- GW = \arg\min_{\gamma} \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})\cdot
- \gamma_{i,j}\cdot\gamma_{k,l} + reg\cdot\Omega(\gamma)
+ GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot
+ \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
- s.t.
- \gamma\geq 0 \\
- \gamma 1 \leq a\\
- \gamma^T 1 \leq b\\
- 1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
+ .. math::
+ s.t. \ \gamma &\geq 0
+
+ \gamma \mathbf{1} &\leq \mathbf{a}
+
+ \gamma^T \mathbf{1} &\leq \mathbf{b}
+
+ \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
where :
- - C1 is the metric cost matrix in the source space
- - C2 is the metric cost matrix in the target space
- - p and q are the sample weights
- - L : quadratic loss function
- - :math:`\Omega` is the entropic regularization term
- :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - m is the amount of mass to be transported
+ - :math:`\mathbf{C_1}` is the metric cost matrix in the source space
+ - :math:`\mathbf{C_2}` is the metric cost matrix in the target space
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
+ - `L` : quadratic loss function
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - `m` is the amount of mass to be transported
- The formulation of the GW problem has been proposed in [12]_ and the
- partial GW in [29]_.
+ The formulation of the GW problem has been proposed in :ref:`[12] <references-entropic-partial-gromov-wassertein2>` and the partial GW in :ref:`[29] <references-entropic-partial-gromov-wassertein2>`
Parameters
@@ -1006,7 +1047,7 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
reg: float
entropic regularization parameter
m : float, optional
- Amount of mass to be transported (default: min (|p|_1, |q|_1))
+ Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
G0 : ndarray, shape (ns, nt), optional
Initialisation of the transportation matrix
numItermax : int, optional
@@ -1039,14 +1080,17 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None,
>>> np.round(entropic_partial_gromov_wasserstein2(C1, C2, a, b,50), 2)
1.87
+
+ .. _references-entropic-partial-gromov-wassertein2:
References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.
- .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov-
- Wasserstein with Applications on Positive-Unlabeled Learning".
- arXiv preprint arXiv:2002.08276.
+
+ .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
+ Transport with Applications on Positive-Unlabeled Learning".
+ NeurIPS.
"""
partial_gw, log_gw = entropic_partial_gromov_wasserstein(C1, C2, p, q, reg,
diff --git a/ot/plot.py b/ot/plot.py
index ad436b4..3e3bed7 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -18,10 +18,10 @@ from matplotlib import gridspec
def plot1D_mat(a, b, M, title=''):
- """ Plot matrix M with the source and target 1D distribution
+ """ Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution
- Creates a subplot with the source distribution a on the left and
- target distribution b on the tot. The matrix M is shown in between.
+ Creates a subplot with the source distribution :math:`\mathbf{a}` on the left and
+ target distribution :math:`\mathbf{b}` on the top. The matrix :math:`\mathbf{M}` is shown in between.
Parameters
@@ -61,10 +61,10 @@ def plot1D_mat(a, b, M, title=''):
def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
- """ Plot matrix M in 2D with lines using alpha values
+ """ Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values
Plot lines between source and target 2D samples with a color
- proportional to the value of the matrix G between samples.
+ proportional to the value of the matrix :math:`\mathbf{G}` between samples.
Parameters
diff --git a/ot/regpath.py b/ot/regpath.py
new file mode 100644
index 0000000..269937a
--- /dev/null
+++ b/ot/regpath.py
@@ -0,0 +1,827 @@
+# -*- coding: utf-8 -*-
+"""
+Regularization path OT solvers
+"""
+
+# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
+# License: MIT License
+
+import numpy as np
+import scipy.sparse as sp
+
+
+def recast_ot_as_lasso(a, b, C):
+ r"""This function recasts the l2-penalized UOT problem as a Lasso problem
+
+ Recall the l2-penalized UOT problem defined in [Chapel et al., 2021]
+ .. math::
+ UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2 +
+ \lambda \|T^T 1_n - b\|_2^2
+ s.t.
+ T \geq 0
+ where :
+ - C is the (dim_a, dim_b) metric cost matrix
+ - :math:`\lambda` is the l2-regularization coefficient
+ - a and b are source and target distributions
+ - T is the transport plan to optimize
+
+ The problem above can be reformulated to a non-negative penalized
+ linear regression problem, particularly Lasso
+ .. math::
+ UOT2 = \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2
+ s.t.
+ t \geq 0
+ where :
+ - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
+ - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
+ - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T]
+ - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix,
+ see [Chapel et al., 2021] for the design of H. The matrix product H t
+ computes both the source marginal and the target marginal.
+ - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Histogram of dimension dim_b
+ C : np.ndarray, shape (dim_a, dim_b)
+ Cost matrix
+ Returns
+ -------
+ H : np.ndarray (dim_a+dim_b, dim_a*dim_b)
+ Auxiliary matrix constituted by 0 and 1
+ y : np.ndarray (ns + nt, )
+ Concatenation of histogram a and histogram b
+ c : np.ndarray (ns * nt, )
+ Flattened array of cost matrix
+ Examples
+ --------
+ >>> import ot
+ >>> a = np.array([0.2, 0.3, 0.5])
+ >>> b = np.array([0.1, 0.9])
+ >>> C = np.array([[16., 25.], [28., 16.], [40., 36.]])
+ >>> H, y, c = ot.regpath.recast_ot_as_lasso(a, b, C)
+ >>> H.toarray()
+ array([[1., 1., 0., 0., 0., 0.],
+ [0., 0., 1., 1., 0., 0.],
+ [0., 0., 0., 0., 1., 1.],
+ [1., 0., 1., 0., 1., 0.],
+ [0., 1., 0., 1., 0., 1.]])
+ >>> y
+ array([0.2, 0.3, 0.5, 0.1, 0.9])
+ >>> c
+ array([16., 25., 28., 16., 40., 36.])
+
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+
+ dim_a = np.shape(a)[0]
+ dim_b = np.shape(b)[0]
+ y = np.concatenate((a, b))
+ c = C.flatten()
+ jHa = np.arange(dim_a * dim_b)
+ iHa = np.repeat(np.arange(dim_a), dim_b)
+ jHb = np.arange(dim_a * dim_b)
+ iHb = np.tile(np.arange(dim_b), dim_a) + dim_a
+ j = np.concatenate((jHa, jHb))
+ i = np.concatenate((iHa, iHb))
+ H = sp.csc_matrix((np.ones(dim_a * dim_b * 2), (i, j)),
+ shape=(dim_a + dim_b, dim_a * dim_b))
+ return H, y, c
+
+
+def recast_semi_relaxed_as_lasso(a, b, C):
+ r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem
+
+ .. math::
+ semi-relaxed UOT = \min_T <C, T> + \lambda \|T 1_m - a\|_2^2
+ s.t.
+ T^T 1_n = b
+ t \geq 0
+ where :
+ - C is the (dim_a, dim_b) metric cost matrix
+ - :math:`\lambda` is the l2-regularization coefficient
+ - a and b are source and target distributions
+ - T is the transport plan to optimize
+
+ The problem above can be reformulated as follows
+ .. math::
+ semi-relaxed UOT2 = \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2
+ s.t.
+ H_c t = b
+ t \geq 0
+ where :
+ - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
+ - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
+ - H_r is a (dim_a, dim_a * dim_b) metric matrix,
+ which computes the sum along the rows of transport plan T
+ - H_c is a (dim_b, dim_a * dim_b) metric matrix,
+ which computes the sum along the columns of transport plan T
+ - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Histogram of dimension dim_b
+ C : np.ndarray, shape (dim_a, dim_b)
+ Cost matrix
+ Returns
+ -------
+ Hr : np.ndarray (dim_a, dim_a * dim_b)
+ Auxiliary matrix constituted by 0 and 1, which computes
+ the sum along the rows of transport plan T
+ Hc : np.ndarray (dim_b, dim_a * dim_b)
+ Auxiliary matrix constituted by 0 and 1, which computes
+ the sum along the columns of transport plan T
+ c : np.ndarray (ns * nt, )
+ Flattened array of cost matrix
+ Examples
+ --------
+ >>> import ot
+ >>> a = np.array([0.2, 0.3, 0.5])
+ >>> b = np.array([0.1, 0.9])
+ >>> C = np.array([[16., 25.], [28., 16.], [40., 36.]])
+ >>> Hr,Hc,c = ot.regpath.recast_semi_relaxed_as_lasso(a, b, C)
+ >>> Hr.toarray()
+ array([[1., 1., 0., 0., 0., 0.],
+ [0., 0., 1., 1., 0., 0.],
+ [0., 0., 0., 0., 1., 1.]])
+ >>> Hc.toarray()
+ array([[1., 0., 1., 0., 1., 0.],
+ [0., 1., 0., 1., 0., 1.]])
+ >>> c
+ array([16., 25., 28., 16., 40., 36.])
+ """
+
+ dim_a = np.shape(a)[0]
+ dim_b = np.shape(b)[0]
+
+ c = C.flatten()
+ jHr = np.arange(dim_a * dim_b)
+ iHr = np.repeat(np.arange(dim_a), dim_b)
+ jHc = np.arange(dim_a * dim_b)
+ iHc = np.tile(np.arange(dim_b), dim_a)
+
+ Hr = sp.csc_matrix((np.ones(dim_a * dim_b), (iHr, jHr)),
+ shape=(dim_a, dim_a * dim_b))
+ Hc = sp.csc_matrix((np.ones(dim_a * dim_b), (iHc, jHc)),
+ shape=(dim_b, dim_a * dim_b))
+
+ return Hr, Hc, c
+
+
+def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma):
+ r""" This function computes the next value of gamma if a variable
+ will be added in next iteration of the regularization path
+
+ We look for the largest value of gamma such that
+ the gradient of an inactive variable vanishes
+ .. math::
+ \max_{i \in \bar{A}} \frac{h_i^T(H_A \phi - y)}{h_i^T H_A \delta - c_i}
+ where :
+ - A is the current active set
+ - h_i is the ith column of auxiliary matrix H
+ - H_A is the sub-matrix constructed by the columns of H
+ whose indices belong to the active set A
+ - c_i is the ith element of cost vector c
+ - y is the concatenation of source and target distribution
+ - :math:`\phi` is the intercept of the solutions in current iteration
+ - :math:`\delta` is the slope of the solutions in current iteration
+ Parameters
+ ----------
+ phi : np.ndarray (|A|, )
+ Intercept of the solutions in current iteration (t is piecewise linear)
+ delta : np.ndarray (|A|, )
+ Slope of the solutions in current iteration (t is piecewise linear)
+ HtH : np.ndarray (dim_a * dim_b, dim_a * dim_b)
+ Matrix product of H^T H
+ Hty : np.ndarray (dim_a + dim_b, )
+ Matrix product of H^T y
+ c: np.ndarray (dim_a * dim_b, )
+ Flattened array of cost matrix C
+ active_index : list
+ Indices of active variables
+ current_gamma : float
+ Value of regularization coefficient at the start of current iteration
+ Returns
+ -------
+ next_gamma : float
+ Value of gamma if a variable is added to active set in next iteration
+ next_active_index : int
+ Index of variable to be activated
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+ M = (HtH[:, active_index].dot(phi) - Hty) / \
+ (HtH[:, active_index].dot(delta) - c + 1e-16)
+ M[active_index] = 0
+ M[M > (current_gamma - 1e-10 * current_gamma)] = 0
+ return np.max(M), np.argmax(M)
+
+
+def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra,
+ c, active_index, current_gamma):
+ r""" This function computes the next value of gamma when a variable is
+ active in the regularization path of semi-relaxed UOT.
+
+ By taking the Lagrangian form of the problem, we obtain a similar update
+ as the two-sided relaxed UOT
+ .. math::
+ \max_{i \in \bar{A}} \frac{h_{r i}^T(H_{r A} \phi - a) + h_{c i}^T
+ \phi_u}{h_{r i}^T H_{r A} \delta + h_{c i} \delta_u - c_i}
+ where :
+ - A is the current active set
+ - h_{r i} is the ith column of the matrix H_r
+ - h_{c i} is the ith column of the matrix H_c
+ - H_{r A} is the sub-matrix constructed by the columns of H_r
+ whose indices belong to the active set A
+ - c_i is the ith element of cost vector c
+ - y is the concatenation of source and target distribution
+ - :math:`\phi` is the intercept of the solutions in current iteration
+ - :math:`\delta` is the slope of the solutions in current iteration
+ - :math:`\phi_u` is the intercept of Lagrange parameter in current
+ iteration
+ - :math:`\delta_u` is the slope of Lagrange parameter in current iteration
+ Parameters
+ ----------
+ phi : np.ndarray (|A|, )
+ Intercept of the solutions in current iteration (t is piecewise linear)
+ delta : np.ndarray (|A|, )
+ Slope of the solutions in current iteration (t is piecewise linear)
+ phi_u : np.ndarray (dim_b, )
+ Intercept of the Lagrange parameter in current iteration (also linear)
+ delta_u : np.ndarray (dim_b, )
+ Slope of the Lagrange parameter in current iteration (also linear)
+ HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b)
+ Matrix product of H_r^T H_r
+ Hc : np.ndarray (dim_b, dim_a * dim_b)
+ Matrix that computes the sum along the columns of transport plan T
+ Hra : np.ndarray (dim_a * dim_b, )
+ Matrix product of H_r^T a
+ c: np.ndarray (dim_a * dim_b, )
+ Flattened array of cost matrix C
+ active_index : list
+ Indices of active variables
+ current_gamma : float
+ Value of regularization coefficient at the start of current iteration
+ Returns
+ -------
+ next_gamma : float
+ Value of gamma if a variable is added to active set in next iteration
+ next_active_index : int
+ Index of variable to be activated
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+
+ M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \
+ (HrHr[:, active_index].dot(delta) - c + Hc.T.dot(delta_u) + 1e-16)
+ M[active_index] = 0
+ M[M > (current_gamma - 1e-10 * current_gamma)] = 0
+ return np.max(M), np.argmax(M)
+
+
+def compute_next_removal(phi, delta, current_gamma):
+ r""" This function computes the next value of gamma if a variable
+ is removed in next iteration of regularization path
+
+ We look for the largest value of gamma such that
+ an element of current solution vanishes
+ .. math::
+ \max_{j \in A} \frac{\phi_j}{\delta_j}
+ where :
+ - A is the current active set
+ - phi_j is the jth element of the intercept of current solution
+ - delta_j is the jth elemnt of the slope of current solution
+ Parameters
+ ----------
+ phi : np.ndarray (|A|, )
+ Intercept of the solutions in current iteration (t is piecewise linear)
+ delta : np.ndarray (|A|, )
+ Slope of the solutions in current iteration (t is piecewise linear)
+ current_gamma : float
+ Value of regularization coefficient at the start of current iteration
+ Returns
+ -------
+ next_removal_gamma : float
+ Value of gamma if a variable is removed in next iteration
+ next_removal_index : int
+ Index of the variable to remove in next iteration
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+ r_candidate = phi / (delta - 1e-16)
+ r_candidate[r_candidate >= (1 - 1e-8) * current_gamma] = 0
+ return np.max(r_candidate), np.argmax(r_candidate)
+
+
+def complement_schur(M_current, b, d, id_pop):
+ r""" This function computes the inverse of matrix in regularization path
+ using Schur complement
+
+ Two cases may arise: Firstly one variable is added to the active set
+ .. math::
+ M_{k+1}^{-1} =
+ \begin{bmatrix}
+ M_{k}^{-1} + s^{-1} M_{k}^{-1} b b^T M_{k}^{-1} & -s^{-1} \\
+ - s^{-1} b^T M_{k}^{-1} & s^{-1}
+ \end{bmatrix}
+ where :
+ - :math:`M_k^{-1}` is the inverse of matrix in previous iteration and
+ :math:`M_k` is the upper left block matrix in Schur formulation
+ - b is the upper right block matrix in Schur formulation. In our case,
+ b is reduced to a column vector and b^T is the lower left block matrix
+ - s is the Schur complement, given by
+ :math:`s = d - b^T M_{k}^{-1} b` in our case
+
+ Secondly, one variable is removed from the active set
+ .. math::
+ M_{k+1}^{-1} = M^{-1}_{A_k \backslash q} -
+ \frac{r_{-q,q} r^{T}_{-q,q}}{r_{q,q}}
+ where :
+ - q is the index of column and row to delete
+ - :math:`M^{-1}_{A_k \backslash q}` is the previous inverse matrix
+ without qth column and qth row
+ - r_{-q,q} is the qth column of :math:`M^{-1}_{k}` without the qth element
+ - r_{q, q} is the element of qth column and qth row in :math:`M^{-1}_{k}`
+ Parameters
+ ----------
+ M_current : np.ndarray (|A|-1, |A|-1)
+ Inverse matrix in previous iteration
+ b : np.ndarray (|A|-1, )
+ Upper right matrix in Schur complement, a column vector in our case
+ d : float
+ Lower right matrix in Schur complement, a scalar in our case
+ id_pop
+ Index of the variable to be removed, equal to -1
+ if none of the variables is deleted in current iteration
+ Returns
+ -------
+ M : np.ndarray (|A|, |A|)
+ Inverse matrix needed in current iteration
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+ if b is None:
+ b = M_current[id_pop, :]
+ b = np.delete(b, id_pop)
+ M_del = np.delete(M_current, id_pop, 0)
+ a = M_del[:, id_pop]
+ M_del = np.delete(M_del, id_pop, 1)
+ M = M_del - np.outer(a, b) / M_current[id_pop, id_pop]
+ else:
+ n = b.shape[0] + 1
+ if np.shape(b)[0] == 0:
+ M = np.array([[0.5]])
+ else:
+ X = M_current.dot(b)
+ s = d - b.T.dot(X)
+ M = np.zeros((n, n))
+ M[:-1, :-1] = M_current + X.dot(X.T) / s
+ X_ravel = X.ravel()
+ M[-1, :-1] = -X_ravel / s
+ M[:-1, -1] = -X_ravel / s
+ M[-1, -1] = 1 / s
+ return M
+
+
+def construct_augmented_H(active_index, m, Hc, HrHr):
+ r""" This function construct an augmented matrix for the first iteration of
+ semi-relaxed regularization path
+
+ .. math::
+ Augmented_H =
+ \begin{bmatrix}
+ 0 & H_{c A} \\
+ H_{c A}^T & H_{r A}^T H_{r A}
+ \end{bmatrix}
+ where :
+ - H_{r A} is the sub-matrix constructed by the columns of H_r
+ whose indices belong to the active set A
+ - H_{c A} is the sub-matrix constructed by the columns of H_c
+ whose indices belong to the active set A
+ Parameters
+ ----------
+ active_index : list
+ Indices of active variables
+ m : int
+ Length of the target distribution
+ Hc : np.ndarray (dim_b, dim_a * dim_b)
+ Matrix that computes the sum along the columns of transport plan T
+ HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b)
+ Matrix product of H_r^T H_r
+ Returns
+ -------
+ H_augmented : np.ndarray (dim_b + |A|, dim_b + |A|)
+ Augmented matrix for the first iteration of the semi-relaxed
+ regularization path
+ """
+ Hc_sub = Hc[:, active_index].toarray()
+ HrHr_sub = HrHr[:, active_index]
+ HrHr_sub = HrHr_sub[active_index, :].toarray()
+ H_augmented = np.block([[np.zeros((m, m)), Hc_sub], [Hc_sub.T, HrHr_sub]])
+ return H_augmented
+
+
+def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
+ itmax=50000):
+ r"""This function gives the regularization path of l2-penalized UOT problem
+
+ The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+ .. math::
+ \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2
+ s.t.
+ t \geq 0
+ where :
+ - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
+ - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
+ - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T]
+ - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix,
+ see [Chapel et al., 2021] for the design of H. The matrix product Ht
+ computes both the source marginal and the target marginal.
+ - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Histogram of dimension dim_b
+ C : np.ndarray, shape (dim_a, dim_b)
+ Cost matrix
+ reg: float
+ l2-regularization coefficient
+ itmax: int
+ Maximum number of iteration
+ Returns
+ -------
+ t : np.ndarray (dim_a*dim_b, )
+ Flattened vector of optimal transport matrix
+ t_list : list
+ List of solutions in regularization path
+ gamma_list : list
+ List of regularization coefficient in regularization path
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> n = 3
+ >>> xs = np.array([1., 2., 3.]).reshape((n, 1))
+ >>> xt = np.array([5., 6., 7.]).reshape((n, 1))
+ >>> C = ot.dist(xs, xt)
+ >>> C /= C.max()
+ >>> a = np.array([0.2, 0.5, 0.3])
+ >>> b = np.array([0.2, 0.5, 0.3])
+ >>> t, _, _ = ot.regpath.fully_relaxed_path(a, b, C, 1e-4)
+ >>> t
+ array([1.99958333e-01, 0.00000000e+00, 0.00000000e+00, 3.88888889e-05,
+ 4.99938889e-01, 0.00000000e+00, 0.00000000e+00, 3.88888889e-05,
+ 2.99958333e-01])
+
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+
+ n = np.shape(a)[0]
+ m = np.shape(b)[0]
+ H, y, c = recast_ot_as_lasso(a, b, C)
+ HtH = H.T.dot(H)
+ Hty = H.T.dot(y)
+ n_iter = 1
+
+ # initialization
+ M0 = Hty / c
+ gamma_list = [np.max(M0)]
+ active_index = [np.argmax(M0)]
+ t_list = [np.zeros((n * m,))]
+ H_inv = np.array([[]])
+ add_col = np.array([])
+ id_pop = -1
+
+ while n_iter < itmax and gamma_list[-1] > reg:
+ H_inv = complement_schur(H_inv, add_col, 2., id_pop)
+ current_gamma = gamma_list[-1]
+
+ # compute the intercept and slope of solutions in current iteration
+ # t = phi - gamma * delta
+ phi = H_inv.dot(Hty[active_index])
+ delta = H_inv.dot(c[active_index])
+ gamma, ik = ot_next_gamma(phi, delta, HtH, Hty, c, active_index,
+ current_gamma)
+
+ # compute the next lambda when removing a point from the active set
+ alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma)
+
+ # if the positivity constraint is violated, we remove id_pop
+ # from active set, otherwise we add ik to active set
+ if alt_gamma > gamma:
+ gamma = alt_gamma
+ else:
+ id_pop = -1
+
+ # compute the solution of current segment
+ tA = phi - gamma * delta
+ sol = np.zeros((n * m, ))
+ sol[active_index] = tA
+
+ if id_pop != -1:
+ active_index.pop(id_pop)
+ add_col = None
+ else:
+ active_index.append(ik)
+ add_col = HtH[active_index[:-1], ik].toarray()
+
+ gamma_list.append(gamma)
+ t_list.append(sol)
+ n_iter += 1
+
+ if itmax <= n_iter:
+ print('maximum iteration has been reached !')
+
+ # correct the last solution and gamma
+ if len(t_list) > 1:
+ t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) *
+ (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2]))
+ t_list[-1] = t_final
+ gamma_list[-1] = reg
+ else:
+ gamma_list[-1] = reg
+ print('Regularization path does not exist !')
+
+ return t_list[-1], t_list, gamma_list
+
+
+def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
+ itmax=50000):
+ r"""This function gives the regularization path of semi-relaxed
+ l2-UOT problem
+
+ The problem to optimize is the Lasso reformulation of the l2-penalized UOT:
+ .. math::
+ \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2
+ s.t.
+ H_c t = b
+ t \geq 0
+ where :
+ - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C)
+ - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient
+ - H_r is a (dim_a, dim_a * dim_b) metric matrix,
+ which computes the sum along the rows of transport plan T
+ - H_c is a (dim_b, dim_a * dim_b) metric matrix,
+ which computes the sum along the columns of transport plan T
+ - t is a (dim_a * dim_b, ) metric vector (flattened version of T)
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Histogram of dimension dim_b
+ C : np.ndarray, shape (dim_a, dim_b)
+ Cost matrix
+ reg: float (optional)
+ l2-regularization coefficient
+ itmax: int (optional)
+ Maximum number of iteration
+ Returns
+ -------
+ t : np.ndarray (dim_a*dim_b, )
+ Flattened vector of optimal transport matrix
+ t_list : list
+ List of solutions in regularization path
+ gamma_list : list
+ List of regularization coefficient in regularization path
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> n = 3
+ >>> xs = np.array([1., 2., 3.]).reshape((n, 1))
+ >>> xt = np.array([5., 6., 7.]).reshape((n, 1))
+ >>> C = ot.dist(xs, xt)
+ >>> C /= C.max()
+ >>> a = np.array([0.2, 0.5, 0.3])
+ >>> b = np.array([0.2, 0.5, 0.3])
+ >>> t, _, _ = ot.regpath.semi_relaxed_path(a, b, C, 1e-4)
+ >>> t
+ array([1.99980556e-01, 0.00000000e+00, 0.00000000e+00, 1.94444444e-05,
+ 4.99980556e-01, 0.00000000e+00, 0.00000000e+00, 1.94444444e-05,
+ 3.00000000e-01])
+
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+
+ n = np.shape(a)[0]
+ m = np.shape(b)[0]
+ Hr, Hc, c = recast_semi_relaxed_as_lasso(a, b, C)
+ Hra = Hr.T.dot(a)
+ HrHr = Hr.T.dot(Hr)
+ n_iter = 1
+ active_index = []
+
+ # initialization
+ for j in range(np.shape(C)[1]):
+ i = np.argmin(C[:, j])
+ active_index.append(i * m + j)
+ gamma_list = []
+ t_list = []
+ current_gamma = np.Inf
+ augmented_H0 = construct_augmented_H(active_index, m, Hc, HrHr)
+ add_col = np.array([])
+ id_pop = -1
+
+ while n_iter < itmax and current_gamma > reg:
+ if n_iter == 1:
+ H_inv = np.linalg.inv(augmented_H0)
+ else:
+ H_inv = complement_schur(H_inv, add_col, 1., id_pop + m)
+ # compute the intercept and slope of solutions in current iteration
+ augmented_phi = H_inv.dot(np.concatenate((b, Hra[active_index])))
+ augmented_delta = H_inv[:, m:].dot(c[active_index])
+ phi = augmented_phi[m:]
+ delta = augmented_delta[m:]
+ phi_u = augmented_phi[0:m]
+ delta_u = augmented_delta[0:m]
+ gamma, ik = semi_relaxed_next_gamma(phi, delta, phi_u, delta_u,
+ HrHr, Hc, Hra, c, active_index,
+ current_gamma)
+
+ # compute the next lambda when removing a point from the active set
+ alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma)
+
+ # if the positivity constraint is violated, we remove id_pop
+ # from active set, otherwise we add ik to active set
+ if alt_gamma > gamma:
+ gamma = alt_gamma
+ else:
+ id_pop = -1
+
+ # compute the solution of current segment
+ tA = phi - gamma * delta
+ sol = np.zeros((n * m, ))
+ sol[active_index] = tA
+ if id_pop != -1:
+ active_index.pop(id_pop)
+ add_col = None
+ else:
+ active_index.append(ik)
+ add_col = np.concatenate((Hc.toarray()[:, ik],
+ HrHr.toarray()[active_index[:-1], ik]))
+ add_col = add_col[:, np.newaxis]
+
+ gamma_list.append(gamma)
+ t_list.append(sol)
+ current_gamma = gamma
+ n_iter += 1
+
+ if itmax <= n_iter:
+ print('maximum iteration has been reached !')
+
+ # correct the last solution and gamma
+ if len(t_list) > 1:
+ t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) *
+ (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2]))
+ t_list[-1] = t_final
+ gamma_list[-1] = reg
+ else:
+ gamma_list[-1] = reg
+ print('Regularization path does not exist !')
+
+ return t_list[-1], t_list, gamma_list
+
+
+def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4,
+ semi_relaxed=False, itmax=50000):
+ r"""This function combines both the semi-relaxed and the fully-relaxed
+ regularization paths of l2-UOT problem
+
+ Parameters
+ ----------
+ a : np.ndarray (dim_a,)
+ Histogram of dimension dim_a
+ b : np.ndarray (dim_b,)
+ Histogram of dimension dim_b
+ C : np.ndarray, shape (dim_a, dim_b)
+ Cost matrix
+ reg: float (optional)
+ l2-regularization coefficient
+ semi_relaxed : bool (optional)
+ Give the semi-relaxed path if true
+ itmax: int (optional)
+ Maximum number of iteration
+ Returns
+ -------
+ t : np.ndarray (dim_a*dim_b, )
+ Flattened vector of optimal transport matrix
+ t_list : list
+ List of solutions in regularization path
+ gamma_list : list
+ List of regularization coefficient in regularization path
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+ if semi_relaxed:
+ t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg,
+ itmax=itmax)
+ else:
+ t, t_list, gamma_list = fully_relaxed_path(a, b, C, reg=reg,
+ itmax=itmax)
+ return t, t_list, gamma_list
+
+
+def compute_transport_plan(gamma, gamma_list, Pi_list):
+ r""" Given the regularization path, this function computes the transport
+ plan for any value of gamma by the piecewise linearity of the path
+
+ .. math::
+ t(\gamma) = \phi(\gamma) - \gamma \delta(\gamma)
+ where :
+ - :math:`\gamma` is the regularization coefficient
+ - :math:`\phi(\gamma)` is the corresponding intercept
+ - :math:`\delta(\gamma)` is the corresponding slope
+ - t is a (dim_a * dim_b, ) vector (flattened version of transport matrix)
+ Parameters
+ ----------
+ gamma : float
+ Regularization coefficient
+ gamma_list : list
+ List of regularization coefficients in regularization path
+ Pi_list : list
+ List of solutions in regularization path
+ Returns
+ -------
+ t : np.ndarray (dim_a*dim_b, )
+ Transport vector corresponding to the given value of gamma
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> n = 3
+ >>> xs = np.array([1., 2., 3.]).reshape((n, 1))
+ >>> xt = np.array([5., 6., 7.]).reshape((n, 1))
+ >>> C = ot.dist(xs, xt)
+ >>> C /= C.max()
+ >>> a = np.array([0.2, 0.5, 0.3])
+ >>> b = np.array([0.2, 0.5, 0.3])
+ >>> t, pi_list, g_list = ot.regpath.regularization_path(a, b, C, reg=1e-4)
+ >>> gamma = 1
+ >>> t2 = ot.regpath.compute_transport_plan(gamma, g_list, pi_list)
+ >>> t2
+ array([0. , 0. , 0. , 0.19722222, 0.05555556,
+ 0. , 0. , 0.24722222, 0. ])
+
+ References
+ ----------
+ [Chapel et al., 2021]:
+ Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression.
+ """
+
+ if gamma >= gamma_list[0]:
+ Pi = Pi_list[0]
+ elif gamma <= gamma_list[-1]:
+ Pi = Pi_list[-1]
+ else:
+ idx = np.where(gamma <= np.array(gamma_list))[0][-1]
+ gamma_k0 = gamma_list[idx]
+ gamma_k1 = gamma_list[idx + 1]
+ pi_k0 = Pi_list[idx]
+ pi_k1 = Pi_list[idx + 1]
+ Pi = pi_k0 + (pi_k1 - pi_k0) * (gamma - gamma_k0) \
+ / (gamma_k1 - gamma_k0)
+ return Pi
diff --git a/ot/sliced.py b/ot/sliced.py
new file mode 100644
index 0000000..cf2d3be
--- /dev/null
+++ b/ot/sliced.py
@@ -0,0 +1,258 @@
+"""
+Sliced OT Distances
+
+"""
+
+# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+
+import numpy as np
+from .backend import get_backend, NumpyBackend
+from .utils import list_to_array
+
+
+def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None):
+ r"""
+ Generates n_projections samples from the uniform on the unit sphere of dimension :math:`d-1`: :math:`\mathcal{U}(\mathcal{S}^{d-1})`
+
+ Parameters
+ ----------
+ d : int
+ dimension of the space
+ n_projections : int
+ number of samples requested
+ seed: int or RandomState, optional
+ Seed used for numpy random number generator
+ backend:
+ Backend to ue for random generation
+
+ Returns
+ -------
+ out: ndarray, shape (d, n_projections)
+ The uniform unit vectors on the sphere
+
+ Examples
+ --------
+ >>> n_projections = 100
+ >>> d = 5
+ >>> projs = get_random_projections(d, n_projections)
+ >>> np.allclose(np.sum(np.square(projs), 0), 1.) # doctest: +NORMALIZE_WHITESPACE
+ True
+
+ """
+
+ if backend is None:
+ nx = NumpyBackend()
+ else:
+ nx = backend
+
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ projections = seed.randn(d, n_projections)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ projections = nx.randn(d, n_projections, type_as=type_as)
+
+ projections = projections / nx.sqrt(nx.sum(projections**2, 0, keepdims=True))
+ return projections
+
+
+def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
+ projections=None, seed=None, log=False):
+ r"""
+ Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance
+
+ .. math::
+ \mathcal{SWD}_p(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}\left(\mathcal{W}_p^p(\theta_\# \mu, \theta_\# \nu)\right)^{\frac{1}{p}}
+
+
+ where :
+
+ - :math:`\theta_\# \mu` stands for the pushforwards of the projection :math:`X \in \mathbb{R}^d \mapsto \langle \theta, X \rangle`
+
+
+ Parameters
+ ----------
+ X_s : ndarray, shape (n_samples_a, dim)
+ samples in the source domain
+ X_t : ndarray, shape (n_samples_b, dim)
+ samples in the target domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,), optional
+ samples weights in the target domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ p: float, optional =
+ Power p used for computing the sliced Wasserstein
+ projections: shape (dim, n_projections), optional
+ Projection matrix (n_projections and seed are not used in this case)
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Sliced Wasserstein Cost
+ log : dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_samples_a = 20
+ >>> reg = 0.1
+ >>> X = np.random.normal(0., 1., (n_samples_a, 5))
+ >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
+ 0.0
+
+ References
+ ----------
+
+ .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
+ """
+ from .lp import wasserstein_1d
+
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ if a is not None and b is not None and projections is None:
+ nx = get_backend(X_s, X_t, a, b)
+ elif a is not None and b is not None and projections is not None:
+ nx = get_backend(X_s, X_t, a, b, projections)
+ elif a is None and b is None and projections is not None:
+ nx = get_backend(X_s, X_t, projections)
+ else:
+ nx = get_backend(X_s, X_t)
+
+ n = X_s.shape[0]
+ m = X_t.shape[0]
+
+ if X_s.shape[1] != X_t.shape[1]:
+ raise ValueError(
+ "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
+ X_t.shape[1]))
+
+ if a is None:
+ a = nx.full(n, 1 / n, type_as=X_s)
+ if b is None:
+ b = nx.full(m, 1 / m, type_as=X_s)
+
+ d = X_s.shape[1]
+
+ if projections is None:
+ projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
+
+ X_s_projections = nx.dot(X_s, projections)
+ X_t_projections = nx.dot(X_t, projections)
+
+ projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p)
+
+ res = (nx.sum(projected_emd) / n_projections) ** (1.0 / p)
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
+
+
+def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
+ projections=None, seed=None, log=False):
+ r"""
+ Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance
+
+ .. math::
+ \mathcal{Max-SWD}_p(\mu, \nu) = \underset{\theta _in
+ \mathcal{U}(\mathbb{S}^{d-1})}{\max} [\mathcal{W}_p^p(\theta_\#
+ \mu, \theta_\# \nu)]^{\frac{1}{p}}
+
+ where :
+
+ - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle`
+
+
+ Parameters
+ ----------
+ X_s : ndarray, shape (n_samples_a, dim)
+ samples in the source domain
+ X_t : ndarray, shape (n_samples_b, dim)
+ samples in the target domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,), optional
+ samples weights in the target domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ p: float, optional =
+ Power p used for computing the sliced Wasserstein
+ projections: shape (dim, n_projections), optional
+ Projection matrix (n_projections and seed are not used in this case)
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Sliced Wasserstein Cost
+ log : dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> n_samples_a = 20
+ >>> reg = 0.1
+ >>> X = np.random.normal(0., 1., (n_samples_a, 5))
+ >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
+ 0.0
+
+ References
+ ----------
+
+ .. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). Max-sliced wasserstein distance and its use for gans. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
+ """
+ from .lp import wasserstein_1d
+
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ if a is not None and b is not None and projections is None:
+ nx = get_backend(X_s, X_t, a, b)
+ elif a is not None and b is not None and projections is not None:
+ nx = get_backend(X_s, X_t, a, b, projections)
+ elif a is None and b is None and projections is not None:
+ nx = get_backend(X_s, X_t, projections)
+ else:
+ nx = get_backend(X_s, X_t)
+
+ n = X_s.shape[0]
+ m = X_t.shape[0]
+
+ if X_s.shape[1] != X_t.shape[1]:
+ raise ValueError(
+ "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
+ X_t.shape[1]))
+
+ if a is None:
+ a = nx.full(n, 1 / n, type_as=X_s)
+ if b is None:
+ b = nx.full(m, 1 / m, type_as=X_s)
+
+ d = X_s.shape[1]
+
+ if projections is None:
+ projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
+
+ X_s_projections = nx.dot(X_s, projections)
+ X_t_projections = nx.dot(X_t, projections)
+
+ projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p)
+
+ res = nx.max(projected_emd) ** (1.0 / p)
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
diff --git a/ot/smooth.py b/ot/smooth.py
index 81f6a3e..6855005 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -47,15 +47,24 @@ from scipy.optimize import minimize
def projection_simplex(V, z=1, axis=None):
- """ Projection of x onto the simplex, scaled by z
+ r""" Projection of :math:`\mathbf{V}` onto the simplex, scaled by `z`
- P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
+ .. math::
+ P\left(\mathbf{V}, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z}} \quad \|\mathbf{y} - \mathbf{V}\|^2
+
+ Parameters
+ ----------
+ V: ndarray, rank 2
z: float or array
- If array, len(z) must be compatible with V
+ If array, len(z) must be compatible with :math:`\mathbf{V}`
axis: None or int
- - axis=None: project V by P(V.ravel(); z)
- - axis=1: project each V[i] by P(V[i]; z[i])
- - axis=0: project each V[:, j] by P(V[:, j]; z[j])
+ - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), z)`
+ - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, z_i)`
+ - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, z_j)`
+
+ Returns
+ -------
+ projection: ndarray, shape :math:`\mathbf{V}`.shape
"""
if axis == 1:
n_features = V.shape[1]
@@ -77,12 +86,12 @@ def projection_simplex(V, z=1, axis=None):
class Regularization(object):
- """Base class for Regularization objects
+ r"""Base class for Regularization objects
Notes
-----
- This class is not intended for direct use but as aparent for true
- regularizatiojn implementation.
+ This class is not intended for direct use but as apparent for true
+ regularization implementation.
"""
def __init__(self, gamma=1.0):
@@ -98,40 +107,48 @@ class Regularization(object):
self.gamma = gamma
def delta_Omega(X):
- """
- Compute delta_Omega(X[:, j]) for each X[:, j].
- delta_Omega(x) = sup_{y >= 0} y^T x - Omega(y).
+ r"""
+ Compute :math:`\delta_\Omega(\mathbf{X}_{:, j})` for each :math:`\mathbf{X}_{:, j}`.
+
+ .. math::
+ \delta_\Omega(\mathbf{x}) = \sup_{\mathbf{y} >= 0} \
+ \mathbf{y}^T \mathbf{x} - \Omega(\mathbf{y})
Parameters
----------
- X: array, shape = len(a) x len(b)
+ X: array, shape = (len(a), len(b))
Input array.
Returns
-------
- v: array, len(b)
- Values: v[j] = delta_Omega(X[:, j])
- G: array, len(a) x len(b)
- Gradients: G[:, j] = nabla delta_Omega(X[:, j])
+ v: array, (len(b), )
+ Values: :math:`\mathbf{v}_j = \delta_\Omega(\mathbf{X}_{:, j})`
+ G: array, (len(a), len(b))
+ Gradients: :math:`\mathbf{G}_{:, j} = \nabla \delta_\Omega(\mathbf{X}_{:, j})`
"""
raise NotImplementedError
def max_Omega(X, b):
- """
- Compute max_Omega_j(X[:, j]) for each X[:, j].
- max_Omega_j(x) = sup_{y >= 0, sum(y) = 1} y^T x - Omega(b[j] y) / b[j].
+ r"""
+ Compute :math:`\mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})` for each :math:`\mathbf{X}_{:, j}`.
+
+ .. math::
+ \mathrm{max}_{\Omega, j}(\mathbf{x}) =
+ \sup_{\substack{\mathbf{y} >= 0 \ \sum_i \mathbf{y}_i = 1}}
+ \mathbf{y}^T \mathbf{x} - \frac{1}{\mathbf{b}_j} \Omega(\mathbf{b}_j \mathbf{y})
Parameters
----------
- X: array, shape = len(a) x len(b)
+ X: array, shape = (len(a), len(b))
Input array.
+ b: array, shape = (len(b), )
Returns
-------
- v: array, len(b)
- Values: v[j] = max_Omega_j(X[:, j])
- G: array, len(a) x len(b)
- Gradients: G[:, j] = nabla max_Omega_j(X[:, j])
+ v: array, (len(b), )
+ Values: :math:`\mathbf{v}_j = \mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})`
+ G: array, (len(a), len(b))
+ Gradients: :math:`\mathbf{G}_{:, j} = \nabla \mathrm{max}_{\Omega, j}(\mathbf{X}_{:, j})`
"""
raise NotImplementedError
@@ -192,7 +209,7 @@ class SquaredL2(Regularization):
def dual_obj_grad(alpha, beta, a, b, C, regul):
- """
+ r"""
Compute objective value and gradients of dual objective.
Parameters
@@ -203,19 +220,19 @@ def dual_obj_grad(alpha, beta, a, b, C, regul):
a: array, shape = len(a)
b: array, shape = len(b)
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
Returns
-------
obj: float
Objective value (higher is better).
grad_alpha: array, shape = len(a)
- Gradient w.r.t. alpha.
+ Gradient w.r.t. `alpha`.
grad_beta: array, shape = len(b)
- Gradient w.r.t. beta.
+ Gradient w.r.t. `beta`.
"""
obj = np.dot(alpha, a) + np.dot(beta, b)
grad_alpha = a.copy()
@@ -242,13 +259,13 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Parameters
----------
- a: array, shape = len(a)
- b: array, shape = len(b)
+ a: array, shape = (len(a), )
+ b: array, shape = (len(b), )
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
method: str
Solver to be used (passed to `scipy.optimize.minimize`).
tol: float
@@ -258,8 +275,8 @@ def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Returns
-------
- alpha: array, shape = len(a)
- beta: array, shape = len(b)
+ alpha: array, shape = (len(a), )
+ beta: array, shape = (len(b), )
Dual potentials.
"""
@@ -302,10 +319,10 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
a: array, shape = len(a)
b: array, shape = len(b)
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a max_Omega(X) method.
+ Should implement a `max_Omega(X)` method.
Returns
-------
@@ -337,13 +354,13 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Parameters
----------
- a: array, shape = len(a)
- b: array, shape = len(b)
+ a: array, shape = (len(a), )
+ b: array, shape = (len(b), )
Input histograms (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a max_Omega(X) method.
+ Should implement a `max_Omega(X)` method.
method: str
Solver to be used (passed to `scipy.optimize.minimize`).
tol: float
@@ -353,7 +370,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
Returns
-------
- alpha: array, shape = len(a)
+ alpha: array, shape = (len(a), )
Semi-dual potentials.
"""
@@ -371,7 +388,7 @@ def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
def get_plan_from_dual(alpha, beta, C, regul):
- """
+ r"""
Retrieve optimal transportation plan from optimal dual potentials.
Parameters
@@ -379,14 +396,14 @@ def get_plan_from_dual(alpha, beta, C, regul):
alpha: array, shape = len(a)
beta: array, shape = len(b)
Optimal dual potentials.
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
Returns
-------
- T: array, shape = len(a) x len(b)
+ T: array, shape = (len(a), len(b))
Optimal transportation plan.
"""
X = alpha[:, np.newaxis] + beta - C
@@ -394,7 +411,7 @@ def get_plan_from_dual(alpha, beta, C, regul):
def get_plan_from_semi_dual(alpha, b, C, regul):
- """
+ r"""
Retrieve optimal transportation plan from optimal semi-dual potentials.
Parameters
@@ -403,14 +420,14 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
Optimal semi-dual potentials.
b: array, shape = len(b)
Second input histogram (should be non-negative and sum to 1).
- C: array, shape = len(a) x len(b)
+ C: array, shape = (len(a), len(b))
Ground cost matrix.
regul: Regularization object
- Should implement a delta_Omega(X) method.
+ Should implement a `delta_Omega(X)` method.
Returns
-------
- T: array, shape = len(a) x len(b)
+ T: array, shape = (len(a), len(b))
Optimal transportation plan.
"""
X = alpha[:, np.newaxis] - C
@@ -422,19 +439,21 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
r"""
Solve the regularized OT problem in the dual and return the OT matrix
- The function solves the smooth relaxed dual formulation (7) in [17]_ :
+ The function solves the smooth relaxed dual formulation (7) in
+ :ref:`[17] <references-smooth-ot-dual>`:
.. math::
- \max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j)
+ \max_{\alpha,\beta}\quad \mathbf{a}^T\alpha + \mathbf{b}^T\beta -
+ \sum_j \delta_\Omega \left(\alpha+\beta_j-\mathbf{m}_j \right)
where :
- - :math:`\mathbf{m}_j` is the jth column of the cost matrix
+ - :math:`\mathbf{m}_j` is the j-th column of the cost matrix
- :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega`
- (See [17]_ Proposition 1).
+ (See :ref:`[17] <references-smooth-ot-dual>` Proposition 1).
The optimization algorithm is using gradient decent (L-BFGS by default).
@@ -444,21 +463,25 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
samples weights in the source domain
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
M : np.ndarray (ns,nt)
loss matrix
reg : float
Regularization term >0
reg_type : str
- Regularization type, can be the following (default ='l2'):
- - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
- - 'l2' : Squared Euclidean regularization
+ Regularization type, can be the following (default ='l2'):
+
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn
+ :ref:`[2] <references-smooth-ot-dual>`)
+
+ - 'l2' : Squared Euclidean regularization
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -467,15 +490,15 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-smooth-ot-dual:
References
----------
-
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
@@ -514,21 +537,23 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
r"""
Solve the regularized OT problem in the semi-dual and return the OT matrix
- The function solves the smooth relaxed dual formulation (10) in [17]_ :
+ The function solves the smooth relaxed dual formulation (10) in
+ :ref:`[17] <references-smooth-ot-semi-dual>`:
.. math::
- \max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b)
+ \max_{\alpha}\quad \mathbf{a}^T\alpha- \mathrm{OT}_\Omega^*(\alpha, \mathbf{b})
where :
.. math::
- OT_\Omega^*(\alpha,b)=\sum_j b_j
+ \mathrm{OT}_\Omega^*(\alpha,b)=\sum_j \mathbf{b}_j
- - :math:`\mathbf{m}_j` is the jth column of the cost matrix
- - :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17]
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{m}_j` is the j-th column of the cost matrix
+ - :math:`\mathrm{OT}_\Omega^*(\alpha,b)` is defined in Eq. (9) in
+ :ref:`[17] <references-smooth-ot-semi-dual>`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
- The OT matrix can is reconstructed using [17]_ Proposition 2.
+ The OT matrix can is reconstructed using :ref:`[17] <references-smooth-ot-semi-dual>` Proposition 2.
The optimization algorithm is using gradient decent (L-BFGS by default).
@@ -538,21 +563,25 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
samples weights in the source domain
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
+ and fixed:math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
M : np.ndarray (ns,nt)
loss matrix
reg : float
Regularization term >0
reg_type : str
- Regularization type, can be the following (default ='l2'):
- - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
- - 'l2' : Squared Euclidean regularization
+ Regularization type, can be the following (default ='l2'):
+
+ - 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn
+ :ref:`[2] <references-smooth-ot-semi-dual>`)
+
+ - 'l2' : Squared Euclidean regularization
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -561,15 +590,15 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
+ .. _references-smooth-ot-semi-dual:
References
----------
-
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
diff --git a/ot/stochastic.py b/ot/stochastic.py
index 13ed9cc..693675f 100644
--- a/ot/stochastic.py
+++ b/ot/stochastic.py
@@ -18,22 +18,25 @@ import numpy as np
def coordinate_grad_semi_dual(b, M, reg, beta, i):
r'''
- Compute the coordinate gradient update for regularized discrete distributions for (i, :)
+ Compute the coordinate gradient update for regularized discrete distributions for :math:`(i, :)`
The function computes the gradient of the semi dual problem:
.. math::
- \max_v \sum_i (\sum_j v_j * b_j - reg * log(\sum_j exp((v_j - M_{i,j})/reg) * b_j)) * a_i
+ \max_\mathbf{v} \ \sum_i \mathbf{a}_i \left[ \sum_j \mathbf{v}_j \mathbf{b}_j - \mathrm{reg}
+ \cdot \log \left( \sum_j \mathbf{b}_j
+ \exp \left( \frac{\mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}}
+ \right) \right) \right]
Where :
- - M is the (ns,nt) metric cost matrix
- - v is a dual variable in R^J
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{v}` is a dual variable in :math:`\mathbb{R}^{nt}`
- reg is the regularization term
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the ASGD & SAG algorithms
- as proposed in [18]_ [alg.1 & alg.2]
+ as proposed in :ref:`[18] <references-coordinate-grad-semi-dual>` [alg.1 & alg.2]
Parameters
@@ -47,7 +50,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
v : ndarray, shape (nt,)
Dual variable.
i : int
- Picked number i.
+ Picked number `i`.
Returns
-------
@@ -74,12 +77,10 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+ .. _references-coordinate-grad-semi-dual:
References
----------
- [Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).
'''
r = M[i, :] - beta
exp_beta = np.exp(-r / reg) * b
@@ -88,29 +89,29 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
- r'''
- Compute the SAG algorithm to solve the regularized discrete measures
- optimal transport max problem
+ r"""
+ Compute the SAG algorithm to solve the regularized discrete measures optimal transport max problem
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1 = b
+ \gamma^T \mathbf{1} = \mathbf{b}
\gamma \geq 0
Where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the SAG algorithm
- as proposed in [18]_ [alg.1]
+ as proposed in :ref:`[18] <references-sag-entropic-transport>` [alg.1]
Parameters
@@ -131,7 +132,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
Returns
-------
- v : ndarray, shape (nt,)
+ v : ndarray, shape (`nt`,)
Dual variable.
Examples
@@ -154,14 +155,12 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ .. _references-sag-entropic-transport:
References
----------
-
- [Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
- '''
+ .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).
+ """
if lr is None:
lr = 1. / max(a / reg)
@@ -187,22 +186,23 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg}\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
\gamma \geq 0
Where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the ASGD algorithm
- as proposed in [18]_ [alg.2]
+ as proposed in :ref:`[18] <references-averaged-sgd-entropic-transport>` [alg.2]
Parameters
@@ -220,7 +220,7 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
Returns
-------
- ave_v : ndarray, shape (nt,)
+ ave_v : ndarray, shape (`nt`,)
dual variable
Examples
@@ -243,13 +243,11 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ .. _references-averaged-sgd-entropic-transport:
References
----------
-
- [Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).
'''
if lr is None:
@@ -271,20 +269,21 @@ def c_transform_entropic(b, M, reg, beta):
r'''
The goal is to recover u from the c-transform.
- The function computes the c_transform of a dual variable from the other
+ The function computes the c-transform of a dual variable from the other
dual variable:
.. math::
- u = v^{c,reg} = -reg \sum_j exp((v - M)/reg) b_j
+ \mathbf{u} = \mathbf{v}^{c,reg} = - \mathrm{reg} \sum_j \mathbf{b}_j
+ \exp\left( \frac{\mathbf{v} - \mathbf{M}}{\mathrm{reg}} \right)
Where :
- - M is the (ns,nt) metric cost matrix
- - u, v are dual variables in R^IxR^J
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{u}`, :math:`\mathbf{v}` are dual variables in :math:`\mathbb{R}^{ns} \times \mathbb{R}^{nt}`
- reg is the regularization term
It is used to recover an optimal u from optimal v solving the semi dual
- problem, see Proposition 2.1 of [18]_
+ problem, see Proposition 2.1 of :ref:`[18] <references-c-transform-entropic>`
Parameters
@@ -300,7 +299,7 @@ def c_transform_entropic(b, M, reg, beta):
Returns
-------
- u : ndarray, shape (ns,)
+ u : ndarray, shape (`ns`,)
Dual variable.
Examples
@@ -323,13 +322,11 @@ def c_transform_entropic(b, M, reg, beta):
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ .. _references-c-transform-entropic:
References
----------
-
- [Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).
'''
n_source = np.shape(M)[0]
@@ -345,27 +342,28 @@ def c_transform_entropic(b, M, reg, beta):
def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
log=False):
r'''
- Compute the transportation matrix to solve the regularized discrete
- measures optimal transport max problem
+ Compute the transportation matrix to solve the regularized discrete measures optimal transport max problem
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
\gamma \geq 0
Where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+
The algorithm used for solving the problem is the SAG or ASGD algorithms
- as proposed in [18]_
+ as proposed in :ref:`[18] <references-solve-semi-dual-entropic>`
Parameters
@@ -419,13 +417,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
[2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
[4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
+
+ .. _references-solve-semi-dual-entropic:
References
----------
-
- [Genevay et al., 2016] :
- Stochastic Optimization for Large-scale Optimal Transport,
- Advances in Neural Information Processing Systems (2016),
- arXiv preprint arxiv:1605.08527.
+ .. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).
'''
if method.lower() == "sag":
@@ -459,26 +455,30 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
r'''
Computes the partial gradient of the dual optimal transport problem.
- For each (i,j) in a batch of coordinates, the partial gradients are :
+ For each :math:`(i,j)` in a batch of coordinates, the partial gradients are :
.. math::
- \partial_{u_i} F = u_i * b_s/l_{v} - \sum_{j \in B_v} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j
+ \partial_{\mathbf{u}_i} F = \frac{b_s}{l_v} \mathbf{u}_i -
+ \sum_{j \in B_v} \mathbf{a}_i \mathbf{b}_j
+ \exp\left( \frac{\mathbf{u}_i + \mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}} \right)
- \partial_{v_j} F = v_j * b_s/l_{u} - \sum_{i \in B_u} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j
+ \partial_{\mathbf{v}_j} F = \frac{b_s}{l_u} \mathbf{v}_j -
+ \sum_{i \in B_u} \mathbf{a}_i \mathbf{b}_j
+ \exp\left( \frac{\mathbf{u}_i + \mathbf{v}_j - \mathbf{M}_{i,j}}{\mathrm{reg}} \right)
Where :
- - M is the (ns,nt) metric cost matrix
- - u, v are dual variables in R^ixR^J
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{u}`, :math:`\mathbf{v}` are dual variables in :math:`\mathbb{R}^{ns} \times \mathbb{R}^{nt}`
- reg is the regularization term
- :math:`B_u` and :math:`B_v` are lists of index
- - :math:`b_s` is the size of the batchs :math:`B_u` and :math:`B_v`
- - :math:`l_u` and :math:`l_v` are the lenghts of :math:`B_u` and :math:`B_v`
- - a and b are source and target weights (sum to 1)
+ - :math:`b_s` is the size of the batches :math:`B_u` and :math:`B_v`
+ - :math:`l_u` and :math:`l_v` are the lengths of :math:`B_u` and :math:`B_v`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the dual problem is the SGD algorithm
- as proposed in [19]_ [alg.1]
+ as proposed in :ref:`[19] <references-batch-grad-dual>` [alg.1]
Parameters
@@ -504,7 +504,7 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
Returns
-------
- grad : ndarray, shape (ns,)
+ grad : ndarray, shape (`ns`,)
partial grad F
Examples
@@ -533,12 +533,11 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
[5.06266486e-02, 2.16230494e-03, 2.26215141e-03, 6.81514609e-04],
[6.06713990e-02, 3.98139808e-02, 5.46829338e-02, 8.62371424e-06]])
+
+ .. _references-batch-grad-dual:
References
----------
-
- [Seguy et al., 2018] :
- International Conference on Learning Representation (2018),
- arXiv preprint arxiv:1711.02283.
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
'''
G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] -
M[batch_alpha, :][:, batch_beta]) / reg) *
@@ -555,25 +554,25 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
r'''
- Compute the sgd algorithm to solve the regularized discrete measures
- optimal transport dual problem
+ Compute the sgd algorithm to solve the regularized discrete measures optimal transport dual problem
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
\gamma \geq 0
Where :
- - M is the (ns,nt) metric cost matrix
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
@@ -632,9 +631,7 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
References
----------
- [Seguy et al., 2018] :
- International Conference on Learning Representation (2018),
- arXiv preprint arxiv:1711.02283.
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
'''
n_source = np.shape(M)[0]
@@ -657,25 +654,25 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
log=False):
r'''
- Compute the transportation matrix to solve the regularized discrete measures
- optimal transport dual problem
+ Compute the transportation matrix to solve the regularized discrete measures optimal transport dual problem
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma \mathbf{1} = \mathbf{a}
- \gamma^T 1= b
+ \gamma^T \mathbf{1} = \mathbf{b}
\gamma \geq 0
Where :
- - M is the (ns,nt) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (sum to 1)
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
@@ -736,10 +733,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
References
----------
-
- [Seguy et al., 2018] :
- International Conference on Learning Representation (2018),
- arXiv preprint arxiv:1711.02283.
+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
'''
opt_alpha, opt_beta = sgd_entropic_regularization(a, b, M, reg, batch_size,
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index e37f10c..15e180b 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -23,29 +23,31 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+ W = \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
- \gamma\geq 0
+ \gamma \geq 0
+
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization
- term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target unbalanced distributions
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced>`
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
- One or multiple unnormalized histograms of dimension dim_b
- If many, compute all the OT distances (a, b_i)
+ One or multiple unnormalized histograms of dimension `dim_b`.
+ If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
@@ -58,7 +60,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -68,14 +70,14 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
Returns
-------
if n_hists == 1:
- gamma : (dim_a x dim_b) ndarray
+ - gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
- log : dict
+ - log : dict
log dictionary returned only if `log` is `True`
else:
- ot_distance : (n_hists,) ndarray
- the OT distance between `a` and each of the histograms `b_i`
- log : dict
+ - ot_distance : (n_hists,) ndarray
+ the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
+ - log : dict
log dictionary returned only if `log` is `True`
Examples
@@ -90,9 +92,9 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
[0.18807035, 0.51122823]])
+ .. _references-sinkhorn-unbalanced:
References
----------
-
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems
(NIPS) 26, 2013
@@ -111,11 +113,11 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
See Also
--------
- ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
+ ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced>`
ot.unbalanced.sinkhorn_stabilized_unbalanced:
- Unbalanced Stabilized sinkhorn [9][10]
+ Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced>`
ot.unbalanced.sinkhorn_reg_scaling_unbalanced:
- Unbalanced Sinkhorn with epslilon scaling [9][10]
+ Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced>`
"""
@@ -151,29 +153,30 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
\gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term
- :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target unbalanced distributions
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced2>`
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
- One or multiple unnormalized histograms of dimension dim_b
- If many, compute all the OT distances (a, b_i)
+ One or multiple unnormalized histograms of dimension `dim_b`.
+ If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
@@ -186,7 +189,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -196,7 +199,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
Returns
-------
ot_distance : (n_hists,) ndarray
- the OT distance between `a` and each of the histograms `b_i`
+ the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
log : dict
log dictionary returned only if `log` is `True`
@@ -211,10 +214,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
array([0.31912866])
-
+ .. _references-sinkhorn-unbalanced2:
References
----------
-
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems
(NIPS) 26, 2013
@@ -232,9 +234,9 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
See Also
--------
- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
- ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
- ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
+ ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced2>`
+ ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
+ ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
"""
b = np.asarray(b, dtype=np.float64)
@@ -270,26 +272,29 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
The function solves the following optimization problem:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + \reg_m KL(\gamma 1, a) + \reg_m KL(\gamma^T 1, b)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
- \gamma\geq 0
+ \gamma \geq 0
+
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target unbalanced distributions
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-knopp-unbalanced>`
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
- One or multiple unnormalized histograms of dimension dim_b
+ One or multiple unnormalized histograms of dimension `dim_b`
If many, compute all the OT distances (a, b_i)
M : np.ndarray (dim_a, dim_b)
loss matrix
@@ -300,7 +305,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -310,15 +315,16 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
Returns
-------
if n_hists == 1:
- gamma : (dim_a x dim_b) ndarray
+ - gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
- log : dict
+ - log : dict
log dictionary returned only if `log` is `True`
else:
- ot_distance : (n_hists,) ndarray
- the OT distance between `a` and each of the histograms `b_i`
- log : dict
+ - ot_distance : (n_hists,) ndarray
+ the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
+ - log : dict
log dictionary returned only if `log` is `True`
+
Examples
--------
@@ -330,9 +336,10 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
array([[0.51122823, 0.18807035],
[0.18807035, 0.51122823]])
+
+ .. _references-sinkhorn-knopp-unbalanced:
References
----------
-
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint
arXiv:1607.05816.
@@ -445,32 +452,34 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
problem and return the loss
The function solves the following optimization problem using log-domain
- stabilization as proposed in [10]:
+ stabilization as proposed in :ref:`[10] <references-sinkhorn-stabilized-unbalanced>`:
.. math::
- W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
- \gamma\geq 0
+ \gamma \geq 0
+
where :
- - M is the (dim_a, dim_b) metric cost matrix
- - :math:`\Omega` is the entropic regularization
- term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target unbalanced distributions
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-stabilized-unbalanced>`
Parameters
----------
a : np.ndarray (dim_a,)
- Unnormalized histogram of dimension dim_a
+ Unnormalized histogram of dimension `dim_a`
b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
- One or multiple unnormalized histograms of dimension dim_b
- If many, compute all the OT distances (a, b_i)
+ One or multiple unnormalized histograms of dimension `dim_b`.
+ If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
@@ -482,7 +491,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -492,14 +501,14 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
Returns
-------
if n_hists == 1:
- gamma : (dim_a x dim_b) ndarray
+ - gamma : (dim_a, dim_b) ndarray
Optimal transportation matrix for the given parameters
- log : dict
+ - log : dict
log dictionary returned only if `log` is `True`
else:
- ot_distance : (n_hists,) ndarray
- the OT distance between `a` and each of the histograms `b_i`
- log : dict
+ - ot_distance : (n_hists,) ndarray
+ the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
+ - log : dict
log dictionary returned only if `log` is `True`
Examples
--------
@@ -512,9 +521,10 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
array([[0.51122823, 0.18807035],
[0.18807035, 0.51122823]])
+
+ .. _references-sinkhorn-stabilized-unbalanced:
References
----------
-
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
@@ -654,29 +664,27 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
numItermax=1000, stopThr=1e-6,
verbose=False, log=False):
- r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
+ r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}` with stabilization.
The function solves the following optimization problem:
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
- Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- - :math:`\mathbf{a}_i` are training distributions in the columns of
- matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and
- the cost matrix for OT
+ - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- reg_mis the marginal relaxation hyperparameter
- The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+
+ The algorithm used for solving the problem is the generalized
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-stabilized>`
Parameters
----------
A : np.ndarray (dim, n_hists)
- `n_hists` training distributions a_i of dimension dim
+ `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
M : np.ndarray (dim, dim)
ground metric matrix for OT.
reg : float
@@ -691,7 +699,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -706,9 +714,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
log dictionary return only if log==True in parameters
+ .. _references-barycenter-unbalanced-stabilized:
References
----------
-
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré,
G. (2015). Iterative Bregman projections for regularized transportation
problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
@@ -806,29 +814,27 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
numItermax=1000, stopThr=1e-6,
verbose=False, log=False):
- r"""Compute the entropic unbalanced wasserstein barycenter of A.
+ r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`.
- The function solves the following optimization problem with a
+ The function solves the following optimization problem with :math:`\mathbf{a}`
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
- Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
- :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and
- the cost matrix for OT
+ - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- reg_mis the marginal relaxation hyperparameter
+
The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-sinkhorn>`
Parameters
----------
A : np.ndarray (dim, n_hists)
- `n_hists` training distributions a_i of dimension dim
+ `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
M : np.ndarray (dim, dim)
ground metric matrix for OT.
reg : float
@@ -841,7 +847,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -856,9 +862,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
log dictionary return only if log==True in parameters
+ .. _references-barycenter-unbalanced-sinkhorn:
References
----------
-
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
(2015). Iterative Bregman projections for regularized transportation
problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
@@ -936,29 +942,27 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
numItermax=1000, stopThr=1e-6,
verbose=False, log=False, **kwargs):
- r"""Compute the entropic unbalanced wasserstein barycenter of A.
+ r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`.
- The function solves the following optimization problem with a
+ The function solves the following optimization problem with :math:`\mathbf{a}`
.. math::
- \mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
+ \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)
where :
- - :math:`Wu_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized
- Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
- :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and
- the cost matrix for OT
+ - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- reg_mis the marginal relaxation hyperparameter
+
The algorithm used for solving the problem is the generalized
- Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
+ Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced>`
Parameters
----------
A : np.ndarray (dim, n_hists)
- `n_hists` training distributions a_i of dimension dim
+ `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
M : np.ndarray (dim, dim)
ground metric matrix for OT.
reg : float
@@ -971,7 +975,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -986,9 +990,9 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
log dictionary return only if log==True in parameters
+ .. _references-barycenter-unbalanced:
References
----------
-
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
(2015). Iterative Bregman projections for regularized transportation
problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
diff --git a/ot/utils.py b/ot/utils.py
index f9911a1..c878563 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -7,7 +7,6 @@ Various useful functions
#
# License: MIT License
-import multiprocessing
from functools import reduce
import time
@@ -15,67 +14,127 @@ import numpy as np
from scipy.spatial.distance import cdist
import sys
import warnings
-try:
- from inspect import signature
-except ImportError:
- from .externals.funcsigs import signature
+from inspect import signature
+from .backend import get_backend
__time_tic_toc = time.time()
def tic():
- """ Python implementation of Matlab tic() function """
+ r""" Python implementation of Matlab tic() function """
global __time_tic_toc
__time_tic_toc = time.time()
def toc(message='Elapsed time : {} s'):
- """ Python implementation of Matlab toc() function """
+ r""" Python implementation of Matlab toc() function """
t = time.time()
print(message.format(t - __time_tic_toc))
return t - __time_tic_toc
def toq():
- """ Python implementation of Julia toc() function """
+ r""" Python implementation of Julia toc() function """
t = time.time()
return t - __time_tic_toc
def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
- """Compute kernel matrix"""
+ r"""Compute kernel matrix"""
+
+ nx = get_backend(x1, x2)
+
if method.lower() in ['gaussian', 'gauss', 'rbf']:
- K = np.exp(-dist(x1, x2) / (2 * sigma**2))
+ K = nx.exp(-dist(x1, x2) / (2 * sigma**2))
return K
def laplacian(x):
- """Compute Laplacian matrix"""
+ r"""Compute Laplacian matrix"""
L = np.diag(np.sum(x, axis=0)) - x
return L
-def unif(n):
- """ return a uniform histogram of length n (simplex)
+def list_to_array(*lst):
+ r""" Convert a list if in numpy format """
+ if len(lst) > 1:
+ return [np.array(a) if isinstance(a, list) else a for a in lst]
+ else:
+ return np.array(lst[0]) if isinstance(lst[0], list) else lst[0]
+
+
+def proj_simplex(v, z=1):
+ r"""Compute the closest point (orthogonal projection) on the
+ generalized `(n-1)`-simplex of a vector :math:`\mathbf{v}` wrt. to the Euclidean
+ distance, thus solving:
+
+ .. math::
+ \mathcal{P}(w) \in \mathop{\arg \min}_\gamma \| \gamma - \mathbf{v} \|_2
+
+ s.t. \ \gamma^T \mathbf{1} = z
+
+ \gamma \geq 0
+
+ If :math:`\mathbf{v}` is a 2d array, compute all the projections wrt. axis 0
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Parameters
----------
+ v : {array-like}, shape (n, d)
+ z : int, optional
+ 'size' of the simplex (each vectors sum to z, 1 by default)
+
+ Returns
+ -------
+ h : ndarray, shape (`n`, `d`)
+ Array of projections on the simplex
+ """
+ nx = get_backend(v)
+ n = v.shape[0]
+ if v.ndim == 1:
+ d1 = 1
+ v = v[:, None]
+ else:
+ d1 = 0
+ d = v.shape[1]
+
+ # sort u in ascending order
+ u = nx.sort(v, axis=0)
+ # take the descending order
+ u = nx.flip(u, 0)
+ cssv = nx.cumsum(u, axis=0) - z
+ ind = nx.arange(n, type_as=v)[:, None] + 1
+ cond = u - cssv / ind > 0
+ rho = nx.sum(cond, 0)
+ theta = cssv[rho - 1, nx.arange(d)] / rho
+ w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v))
+ if d1:
+ return w[:, 0]
+ else:
+ return w
+
+
+def unif(n):
+ r"""
+ Return a uniform histogram of length `n` (simplex).
+ Parameters
+ ----------
n : int
number of bins in the histogram
Returns
-------
- h : np.array (n,)
- histogram of length n such that h_i=1/n for all i
-
-
+ h : np.array (`n`,)
+ histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}`
"""
return np.ones((n,)) / n
def clean_zeros(a, b, M):
- """ Remove all components with zeros weights in a and b
+ r""" Remove all components with zeros weights in :math:`\mathbf{a}` and :math:`\mathbf{b}`
"""
M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd)
a2 = a[a > 0]
@@ -84,55 +143,71 @@ def clean_zeros(a, b, M):
def euclidean_distances(X, Y, squared=False):
- """
- Considering the rows of X (and Y=X) as vectors, compute the
+ r"""
+ Considering the rows of :math:`\mathbf{X}` (and :math:`\mathbf{Y} = \mathbf{X}`) as vectors, compute the
distance matrix between each pair of vectors.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
+
Parameters
----------
- X : {array-like}, shape (n_samples_1, n_features)
- Y : {array-like}, shape (n_samples_2, n_features)
+ X : array-like, shape (n_samples_1, n_features)
+ Y : array-like, shape (n_samples_2, n_features)
squared : boolean, optional
Return squared Euclidean distances.
+
Returns
-------
- distances : {array}, shape (n_samples_1, n_samples_2)
+ distances : array-like, shape (`n_samples_1`, `n_samples_2`)
"""
- XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
- YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
- distances = np.dot(X, Y.T)
- distances *= -2
- distances += XX
- distances += YY
- np.maximum(distances, 0, out=distances)
+
+ nx = get_backend(X, Y)
+
+ a2 = nx.einsum('ij,ij->i', X, X)
+ b2 = nx.einsum('ij,ij->i', Y, Y)
+
+ c = -2 * nx.dot(X, Y.T)
+ c += a2[:, None]
+ c += b2[None, :]
+
+ c = nx.maximum(c, 0)
+
+ if not squared:
+ c = nx.sqrt(c)
+
if X is Y:
- # Ensure that distances between vectors and themselves are set to 0.0.
- # This may not be the case due to floating point rounding errors.
- distances.flat[::distances.shape[0] + 1] = 0.0
- return distances if squared else np.sqrt(distances, out=distances)
+ c = c * (1 - nx.eye(X.shape[0], type_as=c))
+
+ return c
+
+def dist(x1, x2=None, metric='sqeuclidean', p=2):
+ r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`
-def dist(x1, x2=None, metric='sqeuclidean'):
- """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends.
Parameters
----------
- x1 : ndarray, shape (n1,d)
- matrix with n1 samples of size d
- x2 : array, shape (n2,d), optional
- matrix with n2 samples of size d (if None then x2=x1)
+ x1 : array-like, shape (n1,d)
+ matrix with `n1` samples of size `d`
+ x2 : array-like, shape (n2,d), optional
+ matrix with `n2` samples of size `d` (if None then :math:`\mathbf{x_2} = \mathbf{x_1}`)
metric : str | callable, optional
- Name of the metric to be computed (full list in the doc of scipy), If a string,
- the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
- 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
- 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
+ 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also
+ accepts from the scipy.spatial.distance.cdist function : 'braycurtis',
+ 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice',
+ 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis',
+ 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
Returns
-------
- M : np.array (n1,n2)
+ M : array-like, shape (`n1`, `n2`)
distance matrix computed with given metric
"""
@@ -140,11 +215,17 @@ def dist(x1, x2=None, metric='sqeuclidean'):
x2 = x1
if metric == "sqeuclidean":
return euclidean_distances(x1, x2, squared=True)
- return cdist(x1, x2, metric=metric)
+ elif metric == "euclidean":
+ return euclidean_distances(x1, x2, squared=False)
+ else:
+ if not get_backend(x1, x2).__name__ == 'numpy':
+ raise NotImplementedError()
+ else:
+ return cdist(x1, x2, metric=metric, p=p)
def dist0(n, method='lin_square'):
- """Compute standard cost matrices of size (n, n) for OT problems
+ r"""Compute standard cost matrices of size (`n`, `n`) for OT problems
Parameters
----------
@@ -153,11 +234,11 @@ def dist0(n, method='lin_square'):
method : str, optional
Type of loss matrix chosen from:
- * 'lin_square' : linear sampling between 0 and n-1, quadratic loss
+ * 'lin_square' : linear sampling between 0 and `n-1`, quadratic loss
Returns
-------
- M : ndarray, shape (n1,n2)
+ M : ndarray, shape (`n1`, `n2`)
Distance matrix computed with given metric.
"""
res = 0
@@ -168,7 +249,7 @@ def dist0(n, method='lin_square'):
def cost_normalization(C, norm=None):
- """ Apply normalization to the loss matrix
+ r""" Apply normalization to the loss matrix
Parameters
----------
@@ -180,7 +261,7 @@ def cost_normalization(C, norm=None):
Returns
-------
- C : ndarray, shape (n1, n2)
+ C : ndarray, shape (`n1`, `n2`)
The input cost matrix normalized according to given norm.
"""
@@ -202,23 +283,23 @@ def cost_normalization(C, norm=None):
def dots(*args):
- """ dots function for multiple matrix multiply """
+ r""" dots function for multiple matrix multiply """
return reduce(np.dot, args)
def label_normalization(y, start=0):
- """ Transform labels to start at a given value
+ r""" Transform labels to start at a given value
Parameters
----------
y : array-like, shape (n, )
The vector of labels to be normalized.
start : int
- Desired value for the smallest label in y (default=0)
+ Desired value for the smallest label in :math:`\mathbf{y}` (default=0)
Returns
-------
- y : array-like, shape (n1, )
+ y : array-like, shape (`n1`, )
The input vector of labels normalized according to given start value.
"""
@@ -228,42 +309,15 @@ def label_normalization(y, start=0):
return y
-def fun(f, q_in, q_out):
- """ Utility function for parmap with no serializing problems """
- while True:
- i, x = q_in.get()
- if i is None:
- break
- q_out.put((i, f(x)))
-
-
-def parmap(f, X, nprocs=multiprocessing.cpu_count()):
- """ paralell map for multiprocessing (only map on windows)"""
-
- if not sys.platform.endswith('win32'):
-
- q_in = multiprocessing.Queue(1)
- q_out = multiprocessing.Queue()
-
- proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
- for _ in range(nprocs)]
- for p in proc:
- p.daemon = True
- p.start()
-
- sent = [q_in.put((i, x)) for i, x in enumerate(X)]
- [q_in.put((None, None)) for _ in range(nprocs)]
- res = [q_out.get() for _ in range(len(sent))]
-
- [p.join() for p in proc]
-
- return [x for i, x in sorted(res)]
- else:
- return list(map(f, X))
+def parmap(f, X, nprocs="default"):
+ r""" parallel map for multiprocessing.
+ The function has been deprecated and only performs a regular map.
+ """
+ return list(map(f, X))
def check_params(**kwargs):
- """check_params: check whether some parameters are missing
+ r"""check_params: check whether some parameters are missing
"""
missing_params = []
@@ -284,14 +338,14 @@ def check_params(**kwargs):
def check_random_state(seed):
- """Turn seed into a np.random.RandomState instance
+ r"""Turn `seed` into a np.random.RandomState instance
Parameters
----------
seed : None | int | instance of RandomState
- If seed is None, return the RandomState singleton used by np.random.
- If seed is an int, return a new RandomState instance seeded with seed.
- If seed is already a RandomState instance, return it.
+ If `seed` is None, return the RandomState singleton used by np.random.
+ If `seed` is an int, return a new RandomState instance seeded with `seed`.
+ If `seed` is already a RandomState instance, return it.
Otherwise raise ValueError.
"""
if seed is None or seed is np.random:
@@ -305,18 +359,21 @@ def check_random_state(seed):
class deprecated(object):
- """Decorator to mark a function or class as deprecated.
+ r"""Decorator to mark a function or class as deprecated.
deprecated class from scikit-learn package
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
Issue a warning when the function is called/the class is instantiated and
adds a warning to the docstring.
The optional extra argument will be appended to the deprecation message
- and the docstring. Note: to use this with the default value for extra, put
- in an empty of parentheses:
- >>> from ot.deprecation import deprecated # doctest: +SKIP
- >>> @deprecated() # doctest: +SKIP
- ... def some_function(): pass # doctest: +SKIP
+ and the docstring.
+
+ .. note::
+ To use this with the default value for extra, use empty parentheses:
+
+ >>> from ot.deprecation import deprecated # doctest: +SKIP
+ >>> @deprecated() # doctest: +SKIP
+ ... def some_function(): pass # doctest: +SKIP
Parameters
----------
@@ -331,7 +388,7 @@ class deprecated(object):
self.extra = extra
def __call__(self, obj):
- """Call method
+ r"""Call method
Parameters
----------
obj : object
@@ -362,7 +419,7 @@ class deprecated(object):
return cls
def _decorate_fun(self, fun):
- """Decorate function fun"""
+ r"""Decorate function fun"""
msg = "Function %s is deprecated" % fun.__name__
if self.extra:
@@ -388,7 +445,7 @@ class deprecated(object):
def _is_deprecated(func):
- """Helper to check if func is wraped by our deprecated decorator"""
+ r"""Helper to check if func is wraped by our deprecated decorator"""
if sys.version_info < (3, 5):
raise NotImplementedError("This is only available for python3.5 "
"or above")
@@ -402,7 +459,7 @@ def _is_deprecated(func):
class BaseEstimator(object):
- """Base class for most objects in POT
+ r"""Base class for most objects in POT
Code adapted from sklearn BaseEstimator class
@@ -415,7 +472,7 @@ class BaseEstimator(object):
@classmethod
def _get_param_names(cls):
- """Get parameter names for the estimator"""
+ r"""Get parameter names for the estimator"""
# fetch the constructor or the original constructor before
# deprecation wrapping if any
@@ -442,7 +499,7 @@ class BaseEstimator(object):
return sorted([p.name for p in parameters])
def get_params(self, deep=True):
- """Get parameters for this estimator.
+ r"""Get parameters for this estimator.
Parameters
----------
@@ -479,7 +536,7 @@ class BaseEstimator(object):
return out
def set_params(self, **params):
- """Set the parameters of this estimator.
+ r"""Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects
(such as pipelines). The latter have parameters of the form
@@ -519,7 +576,7 @@ class BaseEstimator(object):
class UndefinedParameter(Exception):
- """
+ r"""
Aim at raising an Exception when a undefined parameter is called
"""