summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/__init__.py24
-rw-r--r--ot/backend.py252
-rw-r--r--ot/bregman.py378
-rw-r--r--ot/coot.py434
-rw-r--r--ot/da.py136
-rw-r--r--ot/dr.py32
-rw-r--r--ot/gaussian.py333
-rw-r--r--ot/gromov.py2835
-rw-r--r--ot/gromov/__init__.py48
-rw-r--r--ot/gromov/_bregman.py348
-rw-r--r--ot/gromov/_dictionary.py1008
-rw-r--r--ot/gromov/_estimators.py425
-rw-r--r--ot/gromov/_gw.py978
-rw-r--r--ot/gromov/_semirelaxed.py543
-rw-r--r--ot/gromov/_utils.py413
-rw-r--r--ot/helpers/pre_build_helpers.py24
-rw-r--r--ot/lp/EMD.h5
-rw-r--r--ot/lp/EMD_wrapper.cpp40
-rw-r--r--ot/lp/__init__.py161
-rw-r--r--ot/lp/cvx.py2
-rw-r--r--ot/lp/emd_wrap.pyx9
-rw-r--r--ot/lp/network_simplex_simple.h12
-rw-r--r--ot/lp/network_simplex_simple_omp.h20
-rw-r--r--ot/lp/solver_1d.py629
-rw-r--r--ot/optim.py496
-rwxr-xr-xot/partial.py122
-rw-r--r--ot/sliced.py187
-rw-r--r--ot/smooth.py11
-rw-r--r--ot/solvers.py347
-rw-r--r--ot/unbalanced.py189
-rw-r--r--ot/utils.py238
-rw-r--r--ot/weak.py6
32 files changed, 7255 insertions, 3430 deletions
diff --git a/ot/__init__.py b/ot/__init__.py
index 86ed94e..1a685b6 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -8,7 +8,6 @@
, :py:mod:`ot.unbalanced`.
The following sub-modules are not imported due to additional dependencies:
- :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`.
- - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU.
- :any:`ot.plot` : depends on :code:`matplotlib`
"""
@@ -34,32 +33,39 @@ from . import backend
from . import regpath
from . import weak
from . import factored
+from . import solvers
+from . import gaussian
# OT functions
-from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
+from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
+ binary_search_circle, wasserstein_circle,
+ semidiscrete_wasserstein2_unif_circle)
from .bregman import sinkhorn, sinkhorn2, barycenter
from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
-from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
+from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance,
+ sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif)
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
-
+from .solvers import solve
# utils functions
from .utils import dist, unif, tic, toc, toq
-__version__ = "0.8.2"
+__version__ = "0.9.0"
__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
- 'emd2_1d', 'wasserstein_1d', 'backend',
+ 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
- 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
+ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
- 'factored_optimal_transport',
- 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
+ 'factored_optimal_transport', 'solve',
+ 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
+ 'binary_search_circle', 'wasserstein_circle',
+ 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']
diff --git a/ot/backend.py b/ot/backend.py
index 361ffba..0779243 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -534,9 +534,9 @@ class Backend():
"""
raise NotImplementedError()
- def zero_pad(self, a, pad_width):
+ def zero_pad(self, a, pad_width, value=0):
r"""
- Pads a tensor.
+ Pads a tensor with a given value (0 by default).
This function follows the api from :any:`numpy.pad`
@@ -854,6 +854,21 @@ class Backend():
"""
raise NotImplementedError()
+ def kl_div(self, p, q, eps=1e-16):
+ r"""
+ Computes the Kullback-Leibler divergence.
+
+ This function follows the api from :any:`scipy.stats.entropy`.
+
+ Parameter eps is used to avoid numerical errors and is added in the log.
+
+ .. math::
+ KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
+ """
+ raise NotImplementedError()
+
def isfinite(self, a):
r"""
Tests element-wise for finiteness (not infinity and not Not a Number).
@@ -880,6 +895,62 @@ class Backend():
"""
raise NotImplementedError()
+ def tile(self, a, reps):
+ r"""
+ Construct an array by repeating a the number of times given by reps
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html
+ """
+ raise NotImplementedError()
+
+ def floor(self, a):
+ r"""
+ Return the floor of the input element-wise
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html
+ """
+ raise NotImplementedError()
+
+ def prod(self, a, axis=None):
+ r"""
+ Return the product of all elements.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html
+ """
+ raise NotImplementedError()
+
+ def sort2(self, a, axis=None):
+ r"""
+ Return the sorted array and the indices to sort the array
+
+ See: https://pytorch.org/docs/stable/generated/torch.sort.html
+ """
+ raise NotImplementedError()
+
+ def qr(self, a):
+ r"""
+ Return the QR factorization
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html
+ """
+ raise NotImplementedError()
+
+ def atan2(self, a, b):
+ r"""
+ Element wise arctangent
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html
+ """
+ raise NotImplementedError()
+
+ def transpose(self, a, axes=None):
+ r"""
+ Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -1024,8 +1095,8 @@ class NumpyBackend(Backend):
def concatenate(self, arrays, axis=0):
return np.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return np.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return np.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return np.argmax(a, axis=axis)
@@ -1158,6 +1229,9 @@ class NumpyBackend(Backend):
def sqrtm(self, a):
return scipy.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return np.sum(p * np.log(p / q + eps))
+
def isfinite(self, a):
return np.isfinite(a)
@@ -1167,6 +1241,44 @@ class NumpyBackend(Backend):
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return np.tile(a, reps)
+
+ def floor(self, a):
+ return np.floor(a)
+
+ def prod(self, a, axis=0):
+ return np.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ np_version = tuple([int(k) for k in np.__version__.split(".")])
+ if np_version < (1, 22, 0):
+ M, N = a.shape[-2], a.shape[-1]
+ K = min(M, N)
+
+ if len(a.shape) >= 3:
+ n = a.shape[0]
+
+ qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N))
+
+ for i in range(a.shape[0]):
+ qs[i], rs[i] = np.linalg.qr(a[i])
+
+ else:
+ return np.linalg.qr(a)
+
+ return qs, rs
+ return np.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return np.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return np.transpose(a, axes)
+
class JaxBackend(Backend):
"""
@@ -1333,8 +1445,8 @@ class JaxBackend(Backend):
def concatenate(self, arrays, axis=0):
return jnp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return jnp.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return jnp.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return jnp.argmax(a, axis=axis)
@@ -1481,6 +1593,9 @@ class JaxBackend(Backend):
L, V = jnp.linalg.eigh(a)
return (V * jnp.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return jnp.sum(p * jnp.log(p / q + eps))
+
def isfinite(self, a):
return jnp.isfinite(a)
@@ -1490,6 +1605,27 @@ class JaxBackend(Backend):
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return jnp.tile(a, reps)
+
+ def floor(self, a):
+ return jnp.floor(a)
+
+ def prod(self, a, axis=0):
+ return jnp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return jnp.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return jnp.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return jnp.transpose(a, axes)
+
class TorchBackend(Backend):
"""
@@ -1507,15 +1643,19 @@ class TorchBackend(Backend):
def __init__(self):
- self.rng_ = torch.Generator()
+ self.rng_ = torch.Generator("cpu")
self.rng_.seed()
self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
torch.tensor(1, dtype=torch.float64)]
if torch.cuda.is_available():
+ self.rng_cuda_ = torch.Generator("cuda")
+ self.rng_cuda_.seed()
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
+ else:
+ self.rng_cuda_ = torch.Generator("cpu")
from torch.autograd import Function
@@ -1704,13 +1844,13 @@ class TorchBackend(Backend):
def concatenate(self, arrays, axis=0):
return torch.cat(arrays, dim=axis)
- def zero_pad(self, a, pad_width):
+ def zero_pad(self, a, pad_width, value=0):
from torch.nn.functional import pad
# pad_width is an array of ndim tuples indicating how many 0 before and after
# we need to add. We first need to make it compliant with torch syntax, that
# starts with the last dim, then second last, etc.
how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl)
- return pad(a, how_pad)
+ return pad(a, how_pad, value=value)
def argmax(self, a, axis=None):
return torch.argmax(a, dim=axis)
@@ -1761,20 +1901,26 @@ class TorchBackend(Backend):
def seed(self, seed=None):
if isinstance(seed, int):
self.rng_.manual_seed(seed)
+ self.rng_cuda_.manual_seed(seed)
elif isinstance(seed, torch.Generator):
- self.rng_ = seed
+ if self.device_type(seed) == "GPU":
+ self.rng_cuda_ = seed
+ else:
+ self.rng_ = seed
else:
raise ValueError("Non compatible seed : {}".format(seed))
def rand(self, *size, type_as=None):
if type_as is not None:
- return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device)
+ generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_
+ return torch.rand(size=size, generator=generator, dtype=type_as.dtype, device=type_as.device)
else:
return torch.rand(size=size, generator=self.rng_)
def randn(self, *size, type_as=None):
if type_as is not None:
- return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
+ generator = self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_
+ return torch.randn(size=size, dtype=type_as.dtype, generator=generator, device=type_as.device)
else:
return torch.randn(size=size, generator=self.rng_)
@@ -1891,6 +2037,9 @@ class TorchBackend(Backend):
L, V = torch.linalg.eigh(a)
return (V * torch.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return torch.sum(p * torch.log(p / q + eps))
+
def isfinite(self, a):
return torch.isfinite(a)
@@ -1900,6 +2049,29 @@ class TorchBackend(Backend):
def is_floating_point(self, a):
return a.dtype.is_floating_point
+ def tile(self, a, reps):
+ return a.repeat(reps)
+
+ def floor(self, a):
+ return torch.floor(a)
+
+ def prod(self, a, axis=0):
+ return torch.prod(a, dim=axis)
+
+ def sort2(self, a, axis=-1):
+ return torch.sort(a, axis)
+
+ def qr(self, a):
+ return torch.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return torch.atan2(a, b)
+
+ def transpose(self, a, axes=None):
+ if axes is None:
+ axes = tuple(range(a.ndim)[::-1])
+ return a.permute(axes)
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -2062,8 +2234,8 @@ class CupyBackend(Backend): # pragma: no cover
def concatenate(self, arrays, axis=0):
return cp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return cp.pad(a, pad_width)
+ def zero_pad(self, a, pad_width, value=0):
+ return cp.pad(a, pad_width, constant_values=value)
def argmax(self, a, axis=None):
return cp.argmax(a, axis=axis)
@@ -2238,6 +2410,9 @@ class CupyBackend(Backend): # pragma: no cover
L, V = cp.linalg.eigh(a)
return (V * self.sqrt(L)[None, :]) @ V.T
+ def kl_div(self, p, q, eps=1e-16):
+ return cp.sum(p * cp.log(p / q + eps))
+
def isfinite(self, a):
return cp.isfinite(a)
@@ -2247,6 +2422,27 @@ class CupyBackend(Backend): # pragma: no cover
def is_floating_point(self, a):
return a.dtype.kind == "f"
+ def tile(self, a, reps):
+ return cp.tile(a, reps)
+
+ def floor(self, a):
+ return cp.floor(a)
+
+ def prod(self, a, axis=0):
+ return cp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return cp.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return cp.arctan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return cp.transpose(a, axes)
+
class TensorflowBackend(Backend):
@@ -2417,8 +2613,8 @@ class TensorflowBackend(Backend):
def concatenate(self, arrays, axis=0):
return tnp.concatenate(arrays, axis)
- def zero_pad(self, a, pad_width):
- return tnp.pad(a, pad_width, mode="constant")
+ def zero_pad(self, a, pad_width, value=0):
+ return tnp.pad(a, pad_width, mode="constant", constant_values=value)
def argmax(self, a, axis=None):
return tnp.argmax(a, axis=axis)
@@ -2598,6 +2794,9 @@ class TensorflowBackend(Backend):
def sqrtm(self, a):
return tf.linalg.sqrtm(a)
+ def kl_div(self, p, q, eps=1e-16):
+ return tnp.sum(p * tnp.log(p / q + eps))
+
def isfinite(self, a):
return tnp.isfinite(a)
@@ -2606,3 +2805,24 @@ class TensorflowBackend(Backend):
def is_floating_point(self, a):
return a.dtype.is_floating
+
+ def tile(self, a, reps):
+ return tnp.tile(a, reps)
+
+ def floor(self, a):
+ return tf.floor(a)
+
+ def prod(self, a, axis=0):
+ return tnp.prod(a, axis=axis)
+
+ def sort2(self, a, axis=-1):
+ return self.sort(a, axis), self.argsort(a, axis)
+
+ def qr(self, a):
+ return tf.linalg.qr(a)
+
+ def atan2(self, a, b):
+ return tf.math.atan2(a, b)
+
+ def transpose(self, a, axes=None):
+ return tf.transpose(a, perm=axes)
diff --git a/ot/bregman.py b/ot/bregman.py
index c06af2f..20bef7e 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -24,9 +24,8 @@ from ot.utils import unif, dist, list_to_array
from .backend import get_backend
-def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, warn=True,
- **kwargs):
+def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9,
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -101,6 +100,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -156,34 +158,33 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn,
+ warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn,
+ warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn)
+ warn=warn, warmstart=warmstart)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- return sinkhorn_epsilon_scaling(a, b, M, reg,
- numItermax=numItermax,
- stopThr=stopThr, verbose=verbose,
- log=log, warn=warn,
+ return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the loss
@@ -207,6 +208,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
+ the entropic contribution).
+
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
@@ -257,6 +261,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -320,15 +327,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if len(b.shape) < 2:
if method.lower() == 'sinkhorn':
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
@@ -341,23 +351,25 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log,
+ stopThr=stopThr, warmstart=warmstart,
+ verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
- verbose=False, log=False, warn=True,
- **kwargs):
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -406,6 +418,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -465,12 +480,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
# we assume that no distances are null except those of the diagonal of
# distances
- if n_hists:
- u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
- v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
+ if warmstart is None:
+ if n_hists:
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
+ else:
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
else:
- u = nx.ones(dim_a, type_as=M) / dim_a
- v = nx.ones(dim_b, type_as=M) / dim_b
+ u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
K = nx.exp(M / (-reg))
@@ -538,7 +556,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
- log=False, warn=True, **kwargs):
+ log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem in log space
and return the OT matrix
@@ -587,6 +605,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -647,6 +668,10 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
else:
n_hists = 0
+ # in case of multiple historgrams
+ if n_hists > 1 and warmstart is None:
+ warmstart = [None] * n_hists
+
if n_hists: # we do not want to use tensors sor we do a loop
lst_loss = []
@@ -654,8 +679,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
lst_v = []
for k in range(n_hists):
- res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr,
+ verbose=verbose, log=log, warmstart=warmstart[k], **kwargs)
if log:
lst_loss.append(nx.sum(M * res[0]))
@@ -682,9 +707,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
# we assume that no distances are null except those of the diagonal of
# distances
-
- u = nx.zeros(dim_a, type_as=M)
- v = nx.zeros(dim_b, type_as=M)
+ if warmstart is None:
+ u = nx.zeros(dim_a, type_as=M)
+ v = nx.zeros(dim_b, type_as=M)
+ else:
+ u, v = warmstart
def get_logT(u, v):
if n_hists:
@@ -738,7 +765,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
- log=False, warn=True):
+ log=False, warn=True, warmstart=None):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -786,6 +813,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -844,8 +874,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
K = nx.exp(-M / reg)
- u = nx.full((dim_a,), 1. / dim_a, type_as=K)
- v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ if warmstart is None:
+ u = nx.full((dim_a,), 1. / dim_a, type_as=K)
+ v = nx.full((dim_b,), 1. / dim_b, type_as=K)
+ else:
+ u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
G = u[:, None] * K * v[None, :]
viol = nx.sum(G, axis=1) - a
@@ -1065,7 +1098,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# remove numerical problems and store them in K
if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
- alpha, beta = alpha + reg * nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
+ alpha, beta = alpha + reg * \
+ nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
else:
alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
@@ -1278,7 +1312,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
regi = get_reg(ii)
G, logi = sinkhorn_stabilized(a, b, M, regi,
- numItermax=numInnerItermax, stopThr=1e-9,
+ numItermax=numInnerItermax, stopThr=stopThr,
warmstart=(alpha, beta), verbose=False,
print_period=20, tau=tau, log=True)
@@ -1289,13 +1323,15 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
- err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + nx.norm(nx.sum(transp, axis=1) - a) ** 2
+ err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \
+ nx.norm(nx.sum(transp, axis=1) - a) ** 2
if log:
log['err'].append(err)
if verbose:
if ii % (print_period * 10) == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err <= stopThr and ii > numItermin:
@@ -1511,7 +1547,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
for ii in range(numItermax):
- UKv = u * nx.dot(K, A / nx.dot(K, u))
+ UKv = u * nx.dot(K.T, A / nx.dot(K, u))
u = (u.T * geometricBar(weights, UKv)).T / UKv
if ii % 10 == 1:
@@ -1540,6 +1576,129 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
+def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None,
+ numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None,
+ **kwargs):
+ r"""
+ Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally:
+
+ .. math::
+ \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
+
+ where :
+
+ - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
+ - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
+ - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
+ - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
+
+ This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
+ There are two differences with the following codes:
+
+ - we do not optimize over the weights
+ - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
+ :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ implementation of the fixed-point algorithm of
+ :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
+ - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the
+ transport plan in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
+
+ Parameters
+ ----------
+ measures_locations : list of N (k_i,d) array-like
+ The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
+ measures_weights : list of N (k_i,) array-like
+ Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
+ representing the weights of each discrete input measure
+
+ X_init : (k,d) array-like
+ Initialization of the support locations (on `k` atoms) of the barycenter
+ reg : float
+ Regularization term >0
+ b : (k,) array-like
+ Initialization of the weights of the barycenter (non-negatives, sum to 1)
+ weights : (N,) array-like
+ Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
+
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations when calculating the transport plans with Sinkhorn
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ X : (k,d) array-like
+ Support locations (on k atoms) of the barycenter
+
+ See Also
+ --------
+ ot.bregman.sinkhorn : Entropic regularized OT solver
+ ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming
+
+ .. _references-free-support-barycenter:
+ References
+ ----------
+ .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+ """
+ nx = get_backend(*measures_locations, *measures_weights, X_init)
+
+ iter_count = 0
+
+ N = len(measures_locations)
+ k = X_init.shape[0]
+ d = X_init.shape[1]
+ if b is None:
+ b = nx.ones((k,), type_as=X_init) / k
+ if weights is None:
+ weights = nx.ones((N,), type_as=X_init) / N
+
+ X = X_init
+
+ log_dict = {}
+ displacement_square_norms = []
+
+ displacement_square_norm = stopThr + 1.
+
+ while (displacement_square_norm > stopThr and iter_count < numItermax):
+
+ T_sum = nx.zeros((k, d), type_as=X_init)
+
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
+ M_i = dist(X, measure_locations_i)
+ T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg,
+ numItermax=numInnerItermax, **kwargs)
+ T_sum = T_sum + weight_i * 1. / \
+ b[:, None] * nx.dot(T_i, measure_locations_i)
+
+ displacement_square_norm = nx.sum((T_sum - X) ** 2)
+ if log:
+ displacement_square_norms.append(displacement_square_norm)
+
+ X = T_sum
+
+ if verbose:
+ print('iteration %d, displacement_square_norm=%f\n',
+ iter_count, displacement_square_norm)
+
+ iter_count += 1
+
+ if log:
+ log_dict['displacement_square_norms'] = displacement_square_norms
+ return X, log_dict
+ else:
+ return X
+
+
def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic wasserstein barycenter in log-domain
@@ -2084,7 +2243,8 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
@@ -2162,7 +2322,8 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
@@ -2321,7 +2482,8 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
# debiased Sinkhorn does not converge monotonically
@@ -2401,7 +2563,8 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10
if verbose:
if ii % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr and ii > 20:
break
@@ -2729,7 +2892,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
- log=False, warn=True, **kwargs):
+ log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -2782,6 +2945,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
@@ -2832,14 +2998,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
dict_log = {"err": []}
log_a, log_b = nx.log(a), nx.log(b)
- f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+ if warmstart is None:
+ f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
+ else:
+ f, g = warmstart
if isinstance(batchSize, int):
bs, bt = batchSize, batchSize
elif isinstance(batchSize, tuple) and len(batchSize) == 2:
bs, bt = batchSize[0], batchSize[1]
else:
- raise ValueError("Batch size must be in integer or a tuple of two integers")
+ raise ValueError(
+ "Batch size must be in integer or a tuple of two integers")
range_s, range_t = range(0, ns, bs), range(0, nt, bt)
@@ -2877,7 +3047,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
M = nx.from_numpy(M, type_as=a)
m1_cols.append(
- nx.sum(nx.exp(f[i:i + bs, None] + g[None, :] - M / reg), axis=1)
+ nx.sum(nx.exp(f[i:i + bs, None] +
+ g[None, :] - M / reg), axis=1)
)
m1 = nx.concatenate(m1_cols, axis=0)
err = nx.sum(nx.abs(m1 - a))
@@ -2885,7 +3056,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
dict_log["err"].append(err)
if verbose and (i_ot + 1) % 100 == 0:
- print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))
+ print("Error in marginal at iteration {} = {}".format(
+ i_ot + 1, err))
if err <= stopThr:
break
@@ -2905,17 +3077,17 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=True, **kwargs)
+ verbose=verbose, log=True, warmstart=warmstart, **kwargs)
return pi, log
else:
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
- verbose=verbose, log=False, **kwargs)
+ verbose=verbose, log=False, warmstart=warmstart, **kwargs)
return pi
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9, isLazy=False,
- batchSize=100, verbose=False, log=False, warn=True, **kwargs):
+ numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100,
+ verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -2939,6 +3111,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
+ the entropic contribution).
+
Parameters
----------
@@ -2969,7 +3144,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
-
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -3025,13 +3202,16 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
isLazy=isLazy,
batchSize=batchSize,
verbose=verbose, log=log,
- warn=warn)
+ warn=warn,
+ warmstart=warmstart)
else:
f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
- numIterMax=numIterMax, stopThr=stopThr,
+ numIterMax=numIterMax,
+ stopThr=stopThr,
isLazy=isLazy, batchSize=batchSize,
verbose=verbose, log=log,
- warn=warn)
+ warn=warn,
+ warmstart=warmstart)
bs = batchSize if isinstance(batchSize, int) else batchSize[0]
range_s = range(0, ns, bs)
@@ -3053,25 +3233,23 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
return loss
else:
- M = dist(nx.to_numpy(X_s), nx.to_numpy(X_t), metric=metric)
- M = nx.from_numpy(M, type_as=a)
+ M = dist(X_s, X_t, metric=metric)
if log:
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn, **kwargs)
+ warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss, log
else:
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
- warn=warn, **kwargs)
+ warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
- numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, warn=True,
- **kwargs):
+ numIterMax=10000, stopThr=1e-9, verbose=False,
+ log=False, warn=True, warmstart=None, **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data
@@ -3118,6 +3296,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+ and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F -(\langle \gamma^*_a, \mathbf{M_a} \rangle_F + \langle
+ \gamma^*_b , \mathbf{M_b} \rangle_F)/2`.
+
+ .. note: The current implementation does not account for the entropic contributions and thus differs from the
+ Sinkhorn divergence as introduced in the literature. The possibility to account for the entropic contributions
+ will be provided in a future release.
+
Parameters
----------
@@ -3141,6 +3326,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
+ Initialization of dual potentials. If provided, the dual potentials should be given
+ (that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
@@ -3167,23 +3355,34 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
International Conference on Artficial Intelligence and Statistics,
(AISTATS) 21, 2018
'''
+ X_s, X_t = list_to_array(X_s, X_t)
+
+ nx = get_backend(X_s, X_t)
+ if warmstart is None:
+ warmstart_a, warmstart_b = None, None
+ else:
+ u, v = warmstart
+ warmstart_a = (u, u)
+ warmstart_b = (v, v)
+
if log:
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
- numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart, **kwargs)
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
- numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_a, **kwargs)
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
- numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose,
- log=log, warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_b, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
+ (sinkhorn_loss_a + sinkhorn_loss_b)
log = {}
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
@@ -3193,26 +3392,27 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
log['log_sinkhorn_a'] = log_a
log['log_sinkhorn_b'] = log_b
- return max(0, sinkhorn_div), log
+ return nx.maximum(0, sinkhorn_div), log
else:
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart, **kwargs)
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_a, **kwargs)
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
- numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log,
- warn=warn, **kwargs)
+ numIterMax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=log, warn=warn,
+ warmstart=warmstart_b, **kwargs)
- sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
- return max(0, sinkhorn_div)
+ sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
+ (sinkhorn_loss_a + sinkhorn_loss_b)
+ return nx.maximum(0, sinkhorn_div)
def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
@@ -3379,7 +3579,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
else:
aK_sort = nx.from_numpy(
- bottleneck.partition(nx.to_numpy(K_sum_cols), ns_budget - 1)[ns_budget - 1],
+ bottleneck.partition(nx.to_numpy(
+ K_sum_cols), ns_budget - 1)[ns_budget - 1],
type_as=M
)
epsilon_u_square = a[0] / aK_sort
@@ -3389,7 +3590,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
else:
bK_sort = nx.from_numpy(
- bottleneck.partition(nx.to_numpy(K_sum_rows), nt_budget - 1)[nt_budget - 1],
+ bottleneck.partition(nx.to_numpy(
+ K_sum_rows), nt_budget - 1)[nt_budget - 1],
type_as=M
)
epsilon_v_square = b[0] / bK_sort
diff --git a/ot/coot.py b/ot/coot.py
new file mode 100644
index 0000000..66dd2c8
--- /dev/null
+++ b/ot/coot.py
@@ -0,0 +1,434 @@
+# -*- coding: utf-8 -*-
+"""
+CO-Optimal Transport solver
+"""
+
+# Author: Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
+#
+# License: MIT License
+
+import warnings
+from .lp import emd
+from .utils import list_to_array
+from .backend import get_backend
+from .bregman import sinkhorn
+
+
+def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None,
+ epsilon=0, alpha=0, M_samp=None, M_feat=None,
+ warmstart=None, nits_bcd=100, tol_bcd=1e-7, eval_bcd=1,
+ nits_ot=500, tol_sinkhorn=1e-7, method_sinkhorn="sinkhorn",
+ early_stopping_tol=1e-6, log=False, verbose=False):
+ r"""Compute the CO-Optimal Transport between two matrices.
+
+ Return the sample and feature transport plans between
+ :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and
+ :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`.
+
+ The function solves the following CO-Optimal Transport (COOT) problem:
+
+ .. math::
+ \mathbf{COOT}_{\alpha, \varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}}
+ &\quad \sum_{i,j,k,l}
+ (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l}
+ + \alpha_s \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} \\
+ &+ \alpha_f \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l}
+ + \varepsilon_s \mathbf{KL}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T)
+ + \varepsilon_f \mathbf{KL}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T)
+
+ Where :
+
+ - :math:`\mathbf{X}`: Data matrix in the source space
+ - :math:`\mathbf{Y}`: Data matrix in the target space
+ - :math:`\mathbf{M^{(s)}}`: Additional sample matrix
+ - :math:`\mathbf{M^{(f)}}`: Additional feature matrix
+ - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space
+ - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space
+ - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space
+ - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space
+
+ .. note:: This function allows epsilon to be zero.
+ In that case, the :any:`ot.lp.emd` solver of POT will be used.
+
+ Parameters
+ ----------
+ X : (n_sample_x, n_feature_x) array-like, float
+ First input matrix.
+ Y : (n_sample_y, n_feature_y) array-like, float
+ Second input matrix.
+ wx_samp : (n_sample_x, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix X.
+ Uniform distribution by default.
+ wx_feat : (n_feature_x, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix X.
+ Uniform distribution by default.
+ wy_samp : (n_sample_y, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix Y.
+ Uniform distribution by default.
+ wy_feat : (n_feature_y, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix Y.
+ Uniform distribution by default.
+ epsilon : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Regularization parameters for entropic approximation of sample and feature couplings.
+ Allow the case where epsilon contains 0. In that case, the EMD solver is used instead of
+ Sinkhorn solver. If epsilon is scalar, then the same epsilon is applied to
+ both regularization of sample and feature couplings.
+ alpha : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Coeffficient parameter of linear terms with respect to the sample and feature couplings.
+ If alpha is scalar, then the same alpha is applied to both linear terms.
+ M_samp : (n_sample_x, n_sample_y), float, optional (default = None)
+ Sample matrix with respect to the linear term on sample coupling.
+ M_feat : (n_feature_x, n_feature_y), float, optional (default = None)
+ Feature matrix with respect to the linear term on feature coupling.
+ warmstart : dictionary, optional (default = None)
+ Contains 4 keys:
+ - "duals_sample" and "duals_feature" whose values are
+ tuples of 2 vectors of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature dual vectors
+ if using Sinkhorn algorithm. Zero vectors by default.
+
+ - "pi_sample" and "pi_feature" whose values are matrices
+ of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature couplings.
+ Uniform distributions by default.
+ nits_bcd : int, optional (default = 100)
+ Number of Block Coordinate Descent (BCD) iterations to solve COOT.
+ tol_bcd : float, optional (default = 1e-7)
+ Tolerance of BCD scheme. If the L1-norm between the current and previous
+ sample couplings is under this threshold, then stop BCD scheme.
+ eval_bcd : int, optional (default = 1)
+ Multiplier of iteration at which the COOT cost is evaluated. For example,
+ if `eval_bcd = 8`, then the cost is calculated at iterations 8, 16, 24, etc...
+ nits_ot : int, optional (default = 100)
+ Number of iterations to solve each of the
+ two optimal transport problems in each BCD iteration.
+ tol_sinkhorn : float, optional (default = 1e-7)
+ Tolerance of Sinkhorn algorithm to stop the Sinkhorn scheme for
+ entropic optimal transport problem (if any) in each BCD iteration.
+ Only triggered when Sinkhorn solver is used.
+ method_sinkhorn : string, optional (default = "sinkhorn")
+ Method used in POT's `ot.sinkhorn` solver.
+ Only support "sinkhorn" and "sinkhorn_log".
+ early_stopping_tol : float, optional (default = 1e-6)
+ Tolerance for the early stopping. If the absolute difference between
+ the last 2 recorded COOT distances is under this tolerance, then stop BCD scheme.
+ log : bool, optional (default = False)
+ If True then the cost and 4 dual vectors, including
+ 2 from sample and 2 from feature couplings, are recorded.
+ verbose : bool, optional (default = False)
+ If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration.
+
+ Returns
+ -------
+ pi_samp : (n_sample_x, n_sample_y) array-like, float
+ Sample coupling matrix.
+ pi_feat : (n_feature_x, n_feature_y) array-like, float
+ Feature coupling matrix.
+ log : dictionary, optional
+ Returned if `log` is True. The keys are:
+ duals_sample : (n_sample_x, n_sample_y) tuple, float
+ Pair of dual vectors when solving OT problem w.r.t the sample coupling.
+ duals_feature : (n_feature_x, n_feature_y) tuple, float
+ Pair of dual vectors when solving OT problem w.r.t the feature coupling.
+ distances : list, float
+ List of COOT distances.
+
+ References
+ ----------
+ .. [49] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport,
+ Advances in Neural Information Processing ny_sampstems, 33 (2020).
+ """
+
+ def compute_kl(p, q):
+ kl = nx.sum(p * nx.log(p + 1.0 * (p == 0))) - nx.sum(p * nx.log(q))
+ return kl
+
+ # Main function
+
+ if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]:
+ raise ValueError(
+ "Method {} is not supported in CO-Optimal Transport.".format(method_sinkhorn))
+
+ X, Y = list_to_array(X, Y)
+ nx = get_backend(X, Y)
+
+ if isinstance(epsilon, float) or isinstance(epsilon, int):
+ eps_samp, eps_feat = epsilon, epsilon
+ else:
+ if len(epsilon) != 2:
+ raise ValueError("Epsilon must be either a scalar or an indexable object of length 2.")
+ else:
+ eps_samp, eps_feat = epsilon[0], epsilon[1]
+
+ if isinstance(alpha, float) or isinstance(alpha, int):
+ alpha_samp, alpha_feat = alpha, alpha
+ else:
+ if len(alpha) != 2:
+ raise ValueError("Alpha must be either a scalar or an indexable object of length 2.")
+ else:
+ alpha_samp, alpha_feat = alpha[0], alpha[1]
+
+ # constant input variables
+ if M_samp is None or alpha_samp == 0:
+ M_samp, alpha_samp = 0, 0
+ if M_feat is None or alpha_feat == 0:
+ M_feat, alpha_feat = 0, 0
+
+ nx_samp, nx_feat = X.shape
+ ny_samp, ny_feat = Y.shape
+
+ # measures on rows and columns
+ if wx_samp is None:
+ wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp
+ if wx_feat is None:
+ wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat
+ if wy_samp is None:
+ wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp
+ if wy_feat is None:
+ wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat
+
+ wxy_samp = wx_samp[:, None] * wy_samp[None, :]
+ wxy_feat = wx_feat[:, None] * wy_feat[None, :]
+
+ # pre-calculate cost constants
+ XY_sqr = (X ** 2 @ wx_feat)[:, None] + (Y ** 2 @
+ wy_feat)[None, :] + alpha_samp * M_samp
+ XY_sqr_T = ((X.T)**2 @ wx_samp)[:, None] + ((Y.T)
+ ** 2 @ wy_samp)[None, :] + alpha_feat * M_feat
+
+ # initialize coupling and dual vectors
+ if warmstart is None:
+ pi_samp, pi_feat = wxy_samp, wxy_feat # shape nx_samp x ny_samp and nx_feat x ny_feat
+ duals_samp = (nx.zeros(nx_samp, type_as=X), nx.zeros(
+ ny_samp, type_as=Y)) # shape nx_samp, ny_samp
+ duals_feat = (nx.zeros(nx_feat, type_as=X), nx.zeros(
+ ny_feat, type_as=Y)) # shape nx_feat, ny_feat
+ else:
+ pi_samp, pi_feat = warmstart["pi_sample"], warmstart["pi_feature"]
+ duals_samp, duals_feat = warmstart["duals_sample"], warmstart["duals_feature"]
+
+ # initialize log
+ list_coot = [float("inf")]
+ err = tol_bcd + 1e-3
+
+ for idx in range(nits_bcd):
+ pi_samp_prev = nx.copy(pi_samp)
+
+ # update sample coupling
+ ot_cost = XY_sqr - 2 * X @ pi_feat @ Y.T # size nx_samp x ny_samp
+ if eps_samp > 0:
+ pi_samp, dict_log = sinkhorn(a=wx_samp, b=wy_samp, M=ot_cost, reg=eps_samp, method=method_sinkhorn,
+ numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_samp)
+ duals_samp = (nx.log(dict_log["u"]), nx.log(dict_log["v"]))
+ elif eps_samp == 0:
+ pi_samp, dict_log = emd(
+ a=wx_samp, b=wy_samp, M=ot_cost, numItermax=nits_ot, log=True)
+ duals_samp = (dict_log["u"], dict_log["v"])
+ # update feature coupling
+ ot_cost = XY_sqr_T - 2 * X.T @ pi_samp @ Y # size nx_feat x ny_feat
+ if eps_feat > 0:
+ pi_feat, dict_log = sinkhorn(a=wx_feat, b=wy_feat, M=ot_cost, reg=eps_feat, method=method_sinkhorn,
+ numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_feat)
+ duals_feat = (nx.log(dict_log["u"]), nx.log(dict_log["v"]))
+ elif eps_feat == 0:
+ pi_feat, dict_log = emd(
+ a=wx_feat, b=wy_feat, M=ot_cost, numItermax=nits_ot, log=True)
+ duals_feat = (dict_log["u"], dict_log["v"])
+
+ if idx % eval_bcd == 0:
+ # update error
+ err = nx.sum(nx.abs(pi_samp - pi_samp_prev))
+
+ # COOT part
+ coot = nx.sum(ot_cost * pi_feat)
+ if alpha_samp != 0:
+ coot = coot + alpha_samp * nx.sum(M_samp * pi_samp)
+ # Entropic part
+ if eps_samp != 0:
+ coot = coot + eps_samp * compute_kl(pi_samp, wxy_samp)
+ if eps_feat != 0:
+ coot = coot + eps_feat * compute_kl(pi_feat, wxy_feat)
+ list_coot.append(coot)
+
+ if err < tol_bcd or abs(list_coot[-2] - list_coot[-1]) < early_stopping_tol:
+ break
+
+ if verbose:
+ print(
+ "CO-Optimal Transport cost at iteration {}: {}".format(idx + 1, coot))
+
+ # sanity check
+ if nx.sum(nx.isnan(pi_samp)) > 0 or nx.sum(nx.isnan(pi_feat)) > 0:
+ warnings.warn("There is NaN in coupling.")
+
+ if log:
+ dict_log = {"duals_sample": duals_samp,
+ "duals_feature": duals_feat,
+ "distances": list_coot[1:]}
+
+ return pi_samp, pi_feat, dict_log
+
+ else:
+ return pi_samp, pi_feat
+
+
+def co_optimal_transport2(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None,
+ epsilon=0, alpha=0, M_samp=None, M_feat=None,
+ warmstart=None, log=False, verbose=False, early_stopping_tol=1e-6,
+ nits_bcd=100, tol_bcd=1e-7, eval_bcd=1,
+ nits_ot=500, tol_sinkhorn=1e-7,
+ method_sinkhorn="sinkhorn"):
+ r"""Compute the CO-Optimal Transport distance between two measures.
+
+ Returns the CO-Optimal Transport distance between
+ :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and
+ :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`.
+
+ The function solves the following CO-Optimal Transport (COOT) problem:
+
+ .. math::
+ \mathbf{COOT}_{\alpha, \varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}}
+ &\quad \sum_{i,j,k,l}
+ (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l}
+ + \alpha_1 \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} \\
+ &+ \alpha_2 \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l}
+ + \varepsilon_1 \mathbf{KL}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T)
+ + \varepsilon_2 \mathbf{KL}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T)
+
+ Where :
+
+ - :math:`\mathbf{X}`: Data matrix in the source space
+ - :math:`\mathbf{Y}`: Data matrix in the target space
+ - :math:`\mathbf{M^{(s)}}`: Additional sample matrix
+ - :math:`\mathbf{M^{(f)}}`: Additional feature matrix
+ - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space
+ - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space
+ - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space
+ - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space
+
+ .. note:: This function allows epsilon to be zero.
+ In that case, the :any:`ot.lp.emd` solver of POT will be used.
+
+ Parameters
+ ----------
+ X : (n_sample_x, n_feature_x) array-like, float
+ First input matrix.
+ Y : (n_sample_y, n_feature_y) array-like, float
+ Second input matrix.
+ wx_samp : (n_sample_x, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix X.
+ Uniform distribution by default.
+ wx_feat : (n_feature_x, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix X.
+ Uniform distribution by default.
+ wy_samp : (n_sample_y, ) array-like, float, optional (default = None)
+ Histogram assigned on rows (samples) of matrix Y.
+ Uniform distribution by default.
+ wy_feat : (n_feature_y, ) array-like, float, optional (default = None)
+ Histogram assigned on columns (features) of matrix Y.
+ Uniform distribution by default.
+ epsilon : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Regularization parameters for entropic approximation of sample and feature couplings.
+ Allow the case where epsilon contains 0. In that case, the EMD solver is used instead of
+ Sinkhorn solver. If epsilon is scalar, then the same epsilon is applied to
+ both regularization of sample and feature couplings.
+ alpha : scalar or indexable object of length 2, float or int, optional (default = 0)
+ Coeffficient parameter of linear terms with respect to the sample and feature couplings.
+ If alpha is scalar, then the same alpha is applied to both linear terms.
+ M_samp : (n_sample_x, n_sample_y), float, optional (default = None)
+ Sample matrix with respect to the linear term on sample coupling.
+ M_feat : (n_feature_x, n_feature_y), float, optional (default = None)
+ Feature matrix with respect to the linear term on feature coupling.
+ warmstart : dictionary, optional (default = None)
+ Contains 4 keys:
+ - "duals_sample" and "duals_feature" whose values are
+ tuples of 2 vectors of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature dual vectors
+ if using Sinkhorn algorithm. Zero vectors by default.
+
+ - "pi_sample" and "pi_feature" whose values are matrices
+ of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y).
+ Initialization of sample and feature couplings.
+ Uniform distributions by default.
+ nits_bcd : int, optional (default = 100)
+ Number of Block Coordinate Descent (BCD) iterations to solve COOT.
+ tol_bcd : float, optional (default = 1e-7)
+ Tolerance of BCD scheme. If the L1-norm between the current and previous
+ sample couplings is under this threshold, then stop BCD scheme.
+ eval_bcd : int, optional (default = 1)
+ Multiplier of iteration at which the COOT cost is evaluated. For example,
+ if `eval_bcd = 8`, then the cost is calculated at iterations 8, 16, 24, etc...
+ nits_ot : int, optional (default = 100)
+ Number of iterations to solve each of the
+ two optimal transport problems in each BCD iteration.
+ tol_sinkhorn : float, optional (default = 1e-7)
+ Tolerance of Sinkhorn algorithm to stop the Sinkhorn scheme for
+ entropic optimal transport problem (if any) in each BCD iteration.
+ Only triggered when Sinkhorn solver is used.
+ method_sinkhorn : string, optional (default = "sinkhorn")
+ Method used in POT's `ot.sinkhorn` solver.
+ Only support "sinkhorn" and "sinkhorn_log".
+ early_stopping_tol : float, optional (default = 1e-6)
+ Tolerance for the early stopping. If the absolute difference between
+ the last 2 recorded COOT distances is under this tolerance, then stop BCD scheme.
+ log : bool, optional (default = False)
+ If True then the cost and 4 dual vectors, including
+ 2 from sample and 2 from feature couplings, are recorded.
+ verbose : bool, optional (default = False)
+ If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration.
+
+ Returns
+ -------
+ float
+ CO-Optimal Transport distance.
+ dict
+ Contains logged informations from :any:`co_optimal_transport` solver.
+ Only returned if `log` parameter is True
+
+ References
+ ----------
+ .. [47] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport,
+ Advances in Neural Information Processing ny_sampstems, 33 (2020).
+ """
+
+ pi_samp, pi_feat, dict_log = co_optimal_transport(X=X, Y=Y, wx_samp=wx_samp, wx_feat=wx_feat, wy_samp=wy_samp,
+ wy_feat=wy_feat, epsilon=epsilon, alpha=alpha, M_samp=M_samp,
+ M_feat=M_feat, warmstart=warmstart, nits_bcd=nits_bcd,
+ tol_bcd=tol_bcd, eval_bcd=eval_bcd, nits_ot=nits_ot,
+ tol_sinkhorn=tol_sinkhorn, method_sinkhorn=method_sinkhorn,
+ early_stopping_tol=early_stopping_tol,
+ log=True, verbose=verbose)
+
+ X, Y = list_to_array(X, Y)
+ nx = get_backend(X, Y)
+
+ nx_samp, nx_feat = X.shape
+ ny_samp, ny_feat = Y.shape
+
+ # measures on rows and columns
+ if wx_samp is None:
+ wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp
+ if wx_feat is None:
+ wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat
+ if wy_samp is None:
+ wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp
+ if wy_feat is None:
+ wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat
+
+ vx_samp, vy_samp = dict_log["duals_sample"]
+ vx_feat, vy_feat = dict_log["duals_feature"]
+
+ gradX = 2 * X * (wx_samp[:, None] * wx_feat[None, :]) - \
+ 2 * pi_samp @ Y @ pi_feat.T # shape (nx_samp, nx_feat)
+ gradY = 2 * Y * (wy_samp[:, None] * wy_feat[None, :]) - \
+ 2 * pi_samp.T @ X @ pi_feat # shape (ny_samp, ny_feat)
+
+ coot = dict_log["distances"][-1]
+ coot = nx.set_gradients(coot, (wx_samp, wx_feat, wy_samp, wy_feat, X, Y),
+ (vx_samp, vx_feat, vy_samp, vy_feat, gradX, gradY))
+
+ if log:
+ return coot, dict_log
+
+ else:
+ return coot
diff --git a/ot/da.py b/ot/da.py
index 0b9737e..5067a69 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -17,8 +17,9 @@ from .backend import get_backend
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
-from .utils import list_to_array, check_params, BaseEstimator
+from .utils import list_to_array, check_params, BaseEstimator, deprecated
from .unbalanced import sinkhorn_unbalanced
+from .gaussian import empirical_bures_wasserstein_mapping
from .optim import cg
from .optim import gcg
@@ -126,8 +127,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
W = nx.zeros(M.shape, type_as=M)
for cpt in range(numItermax):
Mreg = M + eta * W
- transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
- stopThr=stopInnerThr)
+ if log:
+ transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr, log=True)
+ else:
+ transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr)
# the transport has been computed. Check if classes are really
# separated
W = nx.ones(M.shape, type_as=M)
@@ -136,7 +141,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs
- return transp
+ if log:
+ return transp, log
+ else:
+ return transp
def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -672,112 +680,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
-def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
- wt=None, bias=True, log=False):
- r"""Return OT linear operator between samples.
-
- The function estimates the optimal linear operator that aligns the two
- empirical distributions. This is equivalent to estimating the closed
- form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
- and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
- :ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in
- :ref:`[15] <references-OT-mapping-linear>`.
-
- The linear operator from source to target :math:`M`
-
- .. math::
- M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
-
- where :
-
- .. math::
- \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
- \Sigma_s^{-1/2}
-
- \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
-
- Parameters
- ----------
- xs : array-like (ns,d)
- samples in the source domain
- xt : array-like (nt,d)
- samples in the target domain
- reg : float,optional
- regularization added to the diagonals of covariances (>0)
- ws : array-like (ns,1), optional
- weights for the source samples
- wt : array-like (ns,1), optional
- weights for the target samples
- bias: boolean, optional
- estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
- log : bool, optional
- record log if True
-
-
- Returns
- -------
- A : (d, d) array-like
- Linear operator
- b : (1, d) array-like
- bias
- log : dict
- log dictionary return only if log==True in parameters
-
-
- .. _references-OT-mapping-linear:
- References
- ----------
- .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of
- distributions", Journal of Optimization Theory and Applications
- Vol 43, 1984
-
- .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
- Transport", 2018.
-
-
- """
- xs, xt = list_to_array(xs, xt)
- nx = get_backend(xs, xt)
-
- d = xs.shape[1]
-
- if bias:
- mxs = nx.mean(xs, axis=0)[None, :]
- mxt = nx.mean(xt, axis=0)[None, :]
-
- xs = xs - mxs
- xt = xt - mxt
- else:
- mxs = nx.zeros((1, d), type_as=xs)
- mxt = nx.zeros((1, d), type_as=xs)
-
- if ws is None:
- ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
-
- if wt is None:
- wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
-
- Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
- Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
-
- Cs12 = nx.sqrtm(Cs)
- Cs_12 = nx.inv(Cs12)
-
- M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
-
- A = dots(Cs_12, M0, Cs_12)
-
- b = mxt - nx.dot(mxs, A)
-
- if log:
- log = {}
- log['Cs'] = Cs
- log['Ct'] = Ct
- log['Cs12'] = Cs12
- log['Cs_12'] = Cs_12
- return A, b, log
- else:
- return A, b
+OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping)
def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5,
@@ -1371,10 +1274,10 @@ class LinearTransport(BaseTransport):
self.mu_t = self.distribution_estimation(Xt)
# coupling estimation
- returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
- ws=nx.reshape(self.mu_s, (-1, 1)),
- wt=nx.reshape(self.mu_t, (-1, 1)),
- bias=self.bias, log=self.log)
+ returned_ = empirical_bures_wasserstein_mapping(Xs, Xt, reg=self.reg,
+ ws=nx.reshape(self.mu_s, (-1, 1)),
+ wt=nx.reshape(self.mu_t, (-1, 1)),
+ bias=self.bias, log=self.log)
# deal with the value of log
if self.log:
@@ -1514,12 +1417,13 @@ class SinkhornTransport(BaseTransport):
Sciences, 7(3), 1853-1882.
"""
- def __init__(self, reg_e=1., max_iter=1000,
+ def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000,
tol=10e-9, verbose=False, log=False,
metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
self.reg_e = reg_e
+ self.method = method
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose
@@ -1560,7 +1464,7 @@ class SinkhornTransport(BaseTransport):
# coupling estimation
returned_ = sinkhorn(
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
- numItermax=self.max_iter, stopThr=self.tol,
+ method=self.method, numItermax=self.max_iter, stopThr=self.tol,
verbose=self.verbose, log=self.log)
# deal with the value of log
diff --git a/ot/dr.py b/ot/dr.py
index 0955c55..b92cd14 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -17,10 +17,10 @@ Dimension reduction with OT
from scipy import linalg
import autograd.numpy as np
-from pymanopt.function import Autograd
-from pymanopt.manifolds import Stiefel
-from pymanopt import Problem
-from pymanopt.solvers import SteepestDescent, TrustRegions
+
+import pymanopt
+import pymanopt.manifolds
+import pymanopt.optimizers
def dist(x1, x2):
@@ -38,8 +38,8 @@ def sinkhorn(w1, w2, M, reg, k):
ui = np.ones((M.shape[0],))
vi = np.ones((M.shape[1],))
for i in range(k):
- vi = w2 / (np.dot(K.T, ui))
- ui = w1 / (np.dot(K, vi))
+ vi = w2 / (np.dot(K.T, ui) + 1e-50)
+ ui = w1 / (np.dot(K, vi) + 1e-50)
G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
return G
@@ -167,7 +167,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (entropic regularization)
- solver : None | str, optional
+ solver : None | str, optional
None for steepest descent or 'TrustRegions' for trust regions algorithm
else should be a pymanopt.solvers
sinkhorn_method : str
@@ -222,7 +222,9 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
else:
regmean = np.ones((len(xc), len(xc)))
- @Autograd
+ manifold = pymanopt.manifolds.Stiefel(d, p)
+
+ @pymanopt.function.autograd(manifold)
def cost(P):
# wda loss
loss_b = 0
@@ -243,21 +245,21 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
return loss_w / loss_b
# declare manifold and problem
- manifold = Stiefel(d, p)
- problem = Problem(manifold=manifold, cost=cost)
+
+ problem = pymanopt.Problem(manifold=manifold, cost=cost)
# declare solver and solve
if solver is None:
- solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
+ solver = pymanopt.optimizers.SteepestDescent(max_iterations=maxiter, log_verbosity=verbose)
elif solver in ['tr', 'TrustRegions']:
- solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
+ solver = pymanopt.optimizers.TrustRegions(max_iterations=maxiter, log_verbosity=verbose)
- Popt = solver.solve(problem, x=P0)
+ Popt = solver.run(problem, initial_point=P0)
def proj(X):
- return (X - mx.reshape((1, -1))).dot(Popt)
+ return (X - mx.reshape((1, -1))).dot(Popt.point)
- return Popt, proj
+ return Popt.point, proj
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
diff --git a/ot/gaussian.py b/ot/gaussian.py
new file mode 100644
index 0000000..4ffb726
--- /dev/null
+++ b/ot/gaussian.py
@@ -0,0 +1,333 @@
+# -*- coding: utf-8 -*-
+"""
+Optimal transport for Gaussian distributions
+"""
+
+# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
+# Remi Flamary <remi.flamary@polytehnique.edu>
+#
+# License: MIT License
+
+from .backend import get_backend
+from .utils import dots
+from .utils import list_to_array
+
+
+def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False):
+ r"""Return OT linear operator between samples.
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
+ and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
+ :ref:`[1] <references-OT-mapping-linear>` and discussed in remark 2.29 in
+ :ref:`[2] <references-OT-mapping-linear>`.
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
+
+ where :
+
+ .. math::
+ \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
+ \Sigma_s^{-1/2}
+
+ \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
+
+ Parameters
+ ----------
+ ms : array-like (d,)
+ mean of the source distribution
+ mt : array-like (d,)
+ mean of the target distribution
+ Cs : array-like (d,)
+ covariance of the source distribution
+ Ct : array-like (d,)
+ covariance of the target distribution
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ A : (d, d) array-like
+ Linear operator
+ b : (1, d) array-like
+ bias
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-OT-mapping-linear:
+ References
+ ----------
+ .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct)
+ nx = get_backend(ms, mt, Cs, Ct)
+
+ Cs12 = nx.sqrtm(Cs)
+ Cs12inv = nx.inv(Cs12)
+
+ M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
+
+ A = dots(Cs12inv, M0, Cs12inv)
+
+ b = mt - nx.dot(ms, A)
+
+ if log:
+ log = {}
+ log['Cs12'] = Cs12
+ log['Cs12inv'] = Cs12inv
+ return A, b, log
+ else:
+ return A, b
+
+
+def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None,
+ wt=None, bias=True, log=False):
+ r"""Return OT linear operator between samples.
+
+ The function estimates the optimal linear operator that aligns the two
+ empirical distributions. This is equivalent to estimating the closed
+ form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)`
+ and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in
+ :ref:`[1] <references-OT-mapping-linear>` and discussed in remark 2.29 in
+ :ref:`[2] <references-OT-mapping-linear>`.
+
+ The linear operator from source to target :math:`M`
+
+ .. math::
+ M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b}
+
+ where :
+
+ .. math::
+ \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2}
+ \Sigma_s^{-1/2}
+
+ \mathbf{b} &= \mu_t - \mathbf{A} \mu_s
+
+ Parameters
+ ----------
+ xs : array-like (ns,d)
+ samples in the source domain
+ xt : array-like (nt,d)
+ samples in the target domain
+ reg : float,optional
+ regularization added to the diagonals of covariances (>0)
+ ws : array-like (ns,1), optional
+ weights for the source samples
+ wt : array-like (ns,1), optional
+ weights for the target samples
+ bias: boolean, optional
+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ A : (d, d) array-like
+ Linear operator
+ b : (1, d) array-like
+ bias
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-OT-mapping-linear:
+ References
+ ----------
+ .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of
+ distributions", Journal of Optimization Theory and Applications
+ Vol 43, 1984
+
+ .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
+
+ d = xs.shape[1]
+
+ if bias:
+ mxs = nx.mean(xs, axis=0)[None, :]
+ mxt = nx.mean(xt, axis=0)[None, :]
+
+ xs = xs - mxs
+ xt = xt - mxt
+ else:
+ mxs = nx.zeros((1, d), type_as=xs)
+ mxt = nx.zeros((1, d), type_as=xs)
+
+ if ws is None:
+ ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
+
+ Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
+ Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
+
+ if log:
+ A, b, log = bures_wasserstein_mapping(mxs, mxt, Cs, Ct, log=log)
+ log['Cs'] = Cs
+ log['Ct'] = Ct
+ return A, b, log
+ else:
+ A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct)
+ return A, b
+
+
+def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
+ r"""Return Bures Wasserstein distance between samples.
+
+ The function estimates the Bures-Wasserstein distance between two
+ empirical distributions source :math:`\mu_s` and target :math:`\mu_t`,
+ discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
+
+ The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}`
+
+ .. math::
+ \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}
+
+ where :
+
+ .. math::
+ \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
+
+ Parameters
+ ----------
+ ms : array-like (d,)
+ mean of the source distribution
+ mt : array-like (d,)
+ mean of the target distribution
+ Cs : array-like (d,)
+ covariance of the source distribution
+ Ct : array-like (d,)
+ covariance of the target distribution
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ W : float
+ Bures Wasserstein distance
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-bures-wasserstein-distance:
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct)
+ nx = get_backend(ms, mt, Cs, Ct)
+
+ Cs12 = nx.sqrtm(Cs)
+
+ B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
+ W = nx.sqrt(nx.norm(ms - mt)**2 + B)
+ if log:
+ log = {}
+ log['Cs12'] = Cs12
+ return W, log
+ else:
+ return W
+
+
+def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
+ wt=None, bias=True, log=False):
+ r"""Return Bures Wasserstein distance from mean and covariance of distribution.
+
+ The function estimates the Bures-Wasserstein distance between two
+ empirical distributions source :math:`\mu_s` and target :math:`\mu_t`,
+ discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
+
+ The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}`
+
+ .. math::
+ \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}
+
+ where :
+
+ .. math::
+ \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
+
+ Parameters
+ ----------
+ xs : array-like (ns,d)
+ samples in the source domain
+ xt : array-like (nt,d)
+ samples in the target domain
+ reg : float,optional
+ regularization added to the diagonals of covariances (>0)
+ ws : array-like (ns,1), optional
+ weights for the source samples
+ wt : array-like (ns,1), optional
+ weights for the target samples
+ bias: boolean, optional
+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ W : float
+ Bures Wasserstein distance
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-bures-wasserstein-distance:
+ References
+ ----------
+
+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
+ Transport", 2018.
+ """
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
+
+ d = xs.shape[1]
+
+ if bias:
+ mxs = nx.mean(xs, axis=0)[None, :]
+ mxt = nx.mean(xt, axis=0)[None, :]
+
+ xs = xs - mxs
+ xt = xt - mxt
+ else:
+ mxs = nx.zeros((1, d), type_as=xs)
+ mxt = nx.zeros((1, d), type_as=xs)
+
+ if ws is None:
+ ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
+
+ if wt is None:
+ wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
+
+ Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
+ Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
+
+ if log:
+ W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log)
+ log['Cs'] = Cs
+ log['Ct'] = Ct
+ return W, log
+ else:
+ W = bures_wasserstein_distance(mxs, mxt, Cs, Ct)
+ return W
diff --git a/ot/gromov.py b/ot/gromov.py
deleted file mode 100644
index 55ab0bd..0000000
--- a/ot/gromov.py
+++ /dev/null
@@ -1,2835 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers
-"""
-
-# Author: Erwan Vautier <erwan.vautier@gmail.com>
-# Nicolas Courty <ncourty@irisa.fr>
-# Rémi Flamary <remi.flamary@unice.fr>
-# Titouan Vayer <titouan.vayer@irisa.fr>
-# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
-#
-# License: MIT License
-
-import numpy as np
-
-
-from .bregman import sinkhorn
-from .utils import dist, UndefinedParameter, list_to_array
-from .optim import cg
-from .lp import emd_1d, emd
-from .utils import check_random_state, unif
-from .backend import get_backend
-
-
-def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
- r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation
-
- Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the
- selected loss function as the loss function of Gromow-Wasserstein discrepancy.
-
- The matrices are computed as described in Proposition 1 in :ref:`[12] <references-init-matrix>`
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{T}`: A coupling between those two spaces
-
- The square-loss function :math:`L(a, b) = |a - b|^2` is read as :
-
- .. math::
-
- L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
-
- \mathrm{with} \ f_1(a) &= a^2
-
- f_2(b) &= b^2
-
- h_1(a) &= a
-
- h_2(b) &= 2b
-
- The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as :
-
- .. math::
-
- L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
-
- \mathrm{with} \ f_1(a) &= a \log(a) - a
-
- f_2(b) &= b
-
- h_1(a) &= a
-
- h_2(b) &= \log(b)
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- T : array-like, shape (ns, nt)
- Coupling between source and target spaces
- p : array-like, shape (ns,)
-
- Returns
- -------
- constC : array-like, shape (ns, nt)
- Constant :math:`\mathbf{C}` matrix in Eq. (6)
- hC1 : array-like, shape (ns, ns)
- :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
- hC2 : array-like, shape (nt, nt)
- :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
-
-
- .. _references-init-matrix:
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- C1, C2, p, q = list_to_array(C1, C2, p, q)
- nx = get_backend(C1, C2, p, q)
-
- if loss_fun == 'square_loss':
- def f1(a):
- return (a**2)
-
- def f2(b):
- return (b**2)
-
- def h1(a):
- return a
-
- def h2(b):
- return 2 * b
- elif loss_fun == 'kl_loss':
- def f1(a):
- return a * nx.log(a + 1e-15) - a
-
- def f2(b):
- return b
-
- def h1(a):
- return a
-
- def h2(b):
- return nx.log(b + 1e-15)
-
- constC1 = nx.dot(
- nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
- nx.ones((1, len(q)), type_as=q)
- )
- constC2 = nx.dot(
- nx.ones((len(p), 1), type_as=p),
- nx.dot(nx.reshape(q, (1, -1)), f2(C2).T)
- )
- constC = constC1 + constC2
- hC1 = h1(C1)
- hC2 = h2(C2)
-
- return constC, hC1, hC2
-
-
-def tensor_product(constC, hC1, hC2, T):
- r"""Return the tensor for Gromov-Wasserstein fast computation
-
- The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-tensor-product>`
-
- Parameters
- ----------
- constC : array-like, shape (ns, nt)
- Constant :math:`\mathbf{C}` matrix in Eq. (6)
- hC1 : array-like, shape (ns, ns)
- :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
- hC2 : array-like, shape (nt, nt)
- :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
-
- Returns
- -------
- tens : array-like, shape (`ns`, `nt`)
- :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result
-
-
- .. _references-tensor-product:
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T)
- nx = get_backend(constC, hC1, hC2, T)
-
- A = - nx.dot(
- nx.dot(hC1, T), hC2.T
- )
- tens = constC + A
- # tens -= tens.min()
- return tens
-
-
-def gwloss(constC, hC1, hC2, T):
- r"""Return the Loss for Gromov-Wasserstein
-
- The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-gwloss>`
-
- Parameters
- ----------
- constC : array-like, shape (ns, nt)
- Constant :math:`\mathbf{C}` matrix in Eq. (6)
- hC1 : array-like, shape (ns, ns)
- :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
- hC2 : array-like, shape (nt, nt)
- :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
- T : array-like, shape (ns, nt)
- Current value of transport matrix :math:`\mathbf{T}`
-
- Returns
- -------
- loss : float
- Gromov Wasserstein loss
-
-
- .. _references-gwloss:
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
-
- tens = tensor_product(constC, hC1, hC2, T)
-
- tens, T = list_to_array(tens, T)
- nx = get_backend(tens, T)
-
- return nx.sum(tens * T)
-
-
-def gwggrad(constC, hC1, hC2, T):
- r"""Return the gradient for Gromov-Wasserstein
-
- The gradient is computed as described in Proposition 2 in :ref:`[12] <references-gwggrad>`
-
- Parameters
- ----------
- constC : array-like, shape (ns, nt)
- Constant :math:`\mathbf{C}` matrix in Eq. (6)
- hC1 : array-like, shape (ns, ns)
- :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
- hC2 : array-like, shape (nt, nt)
- :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
- T : array-like, shape (ns, nt)
- Current value of transport matrix :math:`\mathbf{T}`
-
- Returns
- -------
- grad : array-like, shape (`ns`, `nt`)
- Gromov Wasserstein gradient
-
-
- .. _references-gwggrad:
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- return 2 * tensor_product(constC, hC1, hC2,
- T) # [12] Prop. 2 misses a 2 factor
-
-
-def update_square_loss(p, lambdas, T, Cs):
- r"""
- Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
- couplings calculated at each iteration
-
- Parameters
- ----------
- p : array-like, shape (N,)
- Masses in the targeted barycenter.
- lambdas : list of float
- List of the `S` spaces' weights.
- T : list of S array-like of shape (ns,N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
- Cs : list of S array-like, shape(ns,ns)
- Metric cost matrices.
-
- Returns
- ----------
- C : array-like, shape (`nt`, `nt`)
- Updated :math:`\mathbf{C}` matrix.
- """
- T = list_to_array(*T)
- Cs = list_to_array(*Cs)
- p = list_to_array(p)
- nx = get_backend(p, *T, *Cs)
-
- tmpsum = sum([
- lambdas[s] * nx.dot(
- nx.dot(T[s].T, Cs[s]),
- T[s]
- ) for s in range(len(T))
- ])
- ppt = nx.outer(p, p)
-
- return tmpsum / ppt
-
-
-def update_kl_loss(p, lambdas, T, Cs):
- r"""
- Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
-
-
- Parameters
- ----------
- p : array-like, shape (N,)
- Weights in the targeted barycenter.
- lambdas : list of float
- List of the `S` spaces' weights
- T : list of S array-like of shape (ns,N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
- Cs : list of S array-like, shape(ns,ns)
- Metric cost matrices.
-
- Returns
- ----------
- C : array-like, shape (`ns`, `ns`)
- updated :math:`\mathbf{C}` matrix
- """
- Cs = list_to_array(*Cs)
- T = list_to_array(*T)
- p = list_to_array(p)
- nx = get_backend(p, *T, *Cs)
-
- tmpsum = sum([
- lambdas[s] * nx.dot(
- nx.dot(T[s].T, Cs[s]),
- T[s]
- ) for s in range(len(T))
- ])
- ppt = nx.outer(p, p)
-
- return nx.exp(tmpsum / ppt)
-
-
-def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
- r"""
- Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
-
- The function solves the following optimization problem:
-
- .. math::
- \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
-
- .. note:: This function is backend-compatible and will work on arrays
- from all compatible backends. But the algorithm uses the C++ CPU backend
- which can lead to copy overhead on GPU arrays.
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : str
- loss function used for the solver either 'square_loss' or 'kl_loss'
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
- armijo : bool, optional
- If True the step of the line-search is found via an armijo research. Else closed form is used.
- If there are convergence issues use False.
- G0: array-like, shape (ns,nt), optional
- If None the initial transport plan of the solver is pq^T.
- Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
- **kwargs : dict
- parameters can be directly passed to the ot.optim.cg solver
-
- Returns
- -------
- T : array-like, shape (`ns`, `nt`)
- Coupling between the two spaces that minimizes:
-
- :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
- log : dict
- Convergence information and loss.
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
- metric approach to object matching. Foundations of computational
- mathematics 11.4 (2011): 417-487.
-
- """
- p, q = list_to_array(p, q)
- p0, q0, C10, C20 = p, q, C1, C2
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20)
- else:
- G0_ = G0
- nx = get_backend(p0, q0, C10, C20, G0_)
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
- C1 = nx.to_numpy(C10)
- C2 = nx.to_numpy(C20)
-
- if G0 is None:
- G0 = p[:, None] * q[None, :]
- else:
- G0 = nx.to_numpy(G0_)
- # Check marginals of G0
- np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
- np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
-
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
-
- def f(G):
- return gwloss(constC, hC1, hC2, G)
-
- def df(G):
- return gwggrad(constC, hC1, hC2, G)
-
- if log:
- res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
- log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10)
- log['u'] = nx.from_numpy(log['u'], type_as=C10)
- log['v'] = nx.from_numpy(log['v'], type_as=C10)
- return nx.from_numpy(res, type_as=C10), log
- else:
- return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10)
-
-
-def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs):
- r"""
- Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
-
- The function solves the following optimization problem:
-
- .. math::
- GW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity
- matrices
-
- Note that when using backends, this loss function is differentiable wrt the
- marices and weights for quadratic loss using the gradients from [38]_.
-
- .. note:: This function is backend-compatible and will work on arrays
- from all compatible backends. But the algorithm uses the C++ CPU backend
- which can lead to copy overhead on GPU arrays.
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space.
- q : array-like, shape (nt,)
- Distribution in the target space.
- loss_fun : str
- loss function used for the solver either 'square_loss' or 'kl_loss'
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- record log if True
- armijo : bool, optional
- If True the step of the line-search is found via an armijo research. Else closed form is used.
- If there are convergence issues use False.
- G0: array-like, shape (ns,nt), optional
- If None the initial transport plan of the solver is pq^T.
- Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
-
- Returns
- -------
- gw_dist : float
- Gromov-Wasserstein distance
- log : dict
- convergence information and Coupling marix
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
- metric approach to object matching. Foundations of computational
- mathematics 11.4 (2011): 417-487.
-
- .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
- Graph Dictionary Learning, International Conference on Machine Learning
- (ICML), 2021.
-
- """
- p, q = list_to_array(p, q)
- p0, q0, C10, C20 = p, q, C1, C2
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20)
- else:
- G0_ = G0
- nx = get_backend(p0, q0, C10, C20, G0_)
-
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
- C1 = nx.to_numpy(C10)
- C2 = nx.to_numpy(C20)
-
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
-
- if G0 is None:
- G0 = p[:, None] * q[None, :]
- else:
- G0 = nx.to_numpy(G0_)
- # Check marginals of G0
- np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
- np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
-
- def f(G):
- return gwloss(constC, hC1, hC2, G)
-
- def df(G):
- return gwggrad(constC, hC1, hC2, G)
-
- T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
-
- T0 = nx.from_numpy(T, type_as=C10)
-
- log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10)
- log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10)
- log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10)
- log_gw['T'] = T0
-
- gw = log_gw['gw_dist']
-
- if loss_fun == 'square_loss':
- gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
- gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
- gC1 = nx.from_numpy(gC1, type_as=C10)
- gC2 = nx.from_numpy(gC2, type_as=C10)
- gw = nx.set_gradients(gw, (p0, q0, C10, C20),
- (log_gw['u'] - nx.mean(log_gw['u']),
- log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
-
- if log:
- return gw, log_gw
- else:
- return gw
-
-
-def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
- r"""
- Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
-
- .. math::
- \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
- \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
-
- \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
-
- \mathbf{\gamma} &\geq 0
-
- where :
-
- - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- - `L` is a loss function to account for the misfit between the similarity matrices
-
- .. note:: This function is backend-compatible and will work on arrays
- from all compatible backends. But the algorithm uses the C++ CPU backend
- which can lead to copy overhead on GPU arrays.
-
- The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
-
- Parameters
- ----------
- M : array-like, shape (ns, nt)
- Metric cost matrix between features across domains
- C1 : array-like, shape (ns, ns)
- Metric cost matrix representative of the structure in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix representative of the structure in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : str, optional
- Loss function used for the solver
- alpha : float, optional
- Trade-off parameter (0 < alpha < 1)
- armijo : bool, optional
- If True the step of the line-search is found via an armijo research. Else closed form is used.
- If there are convergence issues use False.
- G0: array-like, shape (ns,nt), optional
- If None the initial transport plan of the solver is pq^T.
- Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
- log : bool, optional
- record log if True
- **kwargs : dict
- parameters can be directly passed to the ot.optim.cg solver
-
- Returns
- -------
- gamma : array-like, shape (`ns`, `nt`)
- Optimal transportation matrix for the given parameters.
- log : dict
- Log dictionary return only if log==True in parameters.
-
-
- .. _references-fused-gromov-wasserstein:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
- and Courty Nicolas "Optimal Transport for structured data with
- application on graphs", International Conference on Machine Learning
- (ICML). 2019.
- """
- p, q = list_to_array(p, q)
- p0, q0, C10, C20, M0 = p, q, C1, C2, M
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20, M0)
- else:
- G0_ = G0
- nx = get_backend(p0, q0, C10, C20, M0, G0_)
-
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
- C1 = nx.to_numpy(C10)
- C2 = nx.to_numpy(C20)
- M = nx.to_numpy(M0)
- if G0 is None:
- G0 = p[:, None] * q[None, :]
- else:
- G0 = nx.to_numpy(G0_)
- # Check marginals of G0
- np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
- np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
-
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
-
- def f(G):
- return gwloss(constC, hC1, hC2, G)
-
- def df(G):
- return gwggrad(constC, hC1, hC2, G)
-
- if log:
- res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
- fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10)
- log['fgw_dist'] = fgw_dist
- log['u'] = nx.from_numpy(log['u'], type_as=C10)
- log['v'] = nx.from_numpy(log['v'], type_as=C10)
- return nx.from_numpy(res, type_as=C10), log
- else:
- return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10)
-
-
-def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs):
- r"""
- Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
-
- .. math::
- \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
-
- \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
-
- \mathbf{\gamma} &\geq 0
-
- where :
-
- - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
- - `L` is a loss function to account for the misfit between the similarity matrices
-
- The algorithm used for solving the problem is conditional gradient as
- discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
-
- .. note:: This function is backend-compatible and will work on arrays
- from all compatible backends. But the algorithm uses the C++ CPU backend
- which can lead to copy overhead on GPU arrays.
-
- Note that when using backends, this loss function is differentiable wrt the
- marices and weights for quadratic loss using the gradients from [38]_.
-
- Parameters
- ----------
- M : array-like, shape (ns, nt)
- Metric cost matrix between features across domains
- C1 : array-like, shape (ns, ns)
- Metric cost matrix representative of the structure in the source space.
- C2 : array-like, shape (nt, nt)
- Metric cost matrix representative of the structure in the target space.
- p : array-like, shape (ns,)
- Distribution in the source space.
- q : array-like, shape (nt,)
- Distribution in the target space.
- loss_fun : str, optional
- Loss function used for the solver.
- alpha : float, optional
- Trade-off parameter (0 < alpha < 1)
- armijo : bool, optional
- If True the step of the line-search is found via an armijo research.
- Else closed form is used. If there are convergence issues use False.
- G0: array-like, shape (ns,nt), optional
- If None the initial transport plan of the solver is pq^T.
- Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
- log : bool, optional
- Record log if True.
- **kwargs : dict
- Parameters can be directly passed to the ot.optim.cg solver.
-
- Returns
- -------
- fgw-distance : float
- Fused gromov wasserstein distance for the given parameters.
- log : dict
- Log dictionary return only if log==True in parameters.
-
-
- .. _references-fused-gromov-wasserstein2:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
- and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
-
- .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
- Graph Dictionary Learning, International Conference on Machine Learning
- (ICML), 2021.
- """
- p, q = list_to_array(p, q)
-
- p0, q0, C10, C20, M0 = p, q, C1, C2, M
- if G0 is None:
- nx = get_backend(p0, q0, C10, C20, M0)
- else:
- G0_ = G0
- nx = get_backend(p0, q0, C10, C20, M0, G0_)
-
- p = nx.to_numpy(p)
- q = nx.to_numpy(q)
- C1 = nx.to_numpy(C10)
- C2 = nx.to_numpy(C20)
- M = nx.to_numpy(M0)
-
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
-
- if G0 is None:
- G0 = p[:, None] * q[None, :]
- else:
- G0 = nx.to_numpy(G0_)
- # Check marginals of G0
- np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
- np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
-
- def f(G):
- return gwloss(constC, hC1, hC2, G)
-
- def df(G):
- return gwggrad(constC, hC1, hC2, G)
-
- T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
-
- fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10)
-
- T0 = nx.from_numpy(T, type_as=C10)
-
- log_fgw['fgw_dist'] = fgw_dist
- log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10)
- log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10)
- log_fgw['T'] = T0
-
- if loss_fun == 'square_loss':
- gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
- gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
- gC1 = nx.from_numpy(gC1, type_as=C10)
- gC2 = nx.from_numpy(gC2, type_as=C10)
- fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
- (log_fgw['u'] - nx.mean(log_fgw['u']),
- log_fgw['v'] - nx.mean(log_fgw['v']),
- alpha * gC1, alpha * gC2, (1 - alpha) * T0))
-
- if log:
- return fgw_dist, log_fgw
- else:
- return fgw_dist
-
-
-def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
- nb_samples_p=None, nb_samples_q=None, std=True, random_state=None):
- r"""
- Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
- with a fixed transport plan :math:`\mathbf{T}`.
-
- The function gives an unbiased approximation of the following equation:
-
- .. math::
-
- GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - `L` : Loss function to account for the misfit between the similarity matrices
- - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}`
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
- Loss function used for the distance, the transport plan does not depend on the loss function
- T : csr or array-like, shape (ns, nt)
- Transport plan matrix, either a sparse csr or a dense matrix
- nb_samples_p : int, optional
- `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}`
- nb_samples_q : int, optional
- `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first
- std : bool, optional
- Standard deviation associated with the prediction of the gromov-wasserstein cost
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- : float
- Gromov-wasserstein cost
-
- References
- ----------
- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
- "Sampled Gromov Wasserstein."
- Machine Learning Journal (MLJ). 2021.
-
- """
- C1, C2, p, q = list_to_array(C1, C2, p, q)
- nx = get_backend(C1, C2, p, q)
-
- generator = check_random_state(random_state)
-
- len_p = p.shape[0]
- len_q = q.shape[0]
-
- # It is always better to sample from the biggest distribution first.
- if len_p < len_q:
- p, q = q, p
- len_p, len_q = len_q, len_p
- C1, C2 = C2, C1
- T = T.T
-
- if nb_samples_p is None:
- if nx.issparse(T):
- # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced
- nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p)
- else:
- nb_samples_p = len_p
- else:
- # The number of sample along the first dimension is without replacement.
- nb_samples_p = min(nb_samples_p, len_p)
- if nb_samples_q is None:
- nb_samples_q = 1
- if std:
- nb_samples_q = max(2, nb_samples_q)
-
- index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
- index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
-
- index_i = generator.choice(
- len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
- )
- index_j = generator.choice(
- len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
- )
-
- for i in range(nb_samples_p):
- if nx.issparse(T):
- T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,))
- T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,))
- else:
- T_indexi = T[index_i[i], :]
- T_indexj = T[index_j[i], :]
- # For each of the row sampled, the column is sampled.
- index_k[i] = generator.choice(
- len_q,
- size=nb_samples_q,
- p=nx.to_numpy(T_indexi / nx.sum(T_indexi)),
- replace=True
- )
- index_l[i] = generator.choice(
- len_q,
- size=nb_samples_q,
- p=nx.to_numpy(T_indexj / nx.sum(T_indexj)),
- replace=True
- )
-
- list_value_sample = nx.stack([
- loss_fun(
- C1[np.ix_(index_i, index_j)],
- C2[np.ix_(index_k[:, n], index_l[:, n])]
- ) for n in range(nb_samples_q)
- ], axis=2)
-
- if std:
- std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5
- return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p)
- else:
- return nx.mean(list_value_sample)
-
-
-def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
- alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None):
- r"""
- Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe.
- This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations.
-
- The function solves the following optimization problem:
-
- .. math::
- \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
-
- \mathbf{T}^T \mathbf{1} &= \mathbf{q}
-
- \mathbf{T} &\geq 0
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
- Loss function used for the distance, the transport plan does not depend on the loss function
- alpha : float
- Step of the Frank-Wolfe algorithm, should be between 0 and 1
- max_iter : int, optional
- Max number of iterations
- threshold_plan : float, optional
- Deleting very small values in the transport plan. If above zero, it violates the marginal constraints.
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- Gives the distance estimated and the standard deviation
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- T : array-like, shape (`ns`, `nt`)
- Optimal coupling between the two spaces
-
- References
- ----------
- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
- "Sampled Gromov Wasserstein."
- Machine Learning Journal (MLJ). 2021.
-
- """
- C1, C2, p, q = list_to_array(C1, C2, p, q)
- nx = get_backend(C1, C2, p, q)
-
- len_p = p.shape[0]
- len_q = q.shape[0]
-
- generator = check_random_state(random_state)
-
- index = np.zeros(2, dtype=int)
-
- # Initialize with default marginal
- index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
- index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q))
- T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
-
- best_gw_dist_estimated = np.inf
- for cpt in range(max_iter):
- index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
- T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
- index[1] = generator.choice(
- len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
- )
-
- if alpha == 1:
- T = nx.tocsr(
- emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
- )
- else:
- new_T = nx.tocsr(
- emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
- )
- T = (1 - alpha) * T + alpha * new_T
- # To limit the number of non 0, the values below the threshold are set to 0.
- T = nx.eliminate_zeros(T, threshold=threshold_plan)
-
- if cpt % 10 == 0 or cpt == (max_iter - 1):
- gw_dist_estimated = GW_distance_estimation(
- C1=C1, C2=C2, loss_fun=loss_fun,
- p=p, q=q, T=T, std=False, random_state=generator
- )
-
- if gw_dist_estimated < best_gw_dist_estimated:
- best_gw_dist_estimated = gw_dist_estimated
- best_T = nx.copy(T)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated))
-
- if log:
- log = {}
- log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
- C1=C1, C2=C2, loss_fun=loss_fun,
- p=p, q=q, T=best_T, random_state=generator
- )
- return best_T, log
- return best_T
-
-
-def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun,
- nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False,
- random_state=None):
- r"""
- Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe.
- This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver.
-
- The function solves the following optimization problem:
-
- .. math::
- \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
- L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
-
- s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
-
- \mathbf{T}^T \mathbf{1} &= \mathbf{q}
-
- \mathbf{T} &\geq 0
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
- Loss function used for the distance, the transport plan does not depend on the loss function
- nb_samples_grad : int
- Number of samples to approximate the gradient
- epsilon : float
- Weight of the Kullback-Leibler regularization
- max_iter : int, optional
- Max number of iterations
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- Gives the distance estimated and the standard deviation
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- T : array-like, shape (`ns`, `nt`)
- Optimal coupling between the two spaces
-
- References
- ----------
- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
- "Sampled Gromov Wasserstein."
- Machine Learning Journal (MLJ). 2021.
-
- """
- C1, C2, p, q = list_to_array(C1, C2, p, q)
- nx = get_backend(C1, C2, p, q)
-
- len_p = p.shape[0]
- len_q = q.shape[0]
-
- generator = check_random_state(random_state)
-
- # The most natural way to define nb_sample is with a simple integer.
- if isinstance(nb_samples_grad, int):
- if nb_samples_grad > len_p:
- # As the sampling along the first dimension is done without replacement, the rest is reported to the second
- # dimension.
- nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p
- else:
- nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1
- else:
- nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad
- T = nx.outer(p, q)
- # continue_loop allows to stop the loop if there is several successive small modification of T.
- continue_loop = 0
-
- # The gradient of GW is more complex if the two matrices are not symmetric.
- C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
-
- for cpt in range(max_iter):
- index0 = generator.choice(
- len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False
- )
- Lik = 0
- for i, index0_i in enumerate(index0):
- index1 = generator.choice(
- len_q, size=nb_samples_grad_q,
- p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])),
- replace=False
- )
- # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
- if (not C_are_symmetric) and generator.rand(1) > 0.5:
- Lik += nx.mean(loss_fun(
- C1[:, [index0[i]] * nb_samples_grad_q][:, None, :],
- C2[:, index1][None, :, :]
- ), axis=2)
- else:
- Lik += nx.mean(loss_fun(
- C1[[index0[i]] * nb_samples_grad_q, :][:, :, None],
- C2[index1, :][:, None, :]
- ), axis=0)
-
- max_Lik = nx.max(Lik)
- if max_Lik == 0:
- continue
- # This division by the max is here to facilitate the choice of epsilon.
- Lik /= max_Lik
-
- if epsilon > 0:
- # Set to infinity all the numbers below exp(-200) to avoid log of 0.
- log_T = nx.log(nx.clip(T, np.exp(-200), 1))
- log_T = nx.where(log_T == -200, -np.inf, log_T)
- Lik = Lik - epsilon * log_T
-
- try:
- new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon)
- except (RuntimeWarning, UserWarning):
- print("Warning catched in Sinkhorn: Return last stable T")
- break
- else:
- new_T = emd(a=p, b=q, M=Lik)
-
- change_T = nx.mean((T - new_T) ** 2)
- if change_T <= 10e-20:
- continue_loop += 1
- if continue_loop > 100: # Number max of low modifications of T
- T = nx.copy(new_T)
- break
- else:
- continue_loop = 0
-
- if verbose and cpt % 10 == 0:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, change_T))
- T = nx.copy(new_T)
-
- if log:
- log = {}
- log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
- C1=C1, C2=C2, loss_fun=loss_fun,
- p=p, q=q, T=T, random_state=generator
- )
- return T, log
- return T
-
-
-def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False):
- r"""
- Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
-
- The function solves the following optimization problem:
-
- .. math::
- \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
-
- s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
-
- \mathbf{T}^T \mathbf{1} &= \mathbf{q}
-
- \mathbf{T} &\geq 0
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
- - `H`: entropy
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : string
- Loss function used for the solver either 'square_loss' or 'kl_loss'
- epsilon : float
- Regularization term >0
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- Record log if True.
-
- Returns
- -------
- T : array-like, shape (`ns`, `nt`)
- Optimal coupling between the two spaces
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- C1, C2, p, q = list_to_array(C1, C2, p, q)
- nx = get_backend(C1, C2, p, q)
-
- T = nx.outer(p, q)
-
- constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
-
- cpt = 0
- err = 1
-
- if log:
- log = {'err': []}
-
- while (err > tol and cpt < max_iter):
-
- Tprev = T
-
- # compute the gradient
- tens = gwggrad(constC, hC1, hC2, T)
-
- T = sinkhorn(p, q, tens, epsilon, method='sinkhorn')
-
- if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all
- # the 10th iterations
- err = nx.norm(T - Tprev)
-
- if log:
- log['err'].append(err)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format(
- 'It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- cpt += 1
-
- if log:
- log['gw_dist'] = gwloss(constC, hC1, hC2, T)
- return T, log
- else:
- return T
-
-
-def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False):
- r"""
- Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
-
- The function solves the following optimization problem:
-
- .. math::
- GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})
- \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
-
- Where :
-
- - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
- - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
- - :math:`\mathbf{p}`: distribution in the source space
- - :math:`\mathbf{q}`: distribution in the target space
- - `L`: loss function to account for the misfit between the similarity matrices
- - `H`: entropy
-
- Parameters
- ----------
- C1 : array-like, shape (ns, ns)
- Metric cost matrix in the source space
- C2 : array-like, shape (nt, nt)
- Metric cost matrix in the target space
- p : array-like, shape (ns,)
- Distribution in the source space
- q : array-like, shape (nt,)
- Distribution in the target space
- loss_fun : str
- Loss function used for the solver either 'square_loss' or 'kl_loss'
- epsilon : float
- Regularization term >0
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations
- log : bool, optional
- Record log if True.
-
- Returns
- -------
- gw_dist : float
- Gromov-Wasserstein distance
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- gw, logv = entropic_gromov_wasserstein(
- C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True)
-
- logv['T'] = gw
-
- if log:
- return logv['gw_dist'], logv
- else:
- return logv['gw_dist']
-
-
-def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
- r"""
- Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
-
- The function solves the following optimization problem:
-
- .. math::
-
- \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
-
- Where :
-
- - :math:`\mathbf{C}_s`: metric cost matrix
- - :math:`\mathbf{p}_s`: distribution
-
- Parameters
- ----------
- N : int
- Size of the targeted barycenter
- Cs : list of S array-like of shape (ns,ns)
- Metric cost matrices
- ps : list of S array-like of shape (ns,)
- Sample weights in the `S` spaces
- p : array-like, shape(N,)
- Weights in the targeted barycenter
- lambdas : list of float
- List of the `S` spaces' weights.
- loss_fun : callable
- Tensor-matrix multiplication function based on specific loss function.
- update : callable
- function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
- :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
- calculated at each iteration
- epsilon : float
- Regularization term >0
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0)
- verbose : bool, optional
- Print information along iterations.
- log : bool, optional
- Record log if True.
- init_C : bool | array-like, shape (N, N)
- Random initial value for the :math:`\mathbf{C}` matrix provided by user.
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- C : array-like, shape (`N`, `N`)
- Similarity matrix in the barycenter space (permutated arbitrarily)
- log : dict
- Log dictionary of error during iterations. Return only if `log=True` in parameters.
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
- """
- Cs = list_to_array(*Cs)
- ps = list_to_array(*ps)
- p = list_to_array(p)
- nx = get_backend(*Cs, *ps, p)
-
- S = len(Cs)
-
- # Initialization of C : random SPD matrix (if not provided by user)
- if init_C is None:
- generator = check_random_state(random_state)
- xalea = generator.randn(N, 2)
- C = dist(xalea, xalea)
- C /= C.max()
- C = nx.from_numpy(C, type_as=p)
- else:
- C = init_C
-
- cpt = 0
- err = 1
-
- error = []
-
- while (err > tol) and (cpt < max_iter):
- Cprev = C
-
- T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
- max_iter, 1e-4, verbose, log=False) for s in range(S)]
- if loss_fun == 'square_loss':
- C = update_square_loss(p, lambdas, T, Cs)
-
- elif loss_fun == 'kl_loss':
- C = update_kl_loss(p, lambdas, T, Cs)
-
- if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all
- # the 10th iterations
- err = nx.norm(C - Cprev)
- error.append(err)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format(
- 'It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- cpt += 1
-
- if log:
- return C, {"err": error}
- else:
- return C
-
-
-def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
- max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
- r"""
- Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
-
- The function solves the following optimization problem with block coordinate descent:
-
- .. math::
-
- \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
-
- Where :
-
- - :math:`\mathbf{C}_s`: metric cost matrix
- - :math:`\mathbf{p}_s`: distribution
-
- Parameters
- ----------
- N : int
- Size of the targeted barycenter
- Cs : list of S array-like of shape (ns, ns)
- Metric cost matrices
- ps : list of S array-like of shape (ns,)
- Sample weights in the `S` spaces
- p : array-like, shape (N,)
- Weights in the targeted barycenter
- lambdas : list of float
- List of the `S` spaces' weights
- loss_fun : callable
- tensor-matrix multiplication function based on specific loss function
- update : callable
- function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
- :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
- calculated at each iteration
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0).
- verbose : bool, optional
- Print information along iterations.
- log : bool, optional
- Record log if True.
- init_C : bool | array-like, shape(N,N)
- Random initial value for the :math:`\mathbf{C}` matrix provided by user.
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- C : array-like, shape (`N`, `N`)
- Similarity matrix in the barycenter space (permutated arbitrarily)
- log : dict
- Log dictionary of error during iterations. Return only if `log=True` in parameters.
-
- References
- ----------
- .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
- "Gromov-Wasserstein averaging of kernel and distance matrices."
- International Conference on Machine Learning (ICML). 2016.
-
- """
- Cs = list_to_array(*Cs)
- ps = list_to_array(*ps)
- p = list_to_array(p)
- nx = get_backend(*Cs, *ps, p)
-
- S = len(Cs)
-
- # Initialization of C : random SPD matrix (if not provided by user)
- if init_C is None:
- generator = check_random_state(random_state)
- xalea = generator.randn(N, 2)
- C = dist(xalea, xalea)
- C /= C.max()
- C = nx.from_numpy(C, type_as=p)
- else:
- C = init_C
-
- cpt = 0
- err = 1
-
- error = []
-
- while(err > tol and cpt < max_iter):
- Cprev = C
-
- T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
- numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=False) for s in range(S)]
- if loss_fun == 'square_loss':
- C = update_square_loss(p, lambdas, T, Cs)
-
- elif loss_fun == 'kl_loss':
- C = update_kl_loss(p, lambdas, T, Cs)
-
- if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all
- # the 10th iterations
- err = nx.norm(C - Cprev)
- error.append(err)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format(
- 'It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- cpt += 1
-
- if log:
- return C, {"err": error}
- else:
- return C
-
-
-def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
- p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
- verbose=False, log=False, init_C=None, init_X=None, random_state=None):
- r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
-
- Parameters
- ----------
- N : int
- Desired number of samples of the target barycenter
- Ys: list of array-like, each element has shape (ns,d)
- Features of all samples
- Cs : list of array-like, each element has shape (ns,ns)
- Structure matrices of all samples
- ps : list of array-like, each element has shape (ns,)
- Masses of all samples.
- lambdas : list of float
- List of the `S` spaces' weights
- alpha : float
- Alpha parameter for the fgw distance
- fixed_structure : bool
- Whether to fix the structure of the barycenter during the updates
- fixed_features : bool
- Whether to fix the feature of the barycenter during the updates
- loss_fun : str
- Loss function used for the solver either 'square_loss' or 'kl_loss'
- max_iter : int, optional
- Max number of iterations
- tol : float, optional
- Stop threshold on error (>0).
- verbose : bool, optional
- Print information along iterations.
- log : bool, optional
- Record log if True.
- init_C : array-like, shape (N,N), optional
- Initialization for the barycenters' structure matrix. If not set
- a random init is used.
- init_X : array-like, shape (N,d), optional
- Initialization for the barycenters' features. If not set a
- random init is used.
- random_state : int or RandomState instance, optional
- Fix the seed for reproducibility
-
- Returns
- -------
- X : array-like, shape (`N`, `d`)
- Barycenters' features
- C : array-like, shape (`N`, `N`)
- Barycenters' structure matrix
- log : dict
- Only returned when log=True. It contains the keys:
-
- - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
- - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
-
-
- .. _references-fgw-barycenters:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
- and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
- """
- Cs = list_to_array(*Cs)
- ps = list_to_array(*ps)
- Ys = list_to_array(*Ys)
- p = list_to_array(p)
- nx = get_backend(*Cs, *Ys, *ps)
-
- S = len(Cs)
- d = Ys[0].shape[1] # dimension on the node features
- if p is None:
- p = nx.ones(N, type_as=Cs[0]) / N
-
- if fixed_structure:
- if init_C is None:
- raise UndefinedParameter('If C is fixed it must be initialized')
- else:
- C = init_C
- else:
- if init_C is None:
- generator = check_random_state(random_state)
- xalea = generator.randn(N, 2)
- C = dist(xalea, xalea)
- C = nx.from_numpy(C, type_as=ps[0])
- else:
- C = init_C
-
- if fixed_features:
- if init_X is None:
- raise UndefinedParameter('If X is fixed it must be initialized')
- else:
- X = init_X
- else:
- if init_X is None:
- X = nx.zeros((N, d), type_as=ps[0])
- else:
- X = init_X
-
- T = [nx.outer(p, q) for q in ps]
-
- Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
-
- cpt = 0
- err_feature = 1
- err_structure = 1
-
- if log:
- log_ = {}
- log_['err_feature'] = []
- log_['err_structure'] = []
- log_['Ts_iter'] = []
-
- while((err_feature > tol or err_structure > tol) and cpt < max_iter):
- Cprev = C
- Xprev = X
-
- if not fixed_features:
- Ys_temp = [y.T for y in Ys]
- X = update_feature_matrix(lambdas, Ys_temp, T, p).T
-
- Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
-
- if not fixed_structure:
- if loss_fun == 'square_loss':
- T_temp = [t.T for t in T]
- C = update_structure_matrix(p, lambdas, T_temp, Cs)
-
- T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
- numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
-
- # T is N,ns
- err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
- err_structure = nx.norm(C - Cprev)
- if log:
- log_['err_feature'].append(err_feature)
- log_['err_structure'].append(err_structure)
- log_['Ts_iter'].append(T)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format(
- 'It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err_structure))
- print('{:5d}|{:8e}|'.format(cpt, err_feature))
-
- cpt += 1
-
- if log:
- log_['T'] = T # from target to Ys
- log_['p'] = p
- log_['Ms'] = Ms
-
- if log:
- return X, C, log_
- else:
- return X, C
-
-
-def update_structure_matrix(p, lambdas, T, Cs):
- r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
-
- It is calculated at each iteration
-
- Parameters
- ----------
- p : array-like, shape (N,)
- Masses in the targeted barycenter.
- lambdas : list of float
- List of the `S` spaces' weights.
- T : list of S array-like of shape (ns, N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
- Cs : list of S array-like, shape (ns, ns)
- Metric cost matrices.
-
- Returns
- -------
- C : array-like, shape (`nt`, `nt`)
- Updated :math:`\mathbf{C}` matrix.
- """
- p = list_to_array(p)
- T = list_to_array(*T)
- Cs = list_to_array(*Cs)
- nx = get_backend(*Cs, *T, p)
-
- tmpsum = sum([
- lambdas[s] * nx.dot(
- nx.dot(T[s].T, Cs[s]),
- T[s]
- ) for s in range(len(T))
- ])
- ppt = nx.outer(p, p)
- return tmpsum / ppt
-
-
-def update_feature_matrix(lambdas, Ys, Ts, p):
- r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
-
-
- See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
- in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
-
- Parameters
- ----------
- p : array-like, shape (N,)
- masses in the targeted barycenter
- lambdas : list of float
- List of the `S` spaces' weights
- Ts : list of S array-like, shape (ns,N)
- The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
- Ys : list of S array-like, shape (d,ns)
- The features.
-
- Returns
- -------
- X : array-like, shape (`d`, `N`)
-
-
- .. _references-update-feature-matrix:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
- """
- p = list_to_array(p)
- Ts = list_to_array(*Ts)
- Ys = list_to_array(*Ys)
- nx = get_backend(*Ys, *Ts, p)
-
- p = 1. / p
- tmpsum = sum([
- lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
- for s in range(len(Ts))
- ])
- return tmpsum
-
-
-def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
- tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
- r"""
- Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
-
- .. math::
- \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2
-
- such that, :math:`\forall s \leq S` :
-
- - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
- - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
- - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
- - reg is the regularization coefficient.
-
- The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38]
-
- Parameters
- ----------
- Cs : list of S symmetric array-like, shape (ns, ns)
- List of Metric/Graph cost matrices of variable size (ns, ns).
- D: int
- Number of dictionary atoms to learn
- nt: int
- Number of samples within each dictionary atoms
- reg : float, optional
- Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
- ps : list of S array-like, shape (ns,), optional
- Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
- q : array-like, shape (nt,), optional
- Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
- epochs: int, optional
- Number of epochs used to learn the dictionary. Default is 32.
- batch_size: int, optional
- Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
- learning_rate: float, optional
- Learning rate used for the stochastic gradient descent. Default is 1.
- Cdict_init: list of D array-like with shape (nt, nt), optional
- Used to initialize the dictionary.
- If set to None (Default), the dictionary will be initialized randomly.
- Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
- projection: str , optional
- If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
- Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
- log: bool, optional
- If set to True, losses evolution by batches and epochs are tracked. Default is False.
- use_adam_optimizer: bool, optional
- If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
- Else perform SGD with fixed learning rate. Default is True.
- tol_outer : float, optional
- Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
- tol_inner : float, optional
- Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
- max_iter_outer : int, optional
- Maximum number of iterations for the BCD. Default is 20.
- max_iter_inner : int, optional
- Maximum number of iterations for the Conjugate Gradient. Default is 200.
- verbose : bool, optional
- Print the reconstruction loss every epoch. Default is False.
-
- Returns
- -------
-
- Cdict_best_state : D array-like, shape (D,nt,nt)
- Metric/Graph cost matrices composing the dictionary.
- The dictionary leading to the best loss over an epoch is saved and returned.
- log: dict
- If use_log is True, contains loss evolutions by batches and epochs.
- References
- -------
-
- ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
- "Online Graph Dictionary Learning"
- International Conference on Machine Learning (ICML). 2021.
- """
- # Handle backend of non-optional arguments
- Cs0 = Cs
- nx = get_backend(*Cs0)
- Cs = [nx.to_numpy(C) for C in Cs0]
- dataset_size = len(Cs)
- # Handle backend of optional arguments
- if ps is None:
- ps = [unif(C.shape[0]) for C in Cs]
- else:
- ps = [nx.to_numpy(p) for p in ps]
- if q is None:
- q = unif(nt)
- else:
- q = nx.to_numpy(q)
- if Cdict_init is None:
- # Initialize randomly structures of dictionary atoms based on samples
- dataset_means = [C.mean() for C in Cs]
- Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
- else:
- Cdict = nx.to_numpy(Cdict_init).copy()
- assert Cdict.shape == (D, nt, nt)
-
- if 'symmetric' in projection:
- Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
- if 'nonnegative' in projection:
- Cdict[Cdict < 0.] = 0
- if use_adam_optimizer:
- adam_moments = _initialize_adam_optimizer(Cdict)
-
- log = {'loss_batches': [], 'loss_epochs': []}
- const_q = q[:, None] * q[None, :]
- Cdict_best_state = Cdict.copy()
- loss_best_state = np.inf
- if batch_size > dataset_size:
- batch_size = dataset_size
- iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
-
- for epoch in range(epochs):
- cumulated_loss_over_epoch = 0.
-
- for _ in range(iter_by_epoch):
- # batch sampling
- batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
- cumulated_loss_over_batch = 0.
- unmixings = np.zeros((batch_size, D))
- Cs_embedded = np.zeros((batch_size, nt, nt))
- Ts = [None] * batch_size
-
- for batch_idx, C_idx in enumerate(batch):
- # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
- unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing(
- Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner,
- max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
- )
- cumulated_loss_over_batch += current_loss
- cumulated_loss_over_epoch += cumulated_loss_over_batch
-
- if use_log:
- log['loss_batches'].append(cumulated_loss_over_batch)
-
- # Stochastic projected gradient step over dictionary atoms
- grad_Cdict = np.zeros_like(Cdict)
- for batch_idx, C_idx in enumerate(batch):
- shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
- grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
- grad_Cdict *= 2 / batch_size
- if use_adam_optimizer:
- Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments)
- else:
- Cdict -= learning_rate * grad_Cdict
- if 'symmetric' in projection:
- Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
- if 'nonnegative' in projection:
- Cdict[Cdict < 0.] = 0.
-
- if use_log:
- log['loss_epochs'].append(cumulated_loss_over_epoch)
- if loss_best_state > cumulated_loss_over_epoch:
- loss_best_state = cumulated_loss_over_epoch
- Cdict_best_state = Cdict.copy()
- if verbose:
- print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
-
- return nx.from_numpy(Cdict_best_state), log
-
-
-def _initialize_adam_optimizer(variable):
-
- # Initialization for our numpy implementation of adam optimizer
- atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor
- atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor
- atoms_adam_count = 1
-
- return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count}
-
-
-def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09):
-
- adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad
- adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2)
- unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count'])
- unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count'])
- variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps)
- adam_moments['count'] += 1
-
- return variable, adam_moments
-
-
-def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
- r"""
- Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`.
-
- .. math::
- \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
-
- such that:
-
- - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt.
- - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
- - reg is the regularization coefficient.
-
- The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 1.
-
- Parameters
- ----------
- C : array-like, shape (ns, ns)
- Metric/Graph cost matrix.
- Cdict : D array-like, shape (D,nt,nt)
- Metric/Graph cost matrices composing the dictionary on which to embed C.
- reg : float, optional.
- Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
- p : array-like, shape (ns,), optional
- Distribution in the source space C. Default is None and corresponds to uniform distribution.
- q : array-like, shape (nt,), optional
- Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
- tol_outer : float, optional
- Solver precision for the BCD algorithm.
- tol_inner : float, optional
- Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
- max_iter_outer : int, optional
- Maximum number of iterations for the BCD. Default is 20.
- max_iter_inner : int, optional
- Maximum number of iterations for the Conjugate Gradient. Default is 200.
-
- Returns
- -------
- w: array-like, shape (D,)
- gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary.
- Cembedded: array-like, shape (nt,nt)
- embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
- T: array-like (ns, nt)
- Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})`
- current_loss: float
- reconstruction error
- References
- -------
-
- ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
- "Online Graph Dictionary Learning"
- International Conference on Machine Learning (ICML). 2021.
- """
- C0, Cdict0 = C, Cdict
- nx = get_backend(C0, Cdict0)
- C = nx.to_numpy(C0)
- Cdict = nx.to_numpy(Cdict0)
- if p is None:
- p = unif(C.shape[0])
- else:
- p = nx.to_numpy(p)
-
- if q is None:
- q = unif(Cdict.shape[-1])
- else:
- q = nx.to_numpy(q)
-
- T = p[:, None] * q[None, :]
- D = len(Cdict)
-
- w = unif(D) # Initialize uniformly the unmixing w
- Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
-
- const_q = q[:, None] * q[None, :]
- # Trackers for BCD convergence
- convergence_criterion = np.inf
- current_loss = 10**15
- outer_count = 0
-
- while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
- previous_loss = current_loss
- # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
- T, log = gromov_wasserstein(C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, log=True, armijo=False, **kwargs)
- current_loss = log['gw_dist']
- if reg != 0:
- current_loss -= reg * np.sum(w**2)
-
- # 2. Solve linear unmixing problem over w with a fixed transport plan T
- w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing(
- C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T,
- starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs
- )
-
- if previous_loss != 0:
- convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
- else: # handle numerical issues around 0
- convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
- outer_count += 1
-
- return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
-
-
-def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs):
- r"""
- Returns for a fixed admissible transport plan,
- the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})`
-
- .. math::
- \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2
-
-
- Such that:
-
- - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points.
- - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
- - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`.
- - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`.
- - reg is the regularization coefficient.
-
- The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]
-
- Parameters
- ----------
-
- C : array-like, shape (ns, ns)
- Metric/Graph cost matrix.
- Cdict : list of D array-like, shape (nt,nt)
- Metric/Graph cost matrices composing the dictionary on which to embed C.
- Each matrix in the dictionary must have the same size (nt,nt).
- Cembedded: array-like, shape (nt,nt)
- Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
- w: array-like, shape (D,)
- Linear unmixing of the input structure onto the dictionary
- const_q: array-like, shape (nt,nt)
- product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
- T: array-like, shape (ns,nt)
- fixed transport plan between the input structure and its representation in the dictionary.
- p : array-like, shape (ns,)
- Distribution in the source space.
- q : array-like, shape (nt,)
- Distribution in the embedding space depicted by the dictionary.
- reg : float, optional.
- Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
-
- Returns
- -------
- w: ndarray (D,)
- optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing.
- """
- convergence_criterion = np.inf
- current_loss = starting_loss
- count = 0
- const_TCT = np.transpose(C.dot(T)).dot(T)
-
- while (convergence_criterion > tol) and (count < max_iter):
-
- previous_loss = current_loss
- # 1) Compute gradient at current point w
- grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
- grad_w -= 2 * reg * w
-
- # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
- min_ = np.min(grad_w)
- x = (grad_w == min_).astype(np.float64)
- x /= np.sum(x)
-
- # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
- gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg)
-
- # 4) Updates: w <-- (1-gamma)*w + gamma*x
- w += gamma * (x - w)
- Cembedded += gamma * Cembedded_diff
- current_loss += a * (gamma**2) + b * gamma
-
- if previous_loss != 0: # not that the loss can be negative if reg >0
- convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
- else: # handle numerical issues around 0
- convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
- count += 1
-
- return w, Cembedded, current_loss
-
-
-def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs):
- r"""
- Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing
- .. math::
- \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2
-
-
- Such that:
-
- - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
-
- Parameters
- ----------
-
- w : array-like, shape (D,)
- Unmixing.
- grad_w : array-like, shape (D, D)
- Gradient of the reconstruction loss with respect to w.
- x: array-like, shape (D,)
- Conditional gradient direction.
- Cdict : list of D array-like, shape (nt,nt)
- Metric/Graph cost matrices composing the dictionary on which to embed C.
- Each matrix in the dictionary must have the same size (nt,nt).
- Cembedded: array-like, shape (nt,nt)
- Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
- const_q: array-like, shape (nt,nt)
- product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
- const_TCT: array-like, shape (nt, nt)
- :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
- Returns
- -------
- gamma: float
- Optimal value for the line-search step
- a: float
- Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
- b: float
- Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
- Cembedded_diff: numpy array, shape (nt, nt)
- Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
- reg : float, optional.
- Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`.
- """
-
- # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
- Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
- Cembedded_diff = Cembedded_x - Cembedded
- trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
- trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
- a = trace_diffx - trace_diffw
- b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
- if reg != 0:
- a -= reg * np.sum((x - w)**2)
- b -= 2 * reg * np.sum(w * (x - w))
-
- if a > 0:
- gamma = min(1, max(0, - b / (2 * a)))
- elif a + b < 0:
- gamma = 1
- else:
- gamma = 0
-
- return gamma, a, b, Cembedded_diff
-
-
-def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
- Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
- tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
- r"""
- Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
-
- .. math::
- \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2
-
-
- Such that :math:`\forall s \leq S` :
-
- - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
- - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
- - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
- - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
- - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
- - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
- - reg is the regularization coefficient.
-
-
- The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38]
-
- Parameters
- ----------
- Cs : list of S symmetric array-like, shape (ns, ns)
- List of Metric/Graph cost matrices of variable size (ns,ns).
- Ys : list of S array-like, shape (ns, d)
- List of feature matrix of variable size (ns,d) with d fixed.
- D: int
- Number of dictionary atoms to learn
- nt: int
- Number of samples within each dictionary atoms
- alpha : float
- Trade-off parameter of Fused Gromov-Wasserstein
- reg : float, optional
- Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
- ps : list of S array-like, shape (ns,), optional
- Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
- q : array-like, shape (nt,), optional
- Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
- epochs: int, optional
- Number of epochs used to learn the dictionary. Default is 32.
- batch_size: int, optional
- Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
- learning_rate_C: float, optional
- Learning rate used for the stochastic gradient descent on Cdict. Default is 1.
- learning_rate_Y: float, optional
- Learning rate used for the stochastic gradient descent on Ydict. Default is 1.
- Cdict_init: list of D array-like with shape (nt, nt), optional
- Used to initialize the dictionary structures Cdict.
- If set to None (Default), the dictionary will be initialized randomly.
- Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
- Ydict_init: list of D array-like with shape (nt, d), optional
- Used to initialize the dictionary features Ydict.
- If set to None, the dictionary features will be initialized randomly.
- Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features.
- projection: str, optional
- If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
- Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
- log: bool, optional
- If set to True, losses evolution by batches and epochs are tracked. Default is False.
- use_adam_optimizer: bool, optional
- If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
- Else perform SGD with fixed learning rate. Default is True.
- tol_outer : float, optional
- Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
- tol_inner : float, optional
- Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
- max_iter_outer : int, optional
- Maximum number of iterations for the BCD. Default is 20.
- max_iter_inner : int, optional
- Maximum number of iterations for the Conjugate Gradient. Default is 200.
- verbose : bool, optional
- Print the reconstruction loss every epoch. Default is False.
-
- Returns
- -------
-
- Cdict_best_state : D array-like, shape (D,nt,nt)
- Metric/Graph cost matrices composing the dictionary.
- The dictionary leading to the best loss over an epoch is saved and returned.
- Ydict_best_state : D array-like, shape (D,nt,d)
- Feature matrices composing the dictionary.
- The dictionary leading to the best loss over an epoch is saved and returned.
- log: dict
- If use_log is True, contains loss evolutions by batches and epoches.
- References
- -------
-
- ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
- "Online Graph Dictionary Learning"
- International Conference on Machine Learning (ICML). 2021.
- """
- Cs0, Ys0 = Cs, Ys
- nx = get_backend(*Cs0, *Ys0)
- Cs = [nx.to_numpy(C) for C in Cs0]
- Ys = [nx.to_numpy(Y) for Y in Ys0]
-
- d = Ys[0].shape[-1]
- dataset_size = len(Cs)
-
- if ps is None:
- ps = [unif(C.shape[0]) for C in Cs]
- else:
- ps = [nx.to_numpy(p) for p in ps]
- if q is None:
- q = unif(nt)
- else:
- q = nx.to_numpy(q)
-
- if Cdict_init is None:
- # Initialize randomly structures of dictionary atoms based on samples
- dataset_means = [C.mean() for C in Cs]
- Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
- else:
- Cdict = nx.to_numpy(Cdict_init).copy()
- assert Cdict.shape == (D, nt, nt)
- if Ydict_init is None:
- # Initialize randomly features of dictionary atoms based on samples distribution by feature component
- dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
- Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
- else:
- Ydict = nx.to_numpy(Ydict_init).copy()
- assert Ydict.shape == (D, nt, d)
-
- if 'symmetric' in projection:
- Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
- if 'nonnegative' in projection:
- Cdict[Cdict < 0.] = 0.
-
- if use_adam_optimizer:
- adam_moments_C = _initialize_adam_optimizer(Cdict)
- adam_moments_Y = _initialize_adam_optimizer(Ydict)
-
- log = {'loss_batches': [], 'loss_epochs': []}
- const_q = q[:, None] * q[None, :]
- diag_q = np.diag(q)
- Cdict_best_state = Cdict.copy()
- Ydict_best_state = Ydict.copy()
- loss_best_state = np.inf
- if batch_size > dataset_size:
- batch_size = dataset_size
- iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
-
- for epoch in range(epochs):
- cumulated_loss_over_epoch = 0.
-
- for _ in range(iter_by_epoch):
-
- # Batch iterations
- batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
- cumulated_loss_over_batch = 0.
- unmixings = np.zeros((batch_size, D))
- Cs_embedded = np.zeros((batch_size, nt, nt))
- Ys_embedded = np.zeros((batch_size, nt, d))
- Ts = [None] * batch_size
-
- for batch_idx, C_idx in enumerate(batch):
- # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
- unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing(
- Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q,
- tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner
- )
- cumulated_loss_over_batch += current_loss
- cumulated_loss_over_epoch += cumulated_loss_over_batch
- if use_log:
- log['loss_batches'].append(cumulated_loss_over_batch)
-
- # Stochastic projected gradient step over dictionary atoms
- grad_Cdict = np.zeros_like(Cdict)
- grad_Ydict = np.zeros_like(Ydict)
-
- for batch_idx, C_idx in enumerate(batch):
- shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
- shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx])
- grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
- grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :]
- grad_Cdict *= 2 / batch_size
- grad_Ydict *= 2 / batch_size
-
- if use_adam_optimizer:
- Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C)
- Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y)
- else:
- Cdict -= learning_rate_C * grad_Cdict
- Ydict -= learning_rate_Y * grad_Ydict
-
- if 'symmetric' in projection:
- Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
- if 'nonnegative' in projection:
- Cdict[Cdict < 0.] = 0.
-
- if use_log:
- log['loss_epochs'].append(cumulated_loss_over_epoch)
- if loss_best_state > cumulated_loss_over_epoch:
- loss_best_state = cumulated_loss_over_epoch
- Cdict_best_state = Cdict.copy()
- Ydict_best_state = Ydict.copy()
- if verbose:
- print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
-
- return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log
-
-
-def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs):
- r"""
- Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`
-
- .. math::
- \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
-
- such that, :math:`\forall s \leq S` :
-
- - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
- - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
- - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
- - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
- - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
- - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
- - reg is the regularization coefficient.
-
- The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 6.
-
- Parameters
- ----------
- C : array-like, shape (ns, ns)
- Metric/Graph cost matrix.
- Y : array-like, shape (ns, d)
- Feature matrix.
- Cdict : D array-like, shape (D,nt,nt)
- Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
- Ydict : D array-like, shape (D,nt,d)
- Feature matrices composing the dictionary on which to embed (C,Y).
- alpha: float,
- Trade-off parameter of Fused Gromov-Wasserstein.
- reg : float, optional
- Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
- p : array-like, shape (ns,), optional
- Distribution in the source space C. Default is None and corresponds to uniform distribution.
- q : array-like, shape (nt,), optional
- Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
- tol_outer : float, optional
- Solver precision for the BCD algorithm.
- tol_inner : float, optional
- Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
- max_iter_outer : int, optional
- Maximum number of iterations for the BCD. Default is 20.
- max_iter_inner : int, optional
- Maximum number of iterations for the Conjugate Gradient. Default is 200.
-
- Returns
- -------
- w: array-like, shape (D,)
- fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary.
- Cembedded: array-like, shape (nt,nt)
- embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
- Yembedded: array-like, shape (nt,d)
- embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`.
- T: array-like (ns,nt)
- Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`.
- current_loss: float
- reconstruction error
- References
- -------
-
- ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty.
- "Online Graph Dictionary Learning"
- International Conference on Machine Learning (ICML). 2021.
- """
- C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict
- nx = get_backend(C0, Y0, Cdict0, Ydict0)
- C = nx.to_numpy(C0)
- Y = nx.to_numpy(Y0)
- Cdict = nx.to_numpy(Cdict0)
- Ydict = nx.to_numpy(Ydict0)
-
- if p is None:
- p = unif(C.shape[0])
- else:
- p = nx.to_numpy(p)
- if q is None:
- q = unif(Cdict.shape[-1])
- else:
- q = nx.to_numpy(q)
-
- T = p[:, None] * q[None, :]
- D = len(Cdict)
- d = Y.shape[-1]
- w = unif(D) # Initialize with uniform weights
- ns = C.shape[-1]
- nt = Cdict.shape[-1]
-
- # modeling (C,Y)
- Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
- Yembedded = np.sum(w[:, None, None] * Ydict, axis=0)
-
- # constants depending on q
- const_q = q[:, None] * q[None, :]
- diag_q = np.diag(q)
- # Trackers for BCD convergence
- convergence_criterion = np.inf
- current_loss = 10**15
- outer_count = 0
- Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix
-
- while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
- previous_loss = current_loss
-
- # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
- Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T)
- M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features
- T, log = fused_gromov_wasserstein(M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, armijo=False, G0=T, log=True)
- current_loss = log['fgw_dist']
- if reg != 0:
- current_loss -= reg * np.sum(w**2)
-
- # 2. Solve linear unmixing problem over w with a fixed transport plan T
- w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w,
- T, p, q, const_q, diag_q, current_loss, alpha, reg,
- tol=tol_inner, max_iter=max_iter_inner, **kwargs)
- if previous_loss != 0:
- convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
- else:
- convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
- outer_count += 1
-
- return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
-
-
-def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs):
- r"""
- Returns for a fixed admissible transport plan,
- the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})`
-
- .. math::
- \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2
-
- Such that :
-
- - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
- - :math:`\mathbf{w} \geq \mathbf{0}_D`
-
- Where :
-
- - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
- - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
- - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
- - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
- - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
- - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
- - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}`
- - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
- - reg is the regularization coefficient.
-
- The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38], algorithm 7.
-
- Parameters
- ----------
-
- C : array-like, shape (ns, ns)
- Metric/Graph cost matrix.
- Y : array-like, shape (ns, d)
- Feature matrix.
- Cdict : list of D array-like, shape (nt,nt)
- Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
- Each matrix in the dictionary must have the same size (nt,nt).
- Ydict : list of D array-like, shape (nt,d)
- Feature matrices composing the dictionary on which to embed (C,Y).
- Each matrix in the dictionary must have the same size (nt,d).
- Cembedded: array-like, shape (nt,nt)
- Embedded structure of (C,Y) onto the dictionary
- Yembedded: array-like, shape (nt,d)
- Embedded features of (C,Y) onto the dictionary
- w: array-like, shape (n_D,)
- Linear unmixing of (C,Y) onto (Cdict,Ydict)
- const_q: array-like, shape (nt,nt)
- product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution.
- diag_q: array-like, shape (nt,nt)
- diagonal matrix with values of q on the diagonal.
- T: array-like, shape (ns,nt)
- fixed transport plan between (C,Y) and its model
- p : array-like, shape (ns,)
- Distribution in the source space (C,Y).
- q : array-like, shape (nt,)
- Distribution in the embedding space depicted by the dictionary.
- alpha: float,
- Trade-off parameter of Fused Gromov-Wasserstein.
- reg : float, optional
- Coefficient of the negative quadratic regularization used to promote sparsity of w.
-
- Returns
- -------
- w: ndarray (D,)
- linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing.
- """
- convergence_criterion = np.inf
- current_loss = starting_loss
- count = 0
- const_TCT = np.transpose(C.dot(T)).dot(T)
- ones_ns_d = np.ones(Y.shape)
-
- while (convergence_criterion > tol) and (count < max_iter):
- previous_loss = current_loss
-
- # 1) Compute gradient at current point w
- # structure
- grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
- # feature
- grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2))
- grad_w -= reg * w
- grad_w *= 2
-
- # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
- min_ = np.min(grad_w)
- x = (grad_w == min_).astype(np.float64)
- x /= np.sum(x)
-
- # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
- gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg)
-
- # 4) Updates: w <-- (1-gamma)*w + gamma*x
- w += gamma * (x - w)
- Cembedded += gamma * Cembedded_diff
- Yembedded += gamma * Yembedded_diff
- current_loss += a * (gamma**2) + b * gamma
-
- if previous_loss != 0:
- convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
- else:
- convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
- count += 1
-
- return w, Cembedded, Yembedded, current_loss
-
-
-def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs):
- r"""
- Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing
- .. math::
- \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2
-
-
- Such that :
-
- - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
-
- Parameters
- ----------
-
- w : array-like, shape (D,)
- Unmixing.
- grad_w : array-like, shape (D, D)
- Gradient of the reconstruction loss with respect to w.
- x: array-like, shape (D,)
- Conditional gradient direction.
- Y: arrat-like, shape (ns,d)
- Feature matrix of the input space
- Cdict : list of D array-like, shape (nt, nt)
- Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
- Each matrix in the dictionary must have the same size (nt,nt).
- Ydict : list of D array-like, shape (nt, d)
- Feature matrices composing the dictionary on which to embed (C,Y).
- Each matrix in the dictionary must have the same size (nt,d).
- Cembedded: array-like, shape (nt, nt)
- Embedded structure of (C,Y) onto the dictionary
- Yembedded: array-like, shape (nt, d)
- Embedded features of (C,Y) onto the dictionary
- T: array-like, shape (ns, nt)
- Fixed transport plan between (C,Y) and its current model.
- const_q: array-like, shape (nt,nt)
- product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
- const_TCT: array-like, shape (nt, nt)
- :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
- ones_ns_d: array-like, shape (ns, d)
- :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations.
- alpha: float,
- Trade-off parameter of Fused Gromov-Wasserstein.
- reg : float, optional
- Coefficient of the negative quadratic regularization used to promote sparsity of w.
-
- Returns
- -------
- gamma: float
- Optimal value for the line-search step
- a: float
- Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
- b: float
- Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
- Cembedded_diff: numpy array, shape (nt, nt)
- Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
- Yembedded_diff: numpy array, shape (nt, nt)
- Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
- """
- # polynomial coefficients from quadratic objective (with respect to w) on structures
- Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
- Cembedded_diff = Cembedded_x - Cembedded
- trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
- trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
- # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
- a_gw = trace_diffx - trace_diffw
- b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
-
- # polynomial coefficient from quadratic objective (with respect to w) on features
- Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0)
- Yembedded_diff = Yembedded_x - Yembedded
- # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
- a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T)
- b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T)))
-
- a = alpha * a_gw + (1 - alpha) * a_w
- b = alpha * b_gw + (1 - alpha) * b_w
- if reg != 0:
- a -= reg * np.sum((x - w)**2)
- b -= 2 * reg * np.sum(w * (x - w))
- if a > 0:
- gamma = min(1, max(0, -b / (2 * a)))
- elif a + b < 0:
- gamma = 1
- else:
- gamma = 0
-
- return gamma, a, b, Cembedded_diff, Yembedded_diff
diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py
new file mode 100644
index 0000000..6184edf
--- /dev/null
+++ b/ot/gromov/__init__.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+"""
+Solvers related to Gromov-Wasserstein problems.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Cedric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+# All submodules and packages
+from ._utils import (init_matrix, tensor_product, gwloss, gwggrad,
+ update_square_loss, update_kl_loss,
+ init_matrix_semirelaxed)
+from ._gw import (gromov_wasserstein, gromov_wasserstein2,
+ fused_gromov_wasserstein, fused_gromov_wasserstein2,
+ solve_gromov_linesearch, gromov_barycenters, fgw_barycenters,
+ update_structure_matrix, update_feature_matrix)
+from ._bregman import (entropic_gromov_wasserstein,
+ entropic_gromov_wasserstein2,
+ entropic_gromov_barycenters)
+from ._estimators import (GW_distance_estimation, pointwise_gromov_wasserstein,
+ sampled_gromov_wasserstein)
+from ._semirelaxed import (semirelaxed_gromov_wasserstein,
+ semirelaxed_gromov_wasserstein2,
+ semirelaxed_fused_gromov_wasserstein,
+ semirelaxed_fused_gromov_wasserstein2,
+ solve_semirelaxed_gromov_linesearch)
+from ._dictionary import (gromov_wasserstein_dictionary_learning,
+ gromov_wasserstein_linear_unmixing,
+ fused_gromov_wasserstein_dictionary_learning,
+ fused_gromov_wasserstein_linear_unmixing)
+
+
+__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad',
+ 'update_square_loss', 'update_kl_loss', 'init_matrix_semirelaxed',
+ 'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
+ 'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
+ 'fgw_barycenters', 'update_structure_matrix', 'update_feature_matrix',
+ 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
+ 'entropic_gromov_barycenters', 'GW_distance_estimation',
+ 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein',
+ 'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2',
+ 'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2',
+ 'solve_semirelaxed_gromov_linesearch', 'gromov_wasserstein_dictionary_learning',
+ 'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning',
+ 'fused_gromov_wasserstein_linear_unmixing']
diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py
new file mode 100644
index 0000000..b0cccfb
--- /dev/null
+++ b/ot/gromov/_bregman.py
@@ -0,0 +1,348 @@
+# -*- coding: utf-8 -*-
+"""
+Bregman projections solvers for entropic Gromov-Wasserstein
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+from ..bregman import sinkhorn
+from ..utils import dist, list_to_array, check_random_state
+from ..backend import get_backend
+
+from ._utils import init_matrix, gwloss, gwggrad
+from ._utils import update_square_loss, update_kl_loss
+
+
+def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg\min}_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+ - `H`: entropy
+
+ .. note:: If the inner solver `ot.sinkhorn` did not convergence, the
+ optimal coupling :math:`\mathbf{T}` returned by this function does not
+ necessarily satisfy the marginal constraints
+ :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and
+ :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned
+ Gromov-Wasserstein loss does not necessarily satisfy distance
+ properties and may be negative.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : string
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ if G0 is None:
+ nx = get_backend(p, q, C1, C2)
+ G0 = nx.outer(p, q)
+ else:
+ nx = get_backend(p, q, C1, C2, G0)
+ T = G0
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx)
+ if symmetric is None:
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
+ if not symmetric:
+ constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx)
+ cpt = 0
+ err = 1
+
+ if log:
+ log = {'err': []}
+
+ while (err > tol and cpt < max_iter):
+
+ Tprev = T
+
+ # compute the gradient
+ if symmetric:
+ tens = gwggrad(constC, hC1, hC2, T, nx)
+ else:
+ tens = 0.5 * (gwggrad(constC, hC1, hC2, T, nx) + gwggrad(constCt, hC1t, hC2t, T, nx))
+ T = sinkhorn(p, q, tens, epsilon, method='sinkhorn')
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(T - Tprev)
+
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ log['gw_dist'] = gwloss(constC, hC1, hC2, T, nx)
+ return T, log
+ else:
+ return T
+
+
+def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, symmetric=None, G0=None,
+ max_iter=1000, tol=1e-9, verbose=False, log=False):
+ r"""
+ Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_\mathbf{T} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})
+ \mathbf{T}_{i,j} \mathbf{T}_{k,l} - \epsilon(H(\mathbf{T}))
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+ - `H`: entropy
+
+ .. note:: If the inner solver `ot.sinkhorn` did not convergence, the
+ optimal coupling :math:`\mathbf{T}` returned by this function does not
+ necessarily satisfy the marginal constraints
+ :math:`\mathbf{T}\mathbf{1}=\mathbf{p}` and
+ :math:`\mathbf{T}^T\mathbf{1}=\mathbf{q}`. So the returned
+ Gromov-Wasserstein loss does not necessarily satisfy distance
+ properties and may be negative.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Record log if True.
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ gw, logv = entropic_gromov_wasserstein(
+ C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter, tol, verbose, log=True)
+
+ logv['T'] = gw
+
+ if log:
+ return logv['gw_dist'], logv
+ else:
+ return logv['gw_dist']
+
+
+def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, symmetric=True,
+ max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None, random_state=None):
+ r"""
+ Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
+
+ The function solves the following optimization problem:
+
+ .. math::
+
+ \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
+
+ Where :
+
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
+
+ Parameters
+ ----------
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S array-like of shape (ns,ns)
+ Metric cost matrices
+ ps : list of S array-like of shape (ns,)
+ Sample weights in the `S` spaces
+ p : array-like, shape(N,)
+ Weights in the targeted barycenter
+ lambdas : list of float
+ List of the `S` spaces' weights.
+ loss_fun : callable
+ Tensor-matrix multiplication function based on specific loss function.
+ update : callable
+ function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
+ :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
+ calculated at each iteration
+ epsilon : float
+ Regularization term >0
+ symmetric : bool, optional.
+ Either structures are to be assumed symmetric or not. Default value is True.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : bool | array-like, shape (N, N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ C : array-like, shape (`N`, `N`)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+ """
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *ps, p)
+
+ S = len(Cs)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
+ else:
+ C = init_C
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while (err > tol) and (cpt < max_iter):
+ Cprev = C
+
+ T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, None,
+ max_iter, 1e-4, verbose, log=False) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(C - Cprev)
+ error.append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ return C, {"err": error}
+ else:
+ return C
diff --git a/ot/gromov/_dictionary.py b/ot/gromov/_dictionary.py
new file mode 100644
index 0000000..5b32671
--- /dev/null
+++ b/ot/gromov/_dictionary.py
@@ -0,0 +1,1008 @@
+# -*- coding: utf-8 -*-
+"""
+(Fused) Gromov-Wasserstein dictionary learning.
+"""
+
+# Author: Rémi Flamary <remi.flamary@unice.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+
+
+from ..utils import unif
+from ..backend import get_backend
+from ._gw import gromov_wasserstein, fused_gromov_wasserstein
+
+
+def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - reg is the regularization coefficient.
+
+ The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38]_
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns, ns).
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate: float, optional
+ Learning rate used for the stochastic gradient descent. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ projection: str , optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epochs.
+ References
+ -------
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+ """
+ # Handle backend of non-optional arguments
+ Cs0 = Cs
+ nx = get_backend(*Cs0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ dataset_size = len(Cs)
+ # Handle backend of optional arguments
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ symmetric = True
+ else:
+ symmetric = False
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0
+ if use_adam_optimizer:
+ adam_moments = _initialize_adam_optimizer(Cdict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ Cdict_best_state = Cdict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+ # batch sampling
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner,
+ max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner, symmetric=symmetric, **kwargs
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ if use_adam_optimizer:
+ Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments)
+ else:
+ Cdict -= learning_rate * grad_Cdict
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ if verbose:
+ print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), log
+
+
+def _initialize_adam_optimizer(variable):
+
+ # Initialization for our numpy implementation of adam optimizer
+ atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor
+ atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor
+ atoms_adam_count = 1
+
+ return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count}
+
+
+def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09):
+
+ adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad
+ adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2)
+ unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count'])
+ unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count'])
+ variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps)
+ adam_moments['count'] += 1
+
+ return variable, adam_moments
+
+
+def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, symmetric=None, **kwargs):
+ r"""
+ Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`.
+
+ .. math::
+ \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38]_ , algorithm 1.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ T: array-like (ns, nt)
+ Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})`
+ current_loss: float
+ reconstruction error
+ References
+ -------
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+ """
+ C0, Cdict0 = C, Cdict
+ nx = get_backend(C0, Cdict0)
+ C = nx.to_numpy(C0)
+ Cdict = nx.to_numpy(Cdict0)
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+
+ w = unif(D) # Initialize uniformly the unmixing w
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+
+ const_q = q[:, None] * q[None, :]
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ T, log = gromov_wasserstein(
+ C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T,
+ max_iter=max_iter_inner, tol_rel=tol_inner, tol_abs=0., log=True, armijo=False, symmetric=symmetric, **kwargs)
+ current_loss = log['gw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing(
+ C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T,
+ starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs
+ )
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points.
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights.
+ - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`.
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]_
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ w: array-like, shape (D,)
+ Linear unmixing of the input structure onto the dictionary
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between the input structure and its representation in the dictionary.
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+
+ previous_loss = current_loss
+ # 1) Compute gradient at current point w
+ grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ grad_w -= 2 * reg * w
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0: # not that the loss can be negative if reg >0
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else: # handle numerical issues around 0
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-15)
+ count += 1
+
+ return w, Cembedded, current_loss
+
+
+def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that:
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed C.
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ reg : float, optional.
+ Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`.
+ """
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ a = trace_diffx - trace_diffw
+ b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+
+ if a > 0:
+ gamma = min(1, max(0, - b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff
+
+
+def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
+ Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
+ tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
+ r"""
+ Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
+
+ .. math::
+ \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2
+
+
+ Such that :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+
+ The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38]_
+
+ Parameters
+ ----------
+ Cs : list of S symmetric array-like, shape (ns, ns)
+ List of Metric/Graph cost matrices of variable size (ns,ns).
+ Ys : list of S array-like, shape (ns, d)
+ List of feature matrix of variable size (ns,d) with d fixed.
+ D: int
+ Number of dictionary atoms to learn
+ nt: int
+ Number of samples within each dictionary atoms
+ alpha : float
+ Trade-off parameter of Fused Gromov-Wasserstein
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ ps : list of S array-like, shape (ns,), optional
+ Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions.
+ q : array-like, shape (nt,), optional
+ Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions.
+ epochs: int, optional
+ Number of epochs used to learn the dictionary. Default is 32.
+ batch_size: int, optional
+ Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32.
+ learning_rate_C: float, optional
+ Learning rate used for the stochastic gradient descent on Cdict. Default is 1.
+ learning_rate_Y: float, optional
+ Learning rate used for the stochastic gradient descent on Ydict. Default is 1.
+ Cdict_init: list of D array-like with shape (nt, nt), optional
+ Used to initialize the dictionary structures Cdict.
+ If set to None (Default), the dictionary will be initialized randomly.
+ Else Cdict must have shape (D, nt, nt) i.e match provided shape features.
+ Ydict_init: list of D array-like with shape (nt, d), optional
+ Used to initialize the dictionary features Ydict.
+ If set to None, the dictionary features will be initialized randomly.
+ Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features.
+ projection: str, optional
+ If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary
+ Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric'
+ log: bool, optional
+ If set to True, losses evolution by batches and epochs are tracked. Default is False.
+ use_adam_optimizer: bool, optional
+ If set to True, adam optimizer with default settings is used as adaptative learning rate strategy.
+ Else perform SGD with fixed learning rate. Default is True.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+ verbose : bool, optional
+ Print the reconstruction loss every epoch. Default is False.
+
+ Returns
+ -------
+
+ Cdict_best_state : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ Ydict_best_state : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary.
+ The dictionary leading to the best loss over an epoch is saved and returned.
+ log: dict
+ If use_log is True, contains loss evolutions by batches and epoches.
+ References
+ -------
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+ """
+ Cs0, Ys0 = Cs, Ys
+ nx = get_backend(*Cs0, *Ys0)
+ Cs = [nx.to_numpy(C) for C in Cs0]
+ Ys = [nx.to_numpy(Y) for Y in Ys0]
+
+ d = Ys[0].shape[-1]
+ dataset_size = len(Cs)
+
+ if ps is None:
+ ps = [unif(C.shape[0]) for C in Cs]
+ else:
+ ps = [nx.to_numpy(p) for p in ps]
+ if q is None:
+ q = unif(nt)
+ else:
+ q = nx.to_numpy(q)
+
+ if Cdict_init is None:
+ # Initialize randomly structures of dictionary atoms based on samples
+ dataset_means = [C.mean() for C in Cs]
+ Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
+ else:
+ Cdict = nx.to_numpy(Cdict_init).copy()
+ assert Cdict.shape == (D, nt, nt)
+ if Ydict_init is None:
+ # Initialize randomly features of dictionary atoms based on samples distribution by feature component
+ dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
+ Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
+ else:
+ Ydict = nx.to_numpy(Ydict_init).copy()
+ assert Ydict.shape == (D, nt, d)
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ symmetric = True
+ else:
+ symmetric = False
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_adam_optimizer:
+ adam_moments_C = _initialize_adam_optimizer(Cdict)
+ adam_moments_Y = _initialize_adam_optimizer(Ydict)
+
+ log = {'loss_batches': [], 'loss_epochs': []}
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ loss_best_state = np.inf
+ if batch_size > dataset_size:
+ batch_size = dataset_size
+ iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0)
+
+ for epoch in range(epochs):
+ cumulated_loss_over_epoch = 0.
+
+ for _ in range(iter_by_epoch):
+
+ # Batch iterations
+ batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
+ cumulated_loss_over_batch = 0.
+ unmixings = np.zeros((batch_size, D))
+ Cs_embedded = np.zeros((batch_size, nt, nt))
+ Ys_embedded = np.zeros((batch_size, nt, d))
+ Ts = [None] * batch_size
+
+ for batch_idx, C_idx in enumerate(batch):
+ # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch
+ unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing(
+ Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q,
+ tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner, symmetric=symmetric, **kwargs
+ )
+ cumulated_loss_over_batch += current_loss
+ cumulated_loss_over_epoch += cumulated_loss_over_batch
+ if use_log:
+ log['loss_batches'].append(cumulated_loss_over_batch)
+
+ # Stochastic projected gradient step over dictionary atoms
+ grad_Cdict = np.zeros_like(Cdict)
+ grad_Ydict = np.zeros_like(Ydict)
+
+ for batch_idx, C_idx in enumerate(batch):
+ shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx])
+ shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx])
+ grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :]
+ grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :]
+ grad_Cdict *= 2 / batch_size
+ grad_Ydict *= 2 / batch_size
+
+ if use_adam_optimizer:
+ Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C)
+ Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y)
+ else:
+ Cdict -= learning_rate_C * grad_Cdict
+ Ydict -= learning_rate_Y * grad_Ydict
+
+ if 'symmetric' in projection:
+ Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1)))
+ if 'nonnegative' in projection:
+ Cdict[Cdict < 0.] = 0.
+
+ if use_log:
+ log['loss_epochs'].append(cumulated_loss_over_epoch)
+ if loss_best_state > cumulated_loss_over_epoch:
+ loss_best_state = cumulated_loss_over_epoch
+ Cdict_best_state = Cdict.copy()
+ Ydict_best_state = Ydict.copy()
+ if verbose:
+ print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch)
+
+ return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log
+
+
+def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5),
+ tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, symmetric=True, **kwargs):
+ r"""
+ Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`
+
+ .. math::
+ \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2
+
+ such that, :math:`\forall s \leq S` :
+
+ - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w_s} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38]_, algorithm 6.
+
+ Parameters
+ ----------
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : D array-like, shape (D,nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Ydict : D array-like, shape (D,nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0.
+ p : array-like, shape (ns,), optional
+ Distribution in the source space C. Default is None and corresponds to uniform distribution.
+ q : array-like, shape (nt,), optional
+ Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution.
+ tol_outer : float, optional
+ Solver precision for the BCD algorithm.
+ tol_inner : float, optional
+ Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`.
+ max_iter_outer : int, optional
+ Maximum number of iterations for the BCD. Default is 20.
+ max_iter_inner : int, optional
+ Maximum number of iterations for the Conjugate Gradient. Default is 200.
+
+ Returns
+ -------
+ w: array-like, shape (D,)
+ fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary.
+ Cembedded: array-like, shape (nt,nt)
+ embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`.
+ Yembedded: array-like, shape (nt,d)
+ embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`.
+ T: array-like (ns,nt)
+ Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`.
+ current_loss: float
+ reconstruction error
+ References
+ -------
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+ """
+ C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict
+ nx = get_backend(C0, Y0, Cdict0, Ydict0)
+ C = nx.to_numpy(C0)
+ Y = nx.to_numpy(Y0)
+ Cdict = nx.to_numpy(Cdict0)
+ Ydict = nx.to_numpy(Ydict0)
+
+ if p is None:
+ p = unif(C.shape[0])
+ else:
+ p = nx.to_numpy(p)
+ if q is None:
+ q = unif(Cdict.shape[-1])
+ else:
+ q = nx.to_numpy(q)
+
+ T = p[:, None] * q[None, :]
+ D = len(Cdict)
+ d = Y.shape[-1]
+ w = unif(D) # Initialize with uniform weights
+ ns = C.shape[-1]
+ nt = Cdict.shape[-1]
+
+ # modeling (C,Y)
+ Cembedded = np.sum(w[:, None, None] * Cdict, axis=0)
+ Yembedded = np.sum(w[:, None, None] * Ydict, axis=0)
+
+ # constants depending on q
+ const_q = q[:, None] * q[None, :]
+ diag_q = np.diag(q)
+ # Trackers for BCD convergence
+ convergence_criterion = np.inf
+ current_loss = 10**15
+ outer_count = 0
+ Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix
+
+ while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer):
+ previous_loss = current_loss
+
+ # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w
+ Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T)
+ M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features
+ T, log = fused_gromov_wasserstein(
+ M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha,
+ max_iter=max_iter_inner, tol_rel=tol_inner, tol_abs=0., armijo=False, G0=T, log=True, symmetric=symmetric, **kwargs)
+ current_loss = log['fgw_dist']
+ if reg != 0:
+ current_loss -= reg * np.sum(w**2)
+
+ # 2. Solve linear unmixing problem over w with a fixed transport plan T
+ w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w,
+ T, p, q, const_q, diag_q, current_loss, alpha, reg,
+ tol=tol_inner, max_iter=max_iter_inner, **kwargs)
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ outer_count += 1
+
+ return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss)
+
+
+def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs):
+ r"""
+ Returns for a fixed admissible transport plan,
+ the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})`
+
+ .. math::
+ \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2
+
+ Such that :
+
+ - :math:`\mathbf{w}^\top \mathbf{1}_D = 1`
+ - :math:`\mathbf{w} \geq \mathbf{0}_D`
+
+ Where :
+
+ - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns.
+ - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d.
+ - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt.
+ - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d.
+ - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}`
+ - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space.
+ - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}`
+ - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein
+ - reg is the regularization coefficient.
+
+ The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38]_, algorithm 7.
+
+ Parameters
+ ----------
+
+ C : array-like, shape (ns, ns)
+ Metric/Graph cost matrix.
+ Y : array-like, shape (ns, d)
+ Feature matrix.
+ Cdict : list of D array-like, shape (nt,nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt,d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt,nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt,d)
+ Embedded features of (C,Y) onto the dictionary
+ w: array-like, shape (n_D,)
+ Linear unmixing of (C,Y) onto (Cdict,Ydict)
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution.
+ diag_q: array-like, shape (nt,nt)
+ diagonal matrix with values of q on the diagonal.
+ T: array-like, shape (ns,nt)
+ fixed transport plan between (C,Y) and its model
+ p : array-like, shape (ns,)
+ Distribution in the source space (C,Y).
+ q : array-like, shape (nt,)
+ Distribution in the embedding space depicted by the dictionary.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ w: ndarray (D,)
+ linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing.
+ """
+ convergence_criterion = np.inf
+ current_loss = starting_loss
+ count = 0
+ const_TCT = np.transpose(C.dot(T)).dot(T)
+ ones_ns_d = np.ones(Y.shape)
+
+ while (convergence_criterion > tol) and (count < max_iter):
+ previous_loss = current_loss
+
+ # 1) Compute gradient at current point w
+ # structure
+ grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2))
+ # feature
+ grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2))
+ grad_w -= reg * w
+ grad_w *= 2
+
+ # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w
+ min_ = np.min(grad_w)
+ x = (grad_w == min_).astype(np.float64)
+ x /= np.sum(x)
+
+ # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c
+ gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg)
+
+ # 4) Updates: w <-- (1-gamma)*w + gamma*x
+ w += gamma * (x - w)
+ Cembedded += gamma * Cembedded_diff
+ Yembedded += gamma * Yembedded_diff
+ current_loss += a * (gamma**2) + b * gamma
+
+ if previous_loss != 0:
+ convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss)
+ else:
+ convergence_criterion = abs(previous_loss - current_loss) / 10**(-12)
+ count += 1
+
+ return w, Cembedded, Yembedded, current_loss
+
+
+def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs):
+ r"""
+ Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing
+ .. math::
+ \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2
+
+
+ Such that :
+
+ - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}`
+
+ Parameters
+ ----------
+
+ w : array-like, shape (D,)
+ Unmixing.
+ grad_w : array-like, shape (D, D)
+ Gradient of the reconstruction loss with respect to w.
+ x: array-like, shape (D,)
+ Conditional gradient direction.
+ Y: arrat-like, shape (ns,d)
+ Feature matrix of the input space
+ Cdict : list of D array-like, shape (nt, nt)
+ Metric/Graph cost matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,nt).
+ Ydict : list of D array-like, shape (nt, d)
+ Feature matrices composing the dictionary on which to embed (C,Y).
+ Each matrix in the dictionary must have the same size (nt,d).
+ Cembedded: array-like, shape (nt, nt)
+ Embedded structure of (C,Y) onto the dictionary
+ Yembedded: array-like, shape (nt, d)
+ Embedded features of (C,Y) onto the dictionary
+ T: array-like, shape (ns, nt)
+ Fixed transport plan between (C,Y) and its current model.
+ const_q: array-like, shape (nt,nt)
+ product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations.
+ const_TCT: array-like, shape (nt, nt)
+ :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations.
+ ones_ns_d: array-like, shape (ns, d)
+ :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations.
+ alpha: float,
+ Trade-off parameter of Fused Gromov-Wasserstein.
+ reg : float, optional
+ Coefficient of the negative quadratic regularization used to promote sparsity of w.
+
+ Returns
+ -------
+ gamma: float
+ Optimal value for the line-search step
+ a: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ b: float
+ Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss
+ Cembedded_diff: numpy array, shape (nt, nt)
+ Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ Yembedded_diff: numpy array, shape (nt, nt)
+ Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`.
+ """
+ # polynomial coefficients from quadratic objective (with respect to w) on structures
+ Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0)
+ Cembedded_diff = Cembedded_x - Cembedded
+ trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q)
+ trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q)
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_gw = trace_diffx - trace_diffw
+ b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT))
+
+ # polynomial coefficient from quadratic objective (with respect to w) on features
+ Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0)
+ Yembedded_diff = Yembedded_x - Yembedded
+ # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss
+ a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T)
+ b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T)))
+
+ a = alpha * a_gw + (1 - alpha) * a_w
+ b = alpha * b_gw + (1 - alpha) * b_w
+ if reg != 0:
+ a -= reg * np.sum((x - w)**2)
+ b -= 2 * reg * np.sum(w * (x - w))
+ if a > 0:
+ gamma = min(1, max(0, -b / (2 * a)))
+ elif a + b < 0:
+ gamma = 1
+ else:
+ gamma = 0
+
+ return gamma, a, b, Cembedded_diff, Yembedded_diff
diff --git a/ot/gromov/_estimators.py b/ot/gromov/_estimators.py
new file mode 100644
index 0000000..0a29a91
--- /dev/null
+++ b/ot/gromov/_estimators.py
@@ -0,0 +1,425 @@
+# -*- coding: utf-8 -*-
+"""
+Gromov-Wasserstein and Fused-Gromov-Wasserstein stochastic estimators.
+"""
+
+# Author: Rémi Flamary <remi.flamary@unice.fr>
+# Tanguy Kerdoncuff <tanguy.kerdoncuff@laposte.net>
+#
+# License: MIT License
+
+import numpy as np
+
+
+from ..bregman import sinkhorn
+from ..utils import list_to_array, check_random_state
+from ..lp import emd_1d, emd
+from ..backend import get_backend
+
+
+def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
+ nb_samples_p=None, nb_samples_q=None, std=True, random_state=None):
+ r"""
+ Returns an approximation of the gromov-wasserstein cost between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+ with a fixed transport plan :math:`\mathbf{T}`.
+
+ The function gives an unbiased approximation of the following equation:
+
+ .. math::
+
+ GW = \sum_{i,j,k,l} L(\mathbf{C_{1}}_{i,k}, \mathbf{C_{2}}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - `L` : Loss function to account for the misfit between the similarity matrices
+ - :math:`\mathbf{T}`: Matrix with marginal :math:`\mathbf{p}` and :math:`\mathbf{q}`
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ T : csr or array-like, shape (ns, nt)
+ Transport plan matrix, either a sparse csr or a dense matrix
+ nb_samples_p : int, optional
+ `nb_samples_p` is the number of samples (without replacement) along the first dimension of :math:`\mathbf{T}`
+ nb_samples_q : int, optional
+ `nb_samples_q` is the number of samples along the second dimension of :math:`\mathbf{T}`, for each sample along the first
+ std : bool, optional
+ Standard deviation associated with the prediction of the gromov-wasserstein cost
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ : float
+ Gromov-wasserstein cost
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ generator = check_random_state(random_state)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ # It is always better to sample from the biggest distribution first.
+ if len_p < len_q:
+ p, q = q, p
+ len_p, len_q = len_q, len_p
+ C1, C2 = C2, C1
+ T = T.T
+
+ if nb_samples_p is None:
+ if nx.issparse(T):
+ # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced
+ nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p)
+ else:
+ nb_samples_p = len_p
+ else:
+ # The number of sample along the first dimension is without replacement.
+ nb_samples_p = min(nb_samples_p, len_p)
+ if nb_samples_q is None:
+ nb_samples_q = 1
+ if std:
+ nb_samples_q = max(2, nb_samples_q)
+
+ index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+ index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
+
+ index_i = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
+ index_j = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
+
+ for i in range(nb_samples_p):
+ if nx.issparse(T):
+ T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,))
+ T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,))
+ else:
+ T_indexi = T[index_i[i], :]
+ T_indexj = T[index_j[i], :]
+ # For each of the row sampled, the column is sampled.
+ index_k[i] = generator.choice(
+ len_q,
+ size=nb_samples_q,
+ p=nx.to_numpy(T_indexi / nx.sum(T_indexi)),
+ replace=True
+ )
+ index_l[i] = generator.choice(
+ len_q,
+ size=nb_samples_q,
+ p=nx.to_numpy(T_indexj / nx.sum(T_indexj)),
+ replace=True
+ )
+
+ list_value_sample = nx.stack([
+ loss_fun(
+ C1[np.ix_(index_i, index_j)],
+ C2[np.ix_(index_k[:, n], index_l[:, n])]
+ ) for n in range(nb_samples_q)
+ ], axis=2)
+
+ if std:
+ std_value = nx.sum(nx.std(list_value_sample, axis=2) ** 2) ** 0.5
+ return nx.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p)
+ else:
+ return nx.mean(list_value_sample)
+
+
+def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a stochastic Frank-Wolfe.
+ This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times PN^2)` time complexity with `P` the number of Sinkhorn iterations.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ alpha : float
+ Step of the Frank-Wolfe algorithm, should be between 0 and 1
+ max_iter : int, optional
+ Max number of iterations
+ threshold_plan : float, optional
+ Deleting very small values in the transport plan. If above zero, it violates the marginal constraints.
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ generator = check_random_state(random_state)
+
+ index = np.zeros(2, dtype=int)
+
+ # Initialize with default marginal
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
+ index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q))
+ T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
+
+ best_gw_dist_estimated = np.inf
+ for cpt in range(max_iter):
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
+ T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
+ index[1] = generator.choice(
+ len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
+ )
+
+ if alpha == 1:
+ T = nx.tocsr(
+ emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
+ )
+ else:
+ new_T = nx.tocsr(
+ emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False)
+ )
+ T = (1 - alpha) * T + alpha * new_T
+ # To limit the number of non 0, the values below the threshold are set to 0.
+ T = nx.eliminate_zeros(T, threshold=threshold_plan)
+
+ if cpt % 10 == 0 or cpt == (max_iter - 1):
+ gw_dist_estimated = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, std=False, random_state=generator
+ )
+
+ if gw_dist_estimated < best_gw_dist_estimated:
+ best_gw_dist_estimated = gw_dist_estimated
+ best_T = nx.copy(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated))
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=best_T, random_state=generator
+ )
+ return best_T, log
+ return best_T
+
+
+def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun,
+ nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False,
+ random_state=None):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` using a 1-stochastic Frank-Wolfe.
+ This method has a :math:`\mathcal{O}(\mathrm{max\_iter} \times N \log(N))` time complexity by relying on the 1D Optimal Transport solver.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{T} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{T}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{T} &\geq 0
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : function: :math:`\mathbb{R} \times \mathbb{R} \mapsto \mathbb{R}`
+ Loss function used for the distance, the transport plan does not depend on the loss function
+ nb_samples_grad : int
+ Number of samples to approximate the gradient
+ epsilon : float
+ Weight of the Kullback-Leibler regularization
+ max_iter : int, optional
+ Max number of iterations
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ Gives the distance estimated and the standard deviation
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Optimal coupling between the two spaces
+
+ References
+ ----------
+ .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc
+ "Sampled Gromov Wasserstein."
+ Machine Learning Journal (MLJ). 2021.
+
+ """
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ len_p = p.shape[0]
+ len_q = q.shape[0]
+
+ generator = check_random_state(random_state)
+
+ # The most natural way to define nb_sample is with a simple integer.
+ if isinstance(nb_samples_grad, int):
+ if nb_samples_grad > len_p:
+ # As the sampling along the first dimension is done without replacement, the rest is reported to the second
+ # dimension.
+ nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1
+ else:
+ nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad
+ T = nx.outer(p, q)
+ # continue_loop allows to stop the loop if there is several successive small modification of T.
+ continue_loop = 0
+
+ # The gradient of GW is more complex if the two matrices are not symmetric.
+ C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
+
+ for cpt in range(max_iter):
+ index0 = generator.choice(
+ len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False
+ )
+ Lik = 0
+ for i, index0_i in enumerate(index0):
+ index1 = generator.choice(
+ len_q, size=nb_samples_grad_q,
+ p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])),
+ replace=False
+ )
+ # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
+ if (not C_are_symmetric) and generator.rand(1) > 0.5:
+ Lik += nx.mean(loss_fun(
+ C1[:, [index0[i]] * nb_samples_grad_q][:, None, :],
+ C2[:, index1][None, :, :]
+ ), axis=2)
+ else:
+ Lik += nx.mean(loss_fun(
+ C1[[index0[i]] * nb_samples_grad_q, :][:, :, None],
+ C2[index1, :][:, None, :]
+ ), axis=0)
+
+ max_Lik = nx.max(Lik)
+ if max_Lik == 0:
+ continue
+ # This division by the max is here to facilitate the choice of epsilon.
+ Lik /= max_Lik
+
+ if epsilon > 0:
+ # Set to infinity all the numbers below exp(-200) to avoid log of 0.
+ log_T = nx.log(nx.clip(T, np.exp(-200), 1))
+ log_T = nx.where(log_T == -200, -np.inf, log_T)
+ Lik = Lik - epsilon * log_T
+
+ try:
+ new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon)
+ except (RuntimeWarning, UserWarning):
+ print("Warning catched in Sinkhorn: Return last stable T")
+ break
+ else:
+ new_T = emd(a=p, b=q, M=Lik)
+
+ change_T = nx.mean((T - new_T) ** 2)
+ if change_T <= 10e-20:
+ continue_loop += 1
+ if continue_loop > 100: # Number max of low modifications of T
+ T = nx.copy(new_T)
+ break
+ else:
+ continue_loop = 0
+
+ if verbose and cpt % 10 == 0:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, change_T))
+ T = nx.copy(new_T)
+
+ if log:
+ log = {}
+ log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(
+ C1=C1, C2=C2, loss_fun=loss_fun,
+ p=p, q=q, T=T, random_state=generator
+ )
+ return T, log
+ return T
diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py
new file mode 100644
index 0000000..c6e4076
--- /dev/null
+++ b/ot/gromov/_gw.py
@@ -0,0 +1,978 @@
+# -*- coding: utf-8 -*-
+"""
+Gromov-Wasserstein and Fused-Gromov-Wasserstein conditional gradient solvers.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+
+
+from ..utils import dist, UndefinedParameter, list_to_array
+from ..optim import cg, line_search_armijo, solve_1d_linesearch_quad
+from ..utils import check_random_state
+from ..backend import get_backend, NumpyBackend
+
+from ._utils import init_matrix, gwloss, gwggrad
+from ._utils import update_square_loss, update_kl_loss
+
+
+def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{GW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Coupling between the two spaces that minimizes:
+
+ :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
+ log : dict
+ Convergence information and loss.
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
+ mathematics 11.4 (2011): 417-487.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ p, q = list_to_array(p, q)
+ p0, q0, C10, C20 = p, q, C1, C2
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, G0_)
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ if symmetric is None:
+ symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
+
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
+ # cg for GW is implemented using numpy on CPU
+ np_ = NumpyBackend()
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_)
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G, np_)
+
+ if symmetric:
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G, np_)
+ else:
+ constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_)
+
+ def df(G):
+ return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
+ if loss_fun == 'kl_loss':
+ armijo = True # there is no closed form line-search with KL
+
+ if armijo:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
+ else:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=np_, **kwargs)
+ if log:
+ res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+ log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
+ else:
+ return nx.from_numpy(cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
+
+
+def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ GW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - :math:`\mathbf{q}`: distribution in the target space
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2) and weights (p, q) for quadratic loss using the gradients from [38]_.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the target space.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ gw_dist : float
+ Gromov-Wasserstein distance
+ log : dict
+ convergence information and Coupling marix
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
+ metric approach to object matching. Foundations of computational
+ mathematics 11.4 (2011): 417-487.
+
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ # simple get_backend as the full one will be handled in gromov_wasserstein
+ nx = get_backend(C1, C2)
+
+ T, log_gw = gromov_wasserstein(
+ C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0,
+ max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
+
+ log_gw['T'] = T
+ gw = log_gw['gw_dist']
+
+ if loss_fun == 'square_loss':
+ gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
+ gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
+ gw = nx.set_gradients(gw, (p, q, C1, C2),
+ (log_gw['u'] - nx.mean(log_gw['u']),
+ log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))
+
+ if log:
+ return gw, log_gw
+ else:
+ return gw
+
+
+def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5,
+ armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Computes the FGW transport between two graphs (see :ref:`[24] <references-fused-gromov-wasserstein>`)
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ q : array-like, shape (nt,)
+ Distribution in the target space
+ loss_fun : str, optional
+ Loss function used for the solver
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research. Else closed form is used.
+ If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ log : bool, optional
+ record log if True
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ gamma : array-like, shape (`ns`, `nt`)
+ Optimal transportation matrix for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-fused-gromov-wasserstein:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas "Optimal Transport for structured data with
+ application on graphs", International Conference on Machine Learning
+ (ICML). 2019.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ p, q = list_to_array(p, q)
+ p0, q0, C10, C20, M0 = p, q, C1, C2, M
+ if G0 is None:
+ nx = get_backend(p0, q0, C10, C20, M0)
+ else:
+ G0_ = G0
+ nx = get_backend(p0, q0, C10, C20, M0, G0_)
+
+ p = nx.to_numpy(p)
+ q = nx.to_numpy(q)
+ C1 = nx.to_numpy(C10)
+ C2 = nx.to_numpy(C20)
+ M = nx.to_numpy(M0)
+
+ if symmetric is None:
+ symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
+
+ if G0 is None:
+ G0 = p[:, None] * q[None, :]
+ else:
+ G0 = nx.to_numpy(G0_)
+ # Check marginals of G0
+ np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
+ np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)
+ # cg for GW is implemented using numpy on CPU
+ np_ = NumpyBackend()
+
+ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, np_)
+
+ def f(G):
+ return gwloss(constC, hC1, hC2, G, np_)
+
+ if symmetric:
+ def df(G):
+ return gwggrad(constC, hC1, hC2, G, np_)
+ else:
+ constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, np_)
+
+ def df(G):
+ return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))
+
+ if loss_fun == 'kl_loss':
+ armijo = True # there is no closed form line-search with KL
+
+ if armijo:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
+ else:
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs)
+ if log:
+ res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+ log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10)
+ log['u'] = nx.from_numpy(log['u'], type_as=C10)
+ log['v'] = nx.from_numpy(log['v'], type_as=C10)
+ return nx.from_numpy(res, type_as=C10), log
+ else:
+ return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10)
+
+
+def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric=None, alpha=0.5,
+ armijo=False, G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Computes the FGW distance between two graphs see (see :ref:`[24] <references-fused-gromov-wasserstein2>`)
+
+ .. math::
+ \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma}^T \mathbf{1} &= \mathbf{q}
+
+ \mathbf{\gamma} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as
+ discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. But the algorithm uses the C++ CPU backend
+ which can lead to copy overhead on GPU arrays.
+ .. note:: All computations in the conjugate gradient solver are done with
+ numpy to limit memory overhead.
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2, M) and weights (p, q) for quadratic loss using the gradients from [38]_.
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space.
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space.
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ q : array-like, shape (nt,)
+ Distribution in the target space.
+ loss_fun : str, optional
+ Loss function used for the solver.
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ armijo : bool, optional
+ If True the step of the line-search is found via an armijo research.
+ Else closed form is used. If there are convergence issues use False.
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ log : bool, optional
+ Record log if True.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ Parameters can be directly passed to the ot.optim.cg solver.
+
+ Returns
+ -------
+ fgw-distance : float
+ Fused gromov wasserstein distance for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-fused-gromov-wasserstein2:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+ .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
+ Graph Dictionary Learning, International Conference on Machine Learning
+ (ICML), 2021.
+
+ .. [47] Chowdhury, S., & Mémoli, F. (2019). The gromov–wasserstein
+ distance between networks and stable network invariants.
+ Information and Inference: A Journal of the IMA, 8(4), 757-787.
+ """
+ nx = get_backend(C1, C2, M)
+
+ T, log_fgw = fused_gromov_wasserstein(
+ M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True,
+ max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
+
+ fgw_dist = log_fgw['fgw_dist']
+ log_fgw['T'] = T
+
+ if loss_fun == 'square_loss':
+ gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
+ gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
+ fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
+ (log_fgw['u'] - nx.mean(log_fgw['u']),
+ log_fgw['v'] - nx.mean(log_fgw['v']),
+ alpha * gC1, alpha * gC2, (1 - alpha) * T))
+
+ if log:
+ return fgw_dist, log_fgw
+ else:
+ return fgw_dist
+
+
+def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
+ alpha_min=None, alpha_max=None, nx=None, **kwargs):
+ """
+ Solve the linesearch in the FW iterations
+
+ Parameters
+ ----------
+
+ G : array-like, shape(ns,nt)
+ The transport map at a given iteration of the FW
+ deltaG : array-like (ns,nt)
+ Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
+ cost_G : float
+ Value of the cost at `G`
+ C1 : array-like (ns,ns), optional
+ Structure matrix in the source domain.
+ C2 : array-like (nt,nt), optional
+ Structure matrix in the target domain.
+ M : array-like (ns,nt)
+ Cost matrix between the features.
+ reg : float
+ Regularization parameter.
+ alpha_min : float, optional
+ Minimum value for alpha
+ alpha_max : float, optional
+ Maximum value for alpha
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ alpha : float
+ The optimal step size of the FW
+ fc : int
+ nb of function call. Useless here
+ cost_G : float
+ The value of the cost for the next iteration
+
+
+ .. _references-solve-linesearch:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ if nx is None:
+ G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)
+
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(G, deltaG, C1, C2)
+ else:
+ nx = get_backend(G, deltaG, C1, C2, M)
+
+ dot = nx.dot(nx.dot(C1, deltaG), C2.T)
+ a = -2 * reg * nx.sum(dot * deltaG)
+ b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))
+
+ alpha = solve_1d_linesearch_quad(a, b)
+ if alpha_min is not None or alpha_max is not None:
+ alpha = np.clip(alpha, alpha_min, alpha_max)
+
+ # the new cost is deduced from the line search quadratic function
+ cost_G = cost_G + a * (alpha ** 2) + b * alpha
+
+ return alpha, 1, cost_G
+
+
+def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, symmetric=True, armijo=False,
+ max_iter=1000, tol=1e-9, verbose=False, log=False,
+ init_C=None, random_state=None, **kwargs):
+ r"""
+ Returns the gromov-wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
+
+ The function solves the following optimization problem with block coordinate descent:
+
+ .. math::
+
+ \mathbf{C} = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)
+
+ Where :
+
+ - :math:`\mathbf{C}_s`: metric cost matrix
+ - :math:`\mathbf{p}_s`: distribution
+
+ Parameters
+ ----------
+ N : int
+ Size of the targeted barycenter
+ Cs : list of S array-like of shape (ns, ns)
+ Metric cost matrices
+ ps : list of S array-like of shape (ns,)
+ Sample weights in the `S` spaces
+ p : array-like, shape (N,)
+ Weights in the targeted barycenter
+ lambdas : list of float
+ List of the `S` spaces' weights
+ loss_fun : callable
+ tensor-matrix multiplication function based on specific loss function
+ symmetric : bool, optional.
+ Either structures are to be assumed symmetric or not. Default value is True.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ update : callable
+ function(:math:`\mathbf{p}`, lambdas, :math:`\mathbf{T}`, :math:`\mathbf{Cs}`) that updates
+ :math:`\mathbf{C}` according to a specific Kernel with the `S` :math:`\mathbf{T}_s` couplings
+ calculated at each iteration
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on relative error (>0)
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : bool | array-like, shape(N,N)
+ Random initial value for the :math:`\mathbf{C}` matrix provided by user.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ C : array-like, shape (`N`, `N`)
+ Similarity matrix in the barycenter space (permutated arbitrarily)
+ log : dict
+ Log dictionary of error during iterations. Return only if `log=True` in parameters.
+
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *ps, p)
+
+ S = len(Cs)
+
+ # Initialization of C : random SPD matrix (if not provided by user)
+ if init_C is None:
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
+ C = dist(xalea, xalea)
+ C /= C.max()
+ C = nx.from_numpy(C, type_as=p)
+ else:
+ C = init_C
+
+ if loss_fun == 'kl_loss':
+ armijo = True
+
+ cpt = 0
+ err = 1
+
+ error = []
+
+ while (err > tol and cpt < max_iter):
+ Cprev = C
+
+ T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo,
+ max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
+ if loss_fun == 'square_loss':
+ C = update_square_loss(p, lambdas, T, Cs)
+
+ elif loss_fun == 'kl_loss':
+ C = update_kl_loss(p, lambdas, T, Cs)
+
+ if cpt % 10 == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = nx.norm(C - Cprev)
+ error.append(err)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+
+ if log:
+ return C, {"err": error}
+ else:
+ return C
+
+
+def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
+ p=None, loss_fun='square_loss', armijo=False, symmetric=True, max_iter=100, tol=1e-9,
+ verbose=False, log=False, init_C=None, init_X=None, random_state=None, **kwargs):
+ r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] <references-fgw-barycenters>`
+
+ Parameters
+ ----------
+ N : int
+ Desired number of samples of the target barycenter
+ Ys: list of array-like, each element has shape (ns,d)
+ Features of all samples
+ Cs : list of array-like, each element has shape (ns,ns)
+ Structure matrices of all samples
+ ps : list of array-like, each element has shape (ns,)
+ Masses of all samples.
+ lambdas : list of float
+ List of the `S` spaces' weights
+ alpha : float
+ Alpha parameter for the fgw distance
+ fixed_structure : bool
+ Whether to fix the structure of the barycenter during the updates
+ fixed_features : bool
+ Whether to fix the feature of the barycenter during the updates
+ loss_fun : str
+ Loss function used for the solver either 'square_loss' or 'kl_loss'
+ symmetric : bool, optional
+ Either structures are to be assumed symmetric or not. Default value is True.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
+ max_iter : int, optional
+ Max number of iterations
+ tol : float, optional
+ Stop threshold on relative error (>0)
+ verbose : bool, optional
+ Print information along iterations.
+ log : bool, optional
+ Record log if True.
+ init_C : array-like, shape (N,N), optional
+ Initialization for the barycenters' structure matrix. If not set
+ a random init is used.
+ init_X : array-like, shape (N,d), optional
+ Initialization for the barycenters' features. If not set a
+ random init is used.
+ random_state : int or RandomState instance, optional
+ Fix the seed for reproducibility
+
+ Returns
+ -------
+ X : array-like, shape (`N`, `d`)
+ Barycenters' features
+ C : array-like, shape (`N`, `N`)
+ Barycenters' structure matrix
+ log : dict
+ Only returned when log=True. It contains the keys:
+
+ - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
+ - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
+
+
+ .. _references-fgw-barycenters:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ Cs = list_to_array(*Cs)
+ ps = list_to_array(*ps)
+ Ys = list_to_array(*Ys)
+ p = list_to_array(p)
+ nx = get_backend(*Cs, *Ys, *ps)
+
+ S = len(Cs)
+ d = Ys[0].shape[1] # dimension on the node features
+ if p is None:
+ p = nx.ones(N, type_as=Cs[0]) / N
+
+ if fixed_structure:
+ if init_C is None:
+ raise UndefinedParameter('If C is fixed it must be initialized')
+ else:
+ C = init_C
+ else:
+ if init_C is None:
+ generator = check_random_state(random_state)
+ xalea = generator.randn(N, 2)
+ C = dist(xalea, xalea)
+ C = nx.from_numpy(C, type_as=ps[0])
+ else:
+ C = init_C
+
+ if fixed_features:
+ if init_X is None:
+ raise UndefinedParameter('If X is fixed it must be initialized')
+ else:
+ X = init_X
+ else:
+ if init_X is None:
+ X = nx.zeros((N, d), type_as=ps[0])
+ else:
+ X = init_X
+
+ T = [nx.outer(p, q) for q in ps]
+
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
+
+ if loss_fun == 'kl_loss':
+ armijo = True
+
+ cpt = 0
+ err_feature = 1
+ err_structure = 1
+
+ if log:
+ log_ = {}
+ log_['err_feature'] = []
+ log_['err_structure'] = []
+ log_['Ts_iter'] = []
+
+ while ((err_feature > tol or err_structure > tol) and cpt < max_iter):
+ Cprev = C
+ Xprev = X
+
+ if not fixed_features:
+ Ys_temp = [y.T for y in Ys]
+ X = update_feature_matrix(lambdas, Ys_temp, T, p).T
+
+ Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
+
+ if not fixed_structure:
+ if loss_fun == 'square_loss':
+ T_temp = [t.T for t in T]
+ C = update_structure_matrix(p, lambdas, T_temp, Cs)
+
+ T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
+ max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
+
+ # T is N,ns
+ err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
+ err_structure = nx.norm(C - Cprev)
+ if log:
+ log_['err_feature'].append(err_feature)
+ log_['err_structure'].append(err_structure)
+ log_['Ts_iter'].append(T)
+
+ if verbose:
+ if cpt % 200 == 0:
+ print('{:5s}|{:12s}'.format(
+ 'It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err_structure))
+ print('{:5d}|{:8e}|'.format(cpt, err_feature))
+
+ cpt += 1
+
+ if log:
+ log_['T'] = T # from target to Ys
+ log_['p'] = p
+ log_['Ms'] = Ms
+
+ if log:
+ return X, C, log_
+ else:
+ return X, C
+
+
+def update_structure_matrix(p, lambdas, T, Cs):
+ r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings.
+
+ It is calculated at each iteration
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ Masses in the targeted barycenter.
+ lambdas : list of float
+ List of the `S` spaces' weights.
+ T : list of S array-like of shape (ns, N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape (ns, ns)
+ Metric cost matrices.
+
+ Returns
+ -------
+ C : array-like, shape (`nt`, `nt`)
+ Updated :math:`\mathbf{C}` matrix.
+ """
+ p = list_to_array(p)
+ T = list_to_array(*T)
+ Cs = list_to_array(*Cs)
+ nx = get_backend(*Cs, *T, p)
+
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+ return tmpsum / ppt
+
+
+def update_feature_matrix(lambdas, Ys, Ts, p):
+ r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.
+
+
+ See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
+ in :ref:`[24] <references-update-feature-matrix>` calculated at each iteration
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ masses in the targeted barycenter
+ lambdas : list of float
+ List of the `S` spaces' weights
+ Ts : list of S array-like, shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
+ Ys : list of S array-like, shape (d,ns)
+ The features.
+
+ Returns
+ -------
+ X : array-like, shape (`d`, `N`)
+
+
+ .. _references-update-feature-matrix:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+ """
+ p = list_to_array(p)
+ Ts = list_to_array(*Ts)
+ Ys = list_to_array(*Ys)
+ nx = get_backend(*Ys, *Ts, p)
+
+ p = 1. / p
+ tmpsum = sum([
+ lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
+ for s in range(len(Ts))
+ ])
+ return tmpsum
diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py
new file mode 100644
index 0000000..638bb1c
--- /dev/null
+++ b/ot/gromov/_semirelaxed.py
@@ -0,0 +1,543 @@
+# -*- coding: utf-8 -*-
+"""
+Semi-relaxed Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers.
+"""
+
+# Author: Rémi Flamary <remi.flamary@unice.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+
+
+from ..utils import list_to_array, unif
+from ..optim import semirelaxed_cg, solve_1d_linesearch_quad
+from ..backend import get_backend
+
+from ._utils import init_matrix_semirelaxed, gwloss, gwggrad
+
+
+def semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Returns the semi-relaxed gromov-wasserstein divergence transport from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{srGW} = \mathop{\arg \min}_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+
+ - `L`: loss function to account for the misfit between the similarity matrices
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ T : array-like, shape (`ns`, `nt`)
+ Coupling between the two spaces that minimizes:
+
+ :math:`\sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}`
+ log : dict
+ Convergence information and loss.
+
+ References
+ ----------
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ if loss_fun == 'kl_loss':
+ raise NotImplementedError()
+ p = list_to_array(p)
+ if G0 is None:
+ nx = get_backend(p, C1, C2)
+ else:
+ nx = get_backend(p, C1, C2, G0)
+
+ if symmetric is None:
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
+ if G0 is None:
+ q = unif(C2.shape[0], type_as=p)
+ G0 = nx.outer(p, q)
+ else:
+ q = nx.sum(G0, 0)
+ # Check first marginal of G0
+ np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
+
+ constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
+
+ ones_p = nx.ones(p.shape[0], type_as=p)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return gwloss(constC + marginal_product, hC1, hC2, G, nx)
+
+ if symmetric:
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return gwggrad(constC + marginal_product, hC1, hC2, G, nx)
+ else:
+ constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t))
+ marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2))
+ return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx))
+
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, M=0., reg=1., nx=nx, **kwargs)
+
+ if log:
+ res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+ log['srgw_dist'] = log['loss'][-1]
+ return res, log
+ else:
+ return semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+
+
+def semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', symmetric=None, log=False, G0=None,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Returns the semi-relaxed gromov-wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}`
+
+ The function solves the following optimization problem:
+
+ .. math::
+ srGW = \min_\mathbf{T} \quad \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma} &\geq 0
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{p}`: distribution in the source space
+ - `L`: loss function to account for the misfit between the similarity
+ matrices
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2) but not yet for the weights p.
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ srgw : float
+ Semi-relaxed Gromov-Wasserstein divergence
+ log : dict
+ convergence information and Coupling matrix
+
+ References
+ ----------
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ nx = get_backend(p, C1, C2)
+
+ T, log_srgw = semirelaxed_gromov_wasserstein(
+ C1, C2, p, loss_fun, symmetric, log=True, G0=G0,
+ max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
+
+ q = nx.sum(T, 0)
+ log_srgw['T'] = T
+ srgw = log_srgw['srgw_dist']
+
+ if loss_fun == 'square_loss':
+ gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
+ gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
+ srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))
+
+ if log:
+ return srgw, log_srgw
+ else:
+ return srgw
+
+
+def semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Computes the semi-relaxed FGW transport between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`)
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F +
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` source weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein>`
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space
+ p : array-like, shape (ns,)
+ Distribution in the source space
+ loss_fun : str
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ log : bool, optional
+ record log if True
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ parameters can be directly passed to the ot.optim.cg solver
+
+ Returns
+ -------
+ gamma : array-like, shape (`ns`, `nt`)
+ Optimal transportation matrix for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-semirelaxed-fused-gromov-wasserstein:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas "Optimal Transport for structured data with
+ application on graphs", International Conference on Machine Learning
+ (ICML). 2019.
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ if loss_fun == 'kl_loss':
+ raise NotImplementedError()
+
+ p = list_to_array(p)
+ if G0 is None:
+ nx = get_backend(p, C1, C2, M)
+ else:
+ nx = get_backend(p, C1, C2, M, G0)
+
+ if symmetric is None:
+ symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
+
+ if G0 is None:
+ q = unif(C2.shape[0], type_as=p)
+ G0 = nx.outer(p, q)
+ else:
+ q = nx.sum(G0, 0)
+ # Check marginals of G0
+ np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
+
+ constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
+
+ ones_p = nx.ones(p.shape[0], type_as=p)
+
+ def f(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return gwloss(constC + marginal_product, hC1, hC2, G, nx)
+
+ if symmetric:
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t))
+ return gwggrad(constC + marginal_product, hC1, hC2, G, nx)
+ else:
+ constCt, hC1t, hC2t, fC2 = init_matrix_semirelaxed(C1.T, C2.T, p, loss_fun, nx)
+
+ def df(G):
+ qG = nx.sum(G, 0)
+ marginal_product_1 = nx.outer(ones_p, nx.dot(qG, fC2t))
+ marginal_product_2 = nx.outer(ones_p, nx.dot(qG, fC2))
+ return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx))
+
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return solve_semirelaxed_gromov_linesearch(
+ G, deltaG, cost_G, C1, C2, ones_p, M=(1 - alpha) * M, reg=alpha, nx=nx, **kwargs)
+
+ if log:
+ res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+ log['srfgw_dist'] = log['loss'][-1]
+ return res, log
+ else:
+ return semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)
+
+
+def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False,
+ max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs):
+ r"""
+ Computes the semi-relaxed FGW divergence between two graphs (see :ref:`[48] <references-semirelaxed-fused-gromov-wasserstein2>`)
+
+ .. math::
+ \min_\gamma \quad (1 - \alpha) \langle \gamma, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l}
+ L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} \mathbf{T}_{k,l}
+
+ s.t. \ \mathbf{\gamma} \mathbf{1} &= \mathbf{p}
+
+ \mathbf{\gamma} &\geq 0
+
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`\mathbf{p}` source weights (sum to 1)
+ - `L` is a loss function to account for the misfit between the similarity matrices
+
+ The algorithm used for solving the problem is conditional gradient as
+ discussed in :ref:`[48] <semirelaxed-fused-gromov-wasserstein2>`
+
+ Note that when using backends, this loss function is differentiable wrt the
+ matrices (C1, C2) but not yet for the weights p.
+
+ .. note:: This function is backend-compatible and will work on arrays
+ from all compatible backends. However all the steps in the conditional
+ gradient are not differentiable.
+
+ Parameters
+ ----------
+ M : array-like, shape (ns, nt)
+ Metric cost matrix between features across domains
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix representative of the structure in the source space.
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix representative of the structure in the target space.
+ p : array-like, shape (ns,)
+ Distribution in the source space.
+ loss_fun : str, optional
+ loss function used for the solver either 'square_loss' or 'kl_loss'.
+ 'kl_loss' is not implemented yet and will raise an error.
+ symmetric : bool, optional
+ Either C1 and C2 are to be assumed symmetric or not.
+ If let to its default None value, a symmetry test will be conducted.
+ Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric).
+ alpha : float, optional
+ Trade-off parameter (0 < alpha < 1)
+ G0: array-like, shape (ns,nt), optional
+ If None the initial transport plan of the solver is pq^T.
+ Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
+ log : bool, optional
+ Record log if True.
+ max_iter : int, optional
+ Max number of iterations
+ tol_rel : float, optional
+ Stop threshold on relative error (>0)
+ tol_abs : float, optional
+ Stop threshold on absolute error (>0)
+ **kwargs : dict
+ Parameters can be directly passed to the ot.optim.cg solver.
+
+ Returns
+ -------
+ srfgw-divergence : float
+ Semi-relaxed Fused gromov wasserstein divergence for the given parameters.
+ log : dict
+ Log dictionary return only if log==True in parameters.
+
+
+ .. _references-semirelaxed-fused-gromov-wasserstein2:
+ References
+ ----------
+ .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+ and Courty Nicolas "Optimal Transport for structured data with
+ application on graphs", International Conference on Machine Learning
+ (ICML). 2019.
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ nx = get_backend(p, C1, C2, M)
+
+ T, log_fgw = semirelaxed_fused_gromov_wasserstein(
+ M, C1, C2, p, loss_fun, symmetric, alpha, G0, log=True,
+ max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs)
+ q = nx.sum(T, 0)
+ srfgw_dist = log_fgw['srfgw_dist']
+ log_fgw['T'] = T
+
+ if loss_fun == 'square_loss':
+ gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
+ gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
+ srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
+ (alpha * gC1, alpha * gC2, (1 - alpha) * T))
+
+ if log:
+ return srfgw_dist, log_fgw
+ else:
+ return srfgw_dist
+
+
+def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
+ M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs):
+ """
+ Solve the linesearch in the FW iterations
+
+ Parameters
+ ----------
+
+ G : array-like, shape(ns,nt)
+ The transport map at a given iteration of the FW
+ deltaG : array-like (ns,nt)
+ Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
+ cost_G : float
+ Value of the cost at `G`
+ C1 : array-like (ns,ns)
+ Structure matrix in the source domain.
+ C2 : array-like (nt,nt)
+ Structure matrix in the target domain.
+ ones_p: array-like (ns,1)
+ Array of ones of size ns
+ M : array-like (ns,nt)
+ Cost matrix between the features.
+ reg : float
+ Regularization parameter.
+ alpha_min : float, optional
+ Minimum value for alpha
+ alpha_max : float, optional
+ Maximum value for alpha
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ alpha : float
+ The optimal step size of the FW
+ fc : int
+ nb of function call. Useless here
+ cost_G : float
+ The value of the cost for the next iteration
+
+ References
+ ----------
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2021.
+ """
+ if nx is None:
+ G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)
+
+ if isinstance(M, int) or isinstance(M, float):
+ nx = get_backend(G, deltaG, C1, C2)
+ else:
+ nx = get_backend(G, deltaG, C1, C2, M)
+
+ qG, qdeltaG = nx.sum(G, 0), nx.sum(deltaG, 0)
+ dot = nx.dot(nx.dot(C1, deltaG), C2.T)
+ C2t_square = C2.T ** 2
+ dot_qG = nx.dot(nx.outer(ones_p, qG), C2t_square)
+ dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), C2t_square)
+ a = reg * nx.sum((dot_qdeltaG - 2 * dot) * deltaG)
+ b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - 2 * dot) * G) + nx.sum((dot_qG - 2 * nx.dot(nx.dot(C1, G), C2.T)) * deltaG))
+ alpha = solve_1d_linesearch_quad(a, b)
+ if alpha_min is not None or alpha_max is not None:
+ alpha = np.clip(alpha, alpha_min, alpha_max)
+
+ # the new cost can be deduced from the line search quadratic function
+ cost_G = cost_G + a * (alpha ** 2) + b * alpha
+
+ return alpha, 1, cost_G
diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py
new file mode 100644
index 0000000..e842250
--- /dev/null
+++ b/ot/gromov/_utils.py
@@ -0,0 +1,413 @@
+# -*- coding: utf-8 -*-
+"""
+Gromov-Wasserstein and Fused-Gromov-Wasserstein utils.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+# Rémi Flamary <remi.flamary@unice.fr>
+# Titouan Vayer <titouan.vayer@irisa.fr>
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
+#
+# License: MIT License
+
+
+from ..utils import list_to_array
+from ..backend import get_backend
+
+
+def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None):
+ r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation
+
+ Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the
+ selected loss function as the loss function of Gromow-Wasserstein discrepancy.
+
+ The matrices are computed as described in Proposition 1 in :ref:`[12] <references-init-matrix>`
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{T}`: A coupling between those two spaces
+
+ The square-loss function :math:`L(a, b) = |a - b|^2` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a^2
+
+ f_2(b) &= b^2
+
+ h_1(a) &= a
+
+ h_2(b) &= 2b
+
+ The kl-loss function :math:`L(a, b) = a \log\left(\frac{a}{b}\right) - a + b` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a \log(a) - a
+
+ f_2(b) &= b
+
+ h_1(a) &= a
+
+ h_2(b) &= \log(b)
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ p : array-like, shape (ns,)
+ Probability distribution in the source space
+ q : array-like, shape (nt,)
+ Probability distribution in the target space
+ loss_fun : str, optional
+ Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss')
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+
+
+ .. _references-init-matrix:
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ if nx is None:
+ C1, C2, p, q = list_to_array(C1, C2, p, q)
+ nx = get_backend(C1, C2, p, q)
+
+ if loss_fun == 'square_loss':
+ def f1(a):
+ return (a**2)
+
+ def f2(b):
+ return (b**2)
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return 2 * b
+ elif loss_fun == 'kl_loss':
+ def f1(a):
+ return a * nx.log(a + 1e-15) - a
+
+ def f2(b):
+ return b
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return nx.log(b + 1e-15)
+
+ constC1 = nx.dot(
+ nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
+ nx.ones((1, len(q)), type_as=q)
+ )
+ constC2 = nx.dot(
+ nx.ones((len(p), 1), type_as=p),
+ nx.dot(nx.reshape(q, (1, -1)), f2(C2).T)
+ )
+ constC = constC1 + constC2
+ hC1 = h1(C1)
+ hC2 = h2(C2)
+
+ return constC, hC1, hC2
+
+
+def tensor_product(constC, hC1, hC2, T, nx=None):
+ r"""Return the tensor for Gromov-Wasserstein fast computation
+
+ The tensor is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-tensor-product>`
+
+ Parameters
+ ----------
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ tens : array-like, shape (`ns`, `nt`)
+ :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` tensor-matrix multiplication result
+
+
+ .. _references-tensor-product:
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ if nx is None:
+ constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T)
+ nx = get_backend(constC, hC1, hC2, T)
+
+ A = - nx.dot(
+ nx.dot(hC1, T), hC2.T
+ )
+ tens = constC + A
+ # tens -= tens.min()
+ return tens
+
+
+def gwloss(constC, hC1, hC2, T, nx=None):
+ r"""Return the Loss for Gromov-Wasserstein
+
+ The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] <references-gwloss>`
+
+ Parameters
+ ----------
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ T : array-like, shape (ns, nt)
+ Current value of transport matrix :math:`\mathbf{T}`
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ loss : float
+ Gromov Wasserstein loss
+
+
+ .. _references-gwloss:
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+
+ tens = tensor_product(constC, hC1, hC2, T, nx)
+ if nx is None:
+ tens, T = list_to_array(tens, T)
+ nx = get_backend(tens, T)
+
+ return nx.sum(tens * T)
+
+
+def gwggrad(constC, hC1, hC2, T, nx=None):
+ r"""Return the gradient for Gromov-Wasserstein
+
+ The gradient is computed as described in Proposition 2 in :ref:`[12] <references-gwggrad>`
+
+ Parameters
+ ----------
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6)
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ T : array-like, shape (ns, nt)
+ Current value of transport matrix :math:`\mathbf{T}`
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ grad : array-like, shape (`ns`, `nt`)
+ Gromov Wasserstein gradient
+
+
+ .. _references-gwggrad:
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ """
+ return 2 * tensor_product(constC, hC1, hC2,
+ T, nx) # [12] Prop. 2 misses a 2 factor
+
+
+def update_square_loss(p, lambdas, T, Cs):
+ r"""
+ Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
+ couplings calculated at each iteration
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ Masses in the targeted barycenter.
+ lambdas : list of float
+ List of the `S` spaces' weights.
+ T : list of S array-like of shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape(ns,ns)
+ Metric cost matrices.
+
+ Returns
+ ----------
+ C : array-like, shape (`nt`, `nt`)
+ Updated :math:`\mathbf{C}` matrix.
+ """
+ T = list_to_array(*T)
+ Cs = list_to_array(*Cs)
+ p = list_to_array(p)
+ nx = get_backend(p, *T, *Cs)
+
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+
+ return tmpsum / ppt
+
+
+def update_kl_loss(p, lambdas, T, Cs):
+ r"""
+ Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
+
+
+ Parameters
+ ----------
+ p : array-like, shape (N,)
+ Weights in the targeted barycenter.
+ lambdas : list of float
+ List of the `S` spaces' weights
+ T : list of S array-like of shape (ns,N)
+ The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
+ Cs : list of S array-like, shape(ns,ns)
+ Metric cost matrices.
+
+ Returns
+ ----------
+ C : array-like, shape (`ns`, `ns`)
+ updated :math:`\mathbf{C}` matrix
+ """
+ Cs = list_to_array(*Cs)
+ T = list_to_array(*T)
+ p = list_to_array(p)
+ nx = get_backend(p, *T, *Cs)
+
+ tmpsum = sum([
+ lambdas[s] * nx.dot(
+ nx.dot(T[s].T, Cs[s]),
+ T[s]
+ ) for s in range(len(T))
+ ])
+ ppt = nx.outer(p, p)
+
+ return nx.exp(tmpsum / ppt)
+
+
+def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None):
+ r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation
+
+ Returns the value of :math:`\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T}` with the
+ selected loss function as the loss function of semi-relaxed Gromow-Wasserstein discrepancy.
+
+ The matrices are computed as described in Proposition 1 in :ref:`[12] <references-init-matrix>`
+ and adapted to the semi-relaxed problem where the second marginal is not a constant anymore.
+
+ Where :
+
+ - :math:`\mathbf{C_1}`: Metric cost matrix in the source space
+ - :math:`\mathbf{C_2}`: Metric cost matrix in the target space
+ - :math:`\mathbf{T}`: A coupling between those two spaces
+
+ The square-loss function :math:`L(a, b) = |a - b|^2` is read as :
+
+ .. math::
+
+ L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)
+
+ \mathrm{with} \ f_1(a) &= a^2
+
+ f_2(b) &= b^2
+
+ h_1(a) &= a
+
+ h_2(b) &= 2b
+
+ Parameters
+ ----------
+ C1 : array-like, shape (ns, ns)
+ Metric cost matrix in the source space
+ C2 : array-like, shape (nt, nt)
+ Metric cost matrix in the target space
+ T : array-like, shape (ns, nt)
+ Coupling between source and target spaces
+ p : array-like, shape (ns,)
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
+ Returns
+ -------
+ constC : array-like, shape (ns, nt)
+ Constant :math:`\mathbf{C}` matrix in Eq. (6) adapted to srGW
+ hC1 : array-like, shape (ns, ns)
+ :math:`\mathbf{h1}(\mathbf{C1})` matrix in Eq. (6)
+ hC2 : array-like, shape (nt, nt)
+ :math:`\mathbf{h2}(\mathbf{C2})` matrix in Eq. (6)
+ fC2t: array-like, shape (nt, nt)
+ :math:`\mathbf{f2}(\mathbf{C2})^\top` matrix in Eq. (6)
+
+
+ .. _references-init-matrix:
+ References
+ ----------
+ .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+ "Gromov-Wasserstein averaging of kernel and distance matrices."
+ International Conference on Machine Learning (ICML). 2016.
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2022.
+ """
+ if nx is None:
+ C1, C2, p = list_to_array(C1, C2, p)
+ nx = get_backend(C1, C2, p)
+
+ if loss_fun == 'square_loss':
+ def f1(a):
+ return (a**2)
+
+ def f2(b):
+ return (b**2)
+
+ def h1(a):
+ return a
+
+ def h2(b):
+ return 2 * b
+
+ constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
+ nx.ones((1, C2.shape[0]), type_as=p))
+
+ hC1 = h1(C1)
+ hC2 = h2(C2)
+ fC2t = f2(C2).T
+ return constC, hC1, hC2, fC2t
diff --git a/ot/helpers/pre_build_helpers.py b/ot/helpers/pre_build_helpers.py
index 93ecd6a..2930036 100644
--- a/ot/helpers/pre_build_helpers.py
+++ b/ot/helpers/pre_build_helpers.py
@@ -4,34 +4,14 @@ import os
import sys
import glob
import tempfile
-import setuptools # noqa
import subprocess
-from distutils.dist import Distribution
-from distutils.sysconfig import customize_compiler
-from numpy.distutils.ccompiler import new_compiler
-from numpy.distutils.command.config_compiler import config_cc
+from setuptools.command.build_ext import customize_compiler, new_compiler
def _get_compiler():
- """Get a compiler equivalent to the one that will be used to build POT
- Handles compiler specified as follows:
- - python setup.py build_ext --compiler=<compiler>
- - CC=<compiler> python setup.py build_ext
- """
- dist = Distribution({'script_name': os.path.basename(sys.argv[0]),
- 'script_args': sys.argv[1:],
- 'cmdclass': {'config_cc': config_cc}})
-
- cmd_opts = dist.command_options.get('build_ext')
- if cmd_opts is not None and 'compiler' in cmd_opts:
- compiler = cmd_opts['compiler'][1]
- else:
- compiler = None
-
- ccompiler = new_compiler(compiler=compiler)
+ ccompiler = new_compiler()
customize_compiler(ccompiler)
-
return ccompiler
diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h
index 8a1f9ac..b56f060 100644
--- a/ot/lp/EMD.h
+++ b/ot/lp/EMD.h
@@ -18,6 +18,7 @@
#include <iostream>
#include <vector>
+#include <cstdint>
typedef unsigned int node_id_type;
@@ -28,8 +29,8 @@ enum ProblemType {
MAX_ITER_REACHED
};
-int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
-int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);
+int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
+int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);
diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp
index 2bdc172..4aa5a6e 100644
--- a/ot/lp/EMD_wrapper.cpp
+++ b/ot/lp/EMD_wrapper.cpp
@@ -20,11 +20,11 @@
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
- double* alpha, double* beta, double *cost, int maxIter) {
+ double* alpha, double* beta, double *cost, uint64_t maxIter) {
// beware M and C are stored in row major C style!!!
using namespace lemon;
- int n, m, cur;
+ uint64_t n, m, cur;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
@@ -51,15 +51,15 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Define the graph
- std::vector<int> indI(n), indJ(m);
+ std::vector<uint64_t> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter);
// Set supply and demand, don't account for 0 values (faster)
cur=0;
- for (int i=0; i<n1; i++) {
+ for (uint64_t i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
@@ -70,7 +70,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Demand is actually negative supply...
cur=0;
- for (int i=0; i<n2; i++) {
+ for (uint64_t i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
@@ -79,12 +79,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
}
- net.supplyMap(&weights1[0], n, &weights2[0], m);
+ net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);
// Set the cost of each edge
int64_t idarc = 0;
- for (int i=0; i<n; i++) {
- for (int j=0; j<m; j++) {
+ for (uint64_t i=0; i<n; i++) {
+ for (uint64_t j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(idarc), val);
++idarc;
@@ -95,7 +95,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm
int ret=net.run();
- int i, j;
+ uint64_t i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
@@ -122,11 +122,11 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
- double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
+ double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) {
// beware M and C are stored in row major C style!!!
using namespace lemon_omp;
- int n, m, cur;
+ uint64_t n, m, cur;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
@@ -153,15 +153,15 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Define the graph
- std::vector<int> indI(n), indJ(m);
+ std::vector<uint64_t> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
- NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
+ NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter, numThreads);
// Set supply and demand, don't account for 0 values (faster)
cur=0;
- for (int i=0; i<n1; i++) {
+ for (uint64_t i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
@@ -172,7 +172,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Demand is actually negative supply...
cur=0;
- for (int i=0; i<n2; i++) {
+ for (uint64_t i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
@@ -181,12 +181,12 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
}
- net.supplyMap(&weights1[0], n, &weights2[0], m);
+ net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);
// Set the cost of each edge
int64_t idarc = 0;
- for (int i=0; i<n; i++) {
- for (int j=0; j<m; j++) {
+ for (uint64_t i=0; i<n; i++) {
+ for (uint64_t j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(idarc), val);
++idarc;
@@ -197,7 +197,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm
int ret=net.run();
- int i, j;
+ uint64_t i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 390c32d..2ff02ab 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Solvers for the original linear program OT problem
+Solvers for the original linear program OT problem.
"""
@@ -20,16 +20,17 @@ from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
-from .solver_1d import emd_1d, emd2_1d, wasserstein_1d
+from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d,
+ binary_search_circle, wasserstein_circle,
+ semidiscrete_wasserstein2_unif_circle)
from ..utils import dist, list_to_array
from ..utils import parmap
from ..backend import get_backend
-
-
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
- 'emd_1d', 'emd2_1d', 'wasserstein_1d']
+ 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter',
+ 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle']
def check_number_threads(numThreads):
@@ -232,6 +233,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
If this behaviour is unwanted, please make sure to provide a
floating point input.
+ .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
Parameters
@@ -391,6 +394,8 @@ def emd2(a, b, M, processes=1,
If this behaviour is unwanted, please make sure to provide a
floating point input.
+ .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
Parameters
@@ -483,6 +488,11 @@ def emd2(a, b, M, processes=1,
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
+ # ensure that same mass
+ np.testing.assert_almost_equal(a.sum(0),
+ b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum')
+ b = b * a.sum(0) / b.sum(0,keepdims=True)
+
asel = a != 0
numThreads = check_number_threads(numThreads)
@@ -517,8 +527,8 @@ def emd2(a, b, M, processes=1,
log['warning'] = result_code_string
log['result_code'] = result_code
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
- (a0, b0, M0), (log['u'] - nx.mean(log['u']),
- log['v'] - nx.mean(log['v']), G))
+ (a0, b0, M0), (log['u'] - nx.mean(log['u']),
+ log['v'] - nx.mean(log['v']), G))
return [cost, log]
else:
def f(b):
@@ -572,18 +582,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
where :
- :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
- - the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i`
- - the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations
+ - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
+ - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
- :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
- This problem is considered in :ref:`[1] <references-free-support-barycenter>` (Algorithm 2).
+ This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
There are two differences with the following codes:
- we do not optimize over the weights
- we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
- :ref:`[1] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+ :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
implementation of the fixed-point algorithm of
- :ref:`[2] <references-free-support-barycenter>` proposed in the continuous setting.
+ :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
Parameters
----------
@@ -623,13 +633,13 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
.. _references-free-support-barycenter:
References
----------
- .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+ .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
- .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+ .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
"""
- nx = get_backend(*measures_locations,*measures_weights,X_init)
+ nx = get_backend(*measures_locations, *measures_weights, X_init)
iter_count = 0
@@ -637,9 +647,9 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
- b = nx.ones((k,),type_as=X_init) / k
+ b = nx.ones((k,), type_as=X_init) / k
if weights is None:
- weights = nx.ones((N,),type_as=X_init) / N
+ weights = nx.ones((N,), type_as=X_init) / N
X = X_init
@@ -650,15 +660,14 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
while (displacement_square_norm > stopThr and iter_count < numItermax):
- T_sum = nx.zeros((k, d),type_as=X_init)
-
+ T_sum = nx.zeros((k, d), type_as=X_init)
- for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
+ for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
- T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i)
+ T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i)
- displacement_square_norm = nx.sum((T_sum - X)**2)
+ displacement_square_norm = nx.sum((T_sum - X) ** 2)
if log:
displacement_square_norms.append(displacement_square_norm)
@@ -675,3 +684,111 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
else:
return X
+
+def generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary, Y_init=None, b=None, weights=None,
+ numItermax=100, stopThr=1e-7, verbose=False, log=None, numThreads=1, eps=0):
+ r"""
+ Solves the free support generalised Wasserstein barycenter problem: finding a barycenter (a discrete measure with
+ a fixed amount of points of uniform weights) whose respective projections fit the input measures.
+ More formally:
+
+ .. math::
+ \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma)
+
+ where :
+
+ - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d`
+ - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter
+ - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}`
+ - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex)
+ - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations
+ - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex)
+ - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}`
+
+ As show by :ref:`[42] <references-generalized-free-support-barycenter>`,
+ this problem can be re-written as a Wasserstein Barycenter problem,
+ which we solve using the free support method :ref:`[20] <references-generalized-free-support-barycenter>`
+ (Algorithm 2).
+
+ Parameters
+ ----------
+ X_list : list of p (k_i,d_i) array-like
+ Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space
+ (:math:`k_i` can be different for each element of the list)
+ a_list : list of p (k_i,) array-like
+ Measure weights: each element is a vector (k_i) on the simplex
+ P_list : list of p (d_i,d) array-like
+ Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}`
+ n_samples_bary : int
+ Number of barycenter points
+ Y_init : (n_samples_bary,d) array-like
+ Initialization of the support locations (on `k` atoms) of the barycenter
+ b : (n_samples_bary,) array-like
+ Initialization of the weights of the barycenter measure (on the simplex)
+ weights : (p,) array-like
+ Initialization of the coefficients of the barycenter (on the simplex)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+ If compiled with OpenMP, chooses the number of threads to parallelize.
+ "max" selects the highest number possible.
+ eps: Stability coefficient for the change of variable matrix inversion
+ If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix
+ inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense)
+
+
+ Returns
+ -------
+ Y : (n_samples_bary,d) array-like
+ Support locations (on n_samples_bary atoms) of the barycenter
+
+
+ .. _references-generalized-free-support-barycenter:
+ References
+ ----------
+ .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+ .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021.
+
+ """
+ nx = get_backend(*X_list, *a_list, *P_list)
+ d = P_list[0].shape[1]
+ p = len(X_list)
+
+ if weights is None:
+ weights = nx.ones(p, type_as=X_list[0]) / p
+
+ # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB)
+ A = eps * nx.eye(d, type_as=X_list[0]) # if eps nonzero: will force the invertibility of A
+ for (P_i, lambda_i) in zip(P_list, weights):
+ A = A + lambda_i * P_i.T @ P_i
+ B = nx.inv(nx.sqrtm(A))
+
+ Z_list = [x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list)] # change of variables -> (WB) problem on Z
+
+ if Y_init is None:
+ Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0])
+
+ if b is None:
+ b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary # not optimised
+
+ out = free_support_barycenter(Z_list, a_list, Y_init, b, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, numThreads=numThreads)
+
+ if log: # unpack
+ Y, log_dict = out
+ else:
+ Y = out
+ log_dict = None
+ Y = Y @ B.T # return to the Generalised WB formulation
+
+ if log:
+ return Y, log_dict
+ else:
+ return Y
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index fbf3c0e..361ad0f 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -80,7 +80,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
if weights is None:
weights = np.ones(A.shape[1]) / A.shape[1]
else:
- assert(len(weights) == A.shape[1])
+ assert len(weights) == A.shape[1]
n_distributions = A.shape[1]
n = A.shape[0]
diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx
index 42e08f4..e5cec89 100644
--- a/ot/lp/emd_wrap.pyx
+++ b/ot/lp/emd_wrap.pyx
@@ -14,13 +14,14 @@ from ..utils import dist
cimport cython
cimport libc.math as math
+from libc.stdint cimport uint64_t
import warnings
cdef extern from "EMD.h":
- int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
- int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
+ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
+ int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
@@ -39,7 +40,7 @@ def check_result(result_code):
@cython.boundscheck(False)
@cython.wraparound(False)
-def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
+def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix
@@ -75,7 +76,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
target histogram
M : (ns,nt) numpy.ndarray, float64
loss matrix
- max_iter : int
+ max_iter : uint64_t
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h
index 3b46b9b..9612a8a 100644
--- a/ot/lp/network_simplex_simple.h
+++ b/ot/lp/network_simplex_simple.h
@@ -233,7 +233,7 @@ namespace lemon {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -242,7 +242,7 @@ namespace lemon {
{
// Reset data structures
reset();
- max_iter=maxiters;
+ max_iter = maxiters;
}
/// The type of the flow amounts, capacity bounds and supply values
@@ -293,7 +293,7 @@ namespace lemon {
private:
- size_t max_iter;
+ uint64_t max_iter;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
typedef std::vector<int> IntVector;
@@ -1427,14 +1427,12 @@ namespace lemon {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
- size_t iter_number=0;
+ uint64_t iter_number = 0;
//pivot.setDantzig(true);
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){
- char errMess[1000];
- sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
- std::cerr << errMess;
+ // max iterations hit
retVal = MAX_ITER_REACHED;
break;
}
diff --git a/ot/lp/network_simplex_simple_omp.h b/ot/lp/network_simplex_simple_omp.h
index 87e4c05..890b7ab 100644
--- a/ot/lp/network_simplex_simple_omp.h
+++ b/ot/lp/network_simplex_simple_omp.h
@@ -41,8 +41,8 @@
#undef EPSILON
#undef _EPSILON
#undef MAX_DEBUG_ITER
-#define EPSILON std::numeric_limits<Cost>::epsilon()*10
-#define _EPSILON 1e-8
+#define EPSILON std::numeric_limits<Cost>::epsilon()
+#define _EPSILON 1e-14
#define MAX_DEBUG_ITER 100000
/// \ingroup min_cost_flow_algs
@@ -67,7 +67,7 @@
//#include "core.h"
//#include "lmath.h"
-#ifdef OMP
+#ifdef _OPENMP
#include <omp.h>
#endif
#include <cmath>
@@ -244,7 +244,7 @@ namespace lemon_omp {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
- NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) :
+ NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters = 0, int numThreads=-1) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
@@ -254,7 +254,7 @@ namespace lemon_omp {
// Reset data structures
reset();
max_iter = maxiters;
-#ifdef OMP
+#ifdef _OPENMP
if (max_threads < 0) {
max_threads = omp_get_max_threads();
}
@@ -317,7 +317,7 @@ namespace lemon_omp {
private:
- size_t max_iter;
+ uint64_t max_iter;
int num_threads;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
@@ -513,7 +513,7 @@ namespace lemon_omp {
int j;
#pragma omp parallel
{
-#ifdef OMP
+#ifdef _OPENMP
int t = omp_get_thread_num();
#else
int t = 0;
@@ -1563,7 +1563,7 @@ namespace lemon_omp {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;
- size_t iter_number = 0;
+ uint64_t iter_number = 0;
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) {
@@ -1610,9 +1610,7 @@ namespace lemon_omp {
} else {
- char errMess[1000];
- sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
- std::cerr << errMess;
+ // max iters
retVal = MAX_ITER_REACHED;
break;
}
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py
index 43763a9..bcfc920 100644
--- a/ot/lp/solver_1d.py
+++ b/ot/lp/solver_1d.py
@@ -53,7 +53,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
distributions
.. math:
- OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq
+ OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq
It is formally the p-Wasserstein distance raised to the power p.
We do so in a vectorized way by first building the individual quantile functions then integrating them.
@@ -129,7 +129,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ
diff_quantiles = nx.abs(u_quantiles - v_quantiles)
if p == 1:
- return nx.sum(delta * nx.abs(diff_quantiles), axis=0)
+ return nx.sum(delta * diff_quantiles, axis=0)
return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)
@@ -365,3 +365,628 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
log_emd = {'G': G}
return cost, log_emd
return cost
+
+
+def roll_cols(M, shifts):
+ r"""
+ Utils functions which allow to shift the order of each row of a 2d matrix
+
+ Parameters
+ ----------
+ M : (nr, nc) ndarray
+ Matrix to shift
+ shifts: int or (nr,) ndarray
+
+ Returns
+ -------
+ Shifted array
+
+ Examples
+ --------
+ >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]])
+ >>> roll_cols(M, 2)
+ array([[2, 3, 1],
+ [5, 6, 4],
+ [8, 9, 7]])
+ >>> roll_cols(M, np.array([[1],[2],[1]]))
+ array([[3, 1, 2],
+ [5, 6, 4],
+ [9, 7, 8]])
+
+ References
+ ----------
+ https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch
+ """
+ nx = get_backend(M)
+
+ n_rows, n_cols = M.shape
+
+ arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1))
+ arange2 = (arange1 - shifts) % n_cols
+
+ return nx.take_along_axis(M, arange2, 1)
+
+
+def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2):
+ r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1])
+
+ Parameters
+ ----------
+ theta: array-like, shape (n_batch, n)
+ Cuts on the circle
+ u_values: array-like, shape (n_batch, n)
+ locations of the first empirical distribution
+ v_values: array-like, shape (n_batch, n)
+ locations of the second empirical distribution
+ u_cdf: array-like, shape (n_batch, n)
+ cdf of the first empirical distribution
+ v_cdf: array-like, shape (n_batch, n)
+ cdf of the second empirical distribution
+ p: float, optional = 2
+ Power p used for computing the Wasserstein distance
+
+ Returns
+ -------
+ dCp: array-like, shape (n_batch, 1)
+ The batched right derivative
+ dCm: array-like, shape (n_batch, 1)
+ The batched left derivative
+
+ References
+ ---------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
+
+ v_values = nx.copy(v_values)
+
+ n = u_values.shape[-1]
+ m_batch, m = v_values.shape
+
+ v_cdf_theta = v_cdf - (theta - nx.floor(theta))
+
+ mask_p = v_cdf_theta >= 0
+ mask_n = v_cdf_theta < 0
+
+ v_values[mask_n] += nx.floor(theta)[mask_n] + 1
+ v_values[mask_p] += nx.floor(theta)[mask_p]
+
+ if nx.any(mask_n) and nx.any(mask_p):
+ v_cdf_theta[mask_n] += 1
+
+ v_cdf_theta2 = nx.copy(v_cdf_theta)
+ v_cdf_theta2[mask_n] = np.inf
+ shift = (-nx.argmin(v_cdf_theta2, axis=-1))
+
+ v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
+ v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdf = u_cdf.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+
+ # quantiles of F_u evaluated in F_v^\theta
+ u_index = nx.searchsorted(u_cdf, v_cdf_theta)
+ u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1)
+
+ # Deal with 1
+ u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1)
+ u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdfm = u_cdfm.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+
+ u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right")
+ u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1)
+
+ dCp = nx.sum(nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p)
+ - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1)
+
+ dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p)
+ - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1)
+
+ return dCp.reshape(-1, 1), dCm.reshape(-1, 1)
+
+
+def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p):
+ r""" Computes the the cost (Equation (6.2) of [1])
+
+ Parameters
+ ----------
+ theta: array-like, shape (n_batch, n)
+ Cuts on the circle
+ u_values: array-like, shape (n_batch, n)
+ locations of the first empirical distribution
+ v_values: array-like, shape (n_batch, n)
+ locations of the second empirical distribution
+ u_cdf: array-like, shape (n_batch, n)
+ cdf of the first empirical distribution
+ v_cdf: array-like, shape (n_batch, n)
+ cdf of the second empirical distribution
+ p: float, optional = 2
+ Power p used for computing the Wasserstein distance
+
+ Returns
+ -------
+ ot_cost: array-like, shape (n_batch,)
+ OT cost evaluated at theta
+
+ References
+ ---------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
+
+ v_values = nx.copy(v_values)
+
+ m_batch, m = v_values.shape
+ n_batch, n = u_values.shape
+
+ v_cdf_theta = v_cdf - (theta - nx.floor(theta))
+
+ mask_p = v_cdf_theta >= 0
+ mask_n = v_cdf_theta < 0
+
+ v_values[mask_n] += nx.floor(theta)[mask_n] + 1
+ v_values[mask_p] += nx.floor(theta)[mask_p]
+
+ if nx.any(mask_n) and nx.any(mask_p):
+ v_cdf_theta[mask_n] += 1
+
+ # Put negative values at the end
+ v_cdf_theta2 = nx.copy(v_cdf_theta)
+ v_cdf_theta2[mask_n] = np.inf
+ shift = (-nx.argmin(v_cdf_theta2, axis=-1))
+
+ v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
+ v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+
+ # Compute absciss
+ cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1)
+ cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)])
+
+ delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1]
+
+ if nx.__name__ == 'torch':
+ # this is to ensure the best performance for torch searchsorted
+ # and avoid a warninng related to non-contiguous arrays
+ u_cdf = u_cdf.contiguous()
+ v_cdf_theta = v_cdf_theta.contiguous()
+ cdf_axis = cdf_axis.contiguous()
+
+ # Compute icdf
+ u_index = nx.searchsorted(u_cdf, cdf_axis)
+ u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1)
+
+ v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
+ v_index = nx.searchsorted(v_cdf_theta, cdf_axis)
+ v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1)
+
+ if p == 1:
+ ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1)
+ else:
+ ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1)
+
+ return ot_cost
+
+
+def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
+ Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True,
+ log=False):
+ r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44].
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
+ using e.g. the atan2 function.
+
+ .. math::
+ W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
+
+ where:
+
+ - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v`
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ p : float, optional (default=1)
+ Power p used for computing the Wasserstein distance
+ Lm : int, optional
+ Lower bound dC
+ Lp : int, optional
+ Upper bound dC
+ tm: float, optional
+ Lower bound theta
+ tp: float, optional
+ Upper bound theta
+ eps: float, optional
+ Stopping condition
+ require_sort: bool, optional
+ If True, sort the values.
+ log: bool, optional
+ If True, returns also the optimal theta
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+ log: dict, optional
+ log dictionary returned only if log==True in parameters
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> binary_search_circle(u.T, v.T, p=1)
+ array([0.1])
+
+ References
+ ----------
+ .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html
+ """
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+ if len(v_values.shape) == 1:
+ v_values = nx.reshape(v_values, (m, 1))
+
+ if u_values.shape[1] != v_values.shape[1]:
+ raise ValueError(
+ "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
+
+ u_values = u_values % 1
+ v_values = v_values % 1
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ u_cdf = nx.cumsum(u_weights, 0).T
+ v_cdf = nx.cumsum(v_weights, 0).T
+
+ u_values = u_values.T
+ v_values = v_values.T
+
+ L = max(Lm, Lp)
+
+ tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
+ tm = nx.tile(tm, (1, m))
+ tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
+ tp = nx.tile(tp, (1, m))
+ tc = (tm + tp) / 2
+
+ done = nx.zeros((u_values.shape[0], m))
+
+ cpt = 0
+ while nx.any(1 - done):
+ cpt += 1
+
+ dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
+ done = ((dCp * dCm) <= 0) * 1
+
+ mask = ((tp - tm) < eps / L) * (1 - done)
+
+ if nx.any(mask):
+ # can probably be improved by computing only relevant values
+ dCptp, dCmtp = derivative_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p)
+ dCptm, dCmtm = derivative_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p)
+ Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
+ Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
+
+ mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001)
+ tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0]
+ done[nx.prod(mask, axis=-1) > 0] = 1
+ elif nx.any(1 - done):
+ tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0]
+ tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0]
+ tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2
+
+ w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
+
+ if log:
+ return w, {"optimal_theta": tc[:, 0]}
+ return w
+
+
+def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True):
+ r"""Computes the 1-Wasserstein distance on the circle using the level median [45].
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates
+ using e.g. the atan2 function.
+ The function runs on backend but tensorflow is not supported.
+
+ .. math::
+ W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ require_sort: bool, optional
+ If True, sort the values.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> wasserstein1_circle(u.T, v.T)
+ array([0.1])
+
+ References
+ ----------
+ .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+ .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
+ """
+
+ if u_weights is not None and v_weights is not None:
+ nx = get_backend(u_values, v_values, u_weights, v_weights)
+ else:
+ nx = get_backend(u_values, v_values)
+
+ n = u_values.shape[0]
+ m = v_values.shape[0]
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+ if len(v_values.shape) == 1:
+ v_values = nx.reshape(v_values, (m, 1))
+
+ if u_values.shape[1] != v_values.shape[1]:
+ raise ValueError(
+ "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
+ v_values.shape[1]))
+
+ u_values = u_values % 1
+ v_values = v_values % 1
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+ if v_weights is None:
+ v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
+ elif v_weights.ndim != v_values.ndim:
+ v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
+
+ if require_sort:
+ u_sorter = nx.argsort(u_values, 0)
+ u_values = nx.take_along_axis(u_values, u_sorter, 0)
+
+ v_sorter = nx.argsort(v_values, 0)
+ v_values = nx.take_along_axis(v_values, v_sorter, 0)
+
+ u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
+ v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
+
+ # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
+ values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0)
+
+ cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0), 0)
+ cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0)
+
+ values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1)
+ delta = values_sorted[1:, ...] - values_sorted[:-1, ...]
+ weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0)
+
+ sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5
+ sum_weights[sum_weights < 0] = np.inf
+ inds = nx.argmin(sum_weights, axis=0)
+
+ levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0)
+
+ return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0)
+
+
+def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
+ Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True):
+ r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or
+ the binary search algorithm proposed in [44] otherwise.
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates
+ using e.g. the atan2 function.
+
+ General loss returned:
+
+ .. math::
+ OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
+
+ For p=1, [45]
+
+ .. math::
+ W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ u_values : ndarray, shape (n, ...)
+ samples in the source domain (coordinates on [0,1[)
+ v_values : ndarray, shape (n, ...)
+ samples in the target domain (coordinates on [0,1[)
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+ v_weights : ndarray, shape (n, ...), optional
+ samples weights in the target domain
+ p : float, optional (default=1)
+ Power p used for computing the Wasserstein distance
+ Lm : int, optional
+ Lower bound dC. For p>1.
+ Lp : int, optional
+ Upper bound dC. For p>1.
+ tm: float, optional
+ Lower bound theta. For p>1.
+ tp: float, optional
+ Upper bound theta. For p>1.
+ eps: float, optional
+ Stopping condition. For p>1.
+ require_sort: bool, optional
+ If True, sort the values.
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]])%1
+ >>> v = np.array([[0.4,0.5,0.7]])%1
+ >>> wasserstein_circle(u.T, v.T)
+ array([0.1])
+
+ References
+ ----------
+ .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
+ .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
+ """
+ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
+
+ if p == 1:
+ return wasserstein1_circle(u_values, v_values, u_weights, v_weights, require_sort)
+
+ return binary_search_circle(u_values, v_values, u_weights, v_weights,
+ p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps,
+ require_sort=require_sort)
+
+
+def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
+ r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1`
+ Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
+ takes the value modulo 1.
+ If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
+ using e.g. the atan2 function.
+
+ .. math::
+ W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12}
+
+ where:
+
+ - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}`
+
+ For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi},
+
+ using e.g. ot.utils.get_coordinate_circle(x)
+
+ Parameters
+ ----------
+ u_values: ndarray, shape (n, ...)
+ Samples
+ u_weights : ndarray, shape (n, ...), optional
+ samples weights in the source domain
+
+ Returns
+ -------
+ loss: float
+ Cost associated to the optimal transportation
+
+ Examples
+ --------
+ >>> x0 = np.array([[0], [0.2], [0.4]])
+ >>> semidiscrete_wasserstein2_unif_circle(x0)
+ array([0.02111111])
+
+ References
+ ----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+
+ if u_weights is not None:
+ nx = get_backend(u_values, u_weights)
+ else:
+ nx = get_backend(u_values)
+
+ n = u_values.shape[0]
+
+ u_values = u_values % 1
+
+ if len(u_values.shape) == 1:
+ u_values = nx.reshape(u_values, (n, 1))
+
+ if u_weights is None:
+ u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
+ elif u_weights.ndim != u_values.ndim:
+ u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
+
+ u_values = nx.sort(u_values, 0)
+ u_cdf = nx.cumsum(u_weights, 0)
+ u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)])
+
+ cpt1 = nx.sum(u_weights * u_values**2, axis=0)
+ u_mean = nx.sum(u_weights * u_values, axis=0)
+
+ ns = 1 - u_weights - 2 * u_cdf[:-1]
+ cpt2 = nx.sum(u_values * u_weights * ns, axis=0)
+
+ return cpt1 - u_mean**2 + cpt2 + 1 / 12
diff --git a/ot/optim.py b/ot/optim.py
index 5a1d605..58e5596 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
"""
-Generic solvers for regularized OT
+Generic solvers for regularized OT or its semi-relaxed version.
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-#
+# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
# License: MIT License
import numpy as np
@@ -27,7 +27,7 @@ with warnings.catch_warnings():
def line_search_armijo(
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
- alpha0=0.99, alpha_min=None, alpha_max=None
+ alpha0=0.99, alpha_min=None, alpha_max=None, nx=None, **kwargs
):
r"""
Armijo linesearch function that works with matrices
@@ -35,6 +35,9 @@ def line_search_armijo(
Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the
armijo conditions.
+ .. note:: If the loss function f returns a float (resp. a 1d array) then
+ the returned alpha and fa are float (resp. 1d arrays).
+
Parameters
----------
f : callable
@@ -45,7 +48,7 @@ def line_search_armijo(
descent direction
gfk : array-like
gradient of `f` at :math:`x_k`
- old_fval : float
+ old_fval : float or 1d array
loss value at :math:`x_k`
args : tuple, optional
arguments given to `f`
@@ -57,138 +60,97 @@ def line_search_armijo(
minimum value for alpha
alpha_max : float, optional
maximum value for alpha
-
+ nx : backend, optional
+ If let to its default value None, a backend test will be conducted.
Returns
-------
- alpha : float
+ alpha : float or 1d array
step that satisfy armijo conditions
fc : int
nb of function call
- fa : float
+ fa : float or 1d array
loss value at step alpha
"""
-
- xk, pk, gfk = list_to_array(xk, pk, gfk)
- nx = get_backend(xk, pk)
+ if nx is None:
+ xk, pk, gfk = list_to_array(xk, pk, gfk)
+ xk0, pk0 = xk, pk
+ nx = get_backend(xk0, pk0)
+ else:
+ xk0, pk0 = xk, pk
if len(xk.shape) == 0:
xk = nx.reshape(xk, (-1,))
+ xk = nx.to_numpy(xk)
+ pk = nx.to_numpy(pk)
+ gfk = nx.to_numpy(gfk)
+
fc = [0]
def phi(alpha1):
+ # The callable function operates on nx backend
fc[0] += 1
- return f(xk + alpha1 * pk, *args)
+ alpha10 = nx.from_numpy(alpha1)
+ fval = f(xk0 + alpha10 * pk0, *args)
+ if type(fval) is float:
+ # prevent bug from nx.to_numpy that can look for .cpu or .gpu
+ return fval
+ else:
+ return nx.to_numpy(fval)
if old_fval is None:
phi0 = phi(0.)
- else:
+ elif type(old_fval) is float:
+ # prevent bug from nx.to_numpy that can look for .cpu or .gpu
phi0 = old_fval
+ else:
+ phi0 = nx.to_numpy(old_fval)
- derphi0 = nx.sum(pk * gfk) # Quickfix for matrices
+ derphi0 = np.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
if alpha is None:
- return 0., fc[0], phi0
+ return 0., fc[0], nx.from_numpy(phi0, type_as=xk0)
else:
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
- return float(alpha), fc[0], phi1
+ return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)
-def solve_linesearch(
- cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None,
- reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None
-):
- """
- Solve the linesearch in the FW iterations
-
- Parameters
- ----------
- cost : method
- Cost in the FW for the linesearch
- G : array-like, shape(ns,nt)
- The transport map at a given iteration of the FW
- deltaG : array-like (ns,nt)
- Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
- Mi : array-like (ns,nt)
- Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
- f_val : float
- Value of the cost at `G`
- armijo : bool, optional
- If True the steps of the line-search is found via an armijo research. Else closed form is used.
- If there is convergence issues use False.
- C1 : array-like (ns,ns), optional
- Structure matrix in the source domain. Only used and necessary when armijo=False
- C2 : array-like (nt,nt), optional
- Structure matrix in the target domain. Only used and necessary when armijo=False
- reg : float, optional
- Regularization parameter. Only used and necessary when armijo=False
- Gc : array-like (ns,nt)
- Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
- constC : array-like (ns,nt)
- Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
- M : array-like (ns,nt), optional
- Cost matrix between the features. Only used and necessary when armijo=False
- alpha_min : float, optional
- Minimum value for alpha
- alpha_max : float, optional
- Maximum value for alpha
-
- Returns
- -------
- alpha : float
- The optimal step size of the FW
- fc : int
- nb of function call. Useless here
- f_val : float
- The value of the cost for the next iteration
+def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None,
+ numItermax=200, stopThr=1e-9,
+ stopThr2=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized OT problem or its semi-relaxed version with
+ conditional gradient or generalized conditional gradient depending on the
+ provided linear program solver.
+ The function solves the following optimization problem if set as a conditional gradient:
- .. _references-solve-linesearch:
- References
- ----------
- .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
- """
- if armijo:
- alpha, fc, f_val = line_search_armijo(
- cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max
- )
- else: # requires symetric matrices
- G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
- if isinstance(M, int) or isinstance(M, float):
- nx = get_backend(G, deltaG, C1, C2, constC)
- else:
- nx = get_backend(G, deltaG, C1, C2, constC, M)
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg_1} \cdot f(\gamma)
- dot = nx.dot(nx.dot(C1, deltaG), C2)
- a = -2 * reg * nx.sum(dot * deltaG)
- b = nx.sum((M + reg * constC) * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2) * deltaG))
- c = cost(G)
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
- alpha = solve_1d_linesearch_quad(a, b, c)
- if alpha_min is not None or alpha_max is not None:
- alpha = np.clip(alpha, alpha_min, alpha_max)
- fc = None
- f_val = cost(G + alpha * deltaG)
+ \gamma^T \mathbf{1} &= \mathbf{b} (optional constraint)
- return alpha, fc, f_val
+ \gamma &\geq 0
+ where :
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
-def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
- stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
- r"""
- Solve the general regularized OT problem with conditional gradient
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
- The function solves the following optimization problem:
+ The function solves the following optimization problem if set a generalized conditional gradient:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
- \mathrm{reg} \cdot f(\gamma)
+ \mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
@@ -197,29 +159,39 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
\gamma &\geq 0
where :
- - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- - :math:`f` is the regularization term (and `df` is its gradient)
- - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
-
- The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, shape (nt,)
- samples in the target domain
+ samples weights in the target domain
M : array-like, shape (ns, nt)
loss matrix
- reg : float
+ f : function
+ Regularization function taking a transportation matrix as argument
+ df: function
+ Gradient of the regularization function taking a transportation matrix as argument
+ reg1 : float
Regularization term >0
+ reg2 : float,
+ Entropic Regularization term >0. Ignored if set to None.
+ lp_solver: function,
+ linear program solver for direction finding of the (generalized) conditional gradient.
+ If set to emd will solve the general regularized OT problem using cg.
+ If set to lp_semi_relaxed_OT will solve the general regularized semi-relaxed OT problem using cg.
+ If set to sinkhorn will solve the general regularized OT problem using generalized cg.
+ line_search: function,
+ Function to find the optimal step. Currently used instances are:
+ line_search_armijo (generic solver). solve_gromov_linesearch for (F)GW problem.
+ solve_semirelaxed_gromov_linesearch for sr(F)GW problem. gcg_linesearch for the Generalized cg.
G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
- numItermaxEmd : int, optional
- Max number of iterations for emd
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
@@ -240,16 +212,20 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
.. _references-cg:
+ .. _references_gcg:
References
----------
.. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
+
See Also
--------
ot.lp.emd : Unregularized optimal ransport
ot.bregman.sinkhorn : Entropic regularized optimal transport
-
"""
a, b, M, G0 = list_to_array(a, b, M, G0)
if isinstance(M, int) or isinstance(M, float):
@@ -265,42 +241,45 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
if G0 is None:
G = nx.outer(a, b)
else:
- G = G0
-
- def cost(G):
- return nx.sum(M * G) + reg * f(G)
+ # to not change G0 in place.
+ G = nx.copy(G0)
- f_val = cost(G)
+ if reg2 is None:
+ def cost(G):
+ return nx.sum(M * G) + reg1 * f(G)
+ else:
+ def cost(G):
+ return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G))
+ cost_G = cost(G)
if log:
- log['loss'].append(f_val)
+ log['loss'].append(cost_G)
it = 0
if verbose:
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, 0, 0))
while loop:
it += 1
- old_fval = f_val
-
+ old_cost_G = cost_G
# problem linearization
- Mi = M + reg * df(G)
+ Mi = M + reg1 * df(G)
+
+ if not (reg2 is None):
+ Mi = Mi + reg2 * (1 + nx.log(G))
# set M positive
- Mi += nx.min(Mi)
+ Mi = Mi + nx.min(Mi)
# solve linear program
- Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True)
+ Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs)
+ # line search
deltaG = Gc - G
- # line search
- alpha, fc, f_val = solve_linesearch(
- cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,
- alpha_min=0., alpha_max=1., **kwargs
- )
+ alpha, fc, cost_G = line_search(cost, G, deltaG, Mi, cost_G, **kwargs)
G = G + alpha * deltaG
@@ -308,29 +287,197 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
if it >= numItermax:
loop = 0
- abs_delta_fval = abs(f_val - old_fval)
- relative_delta_fval = abs_delta_fval / abs(f_val)
- if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
+ abs_delta_cost_G = abs(cost_G - old_cost_G)
+ relative_delta_cost_G = abs_delta_cost_G / abs(cost_G)
+ if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2:
loop = 0
if log:
- log['loss'].append(f_val)
+ log['loss'].append(cost_G)
if verbose:
if it % 20 == 0:
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
+ print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, relative_delta_cost_G, abs_delta_cost_G))
if log:
- log.update(logemd)
+ log.update(innerlog_)
return G, log
else:
return G
+def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo,
+ numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9,
+ verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized OT problem with conditional gradient
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot f(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
+
+ \gamma^T \mathbf{1} &= \mathbf{b}
+
+ \gamma &\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (ns,)
+ samples weights in the source domain
+ b : array-like, shape (nt,)
+ samples in the target domain
+ M : array-like, shape (ns, nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ G0 : array-like, shape (ns,nt), optional
+ initial guess (default is indep joint density)
+ line_search: function,
+ Function to find the optimal step.
+ Default is line_search_armijo.
+ numItermax : int, optional
+ Max number of iterations
+ numItermaxEmd : int, optional
+ Max number of iterations for emd
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ **kwargs : dict
+ Parameters for linesearch
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-cg:
+ References
+ ----------
+
+ .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized optimal ransport
+ ot.bregman.sinkhorn : Entropic regularized optimal transport
+
+ """
+
+ def lp_solver(a, b, M, **kwargs):
+ return emd(a, b, M, numItermaxEmd, log=True)
+
+ return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
+
+
+def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo,
+ numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
+ r"""
+ Solve the general regularized and semi-relaxed OT problem with conditional gradient
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \cdot f(\gamma)
+
+ s.t. \ \gamma \mathbf{1} &= \mathbf{a}
+
+ \gamma &\geq 0
+ where :
+
+ - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
+ - :math:`f` is the regularization term (and `df` is its gradient)
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
+
+
+ Parameters
+ ----------
+ a : array-like, shape (ns,)
+ samples weights in the source domain
+ b : array-like, shape (nt,)
+ currently estimated samples weights in the target domain
+ M : array-like, shape (ns, nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ G0 : array-like, shape (ns,nt), optional
+ initial guess (default is indep joint density)
+ line_search: function,
+ Function to find the optimal step.
+ Default is the armijo line-search.
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on the relative variation (>0)
+ stopThr2 : float, optional
+ Stop threshold on the absolute variation (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ **kwargs : dict
+ Parameters for linesearch
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-cg:
+ References
+ ----------
+
+ .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
+ "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
+ International Conference on Learning Representations (ICLR), 2021.
+
+ """
+
+ nx = get_backend(a, b)
+
+ def lp_solver(a, b, Mi, **kwargs):
+ # get minimum by rows as binary mask
+ Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1)))
+ Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1))
+ # return by default an empty inner_log
+ return Gc, {}
+
+ return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr,
+ stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
+
+
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
- numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
+ numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
r"""
Solve the general regularized OT problem with the generalized conditional gradient
@@ -403,81 +550,18 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
ot.optim.cg : conditional gradient
"""
- a, b, M, G0 = list_to_array(a, b, M, G0)
- nx = get_backend(a, b, M)
-
- loop = 1
-
- if log:
- log = {'loss': []}
-
- if G0 is None:
- G = nx.outer(a, b)
- else:
- G = G0
-
- def cost(G):
- return nx.sum(M * G) + reg1 * nx.sum(G * nx.log(G)) + reg2 * f(G)
-
- f_val = cost(G)
- if log:
- log['loss'].append(f_val)
-
- it = 0
-
- if verbose:
- print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
-
- while loop:
-
- it += 1
- old_fval = f_val
-
- # problem linearization
- Mi = M + reg2 * df(G)
-
- # solve linear program with Sinkhorn
- # Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
- Gc = sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax)
- deltaG = Gc - G
-
- # line search
- dcost = Mi + reg1 * (1 + nx.log(G)) # ??
- alpha, fc, f_val = line_search_armijo(
- cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1.
- )
+ def lp_solver(a, b, Mi, **kwargs):
+ return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs)
- G = G + alpha * deltaG
+ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
+ return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs)
- # test convergence
- if it >= numItermax:
- loop = 0
+ return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0,
+ numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
- abs_delta_fval = abs(f_val - old_fval)
- relative_delta_fval = abs_delta_fval / abs(f_val)
- if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
- loop = 0
-
- if log:
- log['loss'].append(f_val)
-
- if verbose:
- if it % 20 == 0:
- print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
- 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
- print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
-
- if log:
- return G, log
- else:
- return G
-
-
-def solve_1d_linesearch_quad(a, b, c):
+def solve_1d_linesearch_quad(a, b):
r"""
For any convex or non-convex 1d quadratic function `f`, solve the following problem:
@@ -487,7 +571,7 @@ def solve_1d_linesearch_quad(a, b, c):
Parameters
----------
- a,b,c : float
+ a,b : float or tensors (1,)
The coefficients of the quadratic function
Returns
@@ -495,15 +579,11 @@ def solve_1d_linesearch_quad(a, b, c):
x : float
The optimal value which leads to the minimal cost
"""
- f0 = c
- df0 = b
- f1 = a + f0 + df0
-
if a > 0: # convex
- minimum = min(1, max(0, np.divide(-b, 2.0 * a)))
+ minimum = min(1., max(0., -b / (2.0 * a)))
return minimum
else: # non convex
- if f0 > f1:
- return 1
+ if a + b < 0:
+ return 1.
else:
- return 0
+ return 0.
diff --git a/ot/partial.py b/ot/partial.py
index 0a9e450..bf4119d 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -8,6 +8,8 @@ Partial OT solvers
import numpy as np
from .lp import emd
+from .backend import get_backend
+from .utils import list_to_array
def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
@@ -114,14 +116,22 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
ot.partial.partial_wasserstein : Partial Wasserstein with fixed mass
"""
- if np.sum(a) > 1 or np.sum(b) > 1:
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
+ if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15: # 1e-15 for numerical errors
raise ValueError("Problem infeasible. Check that a and b are in the "
"simplex")
if reg_m is None:
- reg_m = np.max(M) + 1
- if reg_m < -np.max(M):
- return np.zeros((len(a), len(b)))
+ reg_m = float(nx.max(M)) + 1
+ if reg_m < -nx.max(M):
+ return nx.zeros((len(a), len(b)), type_as=M)
+
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
eps = 1e-20
M = np.asarray(M, dtype=np.float64)
@@ -149,10 +159,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
gamma = np.zeros((len(a), len(b)))
gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies]
+ # convert back to backend
+ gamma = nx.from_numpy(gamma, type_as=M0)
+
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['cost'] = np.sum(gamma * M)
+ log_emd['cost'] = nx.sum(gamma * M0)
+ log_emd['u'] = nx.from_numpy(log_emd['u'], type_as=a0)
+ log_emd['v'] = nx.from_numpy(log_emd['v'], type_as=b0)
+
if log:
return gamma, log_emd
else:
@@ -250,32 +266,52 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
entropic regularization parameter
"""
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
+ dim_a, dim_b = M.shape
+ if len(a) == 0:
+ a = nx.ones(dim_a, type_as=a) / dim_a
+ if len(b) == 0:
+ b = nx.ones(dim_b, type_as=b) / dim_b
+
if m is None:
return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs)
elif m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
- elif m > np.min((np.sum(a), np.sum(b))):
+ elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")
- b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
- a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
- M_extended = np.zeros((len(a_extended), len(b_extended)))
- M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2
- M_extended[:len(a), :len(b)] = M
+ b_extension = nx.ones(nb_dummies, type_as=b) * (nx.sum(a) - m) / nb_dummies
+ b_extended = nx.concatenate((b, b_extension))
+ a_extension = nx.ones(nb_dummies, type_as=a) * (nx.sum(b) - m) / nb_dummies
+ a_extended = nx.concatenate((a, a_extension))
+ M_extension = nx.ones((nb_dummies, nb_dummies), type_as=M) * nx.max(M) * 2
+ M_extended = nx.concatenate(
+ (nx.concatenate((M, nx.zeros((M.shape[0], M_extension.shape[1]))), axis=1),
+ nx.concatenate((nx.zeros((M_extension.shape[0], M.shape[1])), M_extension), axis=1)),
+ axis=0
+ )
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
**kwargs)
+
+ gamma = gamma[:len(a), :len(b)]
+
if log_emd['warning'] is not None:
raise ValueError("Error in the EMD resolution: try to increase the"
" number of dummy points")
- log_emd['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)])
+ log_emd['partial_w_dist'] = nx.sum(M * gamma)
+ log_emd['u'] = log_emd['u'][:len(a)]
+ log_emd['v'] = log_emd['v'][:len(b)]
if log:
- return gamma[:len(a), :len(b)], log_emd
+ return gamma, log_emd
else:
- return gamma[:len(a), :len(b)]
+ return gamma
def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
@@ -360,14 +396,18 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
NeurIPS.
"""
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
+
partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True,
**kwargs)
log_w['T'] = partial_gw
if log:
- return np.sum(partial_gw * M), log_w
+ return nx.sum(partial_gw * M), log_w
else:
- return np.sum(partial_gw * M)
+ return nx.sum(partial_gw * M)
def gwgrad_partial(C1, C2, T):
@@ -809,60 +849,64 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
ot.partial.partial_wasserstein: exact Partial Wasserstein
"""
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+
+ nx = get_backend(a, b, M)
dim_a, dim_b = M.shape
- dx = np.ones(dim_a, dtype=np.float64)
- dy = np.ones(dim_b, dtype=np.float64)
+ dx = nx.ones(dim_a, type_as=a)
+ dy = nx.ones(dim_b, type_as=b)
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=a) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=b) / dim_b
if m is None:
- m = np.min((np.sum(a), np.sum(b))) * 1.0
+ m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0
if m < 0:
raise ValueError("Problem infeasible. Parameter m should be greater"
" than 0.")
- if m > np.min((np.sum(a), np.sum(b))):
+ if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
raise ValueError("Problem infeasible. Parameter m should lower or"
" equal than min(|a|_1, |b|_1).")
log_e = {'err': []}
- # Next 3 lines equivalent to K=np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
- np.multiply(K, m / np.sum(K), out=K)
+ if type(a) == type(b) == type(M) == np.ndarray:
+ # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+ np.multiply(K, m / np.sum(K), out=K)
+ else:
+ K = nx.exp(-M / reg)
+ K = K * m / nx.sum(K)
err, cpt = 1, 0
- q1 = np.ones(K.shape)
- q2 = np.ones(K.shape)
- q3 = np.ones(K.shape)
+ q1 = nx.ones(K.shape, type_as=K)
+ q2 = nx.ones(K.shape, type_as=K)
+ q3 = nx.ones(K.shape, type_as=K)
while (err > stopThr and cpt < numItermax):
Kprev = K
K = K * q1
- K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K)
+ K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K)
q1 = q1 * Kprev / K1
K1prev = K1
K1 = K1 * q2
- K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy)))
+ K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy)))
q2 = q2 * K1prev / K2
K2prev = K2
K2 = K2 * q3
- K = K2 * (m / np.sum(K2))
+ K = K2 * (m / nx.sum(K2))
q3 = q3 * K2prev / K
- if np.any(np.isnan(K)) or np.any(np.isinf(K)):
+ if nx.any(nx.isnan(K)) or nx.any(nx.isinf(K)):
print('Warning: numerical errors at iteration', cpt)
break
if cpt % 10 == 0:
- err = np.linalg.norm(Kprev - K)
+ err = nx.norm(Kprev - K)
if log:
log_e['err'].append(err)
if verbose:
@@ -872,7 +916,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,
print('{:5d}|{:8e}|'.format(cpt, err))
cpt = cpt + 1
- log_e['partial_w_dist'] = np.sum(M * K)
+ log_e['partial_w_dist'] = nx.sum(M * K)
if log:
return K, log_e
else:
diff --git a/ot/sliced.py b/ot/sliced.py
index cf2d3be..077ff0b 100644
--- a/ot/sliced.py
+++ b/ot/sliced.py
@@ -12,7 +12,8 @@ Sliced OT Distances
import numpy as np
from .backend import get_backend, NumpyBackend
-from .utils import list_to_array
+from .utils import list_to_array, get_coordinate_circle
+from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle
def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None):
@@ -107,7 +108,6 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
--------
>>> n_samples_a = 20
- >>> reg = 0.1
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
0.0
@@ -147,6 +147,8 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
if projections is None:
projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
+ else:
+ n_projections = projections.shape[1]
X_s_projections = nx.dot(X_s, projections)
X_t_projections = nx.dot(X_t, projections)
@@ -206,7 +208,6 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
--------
>>> n_samples_a = 20
- >>> reg = 0.1
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
0.0
@@ -256,3 +257,183 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
if log:
return res, {"projections": projections, "projected_emds": projected_emd}
return res
+
+
+def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
+ p=2, seed=None, log=False):
+ r"""
+ Compute the spherical sliced-Wasserstein discrepancy.
+
+ .. math::
+ SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac{1}{p}}
+
+ where:
+
+ - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}`
+
+ The function runs on backend but tensorflow is not supported.
+
+ Parameters
+ ----------
+ X_s: ndarray, shape (n_samples_a, dim)
+ Samples in the source domain
+ X_t: ndarray, shape (n_samples_b, dim)
+ Samples in the target domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ b : ndarray, shape (n_samples_b,), optional
+ samples weights in the target domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ p: float, optional (default=2)
+ Power p used for computing the spherical sliced Wasserstein
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_sphere returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Spherical Sliced Wasserstein Cost
+ log: dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+ >>> n_samples_a = 20
+ >>> X = np.random.normal(0., 1., (n_samples_a, 5))
+ >>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True))
+ >>> sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
+ 0.0
+
+ References
+ ----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+ if a is not None and b is not None:
+ nx = get_backend(X_s, X_t, a, b)
+ else:
+ nx = get_backend(X_s, X_t)
+
+ n, d = X_s.shape
+ m, _ = X_t.shape
+
+ if X_s.shape[1] != X_t.shape[1]:
+ raise ValueError(
+ "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
+ X_t.shape[1]))
+ if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("X_s is not on the sphere.")
+ if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("Xt is not on the sphere.")
+
+ # Uniforms and independent samples on the Stiefel manifold V_{d,2}
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ Z = seed.randn(n_projections, d, 2)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ Z = nx.randn(n_projections, d, 2, type_as=X_s)
+
+ projections, _ = nx.qr(Z)
+
+ # Projection on S^1
+ # Projection on plane
+ Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
+ Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1))
+
+ # Projection on sphere
+ Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
+ Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True))
+
+ # Get coordinates on [0,1[
+ Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
+ Xpt_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m))
+
+ projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p)
+ res = nx.mean(projected_emd) ** (1 / p)
+
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
+
+
+def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log=False):
+ r"""Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution.
+
+ .. math::
+ SSW_2(\mu_n, \nu)
+
+ where
+
+ - :math:`\mu_n=\sum_{i=1}^n \alpha_i \delta_{x_i}`
+ - :math:`\nu=\mathrm{Unif}(S^1)`
+
+ Parameters
+ ----------
+ X_s: ndarray, shape (n_samples_a, dim)
+ Samples in the source domain
+ a : ndarray, shape (n_samples_a,), optional
+ samples weights in the source domain
+ n_projections : int, optional
+ Number of projections used for the Monte-Carlo approximation
+ seed: int or RandomState or None, optional
+ Seed used for random number generator
+ log: bool, optional
+ if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
+
+ Returns
+ -------
+ cost: float
+ Spherical Sliced Wasserstein Cost
+ log: dict, optional
+ log dictionary return only if log==True in parameters
+
+ Examples
+ ---------
+ >>> np.random.seed(42)
+ >>> x0 = np.random.randn(500,3)
+ >>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True))
+ >>> ssw = sliced_wasserstein_sphere_unif(x0, seed=42)
+ >>> np.allclose(sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3)
+ True
+
+ References:
+ -----------
+ .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
+ """
+ if a is not None:
+ nx = get_backend(X_s, a)
+ else:
+ nx = get_backend(X_s)
+
+ n, d = X_s.shape
+
+ if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)):
+ raise ValueError("X_s is not on the sphere.")
+
+ # Uniforms and independent samples on the Stiefel manifold V_{d,2}
+ if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
+ Z = seed.randn(n_projections, d, 2)
+ else:
+ if seed is not None:
+ nx.seed(seed)
+ Z = nx.randn(n_projections, d, 2, type_as=X_s)
+
+ projections, _ = nx.qr(Z)
+
+ # Projection on S^1
+ # Projection on plane
+ Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
+ # Projection on sphere
+ Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
+ # Get coordinates on [0,1[
+ Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
+
+ projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a)
+ res = nx.mean(projected_emd) ** (1 / 2)
+
+ if log:
+ return res, {"projections": projections, "projected_emds": projected_emd}
+ return res
diff --git a/ot/smooth.py b/ot/smooth.py
index 6855005..8e0ef38 100644
--- a/ot/smooth.py
+++ b/ot/smooth.py
@@ -44,6 +44,7 @@ Original code from https://github.com/mblondel/smooth-ot/
import numpy as np
from scipy.optimize import minimize
+from .backend import get_backend
def projection_simplex(V, z=1, axis=None):
@@ -511,6 +512,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
"""
+ nx = get_backend(a, b, M)
+
if reg_type.lower() in ['l2', 'squaredl2']:
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
@@ -518,15 +521,19 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
else:
raise NotImplementedError('Unknown regularization')
+ a0, b0, M0 = a, b, M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
+
# solve dual
alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,
tol=stopThr, verbose=verbose)
# reconstruct transport matrix
- G = get_plan_from_dual(alpha, beta, M, regul)
+ G = nx.from_numpy(get_plan_from_dual(alpha, beta, M, regul), type_as=M0)
if log:
- log = {'alpha': alpha, 'beta': beta, 'res': res}
+ log = {'alpha': nx.from_numpy(alpha, type_as=a0), 'beta': nx.from_numpy(beta, type_as=b0), 'res': res}
return G, log
else:
return G
diff --git a/ot/solvers.py b/ot/solvers.py
new file mode 100644
index 0000000..0294d71
--- /dev/null
+++ b/ot/solvers.py
@@ -0,0 +1,347 @@
+# -*- coding: utf-8 -*-
+"""
+General OT solvers with unified API
+"""
+
+# Author: Remi Flamary <remi.flamary@polytechnique.edu>
+#
+# License: MIT License
+
+from .utils import OTResult
+from .lp import emd2
+from .backend import get_backend
+from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced
+from .bregman import sinkhorn_log
+from .partial import partial_wasserstein_lagrange
+from .smooth import smooth_ot_dual
+
+
+def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
+ unbalanced_type='KL', n_threads=1, max_iter=None, plan_init=None,
+ potentials_init=None, tol=None, verbose=False):
+ r"""Solve the discrete optimal transport problem and return :any:`OTResult` object
+
+ The function solves the following general optimal transport problem
+
+ .. math::
+ \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) +
+ \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) +
+ \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
+
+ The regularization is selected with :any:`reg` (:math:`\lambda_r`) and :any:`reg_type`. By
+ default ``reg=None`` and there is no regularization. The unbalanced marginal
+ penalization can be selected with :any:`unbalanced` (:math:`\lambda_u`) and
+ :any:`unbalanced_type`. By default ``unbalanced=None`` and the function
+ solves the exact optimal transport problem (respecting the marginals).
+
+ Parameters
+ ----------
+ M : array_like, shape (dim_a, dim_b)
+ Loss matrix
+ a : array-like, shape (dim_a,), optional
+ Samples weights in the source domain (default is uniform)
+ b : array-like, shape (dim_b,), optional
+ Samples weights in the source domain (default is uniform)
+ reg : float, optional
+ Regularization weight :math:`\lambda_r`, by default None (no reg., exact
+ OT)
+ reg_type : str, optional
+ Type of regularization :math:`R` either "KL", "L2", 'entropy', by default "KL"
+ unbalanced : float, optional
+ Unbalanced penalization weight :math:`\lambda_u`, by default None
+ (balanced OT)
+ unbalanced_type : str, optional
+ Type of unbalanced penalization unction :math:`U` either "KL", "L2", 'TV', by default 'KL'
+ n_threads : int, optional
+ Number of OMP threads for exact OT solver, by default 1
+ max_iter : int, optional
+ Maximum number of iteration, by default None (default values in each solvers)
+ plan_init : array_like, shape (dim_a, dim_b), optional
+ Initialization of the OT plan for iterative methods, by default None
+ potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
+ Initialization of the OT dual potentials for iterative methods, by default None
+ tol : _type_, optional
+ Tolerance for solution precision, by default None (default values in each solvers)
+ verbose : bool, optional
+ Print information in the solver, by default False
+
+ Returns
+ -------
+ res : OTResult()
+ Result of the optimization problem. The information can be obtained as follows:
+
+ - res.plan : OT plan :math:`\mathbf{T}`
+ - res.potentials : OT dual potentials
+ - res.value : Optimal value of the optimization problem
+ - res.value_linear : Linear OT loss with the optimal OT plan
+
+ See :any:`OTResult` for more information.
+
+ Notes
+ -----
+
+ The following methods are available for solving the OT problems:
+
+ - **Classical exact OT problem** (default parameters):
+
+ .. math::
+ \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F
+
+ s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
+
+ \mathbf{T}^T \mathbf{1} = \mathbf{b}
+
+ \mathbf{T} \geq 0
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ res = ot.solve(M, a, b)
+
+ - **Entropic regularized OT** (when ``reg!=None``):
+
+ .. math::
+ \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})
+
+ s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
+
+ \mathbf{T}^T \mathbf{1} = \mathbf{b}
+
+ \mathbf{T} \geq 0
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ # default is ``"KL"`` regularization (``reg_type="KL"``)
+ res = ot.solve(M, a, b, reg=1.0)
+ # or for original Sinkhorn paper formulation [2]
+ res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')
+
+ - **Quadratic regularized OT** (when ``reg!=None`` and ``reg_type="L2"``):
+
+ .. math::
+ \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})
+
+ s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
+
+ \mathbf{T}^T \mathbf{1} = \mathbf{b}
+
+ \mathbf{T} \geq 0
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ res = ot.solve(M,a,b,reg=1.0,reg_type='L2')
+
+ - **Unbalanced OT** (when ``unbalanced!=None``):
+
+ .. math::
+ \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ # default is ``"KL"``
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0)
+ # quadratic unbalanced OT
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2')
+ # TV = partial OT
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='TV')
+
+
+ - **Regularized unbalanced regularized OT** (when ``unbalanced!=None`` and ``reg!=None``):
+
+ .. math::
+ \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
+
+ can be solved with the following code:
+
+ .. code-block:: python
+
+ # default is ``"KL"`` for both
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0)
+ # quadratic unbalanced OT with KL regularization
+ res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2')
+ # both quadratic
+ res = ot.solve(M,a,b,reg=1.0, reg_type='L2',unbalanced=1.0,unbalanced_type='L2')
+
+
+ .. _references-solve:
+ References
+ ----------
+
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
+
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
+
+ .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse
+ Optimal Transport. Proceedings of the Twenty-First International
+ Conference on Artificial Intelligence and Statistics (AISTATS).
+
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
+ A., & Peyré, G. (2019, April). Interpolating between optimal transport
+ and MMD using Sinkhorn divergences. In The 22nd International Conference
+ on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
+
+ """
+
+ # detect backend
+ arr = [M]
+ if a is not None:
+ arr.append(a)
+ if b is not None:
+ arr.append(b)
+ nx = get_backend(*arr)
+
+ # create uniform weights if not given
+ if a is None:
+ a = nx.ones(M.shape[0], type_as=M) / M.shape[0]
+ if b is None:
+ b = nx.ones(M.shape[1], type_as=M) / M.shape[1]
+
+ # default values for solutions
+ potentials = None
+ value = None
+ value_linear = None
+ plan = None
+ status = None
+
+ if reg is None or reg == 0: # exact OT
+
+ if unbalanced is None: # Exact balanced OT
+
+ # default values for EMD solver
+ if max_iter is None:
+ max_iter = 1000000
+
+ value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads)
+
+ value = value_linear
+ potentials = (log['u'], log['v'])
+ plan = log['G']
+ status = log["warning"] if log["warning"] is not None else 'Converged'
+
+ elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT
+
+ # default values for exact unbalanced OT
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-12
+
+ plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced,
+ div=unbalanced_type.lower(), numItermax=max_iter,
+ stopThr=tol, log=True,
+ verbose=verbose, G0=plan_init)
+
+ value_linear = log['cost']
+
+ if unbalanced_type.lower() == 'kl':
+ value = value_linear + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b))
+ else:
+ err_a = nx.sum(plan, 1) - a
+ err_b = nx.sum(plan, 0) - b
+ value = value_linear + unbalanced * nx.sum(err_a**2) + unbalanced * nx.sum(err_b**2)
+
+ elif unbalanced_type.lower() == 'tv':
+
+ if max_iter is None:
+ max_iter = 1000000
+
+ plan, log = partial_wasserstein_lagrange(a, b, M, reg_m=unbalanced**2, log=True, numItermax=max_iter)
+
+ value_linear = nx.sum(M * plan)
+ err_a = nx.sum(plan, 1) - a
+ err_b = nx.sum(plan, 0) - b
+ value = value_linear + nx.sqrt(unbalanced**2 / 2.0 * (nx.sum(nx.abs(err_a)) +
+ nx.sum(nx.abs(err_b))))
+
+ else:
+ raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type)))
+
+ else: # regularized OT
+
+ if unbalanced is None: # Balanced regularized OT
+
+ if reg_type.lower() in ['entropy', 'kl']:
+
+ # default values for sinkhorn
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-9
+
+ plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter,
+ stopThr=tol, log=True,
+ verbose=verbose)
+
+ value_linear = nx.sum(M * plan)
+
+ if reg_type.lower() == 'entropy':
+ value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16))
+ else:
+ value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :])
+
+ potentials = (log['log_u'], log['log_v'])
+
+ elif reg_type.lower() == 'l2':
+
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-9
+
+ plan, log = smooth_ot_dual(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose)
+
+ value_linear = nx.sum(M * plan)
+ value = value_linear + reg * nx.sum(plan**2)
+ potentials = (log['alpha'], log['beta'])
+
+ else:
+ raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type)))
+
+ else: # unbalanced AND regularized OT
+
+ if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl':
+
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-9
+
+ plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True)
+
+ value_linear = nx.sum(M * plan)
+
+ value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b))
+
+ potentials = (log['logu'], log['logv'])
+
+ elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']:
+
+ if max_iter is None:
+ max_iter = 1000
+ if tol is None:
+ tol = 1e-12
+
+ plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True)
+
+ value_linear = nx.sum(M * plan)
+
+ value = log['loss']
+
+ else:
+ raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type)))
+
+ res = OTResult(potentials=potentials, value=value,
+ value_linear=value_linear, plan=plan, status=status, backend=nx)
+
+ return res
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 90c920c..a71a0dd 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -10,6 +10,9 @@ Regularized Unbalanced OT solvers
from __future__ import division
import warnings
+import numpy as np
+from scipy.optimize import minimize, Bounds
+
from .backend import get_backend
from .utils import list_to_array
# from .utils import unif, dist
@@ -269,7 +272,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
- Solve the entropic regularization unbalanced optimal transport problem and return the loss
+ Solve the entropic regularization unbalanced optimal transport problem and
+ return the OT plan
The function solves the following optimization problem:
@@ -734,7 +738,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
- assert(len(weights) == A.shape[1])
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -882,7 +886,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
- assert(len(weights) == A.shape[1])
+ assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
@@ -1252,3 +1256,182 @@ def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000,
return log_mm['cost'], log_mm
else:
return log_mm['cost']
+
+
+def _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl'):
+ """
+ return the loss function (scipy.optimize compatible) for regularized
+ unbalanced OT
+ """
+
+ m, n = M.shape
+
+ def kl(p, q):
+ return np.sum(p * np.log(p / q + 1e-16))
+
+ def reg_l2(G):
+ return np.sum((G - a[:, None] * b[None, :])**2) / 2
+
+ def grad_l2(G):
+ return G - a[:, None] * b[None, :]
+
+ def reg_kl(G):
+ return kl(G, a[:, None] * b[None, :])
+
+ def grad_kl(G):
+ return np.log(G / (a[:, None] * b[None, :]) + 1e-16) + 1
+
+ def reg_entropy(G):
+ return kl(G, 1)
+
+ def grad_entropy(G):
+ return np.log(G + 1e-16) + 1
+
+ if reg_div == 'kl':
+ reg_fun = reg_kl
+ grad_reg_fun = grad_kl
+ elif reg_div == 'entropy':
+ reg_fun = reg_entropy
+ grad_reg_fun = grad_entropy
+ else:
+ reg_fun = reg_l2
+ grad_reg_fun = grad_l2
+
+ def marg_l2(G):
+ return 0.5 * np.sum((G.sum(1) - a)**2) + 0.5 * np.sum((G.sum(0) - b)**2)
+
+ def grad_marg_l2(G):
+ return np.outer((G.sum(1) - a), np.ones(n)) + np.outer(np.ones(m), (G.sum(0) - b))
+
+ def marg_kl(G):
+ return kl(G.sum(1), a) + kl(G.sum(0), b)
+
+ def grad_marg_kl(G):
+ return np.outer(np.log(G.sum(1) / a + 1e-16) + 1, np.ones(n)) + np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16) + 1)
+
+ if regm_div == 'kl':
+ regm_fun = marg_kl
+ grad_regm_fun = grad_marg_kl
+ else:
+ regm_fun = marg_l2
+ grad_regm_fun = grad_marg_l2
+
+ def _func(G):
+ G = G.reshape((m, n))
+
+ # compute loss
+ val = np.sum(G * M) + reg * reg_fun(G) + reg_m * regm_fun(G)
+
+ # compute gradient
+ grad = M + reg * grad_reg_fun(G) + reg_m * grad_regm_fun(G)
+
+ return val, grad.ravel()
+
+ return _func
+
+
+def lbfgsb_unbalanced(a, b, M, reg, reg_m, reg_div='kl', regm_div='kl', G0=None, numItermax=1000,
+ stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False):
+ r"""
+ Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B.
+ The function solves the following optimization problem:
+
+ .. math::
+ W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ + \mathrm{reg} \mathrm{div}(\gamma,\mathbf{a}\mathbf{b}^T)
+ \mathrm{reg_m} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) +
+ \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
+
+ s.t.
+ \gamma \geq 0
+
+ where:
+
+ - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ unbalanced distributions
+ - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
+
+ The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize
+
+ Parameters
+ ----------
+ a : array-like (dim_a,)
+ Unnormalized histogram of dimension `dim_a`
+ b : array-like (dim_b,)
+ Unnormalized histogram of dimension `dim_b`
+ M : array-like (dim_a, dim_b)
+ loss matrix
+ reg: float
+ regularization term (>=0)
+ reg_m: float
+ Marginal relaxation term >= 0
+ reg_div: string, optional
+ Divergence used for regularization.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ reg_div: string, optional
+ Divergence to quantify the difference between the marginals.
+ Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
+ G0: array-like (dim_a, dim_b)
+ Initialization of the transport matrix
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+ Returns
+ -------
+ ot_distance : array-like
+ the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}`
+ log : dict
+ log dictionary returned only if `log` is `True`
+
+ Examples
+ --------
+ >>> import ot
+ >>> import numpy as np
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[1., 36.],[9., 4.]]
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2)
+ 0.25
+ >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2)
+ 0.57
+
+ References
+ ----------
+ .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
+ Unbalanced optimal transport through non-negative penalized
+ linear regression. NeurIPS.
+ See Also
+ --------
+ ot.lp.emd2 : Unregularized OT loss
+ ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
+ """
+ nx = get_backend(M, a, b)
+
+ M0 = M
+ # convert to humpy
+ a, b, M = nx.to_numpy(a, b, M)
+
+ if G0 is not None:
+ G0 = nx.to_numpy(G0)
+ else:
+ G0 = np.zeros(M.shape)
+
+ _func = _get_loss_unbalanced(a, b, M, reg, reg_m, reg_div, regm_div)
+
+ res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf),
+ tol=stopThr, options=dict(maxiter=numItermax, disp=verbose))
+
+ G = nx.from_numpy(res.x.reshape(M.shape), type_as=M0)
+
+ if log:
+ log = {'loss': nx.from_numpy(res.fun, type_as=M0), 'res': res}
+ return G, log
+ else:
+ return G
diff --git a/ot/utils.py b/ot/utils.py
index a23ce7e..3423a7e 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
-from .backend import get_backend, Backend
+from .backend import get_backend, Backend, NumpyBackend
__time_tic_toc = time.time()
@@ -232,9 +232,11 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None):
if not get_backend(x1, x2).__name__ == 'numpy':
raise NotImplementedError()
else:
- if metric.endswith("minkowski"):
+ if isinstance(metric, str) and metric.endswith("minkowski"):
return cdist(x1, x2, metric=metric, p=p, w=w)
- return cdist(x1, x2, metric=metric, w=w)
+ if w is not None:
+ return cdist(x1, x2, metric=metric, w=w)
+ return cdist(x1, x2, metric=metric)
def dist0(n, method='lin_square'):
@@ -373,6 +375,36 @@ def check_random_state(seed):
' instance'.format(seed))
+def get_coordinate_circle(x):
+ r"""For :math:`x\in S^1 \subset \mathbb{R}^2`, returns the coordinates in
+ turn (in [0,1[).
+
+ .. math::
+ u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
+
+ Parameters
+ ----------
+ x: ndarray, shape (n, 2)
+ Samples on the circle with ambient coordinates
+
+ Returns
+ -------
+ x_t: ndarray, shape (n,)
+ Coordinates on [0,1[
+
+ Examples
+ --------
+ >>> u = np.array([[0.2,0.5,0.8]]) * (2 * np.pi)
+ >>> x1, y1 = np.cos(u), np.sin(u)
+ >>> x = np.concatenate([x1, y1]).T
+ >>> get_coordinate_circle(x)
+ array([0.2, 0.5, 0.8])
+ """
+ nx = get_backend(x)
+ x_t = (nx.atan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
+ return x_t
+
+
class deprecated(object):
r"""Decorator to mark a function or class as deprecated.
@@ -609,3 +641,203 @@ class UndefinedParameter(Exception):
"""
pass
+
+
+class OTResult:
+ def __init__(self, potentials=None, value=None, value_linear=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None):
+
+ self._potentials = potentials
+ self._value = value
+ self._value_linear = value_linear
+ self._plan = plan
+ self._log = log
+ self._sparse_plan = sparse_plan
+ self._lazy_plan = lazy_plan
+ self._backend = backend if backend is not None else NumpyBackend()
+ self._status = status
+
+ # I assume that other solvers may return directly
+ # some primal objects?
+ # In the code below, let's define the main quantities
+ # that may be of interest to users.
+ # An OT solver returns an object that inherits from OTResult
+ # (e.g. SinkhornOTResult) and implements the relevant
+ # methods (e.g. "plan" and "lazy_plan" but not "sparse_plan", etc.).
+ # log is a dictionary containing potential information about the solver
+
+ # Dual potentials --------------------------------------------
+
+ def __repr__(self):
+ s = 'OTResult('
+ if self._value is not None:
+ s += 'value={},'.format(self._value)
+ if self._value_linear is not None:
+ s += 'value_linear={},'.format(self._value_linear)
+ if self._plan is not None:
+ s += 'plan={}(shape={}),'.format(self._plan.__class__.__name__, self._plan.shape)
+
+ if s[-1] != '(':
+ s = s[:-1] + ')'
+ else:
+ s = s + ')'
+ return s
+
+ @property
+ def potentials(self):
+ """Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
+
+ This pair of arrays has the same shape, numerical type
+ and properties as the input weights "a" and "b".
+ """
+ if self._potentials is not None:
+ return self._potentials
+ else:
+ raise NotImplementedError()
+
+ @property
+ def potential_a(self):
+ """First dual potential, associated to the "source" measure "a"."""
+ if self._potentials is not None:
+ return self._potentials[0]
+ else:
+ raise NotImplementedError()
+
+ @property
+ def potential_b(self):
+ """Second dual potential, associated to the "target" measure "b"."""
+ if self._potentials is not None:
+ return self._potentials[1]
+ else:
+ raise NotImplementedError()
+
+ # Transport plan -------------------------------------------
+ @property
+ def plan(self):
+ """Transport plan, encoded as a dense array."""
+ # N.B.: We may catch out-of-memory errors and suggest
+ # the use of lazy_plan or sparse_plan when appropriate.
+
+ if self._plan is not None:
+ return self._plan
+ else:
+ raise NotImplementedError()
+
+ @property
+ def sparse_plan(self):
+ """Transport plan, encoded as a sparse array."""
+ if self._sparse_plan is not None:
+ return self._sparse_plan
+ elif self._plan is not None:
+ return self._backend.tocsr(self._plan)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def lazy_plan(self):
+ """Transport plan, encoded as a symbolic KeOps LazyTensor."""
+ raise NotImplementedError()
+
+ # Loss values --------------------------------
+
+ @property
+ def value(self):
+ """Full transport cost, including possible regularization terms."""
+ if self._value is not None:
+ return self._value
+ else:
+ raise NotImplementedError()
+
+ @property
+ def value_linear(self):
+ """The "minimal" transport cost, i.e. the product between the transport plan and the cost."""
+ if self._value_linear is not None:
+ return self._value_linear
+ else:
+ raise NotImplementedError()
+
+ # Marginal constraints -------------------------
+ @property
+ def marginals(self):
+ """Marginals of the transport plan: should be very close to "a" and "b"
+ for balanced OT."""
+ if self._plan is not None:
+ return self.marginal_a, self.marginal_b
+ else:
+ raise NotImplementedError()
+
+ @property
+ def marginal_a(self):
+ """First marginal of the transport plan, with the same shape as "a"."""
+ if self._plan is not None:
+ return self._backend.sum(self._plan, 1)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def marginal_b(self):
+ """Second marginal of the transport plan, with the same shape as "b"."""
+ if self._plan is not None:
+ return self._backend.sum(self._plan, 0)
+ else:
+ raise NotImplementedError()
+
+ @property
+ def status(self):
+ """Optimization status of the solver."""
+ if self._status is not None:
+ return self._status
+ else:
+ raise NotImplementedError()
+
+ # Barycentric mappings -------------------------
+ # Return the displacement vectors as an array
+ # that has the same shape as "xa"/"xb" (for samples)
+ # or "a"/"b" * D (for images)?
+
+ @property
+ def a_to_b(self):
+ """Displacement vectors from the first to the second measure."""
+ raise NotImplementedError()
+
+ @property
+ def b_to_a(self):
+ """Displacement vectors from the second to the first measure."""
+ raise NotImplementedError()
+
+ # # Wasserstein barycenters ----------------------
+ # @property
+ # def masses(self):
+ # """Masses for the Wasserstein barycenter."""
+ # raise NotImplementedError()
+
+ # @property
+ # def samples(self):
+ # """Sample locations for the Wasserstein barycenter."""
+ # raise NotImplementedError()
+
+ # Miscellaneous --------------------------------
+
+ @property
+ def citation(self):
+ """Appropriate citation(s) for this result, in plain text and BibTex formats."""
+
+ # The string below refers to the POT library:
+ # successor methods may concatenate the relevant references
+ # to the original definitions, solvers and underlying numerical backends.
+ return """POT library:
+
+ POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021.
+ Website: https://pythonot.github.io/
+ Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;
+
+ @article{flamary2021pot,
+ author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
+ title = {{POT}: {Python} {Optimal} {Transport}},
+ journal = {Journal of Machine Learning Research},
+ year = {2021},
+ volume = {22},
+ number = {78},
+ pages = {1-8},
+ url = {http://jmlr.org/papers/v22/20-451.html}
+ }
+ """
diff --git a/ot/weak.py b/ot/weak.py
index f7d5b23..7364e68 100644
--- a/ot/weak.py
+++ b/ot/weak.py
@@ -18,7 +18,7 @@ def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=
.. math::
- \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2
+ \gamma = \mathop{\arg \min}_\gamma \quad \sum_i \mathbf{a}_i \left(\mathbf{X^a}_i - \frac{1}{\mathbf{a}_i} \sum_j \gamma_{ij} \mathbf{X^b}_j \right)^2
s.t. \ \gamma \mathbf{1} = \mathbf{a}
@@ -28,7 +28,7 @@ def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=
where :
- - :math:`X_a` :math:`X_b` are the sample matrices.
+ - :math:`X^a` and :math:`X^b` are the sample matrices.
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
@@ -49,6 +49,8 @@ def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=
Source histogram (uniform weight if empty list)
b : (nt,) array-like, float
Target histogram (uniform weight if empty list))
+ G0 : (ns,nt) array-like, float
+ initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numItermaxEmd : int, optional