summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorncassereau-idris <84033440+ncassereau-idris@users.noreply.github.com>2021-10-25 11:36:21 +0200
committerGitHub <noreply@github.com>2021-10-25 11:36:21 +0200
commit7a65086dd340265d0223eb8ffb5c9a5152a82dff (patch)
tree300f4a1cd645516fba1e440691fe48830d781b5c
parent7af8c2147d61349f4d99ca33318a8a125e4569aa (diff)
[MRG] Bregman backend (#280)
* Bregman * Resolve conflicts * Bug solve * Bregman updated for JAX compatibility * Tests coherence between backend improved * No longer enforcing 64 bits operations on Jax except for tests * Now using mixtures, to make backend dependent tests with less code * Better test skipping code * Pep8 + test optimizations * redundancy removed * Docs * Typo corrected * Typo * Typo * Docs * Docs * pep8 * Backend docs * Prettier docs * Mistake corrected * small changes * Better wording Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r--docs/source/all.rst1
-rw-r--r--ot/backend.py581
-rw-r--r--ot/bregman.py970
-rw-r--r--ot/gromov.py6
-rw-r--r--ot/smooth.py4
-rw-r--r--ot/unbalanced.py14
-rw-r--r--test/conftest.py49
-rw-r--r--test/test_backend.py102
-rw-r--r--test/test_bregman.py217
-rwxr-xr-xtest/test_partial.py6
-rw-r--r--test/test_smooth.py12
-rw-r--r--test/test_stochastic.py12
12 files changed, 1423 insertions, 551 deletions
diff --git a/docs/source/all.rst b/docs/source/all.rst
index f1f7075..6a07599 100644
--- a/docs/source/all.rst
+++ b/docs/source/all.rst
@@ -14,6 +14,7 @@ API and modules
:template: module.rst
lp
+ backend
bregman
smooth
gromov
diff --git a/ot/backend.py b/ot/backend.py
index 2ed40af..a4a4757 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -1,6 +1,22 @@
# -*- coding: utf-8 -*-
"""
Multi-lib backend for POT
+
+The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
+or Jax, POT code should work nonetheless.
+To achieve that, POT provides backend classes which implements functions in their respective backend
+imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
+
+Examples
+--------
+
+>>> from ot.utils import list_to_array
+>>> from ot.backend import get_backend
+>>> def f(a, b): # the function does not know which backend to use
+... a, b = list_to_array(a, b) # if a list in given, make it an array
+... nx = get_backend(a, b) # infer the backend from the arguments
+... c = nx.dot(a, b) # now use the backend to do any calculation
+... return c
"""
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
@@ -9,6 +25,7 @@ Multi-lib backend for POT
# License: MIT License
import numpy as np
+import scipy.special as scipy
try:
import torch
@@ -20,6 +37,7 @@ except ImportError:
try:
import jax
import jax.numpy as jnp
+ import jax.scipy.special as jscipy
jax_type = jax.numpy.ndarray
except ImportError:
jax = False
@@ -29,7 +47,7 @@ str_type_error = "All array should be from the same type/backend. Current types
def get_backend_list():
- """ returns the list of available backends)"""
+ """Returns the list of available backends"""
lst = [NumpyBackend(), ]
if torch:
@@ -42,7 +60,7 @@ def get_backend_list():
def get_backend(*args):
- """returns the proper backend for a list of input arrays
+ """Returns the proper backend for a list of input arrays
Also raises TypeError if all arrays are not from the same backend
"""
@@ -50,14 +68,12 @@ def get_backend(*args):
if not len(args) > 0:
raise ValueError(" The function takes at least one parameter")
# check all same type
+ if not len(set(type(a) for a in args)) == 1:
+ raise ValueError(str_type_error.format([type(a) for a in args]))
if isinstance(args[0], np.ndarray):
- if not len(set(type(a) for a in args)) == 1:
- raise ValueError(str_type_error.format([type(a) for a in args]))
return NumpyBackend()
- elif torch and isinstance(args[0], torch_type):
- if not len(set(type(a) for a in args)) == 1:
- raise ValueError(str_type_error.format([type(a) for a in args]))
+ elif isinstance(args[0], torch_type):
return TorchBackend()
elif isinstance(args[0], jax_type):
return JaxBackend()
@@ -66,7 +82,7 @@ def get_backend(*args):
def to_numpy(*args):
- """returns numpy arrays from any compatible backend"""
+ """Returns numpy arrays from any compatible backend"""
if len(args) == 1:
return get_backend(args[0]).to_numpy(args[0])
@@ -75,6 +91,13 @@ def to_numpy(*args):
class Backend():
+ """
+ Backend abstract class.
+ Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`
+
+ - The `__name__` class attribute refers to the name of the backend.
+ - The `__type__` class attribute refers to the data structure used by the backend.
+ """
__name__ = None
__type__ = None
@@ -84,90 +107,426 @@ class Backend():
# convert to numpy
def to_numpy(self, a):
+ """Returns the numpy version of a tensor"""
raise NotImplementedError()
# convert from numpy
def from_numpy(self, a, type_as=None):
+ """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
raise NotImplementedError()
def set_gradients(self, val, inputs, grads):
- """ define the gradients for the value val wrt the inputs """
+ """Define the gradients for the value val wrt the inputs """
raise NotImplementedError()
def zeros(self, shape, type_as=None):
+ r"""
+ Creates a tensor full of zeros.
+
+ This function follow the api from :any:`numpy.zeros`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html
+ """
raise NotImplementedError()
def ones(self, shape, type_as=None):
+ r"""
+ Creates a tensor full of ones.
+
+ This function follow the api from :any:`numpy.ones`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html
+ """
raise NotImplementedError()
def arange(self, stop, start=0, step=1, type_as=None):
+ r"""
+ Returns evenly spaced values within a given interval.
+
+ This function follow the api from :any:`numpy.arange`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html
+ """
raise NotImplementedError()
def full(self, shape, fill_value, type_as=None):
+ r"""
+ Creates a tensor with given shape, filled with given value.
+
+ This function follow the api from :any:`numpy.full`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.full.html
+ """
raise NotImplementedError()
def eye(self, N, M=None, type_as=None):
+ r"""
+ Creates the identity matrix of given size.
+
+ This function follow the api from :any:`numpy.eye`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html
+ """
raise NotImplementedError()
def sum(self, a, axis=None, keepdims=False):
+ r"""
+ Sums tensor elements over given dimensions.
+
+ This function follow the api from :any:`numpy.sum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html
+ """
raise NotImplementedError()
def cumsum(self, a, axis=None):
+ r"""
+ Returns the cumulative sum of tensor elements over given dimensions.
+
+ This function follow the api from :any:`numpy.cumsum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
+ """
raise NotImplementedError()
def max(self, a, axis=None, keepdims=False):
+ r"""
+ Returns the maximum of an array or maximum along given dimensions.
+
+ This function follow the api from :any:`numpy.amax`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html
+ """
raise NotImplementedError()
def min(self, a, axis=None, keepdims=False):
+ r"""
+ Returns the maximum of an array or maximum along given dimensions.
+
+ This function follow the api from :any:`numpy.amin`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html
+ """
raise NotImplementedError()
def maximum(self, a, b):
+ r"""
+ Returns element-wise maximum of array elements.
+
+ This function follow the api from :any:`numpy.maximum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html
+ """
raise NotImplementedError()
def minimum(self, a, b):
+ r"""
+ Returns element-wise minimum of array elements.
+
+ This function follow the api from :any:`numpy.minimum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html
+ """
raise NotImplementedError()
def dot(self, a, b):
+ r"""
+ Returns the dot product of two tensors.
+
+ This function follow the api from :any:`numpy.dot`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
+ """
raise NotImplementedError()
def abs(self, a):
+ r"""
+ Computes the absolute value element-wise.
+
+ This function follow the api from :any:`numpy.absolute`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html
+ """
raise NotImplementedError()
def exp(self, a):
+ r"""
+ Computes the exponential value element-wise.
+
+ This function follow the api from :any:`numpy.exp`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html
+ """
raise NotImplementedError()
def log(self, a):
+ r"""
+ Computes the natural logarithm, element-wise.
+
+ This function follow the api from :any:`numpy.log`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.log.html
+ """
raise NotImplementedError()
def sqrt(self, a):
+ r"""
+ Returns the non-ngeative square root of a tensor, element-wise.
+
+ This function follow the api from :any:`numpy.sqrt`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html
+ """
+ raise NotImplementedError()
+
+ def power(self, a, exponents):
+ r"""
+ First tensor elements raised to powers from second tensor, element-wise.
+
+ This function follow the api from :any:`numpy.power`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.power.html
+ """
raise NotImplementedError()
def norm(self, a):
+ r"""
+ Computes the matrix frobenius norm.
+
+ This function follow the api from :any:`numpy.linalg.norm`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
+ """
raise NotImplementedError()
def any(self, a):
+ r"""
+ Tests whether any tensor element along given dimensions evaluates to True.
+
+ This function follow the api from :any:`numpy.any`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.any.html
+ """
raise NotImplementedError()
def isnan(self, a):
+ r"""
+ Tests element-wise for NaN and returns result as a boolean tensor.
+
+ This function follow the api from :any:`numpy.isnan`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html
+ """
raise NotImplementedError()
def isinf(self, a):
+ r"""
+ Tests element-wise for positive or negative infinity and returns result as a boolean tensor.
+
+ This function follow the api from :any:`numpy.isinf`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html
+ """
raise NotImplementedError()
def einsum(self, subscripts, *operands):
+ r"""
+ Evaluates the Einstein summation convention on the operands.
+
+ This function follow the api from :any:`numpy.einsum`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
+ """
raise NotImplementedError()
def sort(self, a, axis=-1):
+ r"""
+ Returns a sorted copy of a tensor.
+
+ This function follow the api from :any:`numpy.sort`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html
+ """
raise NotImplementedError()
def argsort(self, a, axis=None):
+ r"""
+ Returns the indices that would sort a tensor.
+
+ This function follow the api from :any:`numpy.argsort`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html
+ """
+ raise NotImplementedError()
+
+ def searchsorted(self, a, v, side='left'):
+ r"""
+ Finds indices where elements should be inserted to maintain order in given tensor.
+
+ This function follow the api from :any:`numpy.searchsorted`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html
+ """
raise NotImplementedError()
def flip(self, a, axis=None):
+ r"""
+ Reverses the order of elements in a tensor along given dimensions.
+
+ This function follow the api from :any:`numpy.flip`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html
+ """
+ raise NotImplementedError()
+
+ def clip(self, a, a_min, a_max):
+ """
+ Limits the values in a tensor.
+
+ This function follow the api from :any:`numpy.clip`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html
+ """
+ raise NotImplementedError()
+
+ def repeat(self, a, repeats, axis=None):
+ r"""
+ Repeats elements of a tensor.
+
+ This function follow the api from :any:`numpy.repeat`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html
+ """
+ raise NotImplementedError()
+
+ def take_along_axis(self, arr, indices, axis):
+ r"""
+ Gathers elements of a tensor along given dimensions.
+
+ This function follow the api from :any:`numpy.take_along_axis`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html
+ """
+ raise NotImplementedError()
+
+ def concatenate(self, arrays, axis=0):
+ r"""
+ Joins a sequence of tensors along an existing dimension.
+
+ This function follow the api from :any:`numpy.concatenate`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html
+ """
+ raise NotImplementedError()
+
+ def zero_pad(self, a, pad_width):
+ r"""
+ Pads a tensor.
+
+ This function follow the api from :any:`numpy.pad`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
+ """
+ raise NotImplementedError()
+
+ def argmax(self, a, axis=None):
+ r"""
+ Returns the indices of the maximum values of a tensor along given dimensions.
+
+ This function follow the api from :any:`numpy.argmax`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html
+ """
+ raise NotImplementedError()
+
+ def mean(self, a, axis=None):
+ r"""
+ Computes the arithmetic mean of a tensor along given dimensions.
+
+ This function follow the api from :any:`numpy.mean`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html
+ """
+ raise NotImplementedError()
+
+ def std(self, a, axis=None):
+ r"""
+ Computes the standard deviation of a tensor along given dimensions.
+
+ This function follow the api from :any:`numpy.std`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.std.html
+ """
+ raise NotImplementedError()
+
+ def linspace(self, start, stop, num):
+ r"""
+ Returns a specified number of evenly spaced values over a given interval.
+
+ This function follow the api from :any:`numpy.linspace`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html
+ """
+ raise NotImplementedError()
+
+ def meshgrid(self, a, b):
+ r"""
+ Returns coordinate matrices from coordinate vectors (Numpy convention).
+
+ This function follow the api from :any:`numpy.meshgrid`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html
+ """
+ raise NotImplementedError()
+
+ def diag(self, a, k=0):
+ r"""
+ Extracts or constructs a diagonal tensor.
+
+ This function follow the api from :any:`numpy.diag`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html
+ """
+ raise NotImplementedError()
+
+ def unique(self, a):
+ r"""
+ Finds unique elements of given tensor.
+
+ This function follow the api from :any:`numpy.unique`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html
+ """
+ raise NotImplementedError()
+
+ def logsumexp(self, a, axis=None):
+ r"""
+ Computes the log of the sum of exponentials of input elements.
+
+ This function follow the api from :any:`scipy.special.logsumexp`
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
+ """
+ raise NotImplementedError()
+
+ def stack(self, arrays, axis=0):
+ r"""
+ Joins a sequence of tensors along a new dimension.
+
+ This function follow the api from :any:`numpy.stack`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html
+ """
raise NotImplementedError()
class NumpyBackend(Backend):
+ """
+ NumPy implementation of the backend
+
+ - `__name__` is "numpy"
+ - `__type__` is np.ndarray
+ """
__name__ = 'numpy'
__type__ = np.ndarray
@@ -184,7 +543,7 @@ class NumpyBackend(Backend):
return a.astype(type_as.dtype)
def set_gradients(self, val, inputs, grads):
- # no gradients for numpy
+ # No gradients for numpy
return val
def zeros(self, shape, type_as=None):
@@ -247,6 +606,9 @@ class NumpyBackend(Backend):
def sqrt(self, a):
return np.sqrt(a)
+ def power(self, a, exponents):
+ return np.power(a, exponents)
+
def norm(self, a):
return np.sqrt(np.sum(np.square(a)))
@@ -268,11 +630,70 @@ class NumpyBackend(Backend):
def argsort(self, a, axis=-1):
return np.argsort(a, axis)
+ def searchsorted(self, a, v, side='left'):
+ if a.ndim == 1:
+ return np.searchsorted(a, v, side)
+ else:
+ # this is a not very efficient way to make numpy
+ # searchsorted work on 2d arrays
+ ret = np.empty(v.shape, dtype=int)
+ for i in range(a.shape[0]):
+ ret[i, :] = np.searchsorted(a[i, :], v[i, :], side)
+ return ret
+
def flip(self, a, axis=None):
return np.flip(a, axis)
+ def clip(self, a, a_min, a_max):
+ return np.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return np.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return np.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return np.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return np.pad(a, pad_width)
+
+ def argmax(self, a, axis=None):
+ return np.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return np.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return np.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return np.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return np.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return np.diag(a, k)
+
+ def unique(self, a):
+ return np.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ return scipy.logsumexp(a, axis=axis)
+
+ def stack(self, arrays, axis=0):
+ return np.stack(arrays, axis)
+
class JaxBackend(Backend):
+ """
+ JAX implementation of the backend
+
+ - `__name__` is "jax"
+ - `__type__` is jax.numpy.ndarray
+ """
__name__ = 'jax'
__type__ = jax_type
@@ -359,6 +780,9 @@ class JaxBackend(Backend):
def sqrt(self, a):
return jnp.sqrt(a)
+ def power(self, a, exponents):
+ return jnp.power(a, exponents)
+
def norm(self, a):
return jnp.sqrt(jnp.sum(jnp.square(a)))
@@ -380,11 +804,67 @@ class JaxBackend(Backend):
def argsort(self, a, axis=-1):
return jnp.argsort(a, axis)
+ def searchsorted(self, a, v, side='left'):
+ if a.ndim == 1:
+ return jnp.searchsorted(a, v, side)
+ else:
+ # this is a not very efficient way to make jax numpy
+ # searchsorted work on 2d arrays
+ return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])])
+
def flip(self, a, axis=None):
return jnp.flip(a, axis)
+ def clip(self, a, a_min, a_max):
+ return jnp.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return jnp.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return jnp.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return jnp.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return jnp.pad(a, pad_width)
+
+ def argmax(self, a, axis=None):
+ return jnp.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return jnp.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return jnp.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return jnp.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return jnp.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return jnp.diag(a, k)
+
+ def unique(self, a):
+ return jnp.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ return jscipy.logsumexp(a, axis=axis)
+
+ def stack(self, arrays, axis=0):
+ return jnp.stack(arrays, axis)
+
class TorchBackend(Backend):
+ """
+ PyTorch implementation of the backend
+
+ - `__name__` is "torch"
+ - `__type__` is torch.Tensor
+ """
__name__ = 'torch'
__type__ = torch_type
@@ -487,22 +967,23 @@ class TorchBackend(Backend):
a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
if isinstance(b, int) or isinstance(b, float):
b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
- return torch.maximum(a, b)
+ if torch.__version__ >= '1.7.0':
+ return torch.maximum(a, b)
+ else:
+ return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
def minimum(self, a, b):
if isinstance(a, int) or isinstance(a, float):
a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
if isinstance(b, int) or isinstance(b, float):
b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
- return torch.minimum(a, b)
+ if torch.__version__ >= '1.7.0':
+ return torch.minimum(a, b)
+ else:
+ return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
def dot(self, a, b):
- if len(a.shape) == len(b.shape) == 1:
- return torch.dot(a, b)
- elif len(a.shape) == 2 and len(b.shape) == 1:
- return torch.mv(a, b)
- else:
- return torch.mm(a, b)
+ return torch.matmul(a, b)
def abs(self, a):
return torch.abs(a)
@@ -516,6 +997,9 @@ class TorchBackend(Backend):
def sqrt(self, a):
return torch.sqrt(a)
+ def power(self, a, exponents):
+ return torch.pow(a, exponents)
+
def norm(self, a):
return torch.sqrt(torch.sum(torch.square(a)))
@@ -539,6 +1023,10 @@ class TorchBackend(Backend):
sorted, indices = torch.sort(a, dim=axis)
return indices
+ def searchsorted(self, a, v, side='left'):
+ right = (side != 'left')
+ return torch.searchsorted(a, v, right=right)
+
def flip(self, a, axis=None):
if axis is None:
return torch.flip(a, tuple(i for i in range(len(a.shape))))
@@ -546,3 +1034,60 @@ class TorchBackend(Backend):
return torch.flip(a, (axis,))
else:
return torch.flip(a, dims=axis)
+
+ def clip(self, a, a_min, a_max):
+ return torch.clamp(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return torch.repeat_interleave(a, repeats, dim=axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return torch.gather(arr, axis, indices)
+
+ def concatenate(self, arrays, axis=0):
+ return torch.cat(arrays, dim=axis)
+
+ def zero_pad(self, a, pad_width):
+ from torch.nn.functional import pad
+ # pad_width is an array of ndim tuples indicating how many 0 before and after
+ # we need to add. We first need to make it compliant with torch syntax, that
+ # starts with the last dim, then second last, etc.
+ how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl)
+ return pad(a, how_pad)
+
+ def argmax(self, a, axis=None):
+ return torch.argmax(a, dim=axis)
+
+ def mean(self, a, axis=None):
+ if axis is not None:
+ return torch.mean(a, dim=axis)
+ else:
+ return torch.mean(a)
+
+ def std(self, a, axis=None):
+ if axis is not None:
+ return torch.std(a, dim=axis, unbiased=False)
+ else:
+ return torch.std(a, unbiased=False)
+
+ def linspace(self, start, stop, num):
+ return torch.linspace(start, stop, num, dtype=torch.float64)
+
+ def meshgrid(self, a, b):
+ X, Y = torch.meshgrid(a, b)
+ return X.T, Y.T
+
+ def diag(self, a, k=0):
+ return torch.diag(a, diagonal=k)
+
+ def unique(self, a):
+ return torch.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ if axis is not None:
+ return torch.logsumexp(a, dim=axis)
+ else:
+ return torch.logsumexp(a, dim=tuple(range(len(a.shape))))
+
+ def stack(self, arrays, axis=0):
+ return torch.stack(arrays, dim=axis)
diff --git a/ot/bregman.py b/ot/bregman.py
index 317c902..b59ee1b 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -19,7 +19,6 @@ import warnings
import numpy as np
from scipy.optimize import fmin_l_bfgs_b
-from scipy.special import logsumexp
from ot.utils import unif, dist, list_to_array
from .backend import get_backend
@@ -35,36 +34,36 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma 1 = a
\gamma^T 1= b
\gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
- scaling algorithm as proposed in [2]_
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn>`
**Choosing a Sinkhorn solver**
By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
- :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
+ :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
- :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
+ :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the sinkhorn
- :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
- version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
+ :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
+ version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a
fast approximation of the Sinkhorn problem.
@@ -74,7 +73,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
samples weights in the source domain
b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
@@ -85,7 +84,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -109,7 +108,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
-
+ .. _references-sinkhorn:
References
----------
@@ -125,9 +124,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn>`
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
+ ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
"""
@@ -161,21 +160,21 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
.. math::
W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma 1 = a
\gamma^T 1= b
\gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn2>`
**Choosing a Sinkhorn solver**
@@ -199,17 +198,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
samples weights in the source domain
b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
- method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters
+ method used for the solver either 'sinkhorn', 'sinkhorn_stabilized', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -234,7 +233,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
array([0.26894142])
-
+ .. _references-sinkhorn2:
References
----------
@@ -244,7 +243,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
- [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
@@ -252,9 +251,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
- ot.bregman.greenkhorn : Greenkhorn [21]
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
+ ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn2>`
+ ot.bregman.greenkhorn : Greenkhorn :ref:`[21] <references-sinkhorn2>`
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] <references-sinkhorn2>` :ref:`[10] <references-sinkhorn2>`
"""
@@ -291,21 +290,21 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
\gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-knopp>`
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
@@ -320,7 +319,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -337,6 +336,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
[0.13447071, 0.36552929]])
+ .. _references-sinkhorn-knopp:
References
----------
@@ -388,7 +388,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
-
KtransposeU = nx.dot(K.T, u)
v = b / KtransposeU
u = 1. / nx.dot(Kp, v)
@@ -444,53 +443,46 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
- The algorithm used is based on the paper
-
- Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
- by Jason Altschuler, Jonathan Weed, Philippe Rigollet
- appeared at NIPS 2017
-
- which is a stochastic version of the Sinkhorn-Knopp algorithm [2].
+ The algorithm used is based on the paper :ref:`[22] <references-greenkhorn>` which is a stochastic version of the Sinkhorn-Knopp algorithm :ref:`[2] <references-greenkhorn>`
The function solves the following optimization problem:
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma 1 = a
\gamma^T 1= b
\gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
-
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
+ b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (dim_a, dim_b)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
log : bool, optional
record log if True
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -507,11 +499,13 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
[0.13447071, 0.36552929]])
+ .. _references-greenkhorn:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+
+ .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
See Also
@@ -521,60 +515,58 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+ if nx.__name__ == "jax":
+ raise TypeError("JAX arrays have been received. Greenkhorn is not compatible with JAX")
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
dim_a = a.shape[0]
dim_b = b.shape[0]
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty_like(M)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
- u = np.full(dim_a, 1. / dim_a)
- v = np.full(dim_b, 1. / dim_b)
- G = u[:, np.newaxis] * K * v[np.newaxis, :]
+ u = nx.full((dim_a,), 1. / dim_a, type_as=K)
+ v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ G = u[:, None] * K * v[None, :]
- viol = G.sum(1) - a
- viol_2 = G.sum(0) - b
+ viol = nx.sum(G, axis=1) - a
+ viol_2 = nx.sum(G, axis=0) - b
stopThr_val = 1
-
if log:
log = dict()
log['u'] = u
log['v'] = v
for i in range(numItermax):
- i_1 = np.argmax(np.abs(viol))
- i_2 = np.argmax(np.abs(viol_2))
- m_viol_1 = np.abs(viol[i_1])
- m_viol_2 = np.abs(viol_2[i_2])
- stopThr_val = np.maximum(m_viol_1, m_viol_2)
+ i_1 = nx.argmax(nx.abs(viol))
+ i_2 = nx.argmax(nx.abs(viol_2))
+ m_viol_1 = nx.abs(viol[i_1])
+ m_viol_2 = nx.abs(viol_2[i_2])
+ stopThr_val = nx.maximum(m_viol_1, m_viol_2)
if m_viol_1 > m_viol_2:
old_u = u[i_1]
- u[i_1] = a[i_1] / (K[i_1, :].dot(v))
- G[i_1, :] = u[i_1] * K[i_1, :] * v
-
- viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1]
- viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v)
+ new_u = a[i_1] / (K[i_1, :].dot(v))
+ G[i_1, :] = new_u * K[i_1, :] * v
+ viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1]
+ viol_2 += (K[i_1, :].T * (new_u - old_u) * v)
+ u[i_1] = new_u
else:
old_v = v[i_2]
- v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
- G[:, i_2] = u * K[:, i_2] * v[i_2]
+ new_v = b[i_2] / (K[:, i_2].T.dot(u))
+ G[:, i_2] = u * K[:, i_2] * new_v
# aviol = (G@one_m - a)
# aviol_2 = (G.T@one_n - b)
- viol += (-old_v + v[i_2]) * K[:, i_2] * u
- viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
-
+ viol += (-old_v + new_v) * K[:, i_2] * u
+ viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2]
+ v[i_2] = new_v
# print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
if stopThr_val <= stopThr:
@@ -603,41 +595,41 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma 1 = a
\gamma^T 1= b
\gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
- scaling algorithm as proposed in [2]_ but with the log stabilization
- proposed in [10]_ an defined in [9]_ (Algo 3.1) .
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-stabilized>` but with the log stabilization
+ proposed in :ref:`[10] <references-sinkhorn-stabilized>` an defined in :ref:`[9] <references-sinkhorn-stabilized>` (Algo 3.1) .
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,)
+ b : array-like, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
- thershold for max value in u or v for log scaling
- warmstart : tible of vectors
- if given then sarting values for alpha an beta log scalings
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling
+ warmstart : table of vectors
+ if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -645,7 +637,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -662,6 +654,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
[0.13447071, 0.36552929]])
+ .. _references-sinkhorn-stabilized:
References
----------
@@ -679,19 +672,19 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# test if multiple target
if len(b.shape) > 1:
n_hists = b.shape[1]
- a = a[:, np.newaxis]
+ a = a[:, None]
else:
n_hists = 0
@@ -706,25 +699,25 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
+ alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
+ u, v = nx.ones(dim_a, type_as=M) / dim_a, nx.ones(dim_b, type_as=M) / dim_b
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((dim_a, 1))
+ return nx.exp(-(M - alpha.reshape((dim_a, 1))
- beta.reshape((1, dim_b))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
- return np.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
- / reg + np.log(u.reshape((dim_a, 1))) + np.log(v.reshape((1, dim_b))))
+ return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
+ / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b))))
# print(np.min(K))
@@ -739,33 +732,35 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
vprev = v
# sinkhorn update
- v = b / (np.dot(K.T, u) + 1e-16)
- u = a / (np.dot(K, v) + 1e-16)
+ v = b / (nx.dot(K.T, u) + 1e-16)
+ u = a / (nx.dot(K, v) + 1e-16)
# remove numerical problems and store them in K
- if np.abs(u).max() > tau or np.abs(v).max() > tau:
+ if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
- alpha, beta = alpha + reg * np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
+ alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(np.log(v))
else:
- alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
+ alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
- u, v = np.ones((dim_a, n_hists)) / dim_a, np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u, v = np.ones(dim_a) / dim_a, np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
K = get_K(alpha, beta)
if cpt % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
- err_u = abs(u - uprev).max()
- err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
- err_v = abs(v - vprev).max()
- err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err_u = nx.max(nx.abs(u - uprev))
+ err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0)
+ err_v = nx.max(nx.abs(v - vprev))
+ err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0)
err = 0.5 * (err_u + err_v)
else:
transp = get_Gamma(alpha, beta, u, v)
- err = np.linalg.norm((np.sum(transp, axis=0) - b))
+ err = nx.norm(nx.sum(transp, axis=0) - b)
if log:
log['err'].append(err)
@@ -781,7 +776,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
if cpt >= numItermax:
loop = False
- if np.any(np.isnan(u)) or np.any(np.isnan(v)):
+ if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)):
# we have reached the machine precision
# come back to previous solution and quit loop
print('Warning: numerical errors at iteration', cpt)
@@ -795,26 +790,28 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
if n_hists:
alpha = alpha[:, None]
beta = beta[:, None]
- logu = alpha / reg + np.log(u)
- logv = beta / reg + np.log(v)
+ logu = alpha / reg + nx.log(u)
+ logv = beta / reg + nx.log(v)
log['logu'] = logu
log['logv'] = logv
- log['alpha'] = alpha + reg * np.log(u)
- log['beta'] = beta + reg * np.log(v)
+ log['alpha'] = alpha + reg * nx.log(u)
+ log['beta'] = beta + reg * nx.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
if n_hists:
- res = np.zeros((n_hists))
- for i in range(n_hists):
- res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ res = nx.stack([
+ nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ for i in range(n_hists)
+ ])
return res, log
else:
return get_Gamma(alpha, beta, u, v), log
else:
if n_hists:
- res = np.zeros((n_hists))
- for i in range(n_hists):
- res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ res = nx.stack([
+ nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
+ for i in range(n_hists)
+ ])
return res
else:
return get_Gamma(alpha, beta, u, v)
@@ -833,45 +830,45 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma 1 = a
\gamma^T 1= b
\gamma\geq 0
where :
- - M is the (dim_a, dim_b) metric cost matrix
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - a and b are source and target weights (histograms, both sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
- scaling algorithm as proposed in [2]_ but with the log stabilization
- proposed in [10]_ and the log scaling proposed in [9]_ algorithm 3.2
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-epsilon-scaling>` but with the log stabilization
+ proposed in :ref:`[10] <references-sinkhorn-epsilon-scaling>` and the log scaling proposed in :ref:`[9] <references-sinkhorn-epsilon-scaling>` algorithm 3.2
Parameters
----------
- a : ndarray, shape (dim_a,)
+ a : array-like, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (dim_b,)
+ b : array-like, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (dim_a, dim_b)
+ M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
- thershold for max value in u or v for log scaling
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` for log scaling
warmstart : tuple of vectors
- if given then sarting values for alpha an beta log scalings
+ if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
- Max number of iterationsin the inner slog stabilized sinkhorn
+ Max number of iterations in the inner slog stabilized sinkhorn
epsilon0 : int, optional
first epsilon regularization value (then exponential decrease to reg)
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -879,7 +876,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
Returns
-------
- gamma : ndarray, shape (dim_a, dim_b)
+ gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -895,7 +892,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
-
+ .. _references-sinkhorn-epsilon-scaling:
References
----------
@@ -903,6 +900,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+
See Also
--------
ot.lp.emd : Unregularized OT
@@ -910,14 +910,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
if len(a) == 0:
- a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+ a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
- b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+ b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# init data
dim_a = len(a)
@@ -934,7 +934,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
+ alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
@@ -964,15 +964,13 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
- err = np.linalg.norm(
- (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2
+ err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2
if log:
log['err'].append(err)
if verbose:
if cpt % (print_period * 10) == 0:
- print(
- '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
if err <= stopThr and cpt > numItermin:
@@ -991,23 +989,31 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
def geometricBar(weights, alldistribT):
"""return the weighted geometric mean of distributions"""
+ weights, alldistribT = list_to_array(weights, alldistribT)
+ nx = get_backend(weights, alldistribT)
assert (len(weights) == alldistribT.shape[1])
- return np.exp(np.dot(np.log(alldistribT), weights.T))
+ return nx.exp(nx.dot(nx.log(alldistribT), weights.T))
def geometricMean(alldistribT):
"""return the geometric mean of distributions"""
- return np.exp(np.mean(np.log(alldistribT), axis=1))
+ alldistribT = list_to_array(alldistribT)
+ nx = get_backend(alldistribT)
+ return nx.exp(nx.mean(nx.log(alldistribT), axis=1))
def projR(gamma, p):
"""return the KL projection on the row constrints """
- return np.multiply(gamma.T, p / np.maximum(np.sum(gamma, axis=1), 1e-10)).T
+ gamma, p = list_to_array(gamma, p)
+ nx = get_backend(gamma, p)
+ return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T
def projC(gamma, q):
"""return the KL projection on the column constrints """
- return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
+ gamma, q = list_to_array(gamma, q)
+ nx = get_backend(gamma, q)
+ return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10)
def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
@@ -1021,28 +1027,28 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`)
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] <references-barycenter>`
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`a_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
method : str (optional)
method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized'
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -1051,12 +1057,13 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter:
References
----------
@@ -1089,26 +1096,26 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`)
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] <references-barycenter-sinkhorn>`
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`a_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -1117,12 +1124,13 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter-sinkhorn:
References
----------
@@ -1130,8 +1138,12 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
"""
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
if weights is None:
- weights = np.ones(A.shape[1]) / A.shape[1]
+ weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1]
else:
assert (len(weights) == A.shape[1])
@@ -1139,21 +1151,22 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
log = {'err': []}
# M = M/np.median(M) # suggested by G. Peyre
- K = np.exp(-M / reg)
+ K = nx.exp(-M / reg)
cpt = 0
err = 1
- UKv = np.dot(K, np.divide(A.T, np.sum(K, axis=0)).T)
+ UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T)
+
u = (geometricMean(UKv) / UKv.T).T
while (err > stopThr and cpt < numItermax):
cpt = cpt + 1
- UKv = u * np.dot(K, np.divide(A, np.dot(K, u)))
+ UKv = u * nx.dot(K, A / nx.dot(K, u))
u = (u.T * geometricBar(weights, UKv)).T / UKv
if cpt % 10 == 1:
- err = np.sum(np.std(UKv, axis=1))
+ err = nx.sum(nx.std(UKv, axis=1))
# log and verbose print
if log:
@@ -1174,8 +1187,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False):
- r"""Compute the entropic regularized wasserstein barycenter of distributions A
- with stabilization.
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A with stabilization.
The function solves the following optimization problem:
@@ -1184,28 +1196,28 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`)
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] <references-barycenter-stabilized>`
Parameters
----------
- A : ndarray, shape (dim, n_hists)
- n_hists training distributions a_i of size dim
- M : ndarray, shape (dim, dim)
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`a_i` of size `dim`
+ M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
tau : float
- thershold for max value in u or v for log scaling
- weights : ndarray, shape (n_hists,)
- Weights of each histogram a_i on the simplex (barycentric coodinates)
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -1214,12 +1226,13 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+ .. _references-barycenter-stabilized:
References
----------
@@ -1227,49 +1240,48 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
"""
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones((n_hists,), type_as=M) / n_hists
else:
assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
- u = np.ones((dim, n_hists)) / dim
- v = np.ones((dim, n_hists)) / dim
+ u = nx.ones((dim, n_hists), type_as=M) / dim
+ v = nx.ones((dim, n_hists), type_as=M) / dim
- # print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
cpt = 0
err = 1.
- alpha = np.zeros(dim)
- beta = np.zeros(dim)
- q = np.ones(dim) / dim
+ alpha = nx.zeros((dim,), type_as=M)
+ beta = nx.zeros((dim,), type_as=M)
+ q = nx.ones((dim,), type_as=M) / dim
while (err > stopThr and cpt < numItermax):
qprev = q
- Kv = K.dot(v)
+ Kv = nx.dot(K, v)
u = A / (Kv + 1e-16)
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
q = geometricBar(weights, Ktu)
Q = q[:, None]
v = Q / (Ktu + 1e-16)
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha += reg * nx.log(nx.max(u, 1))
+ beta += reg * nx.log(nx.max(v, 1))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(tuple(v.shape), type_as=v)
+ Kv = nx.dot(K, v)
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
@@ -1278,7 +1290,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
if (cpt % 10 == 0 and not absorbing) or cpt == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(u * Kv - A).max()
+ err = nx.max(nx.abs(u * Kv - A))
if log:
log['err'].append(err)
if verbose:
@@ -1314,24 +1326,24 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`)
- :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
- - reg is the regularization strength scalar value
+ - `reg` is the regularization strength scalar value
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [21]_
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[21] <references-convolutional-barycenter-2d>`
Parameters
----------
- A : ndarray, shape (n_hists, width, height)
- n distributions (2D images) of size width x height
+ A : array-like, shape (n_hists, width, height)
+ `n` distributions (2D images) of size `width` x `height`
reg : float
Regularization term >0
- weights : ndarray, shape (n_hists,)
+ weights : array-like, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
@@ -1341,64 +1353,73 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
Returns
-------
- a : ndarray, shape (width, height)
+ a : array-like, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-convolutional-barycenter-2d:
References
----------
- .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015).
- Convolutional wasserstein distances: Efficient optimal transportation on geometric domains
- ACM Transactions on Graphics (TOG), 34(4), 66
+ .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66
"""
+ A = list_to_array(A)
+
+ nx = get_backend(A)
+
if weights is None:
- weights = np.ones(A.shape[0]) / A.shape[0]
+ weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0]
else:
assert (len(weights) == A.shape[0])
if log:
log = {'err': []}
- b = np.zeros_like(A[0, :, :])
- U = np.ones_like(A)
- KV = np.ones_like(A)
+ b = nx.zeros(A.shape[1:], type_as=A)
+ U = nx.ones(A.shape, type_as=A)
+ KV = nx.ones(A.shape, type_as=A)
cpt = 0
err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
- t = np.linspace(0, 1, A.shape[1])
- [Y, X] = np.meshgrid(t, t)
- xi1 = np.exp(-(X - Y) ** 2 / reg)
+ t = nx.linspace(0, 1, A.shape[1])
+ [Y, X] = nx.meshgrid(t, t)
+ xi1 = nx.exp(-(X - Y) ** 2 / reg)
- t = np.linspace(0, 1, A.shape[2])
- [Y, X] = np.meshgrid(t, t)
- xi2 = np.exp(-(X - Y) ** 2 / reg)
+ t = nx.linspace(0, 1, A.shape[2])
+ [Y, X] = nx.meshgrid(t, t)
+ xi2 = nx.exp(-(X - Y) ** 2 / reg)
def K(x):
- return np.dot(np.dot(xi1, x), xi2)
+ return nx.dot(nx.dot(xi1, x), xi2)
while (err > stopThr and cpt < numItermax):
bold = b
cpt = cpt + 1
- b = np.zeros_like(A[0, :, :])
- for r in range(A.shape[0]):
- KV[r, :, :] = K(A[r, :, :] / np.maximum(stabThr, K(U[r, :, :])))
- b += weights[r] * np.log(np.maximum(stabThr, U[r, :, :] * KV[r, :, :]))
- b = np.exp(b)
+ b = nx.zeros(A.shape[1:], type_as=A)
+ KV_cols = []
for r in range(A.shape[0]):
- U[r, :, :] = b / np.maximum(stabThr, KV[r, :, :])
-
+ KV_col_r = K(A[r, :, :] / nx.maximum(stabThr, K(U[r, :, :])))
+ b += weights[r] * nx.log(nx.maximum(stabThr, U[r, :, :] * KV_col_r))
+ KV_cols.append(KV_col_r)
+ KV = nx.stack(KV_cols)
+ b = nx.exp(b)
+
+ U = nx.stack([
+ b / nx.maximum(stabThr, KV[r, :, :])
+ for r in range(A.shape[0])
+ ])
if cpt % 10 == 1:
- err = np.sum(np.abs(bold - b))
+ err = nx.sum(nx.abs(bold - b))
# log and verbose print
if log:
log['err'].append(err)
@@ -1424,34 +1445,35 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
The function solve the following optimization problem:
.. math::
- \mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h})
+
+ \mathbf{h} = arg\min_\mathbf{h} (1- \alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\alpha W_{M_0,reg_0}(\mathbf{h}_0,\mathbf{h})
where :
- - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
- - :math: `\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)`
+ - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)`
- :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
- :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
- - :math:`\mathbf{h}_0` is a prior on `h` of dimension `dim_prior`
- - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (dim_a, dim_a) for OT data fitting
- - reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix (dim_prior, n_atoms) regularization
- - :math:`\\alpha`weight data fitting and regularization
+ - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (`dim_a`, `dim_a`) for OT data fitting
+ - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization term and the cost matrix (`dim_prior`, `n_atoms`) regularization
+ - :math:`\\alpha` weight data fitting and regularization
- The optimization problem is solved suing the algorithm described in [4]
+ The optimization problem is solved following the algorithm described in :ref:`[4] <references-unmix>`
Parameters
----------
- a : ndarray, shape (dim_a)
+ a : array-like, shape (dim_a)
observed distribution (histogram, sums to 1)
- D : ndarray, shape (dim_a, n_atoms)
+ D : array-like, shape (dim_a, n_atoms)
dictionary matrix
- M : ndarray, shape (dim_a, dim_a)
+ M : array-like, shape (dim_a, dim_a)
loss matrix
- M0 : ndarray, shape (n_atoms, dim_prior)
+ M0 : array-like, shape (n_atoms, dim_prior)
loss matrix
- h0 : ndarray, shape (n_atoms,)
+ h0 : array-like, shape (n_atoms,)
prior on the estimated unmixing h
reg : float
Regularization term >0 (Wasserstein data fitting)
@@ -1462,7 +1484,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -1471,11 +1493,13 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Returns
-------
- h : ndarray, shape (n_atoms,)
+ h : array-like, shape (n_atoms,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
+
+ .. _references-unmix:
References
----------
@@ -1483,11 +1507,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
"""
+ a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0)
+
+ nx = get_backend(a, D, M, M0, h0)
+
# M = M/np.median(M)
- K = np.exp(-M / reg)
+ K = nx.exp(-M / reg)
# M0 = M0/np.median(M0)
- K0 = np.exp(-M0 / reg0)
+ K0 = nx.exp(-M0 / reg0)
old = h0
err = 1
@@ -1499,16 +1527,16 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
while (err > stopThr and cpt < numItermax):
K = projC(K, a)
K0 = projC(K0, h0)
- new = np.sum(K0, axis=1)
+ new = nx.sum(K0, axis=1)
# we recombine the current selection from dictionnary
- inv_new = np.dot(D, new)
- other = np.sum(K, axis=1)
+ inv_new = nx.dot(D, new)
+ other = nx.sum(K, axis=1)
# geometric interpolation
- delta = np.exp(alpha * np.log(other) + (1 - alpha) * np.log(inv_new))
+ delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new))
K = projR(K, delta)
- K0 = np.dot(np.diag(np.dot(D.T, delta / inv_new)), K0)
+ K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0)
- err = np.linalg.norm(np.sum(K0, axis=1) - old)
+ err = nx.norm(nx.sum(K0, axis=1) - old)
old = new
if log:
log['err'].append(err)
@@ -1522,14 +1550,14 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
if log:
log['niter'] = cpt
- return np.sum(K0, axis=1), log
+ return nx.sum(K0, axis=1), log
else:
- return np.sum(K0, axis=1)
+ return nx.sum(K0, axis=1)
def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
stopThr=1e-6, verbose=False, log=False, **kwargs):
- r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
+ r'''Joint OT and proportion estimation for multi-source target shift as proposed in :ref:`[27] <references-jcpot-barycenter>`
The function solves the following optimization problem:
@@ -1542,12 +1570,12 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
where :
- - :math:`\lambda_k` is the weight of k-th source domain
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes
- - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C
+ - :math:`\lambda_k` is the weight of `k`-th source domain
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain defined as in [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source domain and `C` is the number of classes
+ - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C`
- :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
- - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)`
+ - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape is :math:`(n_k, C)`
The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
@@ -1556,11 +1584,11 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Parameters
----------
- Xs : list of K np.ndarray(nsk,d)
+ Xs : list of K array-like(nsk,d)
features of all source domains' samples
- Ys : list of K np.ndarray(nsk,)
+ Ys : list of K array-like(nsk,)
labels of all source domains' samples
- Xt : np.ndarray (nt,d)
+ Xt : array-like (nt,d)
samples in the target domain
reg : float
Regularization term > 0
@@ -1577,12 +1605,13 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Returns
-------
- h : (C,) ndarray
+ h : (C,) array-like
proportion estimation in the target domain
log : dict
log dictionary return only if log==True in parameters
+ .. _references-jcpot-barycenter:
References
----------
@@ -1591,7 +1620,14 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
'''
- nbclasses = len(np.unique(Ys[0]))
+
+ Xs = list_to_array(*Xs)
+ Ys = list_to_array(*Ys)
+ Xt = list_to_array(Xt)
+
+ nx = get_backend(*Xs, *Ys, Xt)
+
+ nbclasses = len(nx.unique(Ys[0]))
nbdomains = len(Xs)
# log dictionary
@@ -1608,19 +1644,19 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
dom = {}
nsk = Xs[d].shape[0] # get number of elements for this domain
dom['nbelem'] = nsk
- classes = np.unique(Ys[d]) # get number of classes for this domain
+ classes = nx.unique(Ys[d]) # get number of classes for this domain
# format classes to start from 0 for convenience
- if np.min(classes) != 0:
- Ys[d] = Ys[d] - np.min(classes)
- classes = np.unique(Ys[d])
+ if nx.min(classes) != 0:
+ Ys[d] -= nx.min(classes)
+ classes = nx.unique(Ys[d])
# build the corresponding D_1 and D_2 matrices
- Dtmp1 = np.zeros((nbclasses, nsk))
- Dtmp2 = np.zeros((nbclasses, nsk))
+ Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
+ Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
for c in classes:
- nbelemperclass = np.sum(Ys[d] == c)
+ nbelemperclass = nx.sum(Ys[d] == c)
if nbelemperclass != 0:
Dtmp1[int(c), Ys[d] == c] = 1.
Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
@@ -1631,36 +1667,34 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Mtmp = dist(Xs[d], Xt, metric=metric)
M.append(Mtmp)
- Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)
- np.divide(Mtmp, -reg, out=Ktmp)
- np.exp(Ktmp, out=Ktmp)
+ Ktmp = nx.exp(-Mtmp / reg)
K.append(Ktmp)
# uniform target distribution
- a = unif(np.shape(Xt)[0])
+ a = nx.from_numpy(unif(np.shape(Xt)[0]))
cpt = 0 # iterations count
err = 1
- old_bary = np.ones((nbclasses))
+ old_bary = nx.ones((nbclasses,), type_as=Xs[0])
while (err > stopThr and cpt < numItermax):
- bary = np.zeros((nbclasses))
+ bary = nx.zeros((nbclasses,), type_as=Xs[0])
# update coupling matrices for marginal constraints w.r.t. uniform target distribution
for d in range(nbdomains):
K[d] = projC(K[d], a)
- other = np.sum(K[d], axis=1)
- bary = bary + np.log(np.dot(D1[d], other)) / nbdomains
+ other = nx.sum(K[d], axis=1)
+ bary += nx.log(nx.dot(D1[d], other)) / nbdomains
- bary = np.exp(bary)
+ bary = nx.exp(bary)
# update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
for d in range(nbdomains):
- new = np.dot(D2[d].T, bary)
+ new = nx.dot(D2[d].T, bary)
K[d] = projR(K[d], new)
- err = np.linalg.norm(bary - old_bary)
+ err = nx.norm(bary - old_bary)
cpt = cpt + 1
old_bary = bary
@@ -1672,7 +1706,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
- bary = bary / np.sum(bary)
+ bary = bary / nx.sum(bary)
if log:
log['niter'] = cpt
@@ -1697,39 +1731,38 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma 1 = a
\gamma^T 1= b
\gamma\geq 0
where :
- - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
+ - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
isLazy: boolean, optional
- If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
- If False, calculate full cost matrix and return outputs of sinkhorn function.
+ If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function.
batchSize: int or tuple of 2 int, optional
- Size of the batcheses used to compute the sinkhorn update without memory overhead.
+ Size of the batches used to compute the sinkhorn update without memory overhead.
When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
@@ -1739,7 +1772,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
Returns
-------
- gamma : ndarray, shape (n_samples_a, n_samples_b)
+ gamma : array-like, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1766,18 +1799,23 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
+
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+
ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(ns)
+ a = nx.from_numpy(unif(ns))
if b is None:
- b = unif(nt)
+ b = nx.from_numpy(unif(nt))
if isLazy:
if log:
dict_log = {"err": []}
- log_a, log_b = np.log(a), np.log(b)
- f, g = np.zeros(ns), np.zeros(nt)
+ log_a, log_b = nx.log(a), nx.log(b)
+ f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
if isinstance(batchSize, int):
bs, bt = batchSize, batchSize
@@ -1788,27 +1826,44 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
range_s, range_t = range(0, ns, bs), range(0, nt, bt)
- lse_f = np.zeros(ns)
- lse_g = np.zeros(nt)
+ lse_f = nx.zeros((ns,), type_as=a)
+ lse_g = nx.zeros((nt,), type_as=a)
+
+ X_s_np = nx.to_numpy(X_s)
+ X_t_np = nx.to_numpy(X_t)
for i_ot in range(numIterMax):
+ lse_f_cols = []
for i in range_s:
- M = dist(X_s[i:i + bs, :], X_t, metric=metric)
- lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1)
+ M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ lse_f_cols.append(
+ nx.logsumexp(g[None, :] - M / reg, axis=1)
+ )
+ lse_f = nx.concatenate(lse_f_cols, axis=0)
f = log_a - lse_f
+ lse_g_cols = []
for j in range_t:
- M = dist(X_s, X_t[j:j + bt, :], metric=metric)
- lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0)
+ M = dist(X_s_np, X_t_np[j:j + bt, :], metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ lse_g_cols.append(
+ nx.logsumexp(f[:, None] - M / reg, axis=0)
+ )
+ lse_g = nx.concatenate(lse_g_cols, axis=0)
g = log_b - lse_g
if (i_ot + 1) % 10 == 0:
- m1 = np.zeros_like(a)
+ m1_cols = []
for i in range_s:
- M = dist(X_s[i:i + bs, :], X_t, metric=metric)
- m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1)
- err = np.abs(m1 - a).sum()
+ M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M = nx.from_numpy(M, type_as=a)
+ m1_cols.append(
+ nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1)
+ )
+ m1 = nx.concatenate(m1_cols, axis=0)
+ err = nx.sum(nx.abs(m1 - a))
if log:
dict_log["err"].append(err)
@@ -1826,8 +1881,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
return (f, g)
else:
- M = dist(X_s, X_t, metric=metric)
-
+ M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
+ M = nx.from_numpy(M, type_as=a)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
return pi, log
@@ -1848,39 +1903,38 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
.. math::
W = \min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
- s.t. \gamma 1 = a
+ s.t. \ \gamma 1 = a
\gamma^T 1= b
\gamma\geq 0
where :
- - :math:`M` is the (n_samples_a, n_samples_b) metric cost matrix
+ - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
isLazy: boolean, optional
- If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
- If False, calculate full cost matrix and return outputs of sinkhorn function.
+ If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function.
batchSize: int or tuple of 2 int, optional
- Size of the batcheses used to compute the sinkhorn update without memory overhead.
+ Size of the batches used to compute the sinkhorn update without memory overhead.
When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
@@ -1890,7 +1944,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Returns
-------
- W : (n_hists) ndarray or float
+ W : (n_hists) array-like or float
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1918,11 +1972,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+
ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
- a = unif(ns)
+ a = nx.from_numpy(unif(ns))
if b is None:
- b = unif(nt)
+ b = nx.from_numpy(unif(nt))
if isLazy:
if log:
@@ -1936,10 +1994,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
range_s = range(0, ns, bs)
loss = 0
+
+ X_s_np = nx.to_numpy(X_s)
+ X_t_np = nx.to_numpy(X_t)
+
for i in range_s:
- M_block = dist(X_s[i:i + bs, :], X_t, metric=metric)
- pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
- loss += np.sum(M_block * pi_block)
+ M_block = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
+ M_block = nx.from_numpy(M_block, type_as=a)
+ pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
+ loss += nx.sum(M_block * pi_block)
if log:
return loss, dict_log
@@ -1947,7 +2010,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
return loss
else:
- M = dist(X_s, X_t, metric=metric)
+ M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
+ M = nx.from_numpy(M, type_as=a)
if log:
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
@@ -1975,10 +2039,10 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
W_b &= \min_{\gamma_b} <\gamma_b,M_b>_F + reg\cdot\Omega(\gamma_b)
- S &= W - 1/2 * (W_a + W_b)
+ S &= W - \frac{W_a + W_b}{2}
.. math::
- s.t. \gamma 1 = a
+ s.t. \ \gamma 1 = a
\gamma^T 1= b
@@ -1997,27 +2061,27 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
\gamma_b\geq 0
where :
- - :math:`M` (resp. :math:`M_a, M_b`) is the (n_samples_a, n_samples_b) metric cost matrix (resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b))
+ - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) is the (`n_samples_a`, `n_samples_b`) metric cost matrix (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`))
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`a` and :math:`b` are source and target weights (sum to 1)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (n_samples_a, dim)
+ X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (n_samples_b, dim)
+ X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (n_samples_a,)
+ a : array-like, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (n_samples_b,)
+ b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -2025,7 +2089,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
Returns
-------
- W : (1,) ndarray
+ W : (1,) array-like
Optimal transportation symmetrized loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -2083,47 +2147,54 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True,
maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
- r""""
+ r"""
Screening Sinkhorn Algorithm for Regularized Optimal Transport
- The function solves an approximate dual of Sinkhorn divergence [2] which is written as the following optimization problem:
+ The function solves an approximate dual of Sinkhorn divergence :ref:`[2] <references-screenkhorn>` which is written as the following optimization problem:
- ..math::
- (u, v) = \argmin_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b>
+ .. math::
- where B(u,v) = \diag(e^u) K \diag(e^v), with K = e^{-M/reg} and
+ (u, v) = arg\min_{u, v} 1_{ns}^T B(u,v) 1_{nt} - <\kappa u, a> - <v/\kappa, b>
- s.t. e^{u_i} \geq \epsilon / \kappa, for all i \in {1, ..., ns}
+ where:
- e^{v_j} \geq \epsilon \kappa, for all j \in {1, ..., nt}
+ .. math::
- The parameters \kappa and \epsilon are determined w.r.t the couple number budget of points (ns_budget, nt_budget), see Equation (5) in [26]
+ B(u,v) = \mathrm{diag}(e^u) K \mathrm{diag}(e^v) \text{, with } K = e^{-M/reg} \text{ and}
+
+ .. math::
+
+ s.t. \ e^{u_i} \geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\}
+
+ e^{v_j} \geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\}
+
+ The parameters `kappa` and `epsilon` are determined w.r.t the couple number budget of points (`ns_budget`, `nt_budget`), see Equation (5) in :ref:`[26] <references-screenkhorn>`
Parameters
----------
- a : `numpy.ndarray`, shape=(ns,)
+ a : array-like, shape=(ns,)
samples weights in the source domain
- b : `numpy.ndarray`, shape=(nt,)
+ b : array-like, shape=(nt,)
samples weights in the target domain
- M : `numpy.ndarray`, shape=(ns, nt)
+ M : array-like, shape=(ns, nt)
Cost matrix
reg : `float`
Level of the entropy regularisation
- ns_budget : `int`, deafult=None
- Number budget of points to be keeped in the source domain
- If it is None then 50% of the source sample points will be keeped
+ ns_budget : `int`, default=None
+ Number budget of points to be kept in the source domain.
+ If it is None then 50% of the source sample points will be kept
- nt_budget : `int`, deafult=None
- Number budget of points to be keeped in the target domain
- If it is None then 50% of the target sample points will be keeped
+ nt_budget : `int`, default=None
+ Number budget of points to be kept in the target domain.
+ If it is None then 50% of the target sample points will be kept
uniform : `bool`, default=False
- If `True`, the source and target distribution are supposed to be uniform, i.e., a_i = 1 / ns and b_j = 1 / nt
+ If `True`, the source and target distribution are supposed to be uniform, i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt`
restricted : `bool`, default=True
If `True`, a warm-start initialization for the L-BFGS-B solver
@@ -2133,15 +2204,16 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
Maximum number of iterations in LBFGS solver
maxfun : `int`, default=10000
- Maximum number of function evaluations in LBFGS solver
+ Maximum number of function evaluations in LBFGS solver
pgtol : `float`, default=1e-09
Final objective function accuracy in LBFGS solver
verbose : `bool`, default=False
- If `True`, dispaly informations about the cardinals of the active sets and the paramerters kappa
+ If `True`, display informations about the cardinals of the active sets and the parameters kappa
and epsilon
+
Dependency
----------
To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/)
@@ -2151,15 +2223,19 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
Returns
-------
- gamma : `numpy.ndarray`, shape=(ns, nt)
+ gamma : array-like, shape=(ns, nt)
Screened optimal transportation matrix for the given parameters
log : `dict`, default=False
Log dictionary return only if log==True in parameters
+ .. _references-screenkhorn:
References
-----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+
.. [26] Alaya M. Z., BĂ©rar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
"""
@@ -2171,9 +2247,12 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
"Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.")
bottleneck = np
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(M, a, b)
+ if nx.__name__ == "jax":
+ raise TypeError("JAX arrays have been received but screenkhorn is not compatible with JAX.")
+
ns, nt = M.shape
# by default, we keep only 50% of the sample data points
@@ -2183,9 +2262,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
nt_budget = int(np.floor(0.5 * nt))
# calculate the Gibbs kernel
- K = np.empty_like(M)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
def projection(u, epsilon):
u[u <= epsilon] = epsilon
@@ -2197,8 +2274,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
if ns_budget == ns and nt_budget == nt:
# full number of budget points (ns, nt) = (ns_budget, nt_budget)
- Isel = np.ones(ns, dtype=bool)
- Jsel = np.ones(nt, dtype=bool)
+ Isel = nx.from_numpy(np.ones(ns, dtype=bool))
+ Jsel = nx.from_numpy(np.ones(nt, dtype=bool))
epsilon = 0.0
kappa = 1.0
@@ -2214,57 +2291,61 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
K_IJc = []
K_IcJ = []
- vec_eps_IJc = np.zeros(nt)
- vec_eps_IcJ = np.zeros(ns)
+ vec_eps_IJc = nx.zeros((nt,), type_as=M)
+ vec_eps_IcJ = nx.zeros((ns,), type_as=M)
else:
# sum of rows and columns of K
- K_sum_cols = K.sum(axis=1)
- K_sum_rows = K.sum(axis=0)
+ K_sum_cols = nx.sum(K, axis=1)
+ K_sum_rows = nx.sum(K, axis=0)
if uniform:
if ns / ns_budget < 4:
- aK_sort = np.sort(K_sum_cols)
+ aK_sort = nx.sort(K_sum_cols)
epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
else:
- aK_sort = bottleneck.partition(K_sum_cols, ns_budget - 1)[ns_budget - 1]
+ aK_sort = nx.from_numpy(
+ bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1]
+ )
epsilon_u_square = a[0] / aK_sort
if nt / nt_budget < 4:
- bK_sort = np.sort(K_sum_rows)
+ bK_sort = nx.sort(K_sum_rows)
epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
else:
- bK_sort = bottleneck.partition(K_sum_rows, nt_budget - 1)[nt_budget - 1]
+ bK_sort = nx.from_numpy(
+ bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1]
+ )
epsilon_v_square = b[0] / bK_sort
else:
aK = a / K_sum_cols
bK = b / K_sum_rows
- aK_sort = np.sort(aK)[::-1]
+ aK_sort = nx.flip(nx.sort(aK), axis=0)
epsilon_u_square = aK_sort[ns_budget - 1]
- bK_sort = np.sort(bK)[::-1]
+ bK_sort = nx.flip(nx.sort(bK), axis=0)
epsilon_v_square = bK_sort[nt_budget - 1]
# active sets I and J (see Lemma 1 in [26])
Isel = a >= epsilon_u_square * K_sum_cols
Jsel = b >= epsilon_v_square * K_sum_rows
- if sum(Isel) != ns_budget:
+ if nx.sum(Isel) != ns_budget:
if uniform:
aK = a / K_sum_cols
- aK_sort = np.sort(aK)[::-1]
- epsilon_u_square = aK_sort[ns_budget - 1:ns_budget + 1].mean()
+ aK_sort = nx.flip(nx.sort(aK), axis=0)
+ epsilon_u_square = nx.mean(aK_sort[ns_budget - 1:ns_budget + 1])
Isel = a >= epsilon_u_square * K_sum_cols
- ns_budget = sum(Isel)
+ ns_budget = nx.sum(Isel)
- if sum(Jsel) != nt_budget:
+ if nx.sum(Jsel) != nt_budget:
if uniform:
bK = b / K_sum_rows
- bK_sort = np.sort(bK)[::-1]
- epsilon_v_square = bK_sort[nt_budget - 1:nt_budget + 1].mean()
+ bK_sort = nx.flip(nx.sort(bK), axis=0)
+ epsilon_v_square = nx.mean(bK_sort[nt_budget - 1:nt_budget + 1])
Jsel = b >= epsilon_v_square * K_sum_rows
- nt_budget = sum(Jsel)
+ nt_budget = nx.sum(Jsel)
epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4)
kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2)
@@ -2282,7 +2363,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
K_IcJ = K[np.ix_(Ic, Jsel)]
K_IJc = K[np.ix_(Isel, Jc)]
- K_min = K_IJ.min()
+ #K_min = K_IJ.min()
+ K_min = nx.min(K_IJ)
if K_min == 0:
K_min = np.finfo(float).tiny
@@ -2290,10 +2372,10 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
a_I = a[Isel]
b_J = b[Jsel]
if not uniform:
- a_I_min = a_I.min()
- a_I_max = a_I.max()
- b_J_max = b_J.max()
- b_J_min = b_J.min()
+ a_I_min = nx.min(a_I)
+ a_I_max = nx.max(a_I)
+ b_J_max = nx.max(b_J)
+ b_J_min = nx.min(b_J)
else:
a_I_min = a_I[0]
a_I_max = a_I[0]
@@ -2309,24 +2391,30 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
# pre-calculated constants for the objective
- vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)
- vec_eps_IcJ = (epsilon / kappa) * (np.ones(ns - ns_budget).reshape((-1, 1)) * K_IcJ).sum(axis=0)
+ vec_eps_IJc = epsilon * kappa * nx.sum(
+ K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :],
+ axis=1
+ )
+ vec_eps_IcJ = (epsilon / kappa) * nx.sum(
+ nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ,
+ axis=0
+ )
# initialisation
- u0 = np.full(ns_budget, (1. / ns_budget) + epsilon / kappa)
- v0 = np.full(nt_budget, (1. / nt_budget) + epsilon * kappa)
+ u0 = nx.full((ns_budget,), 1. / ns_budget + epsilon / kappa, type_as=M)
+ v0 = nx.full((nt_budget,), 1. / nt_budget + epsilon * kappa, type_as=M)
# pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26])
if restricted:
if ns_budget != ns or nt_budget != nt:
- cst_u = kappa * epsilon * K_IJc.sum(axis=1)
- cst_v = epsilon * K_IcJ.sum(axis=0) / kappa
+ cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1)
+ cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa
cpt = 1
while cpt < 5: # 5 iterations
- K_IJ_v = np.dot(K_IJ.T, u0) + cst_v
+ K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v
v0 = b_J / (kappa * K_IJ_v)
- KIJ_u = np.dot(K_IJ, v0) + cst_u
+ KIJ_u = nx.dot(K_IJ, v0) + cst_u
u0 = (kappa * a_I) / KIJ_u
cpt += 1
@@ -2343,9 +2431,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
"""
cpt = 1
while cpt < max_iter:
- K_IJ_v = np.dot(K_IJ.T, usc) + cst_v
+ K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v
vsc = b_J / (kappa * K_IJ_v)
- KIJ_u = np.dot(K_IJ, vsc) + cst_u
+ KIJ_u = nx.dot(K_IJ, vsc) + cst_u
usc = (kappa * a_I) / KIJ_u
cpt += 1
@@ -2355,17 +2443,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
return usc, vsc
def screened_obj(usc, vsc):
- part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J,
- np.log(vsc))
- part_IJc = np.dot(usc, vec_eps_IJc)
- part_IcJ = np.dot(vec_eps_IcJ, vsc)
+ part_IJ = (
+ nx.dot(nx.dot(usc, K_IJ), vsc)
+ - kappa * nx.dot(a_I, nx.log(usc))
+ - (1. / kappa) * nx.dot(b_J, nx.log(vsc))
+ )
+ part_IJc = nx.dot(usc, vec_eps_IJc)
+ part_IcJ = nx.dot(vec_eps_IcJ, vsc)
psi_epsilon = part_IJ + part_IJc + part_IcJ
return psi_epsilon
def screened_grad(usc, vsc):
# gradients of Psi_(kappa,epsilon) w.r.t u and v
- grad_u = np.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
- grad_v = np.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc
+ grad_u = nx.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
+ grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc
return grad_u, grad_v
def bfgspost(theta):
@@ -2375,20 +2466,20 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
f = screened_obj(u, v)
# gradient
g_u, g_v = screened_grad(u, v)
- g = np.hstack([g_u, g_v])
- return f, g
+ g = nx.concatenate([g_u, g_v], axis=0)
+ return nx.to_numpy(f), nx.to_numpy(g)
# ----------------------------------------------------------------------------------------------------------------#
# Step 2: L-BFGS-B solver #
# ----------------------------------------------------------------------------------------------------------------#
u0, v0 = restricted_sinkhorn(u0, v0)
- theta0 = np.hstack([u0, v0])
+ theta0 = nx.concatenate([u0, v0], axis=0)
bounds = bounds_u + bounds_v # constraint bounds
def obj(theta):
- return bfgspost(theta)
+ return bfgspost(nx.from_numpy(theta, type_as=M))
theta, _, _ = fmin_l_bfgs_b(func=obj,
x0=theta0,
@@ -2396,12 +2487,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
maxfun=maxfun,
pgtol=pgtol,
maxiter=maxiter)
+ theta = nx.from_numpy(theta)
usc = theta[:ns_budget]
vsc = theta[ns_budget:]
- usc_full = np.full(ns, epsilon / kappa)
- vsc_full = np.full(nt, epsilon * kappa)
+ usc_full = nx.full((ns,), epsilon / kappa, type_as=M)
+ vsc_full = nx.full((nt,), epsilon * kappa, type_as=M)
usc_full[Isel] = usc
vsc_full[Jsel] = vsc
@@ -2413,7 +2505,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
log['Jsel'] = Jsel
gamma = usc_full[:, None] * K * vsc_full[None, :]
- gamma = gamma / gamma.sum()
+ gamma = gamma / nx.sum(gamma)
if log:
return gamma, log
diff --git a/ot/gromov.py b/ot/gromov.py
index a27217a..85b1549 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -1161,7 +1161,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations.
log : bool, optional
@@ -1267,7 +1267,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0).
+ Stop threshold on error (>0).
verbose : bool, optional
Print information along iterations.
log : bool, optional
@@ -1365,7 +1365,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
max_iter : int, optional
Max number of iterations
tol : float, optional
- Stop threshol on error (>0).
+ Stop threshold on error (>0).
verbose : bool, optional
Print information along iterations.
log : bool, optional
diff --git a/ot/smooth.py b/ot/smooth.py
index 81f6a3e..ea26bae 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -458,7 +458,7 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -552,7 +552,7 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index e37f10c..6a61aa1 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -58,7 +58,7 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -186,7 +186,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -300,7 +300,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -482,7 +482,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -691,7 +691,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -841,7 +841,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
@@ -971,7 +971,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (> 0)
+ Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 0000000..876b525
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+
+# Configuration file for pytest
+
+# License: MIT License
+
+import pytest
+from ot.backend import jax
+from ot.backend import get_backend_list
+import functools
+
+if jax:
+ from jax.config import config
+
+backend_list = get_backend_list()
+
+
+@pytest.fixture(params=backend_list)
+def nx(request):
+ backend = request.param
+ if backend.__name__ == "jax":
+ config.update("jax_enable_x64", True)
+
+ yield backend
+
+ if backend.__name__ == "jax":
+ config.update("jax_enable_x64", False)
+
+
+def skip_arg(arg, value, reason=None, getter=lambda x: x):
+ if reason is None:
+ reason = f"Param {arg} should be skipped for value {value}"
+
+ def wrapper(function):
+
+ @functools.wraps(function)
+ def wrapped(*args, **kwargs):
+ if arg in kwargs.keys() and getter(kwargs[arg]) == value:
+ pytest.skip(reason)
+ return function(*args, **kwargs)
+
+ return wrapped
+
+ return wrapper
+
+
+def pytest_configure(config):
+ pytest.skip_arg = skip_arg
+ pytest.skip_backend = functools.partial(skip_arg, "nx", getter=str)
diff --git a/test/test_backend.py b/test/test_backend.py
index cbfaf94..859da5a 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -1,6 +1,7 @@
"""Tests for backend module """
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
@@ -156,6 +157,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.sqrt(M)
with pytest.raises(NotImplementedError):
+ nx.power(v, 2)
+ with pytest.raises(NotImplementedError):
nx.dot(v, v)
with pytest.raises(NotImplementedError):
nx.norm(M)
@@ -174,7 +177,37 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.argsort(M)
with pytest.raises(NotImplementedError):
+ nx.searchsorted(v, v)
+ with pytest.raises(NotImplementedError):
nx.flip(M)
+ with pytest.raises(NotImplementedError):
+ nx.clip(M, -1, 1)
+ with pytest.raises(NotImplementedError):
+ nx.repeat(M, 0, 1)
+ with pytest.raises(NotImplementedError):
+ nx.take_along_axis(M, v, 0)
+ with pytest.raises(NotImplementedError):
+ nx.concatenate([v, v])
+ with pytest.raises(NotImplementedError):
+ nx.zero_pad(M, v)
+ with pytest.raises(NotImplementedError):
+ nx.argmax(M)
+ with pytest.raises(NotImplementedError):
+ nx.mean(M)
+ with pytest.raises(NotImplementedError):
+ nx.std(M)
+ with pytest.raises(NotImplementedError):
+ nx.linspace(0, 1, 50)
+ with pytest.raises(NotImplementedError):
+ nx.meshgrid(v, v)
+ with pytest.raises(NotImplementedError):
+ nx.diag(M)
+ with pytest.raises(NotImplementedError):
+ nx.unique([M, M])
+ with pytest.raises(NotImplementedError):
+ nx.logsumexp(M)
+ with pytest.raises(NotImplementedError):
+ nx.stack([M, M])
@pytest.mark.parametrize('backend', backend_list)
@@ -278,6 +311,10 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('sqrt')
+ A = nx.power(Mb, 2)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('power')
+
A = nx.dot(vb, vb)
lst_b.append(nx.to_numpy(A))
lst_name.append('dot(v,v)')
@@ -326,10 +363,75 @@ def test_func_backends(backend):
lst_b.append(nx.to_numpy(A))
lst_name.append('argsort')
+ A = nx.searchsorted(Mb, Mb, 'right')
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('searchsorted')
+
A = nx.flip(Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('flip')
+ A = nx.clip(vb, 0, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('clip')
+
+ A = nx.repeat(Mb, 0)
+ A = nx.repeat(Mb, 2, -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('repeat')
+
+ A = nx.take_along_axis(vb, nx.arange(3), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('take_along_axis')
+
+ A = nx.concatenate((Mb, Mb), -1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('concatenate')
+
+ A = nx.zero_pad(Mb, len(Mb.shape) * [(3, 3)])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('zero_pad')
+
+ A = nx.argmax(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argmax')
+
+ A = nx.mean(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('mean')
+
+ A = nx.std(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('std')
+
+ A = nx.linspace(0, 1, 50)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('linspace')
+
+ X, Y = nx.meshgrid(vb, vb)
+ lst_b.append(np.stack([nx.to_numpy(X), nx.to_numpy(Y)]))
+ lst_name.append('meshgrid')
+
+ A = nx.diag(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag2D')
+
+ A = nx.diag(vb, 1)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('diag1D')
+
+ A = nx.unique(nx.from_numpy(np.stack([M, M])))
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('unique')
+
+ A = nx.logsumexp(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('logsumexp')
+
+ A = nx.stack([Mb, Mb])
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('stack')
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 88166a5..942cb6d 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -10,11 +10,8 @@ import numpy as np
import pytest
import ot
-from ot.backend import get_backend_list
from ot.backend import torch
-backend_list = get_backend_list()
-
def test_sinkhorn():
# test sinkhorn
@@ -28,14 +25,13 @@ def test_sinkhorn():
G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
-@pytest.mark.parametrize('nx', backend_list)
def test_sinkhorn_backends(nx):
n_samples = 100
n_features = 2
@@ -57,7 +53,6 @@ def test_sinkhorn_backends(nx):
np.allclose(G, nx.to_numpy(Gb))
-@pytest.mark.parametrize('nx', backend_list)
def test_sinkhorn2_backends(nx):
n_samples = 100
n_features = 2
@@ -116,20 +111,20 @@ def test_sinkhorn_empty():
M = ot.dist(x, x)
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10,
method='sinkhorn_stabilized', verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
G, log = ot.sinkhorn(
[], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling',
verbose=True, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
@@ -137,7 +132,8 @@ def test_sinkhorn_empty():
ot.sinkhorn([], [], M, 1, method='greenkhorn', stopThr=1e-10, log=True)
-def test_sinkhorn_variants():
+@pytest.skip_backend("jax")
+def test_sinkhorn_variants(nx):
# test sinkhorn
n = 100
rng = np.random.RandomState(0)
@@ -147,13 +143,18 @@ def test_sinkhorn_variants():
M = ot.dist(x, x)
- G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
- Ges = ot.sinkhorn(
- u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
- G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)
+ ub = nx.from_numpy(u)
+ Mb = nx.from_numpy(M)
+
+ G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
+ G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Ges = nx.to_numpy(ot.sinkhorn(
+ ub, ub, Mb, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
+ G_green = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='greenkhorn', stopThr=1e-10))
# check values
+ np.testing.assert_allclose(G, G0, atol=1e-05)
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
@@ -184,7 +185,7 @@ def test_sinkhorn_variants_log():
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_barycenter(method):
+def test_barycenter(nx, method):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -201,16 +202,23 @@ def test_barycenter(method):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
+ Ab = nx.from_numpy(A)
+ Mb = nx.from_numpy(M)
+ weightsb = nx.from_numpy(weights)
+
# wasserstein
reg = 1e-2
- bary_wass, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
+ bary_wass_np, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
+ bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
- ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
+ ot.bregman.barycenter(Ab, Mb, reg, log=True, verbose=True)
-def test_barycenter_stabilization():
+def test_barycenter_stabilization(nx):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -227,17 +235,26 @@ def test_barycenter_stabilization():
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
+ Ab = nx.from_numpy(A)
+ Mb = nx.from_numpy(M)
+ weights_b = nx.from_numpy(weights)
+
# wasserstein
reg = 1e-2
- bar_stable = ot.bregman.barycenter(A, M, reg, weights,
- method="sinkhorn_stabilized",
- stopThr=1e-8, verbose=True)
- bar = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn",
- stopThr=1e-8, verbose=True)
+ bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
+ bar_stable = nx.to_numpy(ot.bregman.barycenter(
+ Ab, Mb, reg, weights_b, method="sinkhorn_stabilized",
+ stopThr=1e-8, verbose=True
+ ))
+ bar = nx.to_numpy(ot.bregman.barycenter(
+ Ab, Mb, reg, weights_b, method="sinkhorn",
+ stopThr=1e-8, verbose=True
+ ))
np.testing.assert_allclose(bar, bar_stable)
+ np.testing.assert_allclose(bar, bar_np)
-def test_wasserstein_bary_2d():
+def test_wasserstein_bary_2d(nx):
size = 100 # size of a square image
a1 = np.random.randn(size, size)
a1 += a1.min()
@@ -250,17 +267,21 @@ def test_wasserstein_bary_2d():
A[0, :, :] = a1
A[1, :, :] = a2
+ Ab = nx.from_numpy(A)
+
# wasserstein
reg = 1e-2
- bary_wass = ot.bregman.convolutional_barycenter2d(A, reg)
+ bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg))
np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
# help in checking if log and verbose do not bug the function
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
-def test_unmix():
+def test_unmix(nx):
n_bins = 50 # nb bins
# Gaussian distributions
@@ -280,18 +301,26 @@ def test_unmix():
M0 /= M0.max()
h0 = ot.unif(2)
+ ab = nx.from_numpy(a)
+ Db = nx.from_numpy(D)
+ Mb = nx.from_numpy(M)
+ M0b = nx.from_numpy(M0)
+ h0b = nx.from_numpy(h0)
+
# wasserstein
reg = 1e-3
- um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01, )
+ um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01)
+ um = nx.to_numpy(ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01))
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
+ np.testing.assert_allclose(um, um_np)
- ot.bregman.unmix(a, D, M, M0, h0, reg,
+ ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg,
1, alpha=0.01, log=True, verbose=True)
-def test_empirical_sinkhorn():
+def test_empirical_sinkhorn(nx):
# test sinkhorn
n = 10
a = ot.unif(n)
@@ -302,19 +331,28 @@ def test_empirical_sinkhorn():
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='minkowski')
- G_sqe = ot.bregman.empirical_sinkhorn(X_s, X_t, 1)
- sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ Mb = nx.from_numpy(M, type_as=ab)
+ M_mb = nx.from_numpy(M_m, type_as=ab)
+
+ G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1))
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1))
- G_log, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, log=True)
- sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+ G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True)
+ G_log = nx.to_numpy(G_log)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
+ sinkhorn_log = nx.to_numpy(sinkhorn_log)
- G_m = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski')
- sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+ G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski'))
+ sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1)
- loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+ loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1))
- # check constratints
+ # check constraints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
np.testing.assert_allclose(
@@ -330,7 +368,7 @@ def test_empirical_sinkhorn():
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
-def test_lazy_empirical_sinkhorn():
+def test_lazy_empirical_sinkhorn(nx):
# test sinkhorn
n = 10
a = ot.unif(n)
@@ -342,22 +380,34 @@ def test_lazy_empirical_sinkhorn():
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='minkowski')
- f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ Mb = nx.from_numpy(M, type_as=ab)
+ M_mb = nx.from_numpy(M_m, type_as=ab)
+
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
- sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1))
- f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
- sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
+ sinkhorn_log = nx.to_numpy(sinkhorn_log)
- f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
+ f, g = nx.to_numpy(f), nx.to_numpy(g)
G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
- sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)
+ sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
- loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
- loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)
+ loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
+ loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1))
- # check constratints
+ # check constraints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
np.testing.assert_allclose(
@@ -373,7 +423,7 @@ def test_lazy_empirical_sinkhorn():
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
-def test_empirical_sinkhorn_divergence():
+def test_empirical_sinkhorn_divergence(nx):
# Test sinkhorn divergence
n = 10
a = np.linspace(1, n, n)
@@ -385,22 +435,31 @@ def test_empirical_sinkhorn_divergence():
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
- sinkhorn_div = (ot.sinkhorn2(a, b, M, 1) - 1 / 2 * ot.sinkhorn2(a, a, M_s, 1) - 1 / 2 * ot.sinkhorn2(b, b, M_t, 1))
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ X_sb = nx.from_numpy(X_s)
+ X_tb = nx.from_numpy(X_t)
+ Mb = nx.from_numpy(M, type_as=ab)
+ M_sb = nx.from_numpy(M_s, type_as=ab)
+ M_tb = nx.from_numpy(M_t, type_as=ab)
+
+ emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
+ sinkhorn_div = nx.to_numpy(
+ ot.sinkhorn2(ab, bb, Mb, 1)
+ - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1)
+ - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1)
+ )
+ emp_sinkhorn_div_np = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b)
- emp_sinkhorn_div_log, log_es = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1, a=a, b=b, log=True)
- sink_div_log_ab, log_s_ab = ot.sinkhorn2(a, b, M, 1, log=True)
- sink_div_log_a, log_s_a = ot.sinkhorn2(a, a, M_s, 1, log=True)
- sink_div_log_b, log_s_b = ot.sinkhorn2(b, b, M_t, 1, log=True)
- sink_div_log = sink_div_log_ab - 1 / 2 * (sink_div_log_a + sink_div_log_b)
# check constraints
+ np.testing.assert_allclose(emp_sinkhorn_div, emp_sinkhorn_div_np, atol=1e-05)
np.testing.assert_allclose(
emp_sinkhorn_div, sinkhorn_div, atol=1e-05) # cf conv emp sinkhorn
- np.testing.assert_allclose(
- emp_sinkhorn_div_log, sink_div_log, atol=1e-05) # cf conv emp sinkhorn
+ ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb, log=True)
-def test_stabilized_vs_sinkhorn_multidim():
+
+def test_stabilized_vs_sinkhorn_multidim(nx):
# test if stable version matches sinkhorn
# for multidimensional inputs
n = 100
@@ -416,12 +475,21 @@ def test_stabilized_vs_sinkhorn_multidim():
M = ot.utils.dist0(n)
M /= np.median(M)
epsilon = 0.1
- G, log = ot.bregman.sinkhorn(a, b, M, reg=epsilon,
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
+ G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
+ G, log = ot.bregman.sinkhorn(ab, bb, Mb, reg=epsilon,
method="sinkhorn_stabilized",
log=True)
- G2, log2 = ot.bregman.sinkhorn(a, b, M, epsilon,
+ G = nx.to_numpy(G)
+ G2, log2 = ot.bregman.sinkhorn(ab, bb, Mb, epsilon,
method="sinkhorn", log=True)
+ G2 = nx.to_numpy(G2)
+ np.testing.assert_allclose(G_np, G2)
np.testing.assert_allclose(G, G2)
@@ -458,8 +526,9 @@ def test_implemented_methods():
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+@pytest.skip_backend("jax")
@pytest.mark.filterwarnings("ignore:Bottleneck")
-def test_screenkhorn():
+def test_screenkhorn(nx):
# test screenkhorn
rng = np.random.RandomState(0)
n = 100
@@ -468,17 +537,31 @@ def test_screenkhorn():
x = rng.randn(n, 2)
M = ot.dist(x, x)
+
+ ab = nx.from_numpy(a)
+ bb = nx.from_numpy(b)
+ Mb = nx.from_numpy(M, type_as=ab)
+
+ # np sinkhorn
+ G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = ot.sinkhorn(a, b, M, 1e-03)
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1e-03))
# screenkhorn
- G_screen = ot.bregman.screenkhorn(a, b, M, 1e-03, uniform=True, verbose=True)
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, Mb, 1e-03, uniform=True, verbose=True))
# check marginals
+ np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
-def test_convolutional_barycenter_non_square():
+def test_convolutional_barycenter_non_square(nx):
# test for image with height not equal width
A = np.ones((2, 2, 3)) / (2 * 3)
- b = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ Ab = nx.from_numpy(A)
+
+ b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03)
+ b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, 1e-03))
+
+ np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
+ np.testing.assert_allclose(b, b_np)
diff --git a/test/test_partial.py b/test/test_partial.py
index 3571e2a..97c611b 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -104,7 +104,7 @@ def test_partial_wasserstein():
w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=1, m=m,
log=True, verbose=True)
- # check constratints
+ # check constraints
np.testing.assert_equal(
w0.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
@@ -127,7 +127,7 @@ def test_partial_wasserstein():
np.testing.assert_allclose(w0, w0_val, atol=1e-1, rtol=1e-1)
- # check constratints
+ # check constraints
np.testing.assert_equal(
G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
@@ -194,7 +194,7 @@ def test_partial_gromov_wasserstein():
100, m=m,
log=True)
- # check constratints
+ # check constraints
np.testing.assert_equal(
res0.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein
np.testing.assert_equal(
diff --git a/test/test_smooth.py b/test/test_smooth.py
index 2afa4f8..31e0b2e 100644
--- a/test/test_smooth.py
+++ b/test/test_smooth.py
@@ -25,16 +25,16 @@ def test_smooth_ot_dual():
Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
- # kl regyularisation
+ # kl regularisation
G = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
@@ -60,16 +60,16 @@ def test_smooth_ot_semi_dual():
Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, Gl2.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
u, Gl2.sum(0), atol=1e-05) # cf convergence sinkhorn
- # kl regyularisation
+ # kl regularisation
G = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='kl', stopThr=1e-10)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-05) # cf convergence sinkhorn
np.testing.assert_allclose(
diff --git a/test/test_stochastic.py b/test/test_stochastic.py
index 98e93ec..736df32 100644
--- a/test/test_stochastic.py
+++ b/test/test_stochastic.py
@@ -43,7 +43,7 @@ def test_stochastic_sag():
G = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "sag",
numItermax=numItermax)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-03) # cf convergence sag
np.testing.assert_allclose(
@@ -73,7 +73,7 @@ def test_stochastic_asgd():
G, log = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
numItermax=numItermax, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-02) # cf convergence asgd
np.testing.assert_allclose(
@@ -105,7 +105,7 @@ def test_sag_asgd_sinkhorn():
numItermax=nb_iter)
G_sinkhorn = ot.sinkhorn(u, u, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
@@ -148,7 +148,7 @@ def test_stochastic_dual_sgd():
G, log = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
numItermax=numItermax, log=True)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
u, G.sum(1), atol=1e-03) # cf convergence sgd
np.testing.assert_allclose(
@@ -181,7 +181,7 @@ def test_dual_sgd_sinkhorn():
G_sinkhorn = ot.sinkhorn(u, u, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02)
np.testing.assert_allclose(
@@ -206,7 +206,7 @@ def test_dual_sgd_sinkhorn():
G_sinkhorn = ot.sinkhorn(a, b, M, reg)
- # check constratints
+ # check constraints
np.testing.assert_allclose(
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
np.testing.assert_allclose(